Servlet request improvements

This commit is contained in:
Mihai Pitu 2013-08-26 22:41:16 +03:00 committed by Felipe Zimmerle
parent c8e31c42f5
commit 3c61493169
5 changed files with 278 additions and 612 deletions

View File

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.io.output;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.OutputStream;
/**
* A Proxy stream which acts as expected, that is it passes the method
* calls on to the proxied stream and doesn't change which methods are
* being called. It is an alternative base class to FilterOutputStream
* to increase reusability.
* <p>
* See the protected methods for ways in which a subclass can easily decorate
* a stream with custom pre-, post- or error processing functionality.
*
* @author Stephen Colebourne
* @version $Id: ProxyOutputStream.java 1003340 2010-10-01 00:31:53Z sebb $
*/
public class ProxyOutputStream extends FilterOutputStream {
/**
* Constructs a new ProxyOutputStream.
*
* @param proxy the OutputStream to delegate to
*/
public ProxyOutputStream(OutputStream proxy) {
super(proxy);
// the proxy is stored in a protected superclass variable named 'out'
}
/**
* Invokes the delegate's <code>write(int)</code> method.
* @param idx the byte to write
* @throws IOException if an I/O error occurs
*/
@Override
public void write(int idx) throws IOException {
try {
beforeWrite(1);
out.write(idx);
afterWrite(1);
} catch (IOException e) {
handleIOException(e);
}
}
/**
* Invokes the delegate's <code>write(byte[])</code> method.
* @param bts the bytes to write
* @throws IOException if an I/O error occurs
*/
@Override
public void write(byte[] bts) throws IOException {
try {
int len = bts != null ? bts.length : 0;
beforeWrite(len);
out.write(bts);
afterWrite(len);
} catch (IOException e) {
handleIOException(e);
}
}
/**
* Invokes the delegate's <code>write(byte[])</code> method.
* @param bts the bytes to write
* @param st The start offset
* @param end The number of bytes to write
* @throws IOException if an I/O error occurs
*/
@Override
public void write(byte[] bts, int st, int end) throws IOException {
try {
beforeWrite(end);
out.write(bts, st, end);
afterWrite(end);
} catch (IOException e) {
handleIOException(e);
}
}
/**
* Invokes the delegate's <code>flush()</code> method.
* @throws IOException if an I/O error occurs
*/
@Override
public void flush() throws IOException {
try {
out.flush();
} catch (IOException e) {
handleIOException(e);
}
}
/**
* Invokes the delegate's <code>close()</code> method.
* @throws IOException if an I/O error occurs
*/
@Override
public void close() throws IOException {
try {
out.close();
} catch (IOException e) {
handleIOException(e);
}
}
/**
* Invoked by the write methods before the call is proxied. The number
* of bytes to be written (1 for the {@link #write(int)} method, buffer
* length for {@link #write(byte[])}, etc.) is given as an argument.
* <p>
* Subclasses can override this method to add common pre-processing
* functionality without having to override all the write methods.
* The default implementation does nothing.
*
* @since Commons IO 2.0
* @param n number of bytes to be written
* @throws IOException if the pre-processing fails
*/
protected void beforeWrite(int n) throws IOException {
}
/**
* Invoked by the write methods after the proxied call has returned
* successfully. The number of bytes written (1 for the
* {@link #write(int)} method, buffer length for {@link #write(byte[])},
* etc.) is given as an argument.
* <p>
* Subclasses can override this method to add common post-processing
* functionality without having to override all the write methods.
* The default implementation does nothing.
*
* @since Commons IO 2.0
* @param n number of bytes written
* @throws IOException if the post-processing fails
*/
protected void afterWrite(int n) throws IOException {
}
/**
* Handle any IOExceptions thrown.
* <p>
* This method provides a point to implement custom exception
* handling. The default behaviour is to re-throw the exception.
* @param e The IOException thrown
* @throws IOException if an I/O error occurs
* @since Commons IO 2.0
*/
protected void handleIOException(IOException e) throws IOException {
throw e;
}
}

View File

@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.io.output;
import java.io.IOException;
import java.io.OutputStream;
/**
* Classic splitter of OutputStream. Named after the unix 'tee'
* command. It allows a stream to be branched off so there
* are now two streams.
*
* @version $Id: TeeOutputStream.java 659817 2008-05-24 13:23:10Z niallp $
*/
public class TeeOutputStream extends ProxyOutputStream {
/** the second OutputStream to write to */
protected OutputStream branch;
/**
* Constructs a TeeOutputStream.
* @param out the main OutputStream
* @param branch the second OutputStream
*/
public TeeOutputStream( OutputStream out, OutputStream branch ) {
super(out);
this.branch = branch;
}
/**
* Write the bytes to both streams.
* @param b the bytes to write
* @throws IOException if an I/O error occurs
*/
@Override
public synchronized void write(byte[] b) throws IOException {
super.write(b);
this.branch.write(b);
}
/**
* Write the specified bytes to both streams.
* @param b the bytes to write
* @param off The start offset
* @param len The number of bytes to write
* @throws IOException if an I/O error occurs
*/
@Override
public synchronized void write(byte[] b, int off, int len) throws IOException {
super.write(b, off, len);
this.branch.write(b, off, len);
}
/**
* Write a byte to both streams.
* @param b the byte to write
* @throws IOException if an I/O error occurs
*/
@Override
public synchronized void write(int b) throws IOException {
super.write(b);
this.branch.write(b);
}
/**
* Flushes both streams.
* @throws IOException if an I/O error occurs
*/
@Override
public void flush() throws IOException {
super.flush();
this.branch.flush();
}
/**
* Closes both streams.
* @throws IOException if an I/O error occurs
*/
@Override
public void close() throws IOException {
super.close();
this.branch.close();
}
}

View File

@ -42,7 +42,7 @@ public class ModSecurityFilter implements Filter {
HttpServletResponse httpResp = (HttpServletResponse) response;
MsHttpTransaction httpTran = new MsHttpTransaction(httpReq, httpResp); //transaction object used by native code
try {
try {
int status = modsecurity.onRequest(modsecurity.getConfFilename(), httpTran, modsecurity.checkModifiedConfig()); //modsecurity reloads only if primary config file is modified
if (status != ModSecurity.DECLINED) {
@ -53,9 +53,7 @@ public class ModSecurityFilter implements Filter {
//process request
fc.doFilter(httpTran.getMsHttpRequest(), httpTran.getMsHttpResponse());
status = modsecurity.onResponse(httpTran);
if (status != ModSecurity.OK && status != ModSecurity.DECLINED) {
httpTran.getMsHttpResponse().reset();
httpTran.getMsHttpResponse().setStatus(status);

View File

@ -3,67 +3,20 @@ package org.modsecurity;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Locale;
import java.util.TimeZone;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
public class MsHttpServletResponse extends HttpServletResponseWrapper {
private static final int INTERCEPT_OFF = 0;
private static final int INTERCEPT_ON = 1;
private static final int INTERCEPT_OBSERVE_ONLY = 2;
public static final String DEFAULT_CHARACTER_ENCODING = "ISO-8859-1";
private int interceptMode = INTERCEPT_ON;
private ArrayList<Object> headers = new ArrayList<Object>();
private ArrayList<Object> cookies = new ArrayList<Object>();
private int status = -1;
private boolean committed = false;
private boolean suspended = false;
private boolean destroyed = false;
private String statusMessage;
private String contentType;
private String characterEncoding;
private int contentLength = -1;
private Locale locale;
private MsOutputStream msOutputStream;
private MsWriter msWriter;
protected SimpleDateFormat formats[] = {
new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US),
new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz", Locale.US),
new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy", Locale.US)
};
private class Header {
String name;
String value;
Header(String name, String value) {
this.name = name;
this.value = value;
}
}
private boolean destroyed = false;
private boolean suspended;
public MsHttpServletResponse(HttpServletResponse response) {
super(response);
characterEncoding = DEFAULT_CHARACTER_ENCODING;
TimeZone GMT_ZONE = TimeZone.getTimeZone("GMT");
formats[0].setTimeZone(GMT_ZONE);
formats[1].setTimeZone(GMT_ZONE);
formats[2].setTimeZone(GMT_ZONE);
locale = Locale.getDefault();
}
public void destroy() throws IOException {
@ -71,40 +24,6 @@ public class MsHttpServletResponse extends HttpServletResponseWrapper {
return;
}
if (interceptMode == INTERCEPT_ON) {
if (status != -1) {
super.setStatus(status);
}
if (contentType != null) {
super.setContentType(contentType);
}
if (characterEncoding != null) {
super.setCharacterEncoding(characterEncoding);
}
if (contentLength != -1) {
super.setContentLength(contentLength);
}
if (locale != null) {
super.setLocale(locale);
}
// send cookies
for (int i = 0, n = cookies.size(); i < n; i++) {
super.addCookie((Cookie) cookies.get(i));
}
// send headers
for (int i = 0, n = headers.size(); i < n; i++) {
Header h = (Header) headers.get(i);
// TODO don't send our cookie headers because
// they are not well implemented yet. Cookies
// are sent directly
if (h.name.compareTo("Set-Cookie") != 0) {
super.addHeader(h.name, h.value);
}
}
}
if (msWriter != null) {
msWriter.commit();
}
@ -151,48 +70,26 @@ public class MsHttpServletResponse extends HttpServletResponseWrapper {
}
}
@Override
public String getContentType() {
if (interceptMode != INTERCEPT_OFF) {
return contentType;
}
return super.getContentType();
}
@Override
public ServletOutputStream getOutputStream() throws IllegalStateException, IOException {
if (interceptMode != INTERCEPT_OFF) {
if (msWriter != null) {
throw new IllegalStateException();
}
if (msOutputStream == null) {
msOutputStream = new MsOutputStream(super.getOutputStream());
}
if (interceptMode == INTERCEPT_ON) {
msOutputStream.setBuffering(true);
}
return msOutputStream;
} else {
return super.getOutputStream();
if (msWriter != null) {
throw new IllegalStateException();
}
if (msOutputStream == null) {
msOutputStream = new MsOutputStream(super.getOutputStream());
}
return msOutputStream;
}
@Override
public PrintWriter getWriter() throws IllegalStateException, IOException {
if (interceptMode != INTERCEPT_OFF) {
if (msOutputStream != null) {
throw new IllegalStateException();
}
if (msWriter == null) {
msWriter = new MsWriter(super.getWriter());
}
if (interceptMode == INTERCEPT_ON) {
msWriter.setBuffering(true);
}
return msWriter;
} else {
return super.getWriter();
}
}
public ByteArrayInputStream getByteArrayStream() throws Exception {
@ -207,85 +104,6 @@ public class MsHttpServletResponse extends HttpServletResponseWrapper {
return stream;
}
@Override
public void setCharacterEncoding(String charset) {
if (interceptMode != INTERCEPT_ON) {
super.setCharacterEncoding(charset);
}
if (interceptMode != INTERCEPT_OFF) {
characterEncoding = charset;
}
}
@Override
public void setContentLength(int contentLength) {
if (interceptMode != INTERCEPT_ON) {
super.setContentLength(contentLength);
}
if (interceptMode != INTERCEPT_OFF) {
this.contentLength = contentLength;
headers.add(new Header("Content-Length", Integer.toString(contentLength)));
}
}
@Override
public void setContentType(String contentType) {
if (interceptMode != INTERCEPT_ON) {
super.setContentType(contentType);
}
if (interceptMode != INTERCEPT_OFF) {
this.contentType = contentType;
headers.add(new Header("Content-Type", contentType));
}
}
@Override
public void setBufferSize(int size) throws IllegalStateException {
super.setBufferSize(size);
}
@Override
public int getBufferSize() {
return super.getBufferSize();
}
@Override
public void flushBuffer() throws IOException {
if (interceptMode != INTERCEPT_ON) {
super.flushBuffer();
}
if (interceptMode != INTERCEPT_OFF) {
committed = true;
}
}
@Override
public void resetBuffer() {
if (interceptMode != INTERCEPT_ON) {
super.resetBuffer();
}
if (interceptMode != INTERCEPT_OFF) {
if (committed) {
throw new IllegalStateException();
}
if (msWriter != null) {
msWriter.reset();
}
if (msOutputStream != null) {
msOutputStream.reset();
}
}
}
@Override
public boolean isCommitted() {
if (interceptMode != INTERCEPT_OFF) {
return committed;
}
return super.isCommitted();
}
public void setBodyBytes(byte[] bytes) throws IOException {
if (msOutputStream == null) {
msWriter.reset();
@ -296,423 +114,4 @@ public class MsHttpServletResponse extends HttpServletResponseWrapper {
}
}
@Override
public void reset() throws IllegalStateException {
if (interceptMode != INTERCEPT_ON) {
super.reset();
}
if (interceptMode != INTERCEPT_OFF) {
if (committed) {
throw new IllegalStateException();
}
status = 200;
characterEncoding = DEFAULT_CHARACTER_ENCODING;
contentType = null;
contentLength = -1;
headers.clear();
status = 200;
statusMessage = null;
if (msWriter != null) {
msWriter.reset();
}
if (msOutputStream != null) {
msOutputStream.reset();
}
}
}
@Override
public void setLocale(Locale locale) {
if (interceptMode != INTERCEPT_ON) {
super.setLocale(locale);
}
if (interceptMode != INTERCEPT_OFF) {
this.locale = locale;
}
}
@Override
public Locale getLocale() {
if (interceptMode != INTERCEPT_OFF) {
return locale;
}
return super.getLocale();
}
@Override
public void addCookie(Cookie cookie) {
if (interceptMode != INTERCEPT_ON) {
super.addCookie(cookie);
}
if (interceptMode != INTERCEPT_OFF) {
cookies.add(cookie);
// TODO improve; these headers will not be
// sent to the client
StringBuilder sb = new StringBuilder();
sb.append(cookie.getName());
sb.append("=");
if (cookie.getDomain() != null) {
sb.append("; domain=").append(cookie.getDomain());
}
if (cookie.getPath() != null) {
sb.append("; path=").append(cookie.getPath());
}
if (cookie.getSecure()) {
sb.append("; secure");
}
headers.add(new Header("Set-Cookie", sb.toString()));
}
}
@Override
public void addDateHeader(String name, long value) {
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
this.addHeader(name, FastHttpDateFormat.formatDate(value, format));
}
@Override
public void addHeader(String name, String value) {
if (interceptMode != INTERCEPT_ON) {
super.addHeader(name, value);
}
if (interceptMode != INTERCEPT_OFF) {
headers.add(new Header(name, value));
}
}
@Override
public void addIntHeader(String name, int value) {
this.addHeader(name, Integer.toString(value));
}
@Override
public boolean containsHeader(String name) {
if (interceptMode == INTERCEPT_OFF) {
return super.containsHeader(name);
} else {
for (int i = 0, n = headers.size(); i < n; i++) {
Header h = (Header) headers.get(i);
if (h.name.compareTo(name) == 0) {
return true;
}
}
}
return false;
}
@Override
public void setDateHeader(String name, long value) {
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
this.setHeader(name, FastHttpDateFormat.formatDate(value, format));
}
@Override
public void setHeader(String name, String value) {
if (interceptMode != INTERCEPT_ON) {
super.setHeader(name, value);
}
if (interceptMode != INTERCEPT_OFF) {
for (int i = 0, n = headers.size(); i < n; i++) {
Header h = (Header) headers.get(i);
if (h.name.compareTo(name) == 0) {
headers.remove(i);
i--;
}
}
headers.add(new Header(name, value));
}
}
@Override
public void setIntHeader(String name, int value) {
this.setHeader(name, Integer.toString(value));
}
@Override
public void setStatus(int status) {
if (interceptMode != INTERCEPT_ON) {
super.setStatus(status);
}
if (interceptMode != INTERCEPT_OFF) {
this.status = status;
}
}
@Override
public void setStatus(int status, String message) {
if (interceptMode != INTERCEPT_ON) {
super.setStatus(status);
}
if (interceptMode != INTERCEPT_OFF) {
this.status = status;
this.statusMessage = message;
}
}
@Override
public void sendError(int status) throws IOException {
if (interceptMode != INTERCEPT_ON) {
super.sendError(status);
}
if (interceptMode != INTERCEPT_OFF) {
this.status = status;
this.setSuspended(true);
}
}
@Override
public void sendError(int status, String message) throws IOException {
if (interceptMode != INTERCEPT_ON) {
super.sendError(status);
}
if (interceptMode != INTERCEPT_OFF) {
this.status = status;
this.statusMessage = message;
this.setSuspended(true);
}
}
/* -- Inspection methods ---------------------------------------------- */
// TODO throw exception when interceptMode set to OFF
public Cookie[] getCookies() {
return (Cookie[]) cookies.toArray(new Cookie[cookies.size()]);
}
@Override
public int getStatus() {
return status;
}
public int getContentLength() {
return contentLength;
}
public long getDateHeader(String name) throws IllegalArgumentException {
String value = this.getHeader(name);
if (value == null) {
return -1;
}
long result = FastHttpDateFormat.parseDate(value, formats);
if (result == -1) {
throw new IllegalArgumentException(value);
}
return result;
}
@Override
public String getHeader(String name) {
for (int i = 0, n = headers.size(); i < n; i++) {
Header h = (Header) headers.get(i);
if (h.name.compareTo(name) == 0) {
return h.value;
}
}
return null;
}
@Override
public Collection<String> getHeaderNames() {
Collection<String> headerNames = new LinkedList<String>();
for (int i = 0, n = headers.size(); i < n; i++) {
Header h = (Header) headers.get(i);
headerNames.add(h.value);
}
return headerNames;
}
public int getIntHeader(String name) throws NumberFormatException {
String value = this.getHeader(name);
if (value == null) {
return -1;
}
return Integer.parseInt(value);
}
@Override
public Collection<String> getHeaders(String name) {
Collection<String> headerValues = new LinkedList<String>();
for (int i = 0, n = headers.size(); i < n; i++) {
Header h = (Header) headers.get(i);
if (h.name.compareTo(name) == 0) {
headerValues.add(h.value);
}
}
return headerValues;
}
}
/**
* Utility class to generate HTTP dates.
*
* @author Remy Maucherat
*/
final class FastHttpDateFormat {
// -------------------------------------------------------------- Variables
/**
* HTTP date format.
*/
protected static final SimpleDateFormat format =
new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
/**
* The set of SimpleDateFormat formats to use in getDateHeader().
*/
protected static final SimpleDateFormat formats[] = {
new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US),
new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz", Locale.US),
new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy", Locale.US)
};
protected final static TimeZone gmtZone = TimeZone.getTimeZone("GMT");
/**
* GMT timezone - all HTTP dates are on GMT
*/
static {
format.setTimeZone(gmtZone);
formats[0].setTimeZone(gmtZone);
formats[1].setTimeZone(gmtZone);
formats[2].setTimeZone(gmtZone);
}
/**
* Instant on which the currentDate object was generated.
*/
protected static long currentDateGenerated = 0L;
/**
* Current formatted date.
*/
protected static String currentDate = null;
/**
* Formatter cache.
*/
protected static final HashMap<Object, Object> formatCache = new HashMap<Object, Object>();
/**
* Parser cache.
*/
protected static final HashMap<Object, Object> parseCache = new HashMap<Object, Object>();
// --------------------------------------------------------- Public Methods
/**
* Get the current date in HTTP format.
*/
public static String getCurrentDate() {
long now = System.currentTimeMillis();
if ((now - currentDateGenerated) > 1000) {
synchronized (format) {
if ((now - currentDateGenerated) > 1000) {
currentDateGenerated = now;
currentDate = format.format(new Date(now));
}
}
}
return currentDate;
}
/**
* Get the HTTP format of the specified date.
*/
public static String formatDate(long value, DateFormat threadLocalformat) {
String cachedDate = null;
Long longValue = new Long(value);
try {
cachedDate = (String) formatCache.get(longValue);
} catch (Exception e) {
}
if (cachedDate != null) {
return cachedDate;
}
String newDate;
Date dateValue = new Date(value);
if (threadLocalformat != null) {
newDate = threadLocalformat.format(dateValue);
synchronized (formatCache) {
updateCache(formatCache, longValue, newDate);
}
} else {
synchronized (formatCache) {
newDate = format.format(dateValue);
updateCache(formatCache, longValue, newDate);
}
}
return newDate;
}
/**
* Try to parse the given date as a HTTP date.
*/
public static long parseDate(String value,
DateFormat[] threadLocalformats) {
Long cachedDate = null;
try {
cachedDate = (Long) parseCache.get(value);
} catch (Exception e) {
}
if (cachedDate != null) {
return cachedDate.longValue();
}
Long date;
if (threadLocalformats != null) {
date = internalParseDate(value, threadLocalformats);
synchronized (parseCache) {
updateCache(parseCache, value, date);
}
} else {
synchronized (parseCache) {
date = internalParseDate(value, formats);
updateCache(parseCache, value, date);
}
}
if (date == null) {
return (-1L);
} else {
return date.longValue();
}
}
/**
* Parse date with given formatters.
*/
private static Long internalParseDate(String value, DateFormat[] formats) {
Date date = null;
for (int i = 0; (date == null) && (i < formats.length); i++) {
try {
date = formats[i].parse(value);
} catch (ParseException e) {
}
}
if (date == null) {
return null;
}
return new Long(date.getTime());
}
/**
* Update cache.
*/
private static void updateCache(HashMap<Object, Object> cache, Object key,
Object value) {
if (value == null) {
return;
}
if (cache.size() > 1000) {
cache.clear();
}
cache.put(key, value);
}
}