view src/nabble/view/lib/JtpContextServlet.java @ 0:7ecd1a4ef557

add content
author Franklin Schmidt <fschmidt@gmail.com>
date Thu, 21 Mar 2019 19:15:52 -0600
parents
children
line wrap: on
line source

package nabble.view.lib;

import fschmidt.util.servlet.*;
import fschmidt.util.java.ProcUtils;
import fschmidt.util.java.SimpleClassLoader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.URL;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;


public final class JtpContextServlet extends HttpServlet implements JtpContext {
	private static final Logger logger = LoggerFactory.getLogger(JtpContextServlet.class);

	private static final Set<String> allowedMethods = new HashSet<String>(Arrays.asList(
		"GET", "POST", "HEAD"
	));
	private String base;
	private boolean reload = false;
	private boolean recompile = false;
	private SimpleClassLoader.Filter filter = null;
	private ClassLoader cl = null;
	private Map<String,HttpServlet> map = new HashMap<String,HttpServlet>();
	private long clTime;
	private Object lock = new Object();
	private HttpCache httpCache;
	private boolean isCaching;
	private String characterEncoding;
	private Map<String, String> customHeaders = new HashMap<String, String>();
	private UrlMapper urlMapper = new UrlMapper() {
		public UrlMapping getUrlMapping(HttpServletRequest request) {
			return null;
		}
	};
	private Set<String> errorCache = null;
	private Collection<String> ipList = null;
	private static final String authKeyAttr = "authKey";
	private static final String[] noModifyingEvents = new String[]{"_"};

	public void setUrlMapper(UrlMapper urlMapper) {
		this.urlMapper = urlMapper;
	}

	public HttpCache getHttpCache() {
		return httpCache;
	}

	public void setHttpCache(HttpCache httpCache) {
		this.httpCache = httpCache;
	}

	public void addCustomHeader(String key, String value) {
		this.customHeaders.put(key, value);
	}

	public void unloadServlets() {
		if( !reload )
			throw new UnsupportedOperationException("'reload' must be set");
		synchronized(lock) {
			cl = new SimpleClassLoader(filter);
			map = new HashMap<String,HttpServlet>();
			clTime = System.currentTimeMillis();
		}
	}

	public void setBase(String base) {
		if( base==null )
			throw new NullPointerException();
		this.base = base;
	}

	public void init()
		throws ServletException
	{
		ServletContext context = getServletContext();
		String newBase = getInitParameter("base");
		if( newBase != null )
			setBase(newBase);
		recompile = Boolean.valueOf(getInitParameter("recompile"));
		reload = recompile || Boolean.valueOf(getInitParameter("reload"));
		if( reload ) {
			filter = new SimpleClassLoader.Filter(){
				final String s = base + ".";
				public boolean load(String className)  {
					return className.startsWith(s);
				}
			};
			unloadServlets();
		}
		context.setAttribute(JtpContext.attrName,this);
		String servletS = getInitParameter("servlet");
		if( servletS != null ) {
			throw new RuntimeException("the 'servlet' init parameter is no longer supported");
		}
		isCaching = "true".equalsIgnoreCase(getInitParameter("cache"));
		if( isCaching ) {
			if( httpCache==null ) {
				logger.error("can't set init parameter 'cache' to true without httpCache");
				System.exit(-1);
			}
			logger.info("cache");
		}
		characterEncoding = getInitParameter("characterEncoding");
		{
			String s = getInitParameter("timeLimit");
			if( s != null )
				timeLimit = Long.parseLong(s);
		}
		{
			String s = getInitParameter("errorCacheSize");
			if( s != null ) {
				final int errorCacheSize = Integer.parseInt(s);
				errorCache = Collections.synchronizedSet(Collections.newSetFromMap(new LinkedHashMap<String,Boolean>(){
					protected boolean removeEldestEntry(Map.Entry eldest) {
						return size() > errorCacheSize;
					}
				}));
			}
		}
		{
			String s = getInitParameter("ipListSize");
			if( s != null ) {
				final int ipListSize = Integer.parseInt(s);
				ipList = Collections.synchronizedList(new LinkedList<String>() {
					public boolean add(String s) {
						if( contains(s) )
							return false;
						super.add(s);
						if( size() > ipListSize )
							removeFirst();
						return true;
					}
				});
			}
		}
	}

	private boolean isInErrorCache(String s) {
		return errorCache==null || !errorCache.add(s);
	}

	private boolean isInIpList(String ip) {
		return ipList!=null && !ipList.add(ip);
	}

	public static interface DestroyListener {
		public void destroyed();
	}

	private DestroyListener destroyListener = null;

	public void addDestroyListener(DestroyListener dl) {
		synchronized(lock) {
			if( destroyListener!=null )
				throw new RuntimeException("only one DestroyListener allowed");
			destroyListener = dl;
		}
	}

	public void destroy() {
		synchronized(lock) {
			if( destroyListener != null )
				destroyListener.destroyed();
		}
	}

	public static final class RequestAndResponse {
		public final HttpServletRequest request;
		public final HttpServletResponse response;

		public RequestAndResponse(HttpServletRequest request,HttpServletResponse response) {
			this.request = request;
			this.response = response;
		}
	}

	public static interface CustomWrappers {
		public RequestAndResponse wrap(HttpServletRequest request, HttpServletResponse response);
	}

	private CustomWrappers customWrappers;

	public void setCustomWrappers(CustomWrappers customWrappers) {
		this.customWrappers = customWrappers;
	}

	private static String hideNull(String s) {
		return s==null ? "" : s;
	}

	private String getServletPath(HttpServletRequest request) {
		return request.getServletPath() + hideNull(request.getPathInfo());
	}

	protected void service(HttpServletRequest request,HttpServletResponse response)
		throws ServletException, IOException
	{
		final TimeLimit tl = startTimeLimit(request);
		response = new HttpServletResponseWrapper(response) {
			PrintWriter writer = null;
			ServletOutputStream out = null;

			public PrintWriter getWriter()
				throws java.io.IOException
			{
				if( writer==null ) {
					writer = new PrintWriter(super.getWriter()) {
						public void write(String s,int off,int len) {
							long t = System.currentTimeMillis();
							super.write(s,off,len);
							tl.ioTime += System.currentTimeMillis() - t;
						}
						public void write(char[] buf,int off,int len) {
							long t = System.currentTimeMillis();
							super.write(buf,off,len);
							tl.ioTime += System.currentTimeMillis() - t;
						}
						public void write(int c) {
							long t = System.currentTimeMillis();
							super.write(c);
							tl.ioTime += System.currentTimeMillis() - t;
						}
						public void flush() {
							long t = System.currentTimeMillis();
							super.flush();
							tl.ioTime += System.currentTimeMillis() - t;
						}
						public void println() {
							long t = System.currentTimeMillis();
							super.println();
							tl.ioTime += System.currentTimeMillis() - t;
						}
					};
				}
				return writer;
			}

			public ServletOutputStream getOutputStream()
				throws java.io.IOException
			{
				if( out==null ) {
					final ServletOutputStream sos = super.getOutputStream();
					out = new ServletOutputStream() {
						public void write(byte[] b,int off,int len) throws IOException {
							long t = System.currentTimeMillis();
							sos.write(b,off,len);
							tl.ioTime += System.currentTimeMillis() - t;
						}
						public void write(byte[] b) throws IOException {
							long t = System.currentTimeMillis();
							sos.write(b);
							tl.ioTime += System.currentTimeMillis() - t;
						}
						public void write(int c) throws IOException {
							long t = System.currentTimeMillis();
							sos.write(c);
							tl.ioTime += System.currentTimeMillis() - t;
						}
						public void flush() throws IOException {
							long t = System.currentTimeMillis();
							sos.flush();
							tl.ioTime += System.currentTimeMillis() - t;
						}
					};
				}
				return out;
			}

			public void sendError(int sc) throws IOException {
				long t = System.currentTimeMillis();
				super.sendError(sc);
				tl.ioTime += System.currentTimeMillis() - t;
			}

			public void sendRedirect(String location) throws IOException {
				if( containsHeader("Expires") )
					setHeader("Expires",null);
				if( containsHeader("Last-Modified") )
					setHeader("Last-Modified",null);
				if( containsHeader("Etag") )
					setHeader("Etag",null);
				if( containsHeader("Cache-Control") )
					setHeader("Cache-Control",null);
				if( containsHeader("Content-Type") )
					setHeader("Content-Type",null);
				if( containsHeader("Content-Length") )
//					setHeader("Content-Length",null);
					setContentLength(0);
				super.sendRedirect(location);
			}
		};
		service2(request,response);
		checkTimeLimit(request);
	}

	private void service2(HttpServletRequest request, HttpServletResponse response)
		throws ServletException, IOException
	{
		if( !allowedMethods.contains(request.getMethod()) ) {
			response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
			return;
		}
//		String contextPath = request.getContextPath();
//		String contextUrl = ServletUtils.getContextURL(request);

		// First we set the character encoding because any manipulation
		// of request parameters will break without this.
		response.setHeader("Content-Type","text/html; charset=utf-8");  // default, servlet can override
		if( characterEncoding != null ) {
			response.setCharacterEncoding(characterEncoding);
			request.setCharacterEncoding(characterEncoding);
		}

		HttpServlet servlet;
		String path = getServletPath(request);

		UrlMapping urlMapping = urlMapper.getUrlMapping(request);
		if( urlMapping != null ) {
			try {
				servlet = getServletFromClass(urlMapping.servletClass.getName());
			} catch(ClassNotFoundException e) {
				throw new RuntimeException(e);
			}
			final Map params = urlMapping.parameterMap;
			request = new BetterRequestWrapper(request) {
				public Map getParameterMap() {
					return params;
				}
			};
		} else {
			try {
				servlet = getServlet(path);
			} catch(ClassNotFoundException e) {
				response.sendError(HttpServletResponse.SC_NOT_FOUND);
				String agent = request.getHeader("user-agent");
				String referer = request.getHeader("referer");
				String remote = getClientIpAddr(request);
				String msg = request.getRequestURL()+" referer="+referer+" user-agent="+agent+" remote="+remote;
				if( referer==null ) {
					logger.info(msg,e);
				} else {
					logger.warn(msg,e);
				}
				return;
			}
		}

		// Custom headers
		addCustomHeaders(response);

		AuthorizingServlet auth = servlet instanceof AuthorizingServlet ? (AuthorizingServlet)servlet : null;
		if( isCaching ) {
			String etagS = request.getHeader("If-None-Match");
			if( etagS != null ) {
				String prevEtag = null;
				for( String etag : etagS.split("\\s*,\\s*") ) {
					if( etag.equals(prevEtag) )
						continue;
					prevEtag = etag;
					if( etag.length()>=2 && etag.charAt(0)=='"' && etag.charAt(etag.length()-1)=='"' )
						etag = etag.substring(1,etag.length()-1);
					String authKey = null;
					if( etag.length()>=2 && etag.charAt(0)=='[' ) {
						int i = etag.indexOf(']');
						if( i > 0 ) {
							if( auth != null )
								authKey = etag.substring(1,i);
							etag = etag.substring(i+1);
						}
					}
					String[] events = etag.split("~");
					long lastModified = getLastModified(events);
					try {
						if( lastModified <= request.getDateHeader("If-Modified-Since") ) {
							if( authKey==null || authorize(auth,authKey,request,response) )
								response.sendError(HttpServletResponse.SC_NOT_MODIFIED);
							return;
						}
					} catch(RuntimeException e) {
						handleException(request,e);
					}
				}
			}
		}
		String authKey = auth==null ? null : getAuthorizationKey(auth,request);
		if( authKey != null ) {
			if( !authorize(auth,authKey,request,response) )
				return;
			request.setAttribute(authKeyAttr,authKey);
		}

		if( servlet instanceof CanonicalUrl ) {
			CanonicalUrl srv = (CanonicalUrl)servlet;
			StringBuffer currentUrl = request.getRequestURL();
			int i = currentUrl.indexOf(";");
			if( i != -1 )
				currentUrl.setLength(i);
			String query = request.getQueryString();
			if( query != null )
				currentUrl.append( '?' ).append( query );
			try {
				String canonicalUrl = srv.getCanonicalUrl(request);
				if( canonicalUrl != null && !stripScheme(currentUrl.toString()).equals(stripScheme(canonicalUrl)) ) {
					response.setHeader("Location",canonicalUrl);
					response.sendError( HttpServletResponse.SC_MOVED_PERMANENTLY );
					return;
				}
			} catch(RuntimeException e) {
				logger.warn("couldn't get canonical url",e);
			}
		}

		request.setAttribute("servlet",servlet);

		try {
			if (customWrappers != null) {
				RequestAndResponse rr = customWrappers.wrap(request, response);
				request = rr.request;
				response = rr.response;
			}
			servlet.service(request,response);
		} catch(RuntimeException e) {
			handleException(request,e);
		} catch(ServletException e) {
			handleException(request,e);
		}
	}

	private static String stripScheme(String url) {
		return url.substring(url.indexOf(':'));
	}

	public void setEtag( HttpServletRequest request, HttpServletResponse response, String... modifyingEvents ) {
		if( modifyingEvents.length == 0 )
			modifyingEvents = noModifyingEvents;
		StringBuilder buf = new StringBuilder();
		String authKey = (String)request.getAttribute(authKeyAttr);
		if( authKey != null )
			buf.append( '[' ).append( authKey).append( ']' );
		buf.append( modifyingEvents[0] );
		for( int i=1; i<modifyingEvents.length; i++ ) {
			buf.append( '~' ).append( modifyingEvents[i] );
		}
		response.setHeader("Etag",buf.toString());
		long lastModified = getLastModified(modifyingEvents);
		response.setDateHeader("Last-Modified",lastModified);
		response.setHeader("Cache-Control","max-age=0");
	}

	private boolean authorize(AuthorizingServlet auth,String authKey,HttpServletRequest request,HttpServletResponse response)
		throws IOException, ServletException
	{
		try {
			if (customWrappers != null) {
				RequestAndResponse rr = customWrappers.wrap(request, response);
				request = rr.request;
				response = rr.response;
			}
			return auth.authorize(authKey,request,response);
		} catch(RuntimeException e) {
			handleException(request,e);
		} catch(ServletException e) {
			handleException(request,e);
		}
		throw new RuntimeException("never");
	}

	private String getAuthorizationKey(AuthorizingServlet auth,HttpServletRequest request)
		throws ServletException
	{
		try {
			return auth.getAuthorizationKey(request);
		} catch(RuntimeException e) {
			handleException(request,e);
		} catch(ServletException e) {
			handleException(request,e);
		}
		return null;  // never gets here
	}

	private long getLastModified(String[] modifyingEvents) {
		long[] lastModifieds = httpCache.lastModifieds(modifyingEvents);
		long lastModified = lastModifieds[0];
		for( int i=1; i<lastModifieds.length; i++ ) {
			long lm = lastModifieds[i];
			if( lastModified < lm )
				lastModified = lm;
		}
		return lastModified;
	}

	/** Adds all custom headers to the response object. */
	private void addCustomHeaders(HttpServletResponse response) {
		Set<Map.Entry<String, String>> entries = this.customHeaders.entrySet();
		for (Map.Entry<String, String> entry : entries) {
			response.setHeader(entry.getKey(), entry.getValue());
		}
	}

	private void handleException(HttpServletRequest request,RuntimeException e)
		throws ServletException
	{
		JtpRuntimeException rte;
		try {
			String agent = request.getHeader("user-agent");
			if( agent == null )
				throw new JtpServletException(request,"null agent",e);
			if (!isValidAgent(agent))
				throw new JtpServletException(request, "bad agent " + agent, e);
			String remote = getClientIpAddr(request);
			String referer = request.getHeader("referer");
			StringBuilder buf = new StringBuilder()
				.append( "method=" ).append( request.getMethod() )
				.append( " user-agent=" ).append( agent )
				.append( " referer=" ).append( referer )
				.append( " remote=" ).append( remote )
			;
			String etag = request.getHeader("If-None-Match");
			if( etag != null )
				buf.append( " etag=[" ).append( etag ).append( "]" );
			if( referer==null || isInIpList(remote) )
				throw new JtpServletException(request,buf.toString(),e);
			rte = new JtpRuntimeException(request,buf.toString(),e);
		} catch(RuntimeException e2) {
			logger.error("failed to handle",e);
			throw e2;
		}
		throw rte;
	}

	private static void handleException(HttpServletRequest request,ServletException e)
		throws ServletException
	{
		String agent = request.getHeader("user-agent");
		throw new JtpServletException(request,"user-agent="+agent+" method="+request.getMethod()+" referer="+request.getHeader("referer"),e);
	}

	private static class JtpRuntimeException extends RuntimeException {
		private JtpRuntimeException(HttpServletRequest request,String msg,RuntimeException e) {
			super("url="+getCurrentURL(request)+"  "+msg,e);
		}
	}

	private static class JtpServletException extends ServletException {
		private JtpServletException(HttpServletRequest request,String msg,Exception e) {
			super("url="+getCurrentURL(request)+"  "+msg,e);
		}
	}

	// work-around jetty bug
	private static String getCurrentURL(HttpServletRequest request) {
		try {
			return ServletUtils.getCurrentURL(request);
		} catch(RuntimeException e) {
			logger.warn("jetty screwed up",e);
			return "[failed]";
		}
	}

	private static boolean isValidAgent(String agent) {
		if (agent == null)
			return false;
		for (String badAgent : badAgents) {
			if (agent.indexOf(badAgent) >= 0)
				return false;
		}
		return true;
	}

	private static final String[] badAgents = new String[]{
		"MJ12bot",
		"WISEnutbot",
		"Win98",  // not worth handling these
		"Windows 98",
		"Windows 95",
		"RixBot",
		"User-Agent",  // from corrupt header
		"Firefox/0",  // ancient version of Firefox
		"Firefox/2.",  // ancient version of Firefox
		"Firefox/3.",  // ancient version of Firefox
		"Opera 7.",  // ancient version of Opera
		"Opera/7.",
		"Opera 8.",
		"Opera/8.",
		"Opera/9.",
		"TwitterFeed 3",
		"NAVER Blog Rssbot",
		"AOL 9.0",
		"rssreader@newstin.com",
		"PHPCrawl",
		"MSIE 2.",
		"MSIE 4.",
		"MSIE 5.",
		"MSIE 6.",
		"MSIE 7.0",
		"Mozilla/0.",
		"Mozilla/2.0",
		"Mozilla/3.0",
		"Mozilla/4.6",
		"Mozilla/4.7",
		"RSSIncludeBot/1.0", // cause exceptions in xml feeds
		"Powermarks",
		"GenwiFeeder",
		"Akregator",
		"ia_archiver",
		"Atomic_Email_Hunter",
		"Yahoo! Slurp",
		"Python-urllib",
		"BlackBerry",
		"SimplePie", // Feeds parser
		"www.webintegration.at", // crazy bot
		"www.run4dom.com", // crazy bot
		"zia-httpmirror",
		"POE-Component-Client-HTTP",
		"anonymous",
		"Sosospider",
		"Java/1.6",
		"Shareaza",
		"Jakarta Commons-HttpClient",
		"Apache-HttpClient",
		"Baiduspider",
		"bingbot",
		"MLBot", // www.metadatalabs.com/mlbot
		"www.vbseo.com",
		"yacybot", // yacy.net/bot.html
		"SearchBot"
	};

	private static boolean isBot(String agent) {
		if (agent == null)
			return false;
		for (String bot : bots) {
			if (agent.indexOf(bot) >= 0)
				return true;
		}
		return false;
	}

	private static final String[] bots = new String[]{
		"Googlebot"
	};

	private HttpServlet getServlet(String path)
		throws ServletException, ClassNotFoundException, IOException
	{
		int i = path.lastIndexOf('.');
		if( i == -1 )
			throw new ClassNotFoundException(path);
		return getServletFromClass(
			base + path.substring(0,i).replace('/','.')
		);
	}

	private HttpServlet getServletFromClass(String cls)
		throws ClassNotFoundException
	{
		synchronized(lock) {
			if( reload && hasChanged(cls) ) {
				unloadServlets();
			}
			HttpServlet srv = map.get(cls);
			if( srv==null ) {
				try {
					Class clas = reload ? cl.loadClass(cls) : Class.forName(cls);
					srv = (HttpServlet)clas.newInstance();
				} catch(IllegalAccessException e) {
					throw new RuntimeException(e);
				} catch(InstantiationException e) {
					throw new RuntimeException(e);
				}
				try {
					srv.init(this);
				} catch(ServletException e) {
					throw new RuntimeException(e);
				}
				map.put(cls,srv);
			}
			return srv;
		}
	}

	private boolean hasChanged(String cls) {
		try {
			URL url = cl.getResource( SimpleClassLoader.classToResource(cls) );
			if( url==null )
				return true;
			File file = new File(url.getPath());
			if( recompile ) {
				String path = file.toString();
				if( !path.endsWith(".class") )
					throw new RuntimeException();
				File dir = file.getParentFile();
				String base = path.substring(0,path.length()-6);
				File source = new File( base + ".jtp" );
				if( source.lastModified() > clTime ) {
					Process proc = Runtime.getRuntime().exec(new String[]{
						"java", "fschmidt.tools.Jtp", source.getName()
					},null,dir);
					ProcUtils.checkProc(proc);
				}
				source = new File( base + ".java" );
				if( source.lastModified() > clTime ) {
					Process proc = Runtime.getRuntime().exec(new String[]{
						"javac", "-g", source.getName()
					},null,dir);
					ProcUtils.checkProc(proc);
				}
			}
			return file.lastModified() > clTime;
		} catch(IOException e) {
			throw new RuntimeException(e);
		}
	}


	private long timeLimit = 0;
	private static final String timeLimitAttr = "time-limit";

	private static class TimeLimit {
		long timeLimit;
		final long startTime = System.currentTimeMillis();
		long ioTime = 0L;

		TimeLimit(long timeLimit) {
			this.timeLimit = timeLimit;
		}
	}

	public long getTimeLimit() {
		return timeLimit;
	}

	public void setTimeLimit(long timeLimit) {
		this.timeLimit = timeLimit;
	}

	private TimeLimit startTimeLimit(HttpServletRequest request) {
		TimeLimit tl = new TimeLimit(timeLimit);
		request.setAttribute( timeLimitAttr, tl );
		return tl;
	}

	public void setTimeLimit(HttpServletRequest request,long timeLimit) {
		TimeLimit tl = (TimeLimit)request.getAttribute(timeLimitAttr);
		tl.timeLimit = timeLimit;
	}

	private void checkTimeLimit(HttpServletRequest request) {
		TimeLimit tl = (TimeLimit)request.getAttribute(timeLimitAttr);
		if( tl.timeLimit == 0L )
			return;
        long time = System.currentTimeMillis() - tl.startTime - tl.ioTime;
        if( time > tl.timeLimit ) {
			float free = Runtime.getRuntime().freeMemory();
			float total = Runtime.getRuntime().totalMemory();
			float used = (total - free) * 100f;
			logger.error(ServletUtils.getCurrentURL(request,100) + " took " + time + " ms | " + String.format("%.1f",used/total) + '%');
/*
			Scheduler scheduler = TheScheduler.get();
			if( scheduler instanceof ProfilingScheduler ) {
				ProfilingScheduler profilingScheduler = (ProfilingScheduler)scheduler;
				if( profilingScheduler.getMode()==ProfilingScheduler.Mode.FOREGROUND ) {
					profilingScheduler.captureCPUSnapshot();
				}
			}
*/
		}
	}

	public static String getClientIpAddr(HttpServletRequest request) {
		String ip = request.getHeader("X-Real-IP");
		if( ip == null )
			ip = request.getRemoteAddr();
		return ip;
	}

}