view src/cachingfilter/CachingRequestWrapper.java @ 66:3fbe9cb2e325

security
author Franklin Schmidt <fschmidt@gmail.com>
date Wed, 18 Sep 2024 03:51:47 -0600
parents 7ecd1a4ef557
children
line wrap: on
line source

package cachingfilter;

import java.io.IOException;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Enumeration;
import java.util.TimeZone;
import java.util.Date;
import java.util.Locale;
import java.util.Map;
import java.util.HashMap;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.eclipse.jetty.server.AbstractHttpConnection;


final class CachingRequestWrapper extends HttpServletRequestWrapper {
	private static final Logger logger = LoggerFactory.getLogger(CachingRequestWrapper.class);

	static final String IF_MODIFIED_SINCE = "If-Modified-Since".toLowerCase();
	static final String IF_NONE_MATCH = "If-None-Match".toLowerCase();
	static final String ACCEPT_ENCODING = "Accept-Encoding".toLowerCase();

	final Map<String,Object> headerMap = new HashMap<String,Object>();

	CachingRequestWrapper(HttpServletRequest request) throws IOException {
		super(request);
		checkRequest(request);
	}

	private static void checkRequest(HttpServletRequest request) throws IOException {
		AbstractHttpConnection c = ((org.eclipse.jetty.server.Request) request).getConnection();
		if (c.getRequestFields().containsKey("Host")) {
			for (Enumeration<String> e2 = c.getRequestFields().getValues("Host"); e2.hasMoreElements();) {
				if (e2.nextElement().trim().endsWith(":")) {
					throw new IOException("Bad 'Host' request header (ends with colon)");
				}
			}
		}
	}

	public String getHeader(String name) {
/*
		String v = getHeader2(name);
		logger.trace("getHeader "+name+" = "+v);
		return v;
	}
	public String getHeader2(String name) {
*/
		String key = name.toLowerCase();
		if( headerMap.containsKey(key) ) {
			Object val = headerMap.get(key);
			if( val instanceof Long ) {
				long date = (Long)val;
				if( date == -1 )
					return null;
				return formatDate(date);
			}
			return (String)val;
		}
		return super.getHeader(name);
	}

	public long getDateHeader(String name) {
/*
		long v = getDateHeader2(name);
		logger.trace("getDateHeader "+name+" = "+v);
		return v;
	}
	public long getDateHeader2(String name) {
*/
		String key = name.toLowerCase();
		if( headerMap.containsKey(key) )
			return (Long)headerMap.get(key);
		return super.getDateHeader(name);
	}

	boolean isCacheable() {
		boolean isCacheable = false;
		Long val = (Long)headerMap.get(IF_MODIFIED_SINCE);
		if( val!=null ) {
			try {
				long cachedLastModified = val;
				if( cachedLastModified != -1 ) {
					long ifModifiedSince = super.getDateHeader(IF_MODIFIED_SINCE);
					//logger.trace("cachedLastModified = "+cachedLastModified+"  ifModifiedSince = "+ifModifiedSince);
					if( cachedLastModified > ifModifiedSince )
						return false;
					isCacheable = true;
				}
			} catch(IllegalArgumentException e) {
				logger.warn("bad date, user-agent="+getHeader("user-agent"),e);
			}
		}
		String etag = (String)headerMap.get(IF_NONE_MATCH);
		if( etag != null ) {
			String ifNoneMatch = super.getHeader(IF_NONE_MATCH);
			if( !etag.equals(ifNoneMatch) )
				return false;
			isCacheable = true;
		}
		return isCacheable;
	}


	private static ThreadLocal<DateFormat> dateFormat = new ThreadLocal<DateFormat>() {
		protected synchronized DateFormat initialValue() {
			DateFormat httpDateFormat = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
			httpDateFormat.setTimeZone(TimeZone.getTimeZone("GMT"));
			return httpDateFormat;
		}
	};

	private String formatDate(long date) {
		return dateFormat.get().format(new Date(date));
	}
}