view src/cachingfilter/CachingResponseWrapper.java @ 66:3fbe9cb2e325 default tip

security
author Franklin Schmidt <fschmidt@gmail.com>
date Wed, 18 Sep 2024 03:51:47 -0600
parents 7ecd1a4ef557
children
line wrap: on
line source

package cachingfilter;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.ObjectOutputStream;
import java.io.ObjectInputStream;
import java.net.URLEncoder;
import java.util.Map;
import java.util.HashMap;
import java.util.Set;
import java.util.HashSet;
import java.util.List;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Collection;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


final class CachingResponseWrapper extends HttpServletResponseWrapper {
	private static final Logger logger = LoggerFactory.getLogger(CachingResponseWrapper.class);

	private static Set<String> cacheableHeaders = new HashSet<String>();
	static {
		for( String header : new String[]{
			"Last-Modified",
			"Etag",
			"Content-Type",
			"Cache-Control",
			"Content-Encoding",
		} ) {
			cacheableHeaders.add( header.toLowerCase() );
		}
	}

	private final CachingFilter cachingFilter;
	private final CachingRequestWrapper request;
	private final String fullName;
	private boolean isCachingToFile;
	private CachedPage cachingFile;
	private ServletOutputStream outputStream;
	private PrintWriter writer;
	private int status = SC_OK;
	private Long lastModified = null;
	private String etag = null;
	private FileHandler fileHandler = null;
	private ObjectInputStream ois = null;
	private final List<ResponseAction> actions = new ArrayList<ResponseAction>();
	private boolean isLocked = false;  // for debugging
	private StringBuilder log = new StringBuilder();

	private void log(String msg) {
		log.append(msg).append('\n');
	}

	CachingResponseWrapper( CachingFilter cachingFilter, CachingRequestWrapper request, HttpServletResponse response ) {
		super(response);
		try {
			this.cachingFilter = cachingFilter;
			this.request = request;
			this.fullName = getFullName();
			boolean ok = false;
			try {
				if( !openObjectInputStream() ) {  // calls lock()
					ok = true;
					return;
				}
				Long cachedLastModified = (Long)ois.readObject();
				if( cachedLastModified==null )
					cachedLastModified = -1L;
				request.headerMap.put( CachingRequestWrapper.IF_MODIFIED_SINCE, cachedLastModified );
				String cachedEtag = (String)ois.readObject();
				request.headerMap.put( CachingRequestWrapper.IF_NONE_MATCH, cachedEtag );
				ok = true;
			} finally {
				if( !ok && cachingFile != null )
					unlock();
			}
		} catch(ClassNotFoundException e) {
			throw new RuntimeException(e);
		} catch(IOException e) {
			throw new RuntimeException(e);
		}
	}

	private String getFullName()
		throws IOException
	{
		StringBuffer url = request.getRequestURL();
		String queryString = request.getQueryString();
		if( queryString != null ) {
			url.append( '?' );
			url.append( queryString );
		}
		log("url = "+url);
		String fullUrl = URLEncoder.encode(url.toString(),"UTF-8");
		String acceptEncoding = request.getHeader("Accept-Encoding");
		logger.trace("acceptEncoding = "+acceptEncoding);
		if( acceptEncoding == null )
			return fullUrl;
		Set<String> knownEncodings = cachingFilter.getEncodings();
		List<String> list = new ArrayList<String>();
		for( String encoding : acceptEncoding.split(",") ) {
			encoding = encoding.trim();
			if( knownEncodings.contains(encoding) )
				list.add(encoding);
		}
		if( list.isEmpty() ) {
			request.headerMap.put( CachingRequestWrapper.ACCEPT_ENCODING, null );
			return fullUrl;
		}
		String encodings = join( list, "," );
		request.headerMap.put( CachingRequestWrapper.ACCEPT_ENCODING, encodings );
		return fullUrl + '~' + encodings;
	}

	private void lock() {
		CachingFilter.locker.lock(cachingFile.name());
		log("lock");
		isLocked = true;
	}

	private boolean unlock() {
		boolean wasLocked = CachingFilter.locker.unlock(cachingFile.name());
		log("unlock "+wasLocked);
		if( isLocked != wasLocked )
			logger.error("isLocked="+isLocked+" wasLocked="+wasLocked,new Exception());
		isLocked = false;
		return wasLocked;
	}

	private static String join(Collection<?> col,String separator) {
		if( col.isEmpty() )
			return "";
		StringBuilder sb = new StringBuilder();
		Iterator<?> iter = col.iterator();
		sb.append( iter.next() );
		while( iter.hasNext() ) {
			sb.append( separator ).append( iter.next() );
		}
		return sb.toString();
	}

	private static long hashCode(String s) {
		final int len = s.length();
		long h = 0;
        for( int i = 0; i < len; i++ ) {
            h = 31*h + s.charAt(i);
        }
		return h;
	}

	private boolean openObjectInputStream() {
		for( long hash = hashCode(fullName); true; hash++ ) {
			String s = Long.toHexString(hash);
			cachingFile = cachingFilter.newCachedPage(s);
			lock();
			if( !cachingFile.exists() ) {
				logger.trace("couldn't find "+cachingFile);
				return false;
			}
			try {
				FileHandler fileHandler = FileHandler.factory.newInstance(cachingFile.lastFile());
				ObjectInputStream ois = new ObjectInputStream(fileHandler.getInputStream());
				if( ois.readUTF().equals(fullName) ) {
					logger.trace("found file = "+cachingFile);
					this.fileHandler = fileHandler;
					this.ois = ois;
					return true;
				}
				fileHandler.close();
			} catch(IOException e) {
				logger.error("couldn't read "+cachingFile+" length="+cachingFile.lastFile().length(),e);
				if( !cachingFile.delete() )
					logger.error("couldn't delete "+cachingFile);
				return false;
			}
			unlock();
		}
	}

	public void setContentType(String ct)
	{
		logger.trace("setContentType "+ct);
		super.setContentType(ct);
		actions.add( new ResponseAction.SetHeader("Content-Type",ct) );
	}

	public void setStatus(int sc, String sm)
	{
		logger.trace("setStatus2");
		super.setStatus(sc,sm);
		this.status = sc;
	}
	
	public void setStatus(int sc)
	{
		logger.trace("setStatus "+sc);
		super.setStatus(sc);
		this.status = sc;
	}

	public void setHeader(String name, String value) {
		logger.trace("setHeader "+name+" = "+value);
		super.setHeader(name,value);
		if( "Etag".equalsIgnoreCase(name) )
			etag = value;
		if( "Last-Modified".equalsIgnoreCase(name) ) {
			if( value==null )
				lastModified = null;
			else
				logger.error("unsupported",new Exception());
		}
		if( cacheableHeaders.contains(name.toLowerCase()) )
			actions.add( new ResponseAction.SetHeader(name,value) );
	}

	public void addHeader(String name, String value) {
		logger.trace("addHeader "+name+" = "+value);
		super.addHeader(name,value);
		if( cacheableHeaders.contains(name.toLowerCase()) )
			actions.add( new ResponseAction.AddHeader(name,value) );
	}

	public void setIntHeader(String name, int value) {
		logger.trace("setIntHeader "+name);
		super.setIntHeader(name,value);
		if( cacheableHeaders.contains(name.toLowerCase()) )
			actions.add( new ResponseAction.SetIntHeader(name,value) );
	}

	public void addIntHeader(String name, int value) {
		logger.trace("addIntHeader "+name);
		super.addIntHeader(name,value);
		if( cacheableHeaders.contains(name.toLowerCase()) )
			actions.add( new ResponseAction.AddIntHeader(name,value) );
	}

	public void setDateHeader(String name, long value) {
		logger.trace("setDateHeader "+name);
		super.setDateHeader(name,value);
		value = value / 1000 * 1000;  // round to seconds
		if( "Last-Modified".equalsIgnoreCase(name) )
			lastModified = value;
		if( cacheableHeaders.contains(name.toLowerCase()) )
			actions.add( new ResponseAction.SetDateHeader(name,value) );
	}

	public void addDateHeader(String name, long value) {
		logger.trace("addDateHeader "+name);
		super.setDateHeader(name,value);
		if( cacheableHeaders.contains(name.toLowerCase()) )
			actions.add( new ResponseAction.AddDateHeader(name,value) );
	}

	public void reset()
	{
		logger.trace("reset");
		super.reset();
		resetOutput();
		status = SC_OK;
		lastModified = null;
		etag = null;
		actions.clear();
	}

	public void resetBuffer()
	{
		logger.trace("resetBuffer");
		super.resetBuffer();
		resetOutput();
	}

	private void resetOutput() {
		if( isCachingToFile ) {
			try {
				outputStream.close();
			} catch(IOException e) {
				logger.error("resetOutput",e);
			}
			cachingFile.deleteNewFile();
			isCachingToFile = false;
		}
		outputStream = null;
		writer = null;
	}

	public void sendError(int sc, String msg) throws IOException
	{
		logger.trace("sendError2");
		this.status = sc;
		resetBuffer();
		if( shouldSendFile() ) {
			sendFile();
		} else {
			super.sendError(sc,msg);
		}
	}

	public void sendError(int sc) throws IOException
	{
		logger.trace("sendError");
		this.status = sc;
		resetBuffer();
		if( shouldSendFile() ) {
			sendFile();
		} else {
			super.sendError(sc);
		}
	}

	public void sendRedirect(String location) throws IOException
	{
		logger.trace("sendRedirect");
		this.status = SC_MOVED_TEMPORARILY;
		resetBuffer();
		super.sendRedirect(location);
	}

	public void flushBuffer() throws IOException
	{
		logger.trace("flushBuffer "+isCommitted());
		if( writer != null )
			writer.flush();
		if( outputStream != null )
			outputStream.flush();
		else if( shouldSendFile() )
			sendFile();
		else
			getResponse().flushBuffer();
	}

	private boolean shouldSendFile() {
		if( !(status==SC_NOT_MODIFIED && ois!=null) )
			return false;
		if( request.isCacheable() )
			return false;  // no need
		return true;
	}

	private void sendFile() throws IOException {
		CachingResponseWrapper.super.setHeader("Via","cache-yes");
		sendFile2();
	}

	private void sendFile2() throws IOException {
		logger.trace("sendFile");
		unlock();
		setStatus(SC_OK);
		HttpServletResponse response = (HttpServletResponse)getResponse();
		try {
			@SuppressWarnings("unchecked")
			List<ResponseAction> cachedActions = (List<ResponseAction>)ois.readObject();
			for( ResponseAction cachedAction : cachedActions ) {
				cachedAction.apply(response);
			}
		} catch(ClassNotFoundException e) {
			throw new RuntimeException(e);
		}
		ServletOutputStream out = response.getOutputStream();
		fileHandler.writeTo(out);
	}

	public ServletOutputStream getOutputStream()
	{
		logger.trace("getOutputStream");
		if (outputStream==null) {
			 newOutputStream();
		} else if (writer!=null)
			throw new IllegalStateException("getWriter() called");
		
		return outputStream;
	}

	public PrintWriter getWriter() throws IOException
	{
		logger.trace("getWriter");
		if (writer==null)
		{ 
			if (outputStream!=null)
				throw new IllegalStateException("getOutputStream() called");
			
			newOutputStream();
			String encoding = getCharacterEncoding();
			writer = encoding==null ? new PrintWriter(outputStream)
				: new PrintWriter(new OutputStreamWriter(outputStream,encoding));
		}
		return writer;
	}

	private boolean isCacheable() {
		if( getResponse().isCommitted() ) {
			logger.trace("!isCacheable - isCommitted");
			return false;
		}
		if( status != SC_OK ) {
			logger.trace("!isCacheable - status="+status);
			return false;
		}
		if( lastModified==null && etag==null ) {
			logger.trace("!isCacheable - no lastModified,etag");
			return false;
		}
		return true;
	}

	private void newOutputStream() {
		outputStream = new ProxyServletOutputStream() {
			protected OutputStream newOutputStream()
				throws IOException
			{
				CachingResponseWrapper.super.setHeader("Via","cache-no");
				ServletOutputStream out = getResponse().getOutputStream();
				if( !isCacheable() ) {
					unlock();
					logger.trace("return getResponse().getOutputStream() "+out.getClass());
					return out;
				}
				try {
					File newFile = cachingFile.newFile();
					logger.trace("write to cache");
					isCachingToFile = true;
					OutputStream outFile = new BufferedOutputStream(new FileOutputStream(newFile));
					ObjectOutputStream oos = new ObjectOutputStream(outFile);
					oos.writeUTF(fullName);
					oos.writeObject(lastModified);
					oos.writeObject(etag);
					oos.writeObject(actions);
					oos.flush();
					return outFile;
				} catch(IOException e) {
					throw new RuntimeException(e);
				}
			}
		};
	}

	void finish(boolean isDone) throws IOException {
		try {
			log("finish a");
			if( fileHandler != null ) {
				fileHandler.close();
				logger.trace("closed fileHandler");
			}
			if( outputStream == null || !isDone && !isCachingToFile )
				unlock();
			if( isDone ) {
				if( writer != null )
					writer.flush();
				if( outputStream != null ) {
					try {
						outputStream.flush();
					} catch(IOException e) {
						logger.trace("",e);
						isDone = false;
					}
				}
			}
			if( isCachingToFile ) {
				log("finish b");
				try {
					outputStream.close();
				} catch(IOException e) {
					logger.trace("",e);
					isDone = false;
				}
				if( isDone ) {
					log("finish c");
					try {
						log("file size = "+cachingFile.lastFile().length());
						fileHandler = FileHandler.factory.newInstance(cachingFile.lastFile());
						ois = new ObjectInputStream(fileHandler.getInputStream());
						ois.readUTF();  // full name
						ois.readObject();  // lastModified
						ois.readObject();  // etag
					} catch(ClassNotFoundException e) {
						throw new RuntimeException(e);
					} catch(IOException e) {
						cachingFile.deleteNewFile();
						throw new RuntimeException(e);
					}
					CachingResponseWrapper.super.setHeader("Via","cache-write");
					try {
						log("finish d");
						sendFile2();
						log("finish e");
					} catch(IOException e) {
						unlock();
						throw e;
					}
					fileHandler.close();
				} else {
					cachingFile.deleteNewFile();
					unlock();
				}
			}
			log("finish z");
		} catch(RuntimeException e) {
			log("finish RuntimeException");
			logger.error("RuntimeException in finish()",e);
			throw e;
		} catch(Error e) {
			log("finish Error");
			logger.error("Error in finish()",e);
			throw e;
		} finally {
			if( unlock() ) {
				logger.error("still locked isDone="+isDone+" isCachingToFile="+isCachingToFile+" outputStream="+(outputStream!=null));
				logger.error("log:\n"+log);
			}
		}
	}

}