diff --git a/java/ModSecurityJNI.sln b/java/ModSecurityJNI.sln index 56adec80..934e73d9 100644 --- a/java/ModSecurityJNI.sln +++ b/java/ModSecurityJNI.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 11.00 # Visual Studio 2010 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "standalone", "standalone.vcxproj", "{20EC871F-B6A0-4398-9B67-A33598A796E8}" +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ModSecurityJNI", "ModSecurityJNI.vcxproj", "{20EC871F-B6A0-4398-9B67-A33598A796E8}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution diff --git a/java/ModSecurityTestApp/src/java/org/modsecurity/ModSecurity.java b/java/ModSecurityTestApp/src/java/org/modsecurity/ModSecurity.java index 3564e062..e5039f91 100644 --- a/java/ModSecurityTestApp/src/java/org/modsecurity/ModSecurity.java +++ b/java/ModSecurityTestApp/src/java/org/modsecurity/ModSecurity.java @@ -4,16 +4,13 @@ import java.io.File; import java.net.Inet6Address; import java.net.InetAddress; import java.net.UnknownHostException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; import javax.servlet.FilterConfig; import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; +/** + * + * @author Mihai Pitu + */ public final class ModSecurity { public static final int DONE = -2; @@ -54,34 +51,7 @@ public final class ModSecurity { public native int onRequest(String config, MsHttpTransaction httpTran, boolean reloadConfig); - public native int onResponse(ServletResponse response, HttpServletResponse htttpResponse, String requestID); - - public static String[][] getHttpRequestHeaders(HttpServletRequest req) { - ArrayList aList = Collections.list(req.getHeaderNames()); - String[][] result = new String[aList.size()][2]; - - for (int i = 0; i < aList.size(); i++) { - result[i][0] = aList.get(i); - result[i][1] = req.getHeader(aList.get(i)); - } - - return result; - } - - public static String[][] getHttpResponseHeaders(HttpServletResponse resp) { - - Collection headerNames = resp.getHeaderNames(); - String[][] result = new String[headerNames.size()][2]; - - int i = 0; - for (String headerName : headerNames) { - result[i][0] = headerName; - result[i][1] = resp.getHeader(headerName); - i++; - } - - return result; - } + public native int onResponse(MsHttpTransaction httpTran); public static boolean isIPv6(String addr) { try { diff --git a/java/ModSecurityTestApp/src/java/org/modsecurity/ModSecurityFilter.java b/java/ModSecurityTestApp/src/java/org/modsecurity/ModSecurityFilter.java index d5b41901..7a8fcbb3 100644 --- a/java/ModSecurityTestApp/src/java/org/modsecurity/ModSecurityFilter.java +++ b/java/ModSecurityTestApp/src/java/org/modsecurity/ModSecurityFilter.java @@ -1,8 +1,6 @@ package org.modsecurity; -import java.io.BufferedInputStream; import java.io.IOException; -import java.util.UUID; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; @@ -14,7 +12,7 @@ import javax.servlet.http.HttpServletResponse; /** * - * Docs: http://docs.oracle.com/javaee/6/tutorial/doc/bnagb.html + * @author Mihai Pitu */ public class ModSecurityFilter implements Filter { @@ -36,18 +34,19 @@ public class ModSecurityFilter implements Filter { public void doFilter(ServletRequest request, ServletResponse response, FilterChain fc) throws IOException, ServletException { HttpServletRequest httpReq = (HttpServletRequest) request; HttpServletResponse httpResp = (HttpServletResponse) response; - MsHttpTransaction httpTran = new MsHttpTransaction(httpReq, httpResp); + MsHttpTransaction httpTran = new MsHttpTransaction(httpReq, httpResp); //transaction object used by native code try { - int status = modsecurity.onRequest(modsecurity.getConfFilename(), httpTran, modsecurity.checkModifiedConfig()); + int status = modsecurity.onRequest(modsecurity.getConfFilename(), httpTran, modsecurity.checkModifiedConfig()); //modsecurity reloads only if primary config file is modified if (status != ModSecurity.DECLINED) { return; } - //BufferedInputStream buf = new BufferedInputStream(httpReqWrapper.getInputStream()); + //process request fc.doFilter(httpTran.getMsHttpRequest(), httpTran.getMsHttpResponse()); - //status = modsecurity.onResponse(response, httpResp, requestID); + + status = modsecurity.onResponse(httpTran); } finally { httpTran.destroy(); diff --git a/java/ModSecurityTestApp/src/java/org/modsecurity/MsHttpServletRequest.java b/java/ModSecurityTestApp/src/java/org/modsecurity/MsHttpServletRequest.java index aee1094b..36b2e05d 100644 --- a/java/ModSecurityTestApp/src/java/org/modsecurity/MsHttpServletRequest.java +++ b/java/ModSecurityTestApp/src/java/org/modsecurity/MsHttpServletRequest.java @@ -16,6 +16,8 @@ import java.io.OutputStream; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; import java.util.Hashtable; @@ -26,12 +28,12 @@ import javax.servlet.ServletException; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; import org.apache.commons.fileupload.DefaultFileItem; import org.apache.commons.fileupload.DiskFileUpload; import org.apache.commons.fileupload.FileItem; import org.apache.commons.fileupload.FileUploadException; - public class MsHttpServletRequest extends HttpServletRequestWrapper { public final static int BODY_NOTYETREAD = 0; @@ -79,7 +81,23 @@ public class MsHttpServletRequest extends HttpServletRequestWrapper { bodyFile.delete(); } } + + public static String[][] getHttpRequestHeaders(HttpServletRequest req) { + ArrayList aList = Collections.list(req.getHeaderNames()); + String[][] result = new String[aList.size()][2]; + + try { + for (int i = 0; i < aList.size(); i++) { + result[i][0] = aList.get(i); + result[i][1] = req.getHeader(aList.get(i)); + } + } catch (Exception ex) { + } + + return result; + } + public String getTmpPath() { return tmpPath; } @@ -113,7 +131,9 @@ public class MsHttpServletRequest extends HttpServletRequestWrapper { } public void readBody(int maxContentLength) throws IOException, ServletException { + String contentType = req.getContentType(); + if ((contentType != null) && (contentType.startsWith("multipart/form-data"))) { readBodyMultipart(maxContentLength); } else { @@ -202,6 +222,7 @@ public class MsHttpServletRequest extends HttpServletRequestWrapper { } } + /** * Parses the given URL-encoded string and adds the parameters to the * request parameter list. diff --git a/java/ModSecurityTestApp/src/java/org/modsecurity/MsHttpServletResponse.java b/java/ModSecurityTestApp/src/java/org/modsecurity/MsHttpServletResponse.java index f876551b..46729c18 100644 --- a/java/ModSecurityTestApp/src/java/org/modsecurity/MsHttpServletResponse.java +++ b/java/ModSecurityTestApp/src/java/org/modsecurity/MsHttpServletResponse.java @@ -1,5 +1,6 @@ package org.modsecurity; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.PrintWriter; import java.text.DateFormat; @@ -114,6 +115,22 @@ public class MsHttpServletResponse extends HttpServletResponseWrapper { destroyed = true; } + public static String[][] getHttpResponseHeaders(HttpServletResponse resp) { + + Collection headerNames = resp.getHeaderNames(); + String[][] result = new String[headerNames.size()][2]; + try { + int i = 0; + for (String headerName : headerNames) { + result[i][0] = headerName; + result[i][1] = resp.getHeader(headerName); + i++; + } + } catch (Exception ex) { + } + return result; + } + public String getBody() { if (msWriter != null) { return msWriter.toString(); @@ -178,6 +195,19 @@ public class MsHttpServletResponse extends HttpServletResponseWrapper { } } + public ByteArrayInputStream getByteArrayStream() throws Exception { + ByteArrayInputStream stream = null; + if (msOutputStream == null) { + MsWriter writer = ((MsWriter) this.getWriter()); + stream = new ByteArrayInputStream(new String(writer.toCharArray()).getBytes()); + } else if (msWriter == null) { + stream = new ByteArrayInputStream(((MsOutputStream) this.getOutputStream()).toByteArray()); + } else { + + } + return stream; + } + @Override public void setCharacterEncoding(String charset) { if (interceptMode != INTERCEPT_ON) { diff --git a/java/ModSecurityTestApp/src/java/org/modsecurity/MsOutputStream.java b/java/ModSecurityTestApp/src/java/org/modsecurity/MsOutputStream.java index b2c3273e..cdc4a559 100644 --- a/java/ModSecurityTestApp/src/java/org/modsecurity/MsOutputStream.java +++ b/java/ModSecurityTestApp/src/java/org/modsecurity/MsOutputStream.java @@ -1,5 +1,6 @@ package org.modsecurity; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.ByteArrayOutputStream; import java.io.UnsupportedEncodingException; @@ -37,6 +38,10 @@ public class MsOutputStream extends ServletOutputStream { public byte[] toByteArray() { return buffer.toByteArray(); } + + public ByteArrayInputStream getByteArrayStream() { + return new ByteArrayInputStream(buffer.toByteArray()); + } public void reset() { buffer.reset(); diff --git a/java/org_modsecurity_ModSecurity.cpp b/java/org_modsecurity_ModSecurity.cpp index a45b05f6..84040576 100644 --- a/java/org_modsecurity_ModSecurity.cpp +++ b/java/org_modsecurity_ModSecurity.cpp @@ -28,6 +28,9 @@ #define HTTPTRANSACTION_HTTPREQUEST_SIG "()Ljavax/servlet/http/HttpServletRequest;" #define HTTPTRANSACTION_MSHTTPREQUEST_MET "getMsHttpRequest" #define HTTPTRANSACTION_MSHTTPREQUEST_SIG "()Lorg/modsecurity/MsHttpServletRequest;" +#define HTTPTRANSACTION_MSHTTPRESPONSE_MET "getMsHttpResponse" +#define HTTPTRANSACTION_MSHTTPRESPONSE_SIG "()Lorg/modsecurity/MsHttpServletResponse;" + #define HTTPTRANSACTION_TRANSACTIONID_MET "getTransactionID" @@ -55,6 +58,11 @@ #define SERVLETRESPONSE_CONTENTTYPE_MET "getContentType" #define SERVLETRESPONSE_CHARENCODING_MET "getCharacterEncoding" +#define MSSERVLETRESPONSE_OUTPUTSTREAM_MET "getByteArrayStream" +#define MSSERVLETRESPONSE_OUTPUTSTREAM_SIG "()Ljava/io/ByteArrayInputStream;" + +#define MSSERVLETRESPONSE_RESET_MET "reset" +#define MSSERVLETRESPONSE_RESET_SIG "()V" //typedef struct { @@ -69,27 +77,29 @@ apr_table_t *requests; apr_pool_t *requestsPool; -#define JAVASERVLET_INSTREAM "RequestInStream" +#define JAVASERVLET_INSTREAM "ReqBStr" +#define JAVASERVLET_OUTSTREAM "ResBStr" -void storeJavaServletContext(request_rec *r, jobject obj) + +void storeJavaServletContext(request_rec *r, const char *key, jobject obj) { - apr_table_setn(r->notes, JAVASERVLET_INSTREAM, (const char *)obj); + apr_table_setn(r->notes, key, (const char *)obj); } -jobject getJavaServletContext(request_rec *r) +jobject getJavaServletContext(request_rec *r, const char *key) { jobject obj = NULL; request_rec *rx = NULL; /* Look in the current request first. */ - obj = (jobject)apr_table_get(r->notes, JAVASERVLET_INSTREAM); + obj = (jobject)apr_table_get(r->notes, key); if (obj != NULL) return obj; /* If this is a subrequest then look in the main request. */ if (r->main != NULL) { - obj = (jobject)apr_table_get(r->main->notes, JAVASERVLET_INSTREAM); + obj = (jobject)apr_table_get(r->main->notes, key); if (obj != NULL) { return obj; @@ -100,7 +110,7 @@ jobject getJavaServletContext(request_rec *r) rx = r->prev; while(rx != NULL) { - obj = (jobject)apr_table_get(rx->notes, JAVASERVLET_INSTREAM); + obj = (jobject)apr_table_get(rx->notes, key); if (obj != NULL) { return obj; @@ -169,7 +179,7 @@ inline char* fromJString(JNIEnv *env, jstring jStr, apr_pool_t *pool) str = (char*) apr_palloc(pool, len + 1); memcpy(str, jCStr, len); str[len] = '\0'; //null terminate - (env)->ReleaseStringUTFChars(jStr, jCStr); //release java heap memory + (env)->ReleaseStringUTFChars(jStr, jCStr); //release java memory } else str = ""; @@ -205,7 +215,7 @@ void logSec(void *obj, int level, char *str) apr_status_t ReadBodyCallback(request_rec *r, char *buf, unsigned int length, unsigned int *readcnt, int *is_eos) { - jobject inputStream = getJavaServletContext(r); //servlet request input stream + jobject inputStream = getJavaServletContext(r, JAVASERVLET_INSTREAM); //servlet request input stream JNIEnv *env; *readcnt = 0; @@ -253,6 +263,43 @@ apr_status_t WriteBodyCallback(request_rec *r, char *buf, unsigned int length) apr_status_t ReadResponseCallback(request_rec *r, char *buf, unsigned int length, unsigned int *readcnt, int *is_eos) { + jobject inputStream = getJavaServletContext(r, JAVASERVLET_OUTSTREAM); + JNIEnv *env; + + *readcnt = 0; + + if(inputStream == NULL) + { + *is_eos = 1; + return APR_SUCCESS; + } + + if (!(jvm)->AttachCurrentThread((void **)&env, NULL)) + { + jclass inputStreamClass = env->GetObjectClass(inputStream); + jmethodID read = (env)->GetMethodID(inputStreamClass, INPUTSTREAM_READ_MET, INPUTSTREAM_READ_SIG); + + jbyteArray byteArrayBuf = (env)->NewByteArray(length); + + jint count = (env)->CallIntMethod(inputStream, read, byteArrayBuf, 0, length); + jbyte* bufferPtr = (env)->GetByteArrayElements(byteArrayBuf, NULL); + + if (count == -1 || count > length || env->ExceptionCheck() == JNI_TRUE) //end of stream + { + *is_eos = 1; + } + else + { + *readcnt = count; + + memcpy(buf, bufferPtr, *readcnt); + } + (env)->ReleaseByteArrayElements(byteArrayBuf, bufferPtr, NULL); + (env)->DeleteLocalRef(byteArrayBuf); + + (jvm)->DetachCurrentThread(); + } + return APR_SUCCESS; } @@ -302,11 +349,11 @@ JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_destroy(JNIEnv *env, job } -inline void setHeaders(JNIEnv *env, jclass modSecurityClass, jobject httpServletR, apr_table_t *reqHeaders, apr_pool_t *pool, const char *headersMet, const char *headersSig) +inline void setHeaders(JNIEnv *env, jclass httpServletRequestClass, jobject httpServletR, apr_table_t *reqHeaders, apr_pool_t *pool, const char *headersMet, const char *headersSig) { //All headers are returned in a table by a static method from ModSecurity class - jmethodID getHttpHeaders = (env)->GetStaticMethodID(modSecurityClass, headersMet, headersSig); - jobjectArray headersTable = (jobjectArray) (env)->CallStaticObjectMethod(modSecurityClass, getHttpHeaders, httpServletR); + jmethodID getHttpHeaders = (env)->GetStaticMethodID(httpServletRequestClass, headersMet, headersSig); + jobjectArray headersTable = (jobjectArray) (env)->CallStaticObjectMethod(httpServletRequestClass, getHttpHeaders, httpServletR); jsize size = (env)->GetArrayLength(headersTable); for (int i = 0; i < size; i++) @@ -335,23 +382,20 @@ inline void setHeaders(JNIEnv *env, jclass modSecurityClass, jobject httpServlet JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onRequest(JNIEnv *env, jobject obj, jstring configPath, jobject httpTransaction, jboolean reloadConfig) { - //critical section ? conn_rec *c; request_rec *r; - const char *path = (env)->GetStringUTFChars(configPath, NULL); //path to modsecurity.conf - - if (config == NULL || reloadConfig) { config = modsecGetDefaultConfig(); + const char *path = fromJString(env, configPath, config->mp); //path to modsecurity.conf + const char *err = modsecProcessConfig(config, path, NULL); if(err != NULL) { logSec(NULL, 0, (char*)err); - (env)->ReleaseStringUTFChars(configPath, path); return DONE; } } @@ -359,26 +403,27 @@ JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onRequest(JNIEnv *env, j c = modsecNewConnection(); modsecProcessConnection(c); r = modsecNewRequest(c, config); - jclass httpTransactionClass = env->GetObjectClass(httpTransaction); jmethodID getHttpRequest = env->GetMethodID(httpTransactionClass, HTTPTRANSACTION_MSHTTPREQUEST_MET, HTTPTRANSACTION_MSHTTPREQUEST_SIG); - + jobject httpServletRequest = env->CallObjectMethod(httpTransaction, getHttpRequest); - jobject servletRequest = httpServletRequest; //test it - - jclass httpServletRequestClass = env->GetObjectClass(httpServletRequest); //HttpServletRequest interface + jobject servletRequest = httpServletRequest; //superclass of HttpServletRequest + + jclass httpServletRequestClass = env->GetObjectClass(httpServletRequest); //MsHttpServletRequest interface jclass servletRequestClass = env->GetObjectClass(servletRequest); //ServletRequest interface jclass modSecurityClass = env->GetObjectClass(obj); //ModSecurity class + //readBody method reads all bytes from the inputStream or a maximum of 'config->reqbody_limit' bytes jmethodID readBody = env->GetMethodID(httpServletRequestClass, MSHTTPSERVLETREQUEST_READBODY_MET, MSHTTPSERVLETREQUEST_READBODY_SIG); - env->CallIntMethod(httpServletRequest, readBody, config->reqbody_limit); + env->CallVoidMethod(httpServletRequest, readBody, config->reqbody_limit); + if (env->ExceptionCheck() == JNI_TRUE) //read body raised an Exception, return to JVM { modsecFinishRequest(r); return DONE; } - + jmethodID getTransactionID = env->GetMethodID(httpTransactionClass, HTTPTRANSACTION_TRANSACTIONID_MET, STRINGRETURN_SIG); const char *reqID = fromJStringMethod(env, getTransactionID, httpTransaction, r->pool); //fromJString(env, requestID, r->pool); //unique ID of this request apr_table_setn(requests, reqID, (const char*) r); //store this request for response processing @@ -387,7 +432,7 @@ JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onRequest(JNIEnv *env, j jmethodID getInputStream = (env)->GetMethodID(httpServletRequestClass, SERVLETREQUEST_INPUTSTREAM_MET, SERVLETREQUEST_INPUTSTREAM_SIG); jobject inputStream = (env)->CallObjectMethod(httpServletRequest, getInputStream); //Request body input stream used in the read body callback - storeJavaServletContext(r, inputStream); + storeJavaServletContext(r, JAVASERVLET_INSTREAM, inputStream); jmethodID getServerName = (env)->GetMethodID(servletRequestClass, SERVLETREQUEST_SERVERNAME_MET, STRINGRETURN_SIG); r->hostname = fromJStringMethod(env, getServerName, servletRequest, r->pool); @@ -407,7 +452,7 @@ JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onRequest(JNIEnv *env, j r->args = fromJStringMethod(env, getQueryString, httpServletRequest, r->pool); - setHeaders(env, modSecurityClass, httpServletRequest, r->headers_in, r->pool, MODSECURITY__HTTPREQHEADERS_MET, MODSECURITY__HTTPREQHEADERS_SIG); + setHeaders(env, httpServletRequestClass, httpServletRequest, r->headers_in, r->pool, MODSECURITY__HTTPREQHEADERS_MET, MODSECURITY__HTTPREQHEADERS_SIG); jmethodID getCharacterEncoding = (env)->GetMethodID(servletRequestClass, SERVLETREQUEST_CHARENCODING_MET, STRINGRETURN_SIG); @@ -517,41 +562,63 @@ JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onRequest(JNIEnv *env, j int status = modsecProcessRequest(r); - (env)->ReleaseStringUTFChars(configPath, path); - (env)->DeleteLocalRef(inputStream); - return status; } -JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onResponse(JNIEnv *env, jobject obj, jobject servletResponse, jobject httpServletResponse, jstring requestID) +JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onResponse(JNIEnv *env, jobject obj, jobject httpTransaction) { - const char *reqID = env->GetStringUTFChars(requestID, NULL); + jclass httpTransactionClass = env->GetObjectClass(httpTransaction); + + jmethodID getTransactionID = env->GetMethodID(httpTransactionClass, HTTPTRANSACTION_TRANSACTIONID_MET, STRINGRETURN_SIG); + jstring reqIDjStr = (jstring) env->CallObjectMethod(httpTransaction, getTransactionID); + const char *reqID = env->GetStringUTFChars(reqIDjStr, NULL); + request_rec *r = (request_rec*) apr_table_get(requests, reqID); + apr_table_unset(requests, reqID); //remove this request from the requests table + env->ReleaseStringUTFChars(reqIDjStr, reqID); if (r == NULL) { - env->ReleaseStringUTFChars(requestID, reqID); return DONE; } - jclass httpServletResponseClass = env->GetObjectClass(httpServletResponse); //HttpServletResponse interface - jclass servletResponseClass = env->GetObjectClass(servletResponse); //ServletResponse interface - jclass modSecurityClass = env->GetObjectClass(obj); //ModSecurity class + jmethodID getHttpResponse = env->GetMethodID(httpTransactionClass, HTTPTRANSACTION_MSHTTPRESPONSE_MET, HTTPTRANSACTION_MSHTTPRESPONSE_SIG); + jobject httpServletResponse = env->CallObjectMethod(httpTransaction, getHttpResponse); - jmethodID getContentType = (env)->GetMethodID(servletResponseClass, SERVLETRESPONSE_CONTENTTYPE_MET, STRINGRETURN_SIG); - char *ct = fromJStringMethod(env, getContentType, servletResponse, r->pool); + jclass httpServletResponseClass = env->GetObjectClass(httpServletResponse); //MsHttpServletResponse class + //jclass modSecurityClass = env->GetObjectClass(obj); //ModSecurity class + + jmethodID getOutputStream = (env)->GetMethodID(httpServletResponseClass, MSSERVLETRESPONSE_OUTPUTSTREAM_MET, MSSERVLETRESPONSE_OUTPUTSTREAM_SIG); + jobject responseStream = (env)->CallObjectMethod(httpServletResponse, getOutputStream); //Response output stream used in the read response callback + + if (env->ExceptionCheck() == JNI_TRUE) + { + modsecFinishRequest(r); + return DONE; + } + + //jclass msOutputStreamClass = env->GetObjectClass(msOutputStream); + + //jmethodID getByteArrayStream = env->GetMethodID(msOutputStreamClass, MSOUTPUTSTREAM_INPUTSTREAM_MET, MSOUTPUTSTREAM_INPUTSTREAM_SIG); + //jobject responseStream = env->CallObjectMethod(msOutputStream, getByteArrayStream); + + storeJavaServletContext(r, JAVASERVLET_OUTSTREAM, responseStream); + + + jmethodID getContentType = (env)->GetMethodID(httpServletResponseClass, SERVLETRESPONSE_CONTENTTYPE_MET, STRINGRETURN_SIG); + char *ct = fromJStringMethod(env, getContentType, httpServletResponse, r->pool); if(strcmp(ct, "") == 0) ct = "text/html"; r->content_type = ct; - jmethodID getCharEncoding = (env)->GetMethodID(servletResponseClass, SERVLETRESPONSE_CHARENCODING_MET, STRINGRETURN_SIG); - r->content_encoding = fromJStringMethod(env, getCharEncoding, servletResponse, r->pool); + jmethodID getCharEncoding = (env)->GetMethodID(httpServletResponseClass, SERVLETRESPONSE_CHARENCODING_MET, STRINGRETURN_SIG); + r->content_encoding = fromJStringMethod(env, getCharEncoding, httpServletResponse, r->pool); - setHeaders(env, modSecurityClass, httpServletResponse, r->headers_out, r->pool, MODSECURITY__HTTPRESHEADERS_MET, MODSECURITY__HTTPRESHEADERS_SIG); + setHeaders(env, httpServletResponseClass, httpServletResponse, r->headers_out, r->pool, MODSECURITY__HTTPRESHEADERS_MET, MODSECURITY__HTTPRESHEADERS_SIG); const char *lng = apr_table_get(r->headers_out, "Content-Languages"); if(lng != NULL) @@ -560,12 +627,18 @@ JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onResponse(JNIEnv *env, *(const char **)apr_array_push(r->content_languages) = lng; } - //modsecProcessResponse(r); + int status = modsecProcessResponse(r); - apr_table_unset(requests, reqID); //remove this request from the requests table modsecFinishRequest(r); - env->ReleaseStringUTFChars(requestID, reqID); + // the logic here is temporary, needs clarification + if(status != 0 && status != -1) + { + //reset the stream, clear the response + jmethodID reset = (env)->GetMethodID(httpServletResponseClass, MSSERVLETRESPONSE_RESET_MET, MSSERVLETRESPONSE_RESET_SIG); + env->CallVoidMethod(httpServletResponse, reset); - return DONE; -} \ No newline at end of file + } + + return status; +} diff --git a/java/org_modsecurity_ModSecurity.h b/java/org_modsecurity_ModSecurity.h index a8791e72..ec976b47 100644 --- a/java/org_modsecurity_ModSecurity.h +++ b/java/org_modsecurity_ModSecurity.h @@ -40,10 +40,10 @@ JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onRequest /* * Class: org_modsecurity_ModSecurity * Method: onResponse - * Signature: (Ljavax/servlet/ServletResponse;Ljavax/servlet/http/HttpServletResponse;Ljava/lang/String;)I + * Signature: (Lorg/modsecurity/MsHttpTransaction;)I */ JNIEXPORT jint JNICALL Java_org_modsecurity_ModSecurity_onResponse - (JNIEnv *, jobject, jobject, jobject, jstring); + (JNIEnv *, jobject, jobject); #ifdef __cplusplus }