From 34809d8064608963a9e41aaf21f869e1ae28f7d0 Mon Sep 17 00:00:00 2001 From: Martin Vierula Date: Tue, 10 Oct 2023 10:31:52 -0700 Subject: [PATCH] Add expirevar support for lmdb --- src/collection/backend/collection_data.cc | 99 +++++++++ src/collection/backend/collection_data.h | 6 +- src/collection/backend/lmdb.cc | 209 ++++++++++++++++-- src/collection/backend/lmdb.h | 2 + test/cppcheck_suppressions.txt | 1 + .../regression/action-expirevar.json | 70 +++++- 6 files changed, 363 insertions(+), 24 deletions(-) diff --git a/src/collection/backend/collection_data.cc b/src/collection/backend/collection_data.cc index e9ce29ba..95499257 100644 --- a/src/collection/backend/collection_data.cc +++ b/src/collection/backend/collection_data.cc @@ -35,6 +35,105 @@ void CollectionData::setExpiry(int32_t seconds_until_expiry) { m_hasExpiryTime = true; } +std::string CollectionData::getSerialized() const { + std::string serialized; + if (hasValue()) { + serialized.reserve(30 + 10 + getValue().size()); + } else { + serialized.reserve(16+10); + } + + serialized.assign("{"); + + if (hasExpiry()) { + serialized.append("\"__expire_\":"); + uint64_t expiryEpochSeconds = std::chrono::duration_cast(m_expiryTime.time_since_epoch()).count(); + serialized.append(std::to_string(expiryEpochSeconds)); + if (hasValue()) { + serialized.append(","); + } + } + if (hasValue()) { + serialized.append("\"__value_\":\""); + serialized.append(getValue()); + serialized.append("\""); + } + + serialized.append("}"); + + return serialized; +} + +void CollectionData::setFromSerialized(const char* serializedData, size_t length) { + const static std::string expiryPrefix("\"__expire_\":"); + const static std::string valuePrefix("\"__value_\":\""); + m_hasValue = false; + m_hasExpiryTime = false; + + std::string serializedString(serializedData, length); + if ((serializedString.find("{") == 0) && (serializedString.substr(serializedString.length()-1) == "}")) { + size_t currentPos = 1; + uint64_t expiryEpochSeconds = 0; + bool invalidSerializedFormat = false; + bool doneParsing = false; + + // Extract the expiry time, if it exists + if (serializedString.find(expiryPrefix, currentPos) == currentPos) { + currentPos += expiryPrefix.length(); + std::string expiryDigits = serializedString.substr(currentPos, 10); + if (expiryDigits.find_first_not_of("0123456789") == std::string::npos) { + expiryEpochSeconds = strtoll(expiryDigits.c_str(), NULL, 10); + } else { + invalidSerializedFormat = true; + } + currentPos += 10; + } + + if ((!invalidSerializedFormat) && (expiryEpochSeconds > 0)) { + if (serializedString.find(",", currentPos) == currentPos) { + currentPos++; + } else if (currentPos == serializedString.length()-1) { + doneParsing = true; + } else { + invalidSerializedFormat = true; + } + } + + if ((!invalidSerializedFormat) && (!doneParsing)) { + // Extract the value + if ((serializedString.find(valuePrefix, currentPos) == currentPos)) { + currentPos += valuePrefix.length(); + size_t expectedCloseQuotePos = serializedString.length() - 2; + if ((serializedString.substr(expectedCloseQuotePos, 1) == "\"") && (expectedCloseQuotePos >= currentPos)) { + m_value = serializedString.substr(currentPos); + m_value.resize(m_value.length()-2); + m_hasValue = true; + } else { + invalidSerializedFormat = true; + } + } else { + invalidSerializedFormat = true; + } + } + + // Set the object's expiry time, if we found one + if ((!invalidSerializedFormat) && (expiryEpochSeconds > 0)) { + std::chrono::seconds expiryDuration(expiryEpochSeconds); + std::chrono::system_clock::time_point expiryTimePoint(expiryDuration); + m_expiryTime = expiryTimePoint; + m_hasExpiryTime = true; + } + if (!invalidSerializedFormat) { + return; + } + } + + // this is the residual case; the entire string is a simple value (not JSON-ish encoded) + // the foreseen case here is lmdb content from prior to the serialization support + m_value.assign(serializedData, length); + m_hasValue = true; + return; +} } // namespace backend } // namespace collection diff --git a/src/collection/backend/collection_data.h b/src/collection/backend/collection_data.h index 691f206d..8b1b25d8 100644 --- a/src/collection/backend/collection_data.h +++ b/src/collection/backend/collection_data.h @@ -16,13 +16,10 @@ #ifdef __cplusplus #include -#include #include #endif -#include "modsecurity/collection/collection.h" - #ifndef SRC_COLLECTION_DATA_H_ #define SRC_COLLECTION_DATA_H_ @@ -53,6 +50,9 @@ public: bool hasExpiry() const { return m_hasExpiryTime;} bool isExpired() const; + std::string getSerialized() const; + void setFromSerialized(const char* serializedData, size_t length); + private: bool m_hasValue; bool m_hasExpiryTime; diff --git a/src/collection/backend/lmdb.cc b/src/collection/backend/lmdb.cc index 7e0b63a3..0fb92964 100644 --- a/src/collection/backend/lmdb.cc +++ b/src/collection/backend/lmdb.cc @@ -15,6 +15,7 @@ #include "src/collection/backend/lmdb.h" +#include "src/collection/backend/collection_data.h" #include #include @@ -158,6 +159,7 @@ std::unique_ptr LMDB::resolveFirst(const std::string& var) { MDB_val mdb_value_ret; std::unique_ptr ret = NULL; MDB_txn *txn = NULL; + CollectionData collectionData; string2val(var, &mdb_key); @@ -172,17 +174,125 @@ std::unique_ptr LMDB::resolveFirst(const std::string& var) { goto end_get; } - ret = std::unique_ptr(new std::string( - reinterpret_cast(mdb_value_ret.mv_data), - mdb_value_ret.mv_size)); + collectionData.setFromSerialized(reinterpret_cast(mdb_value_ret.mv_data), mdb_value_ret.mv_size); + if ((!collectionData.isExpired()) && (collectionData.hasValue())) { + ret = std::unique_ptr(new std::string(collectionData.getValue())); + } end_get: mdb_txn_abort(txn); end_txn: + // The read-only transaction is complete. Now we can do a delete if the item was expired. + if (collectionData.isExpired()) { + delIfExpired(var); + } return ret; } +void LMDB::setExpiry(const std::string &key, int32_t expiry_seconds) { + int rc; + MDB_txn *txn; + MDB_val mdb_key; + MDB_val mdb_value; + MDB_val mdb_value_ret; + CollectionData previous_data; + CollectionData new_data; + std::string serializedData; + + string2val(key, &mdb_key); + + rc = txn_begin(0, &txn); + lmdb_debug(rc, "txn", "setExpiry"); + if (rc != 0) { + goto end_txn; + } + + rc = mdb_get(txn, m_dbi, &mdb_key, &mdb_value_ret); + lmdb_debug(rc, "get", "setExpiry"); + if (rc == 0) { + previous_data.setFromSerialized(reinterpret_cast(mdb_value_ret.mv_data), mdb_value_ret.mv_size); + rc = mdb_del(txn, m_dbi, &mdb_key, &mdb_value_ret); + lmdb_debug(rc, "del", "setExpiry"); + if (rc != 0) { + goto end_del; + } + } + + if (previous_data.hasValue()) { + new_data = previous_data; + }; + new_data.setExpiry(expiry_seconds); + serializedData = new_data.getSerialized(); + string2val(serializedData, &mdb_value); + + rc = mdb_put(txn, m_dbi, &mdb_key, &mdb_value, 0); + lmdb_debug(rc, "put", "setExpiry"); + if (rc != 0) { + goto end_put; + } + + rc = mdb_txn_commit(txn); + lmdb_debug(rc, "commit", "setExpiry"); + if (rc != 0) { + goto end_commit; + } + +end_put: +end_del: + if (rc != 0) { + mdb_txn_abort(txn); + } +end_commit: +end_txn: + return; +} + +void LMDB::delIfExpired(const std::string& key) { + MDB_txn *txn; + MDB_val mdb_key; + MDB_val mdb_value_ret; + CollectionData collectionData; + + int rc = txn_begin(0, &txn); + lmdb_debug(rc, "txn", "del"); + if (rc != 0) { + goto end_txn; + } + + string2val(key, &mdb_key); + + rc = mdb_get(txn, m_dbi, &mdb_key, &mdb_value_ret); + lmdb_debug(rc, "get", "del"); + if (rc != 0) { + goto end_get; + } + + collectionData.setFromSerialized(reinterpret_cast(mdb_value_ret.mv_data), mdb_value_ret.mv_size); + if (collectionData.isExpired()) { + rc = mdb_del(txn, m_dbi, &mdb_key, &mdb_value_ret); + lmdb_debug(rc, "del", "del"); + if (rc != 0) { + goto end_del; + } + } + + rc = mdb_txn_commit(txn); + lmdb_debug(rc, "commit", "del"); + if (rc != 0) { + goto end_commit; + } + +end_del: +end_get: + if (rc != 0) { + mdb_txn_abort(txn); + } +end_commit: +end_txn: + return; +} + bool LMDB::storeOrUpdateFirst(const std::string &key, const std::string &value) { int rc; @@ -190,9 +300,11 @@ bool LMDB::storeOrUpdateFirst(const std::string &key, MDB_val mdb_key; MDB_val mdb_value; MDB_val mdb_value_ret; + CollectionData previous_data; + CollectionData new_data; + std::string serializedData; string2val(key, &mdb_key); - string2val(value, &mdb_value); rc = txn_begin(0, &txn); lmdb_debug(rc, "txn", "storeOrUpdateFirst"); @@ -203,6 +315,7 @@ bool LMDB::storeOrUpdateFirst(const std::string &key, rc = mdb_get(txn, m_dbi, &mdb_key, &mdb_value_ret); lmdb_debug(rc, "get", "storeOrUpdateFirst"); if (rc == 0) { + previous_data.setFromSerialized(reinterpret_cast(mdb_value_ret.mv_data), mdb_value_ret.mv_size); rc = mdb_del(txn, m_dbi, &mdb_key, &mdb_value_ret); lmdb_debug(rc, "del", "storeOrUpdateFirst"); if (rc != 0) { @@ -210,6 +323,13 @@ bool LMDB::storeOrUpdateFirst(const std::string &key, } } + if (previous_data.hasExpiry()) { + new_data = previous_data; + }; + new_data.setValue(value); + serializedData = new_data.getSerialized(); + string2val(serializedData, &mdb_value); + rc = mdb_put(txn, m_dbi, &mdb_key, &mdb_value, 0); lmdb_debug(rc, "put", "storeOrUpdateFirst"); if (rc != 0) { @@ -241,6 +361,8 @@ void LMDB::resolveSingleMatch(const std::string& var, MDB_val mdb_value; MDB_val mdb_value_ret; MDB_cursor *cursor; + CollectionData collectionData; + std::list expiredVars; rc = txn_begin(MDB_RDONLY, &txn); lmdb_debug(rc, "txn", "resolveSingleMatch"); @@ -253,14 +375,24 @@ void LMDB::resolveSingleMatch(const std::string& var, mdb_cursor_open(txn, m_dbi, &cursor); while ((rc = mdb_cursor_get(cursor, &mdb_key, &mdb_value_ret, MDB_NEXT_DUP)) == 0) { - std::string a(reinterpret_cast(mdb_value_ret.mv_data), mdb_value_ret.mv_size); - VariableValue *v = new VariableValue(&var, &a); + collectionData.setFromSerialized(reinterpret_cast(mdb_value_ret.mv_data), mdb_value_ret.mv_size); + if (collectionData.isExpired()) { + expiredVars.push_back(std::string(reinterpret_cast(mdb_key.mv_data), mdb_key.mv_size)); + continue; + } + if (!collectionData.hasValue()) { + continue; + } + VariableValue *v = new VariableValue(&var, &collectionData.getValue()); l->push_back(v); } mdb_cursor_close(cursor); mdb_txn_abort(txn); end_txn: + for (const auto& expiredVar : expiredVars) { + delIfExpired(expiredVar); + } return; } @@ -309,6 +441,9 @@ bool LMDB::updateFirst(const std::string &key, MDB_val mdb_key; MDB_val mdb_value; MDB_val mdb_value_ret; + CollectionData previous_data; + CollectionData new_data; + std::string serializedData; rc = txn_begin(0, &txn); lmdb_debug(rc, "txn", "updateFirst"); @@ -317,7 +452,6 @@ bool LMDB::updateFirst(const std::string &key, } string2val(key, &mdb_key); - string2val(value, &mdb_value); rc = mdb_get(txn, m_dbi, &mdb_key, &mdb_value_ret); lmdb_debug(rc, "get", "updateFirst"); @@ -325,12 +459,20 @@ bool LMDB::updateFirst(const std::string &key, goto end_get; } + previous_data.setFromSerialized(reinterpret_cast(mdb_value_ret.mv_data), mdb_value_ret.mv_size); rc = mdb_del(txn, m_dbi, &mdb_key, &mdb_value_ret); lmdb_debug(rc, "del", "updateFirst"); if (rc != 0) { goto end_del; } + if (previous_data.hasExpiry()) { + new_data = previous_data; + }; + new_data.setValue(value); + serializedData = new_data.getSerialized(); + string2val(serializedData, &mdb_value); + rc = mdb_put(txn, m_dbi, &mdb_key, &mdb_value, 0); lmdb_debug(rc, "put", "updateFirst"); if (rc != 0) { @@ -400,10 +542,6 @@ end_txn: return; } -void LMDB::setExpiry(const std::string& key, int32_t expiry_seconds) { - // TODO: add implementation -} - void LMDB::resolveMultiMatches(const std::string& var, std::vector *l, variables::KeyExclusions &ke) { @@ -413,6 +551,8 @@ void LMDB::resolveMultiMatches(const std::string& var, MDB_stat mst; size_t keySize = var.size(); MDB_cursor *cursor; + CollectionData collectionData; + std::list expiredVars; rc = txn_begin(MDB_RDONLY, &txn); lmdb_debug(rc, "txn", "resolveMultiMatches"); @@ -428,18 +568,36 @@ void LMDB::resolveMultiMatches(const std::string& var, if (keySize == 0) { while ((rc = mdb_cursor_get(cursor, &key, &data, MDB_NEXT)) == 0) { + collectionData.setFromSerialized(reinterpret_cast(data.mv_data), data.mv_size); + + if (collectionData.isExpired()) { + expiredVars.push_back(std::string(reinterpret_cast(key.mv_data), key.mv_size)); + continue; + } + if (!collectionData.hasValue()) { + continue; + } + std::string key_to_insert(reinterpret_cast(key.mv_data), key.mv_size); - std::string value_to_insert(reinterpret_cast(data.mv_data), data.mv_size); l->insert(l->begin(), new VariableValue( - &m_name, &key_to_insert, &value_to_insert)); + &m_name, &key_to_insert, &collectionData.getValue())); } } else { while ((rc = mdb_cursor_get(cursor, &key, &data, MDB_NEXT)) == 0) { + collectionData.setFromSerialized(reinterpret_cast(data.mv_data), data.mv_size); + + if (collectionData.isExpired()) { + expiredVars.push_back(std::string(reinterpret_cast(key.mv_data), key.mv_size)); + continue; + } + if (!collectionData.hasValue()) { + continue; + } + char *a = reinterpret_cast(key.mv_data); if (strncmp(var.c_str(), a, keySize) == 0) { std::string key_to_insert(reinterpret_cast(key.mv_data), key.mv_size); - std::string value_to_insert(reinterpret_cast(data.mv_data), data.mv_size); - l->insert(l->begin(), new VariableValue(&m_name, &key_to_insert, &value_to_insert)); + l->insert(l->begin(), new VariableValue(&m_name, &key_to_insert, &collectionData.getValue())); } } } @@ -448,6 +606,9 @@ void LMDB::resolveMultiMatches(const std::string& var, end_cursor_open: mdb_txn_abort(txn); end_txn: + for (const auto& expiredVar : expiredVars) { + delIfExpired(expiredVar); + } return; } @@ -460,6 +621,8 @@ void LMDB::resolveRegularExpression(const std::string& var, int rc; MDB_stat mst; MDB_cursor *cursor; + CollectionData collectionData; + std::list expiredVars; Utils::Regex r(var, true); @@ -476,6 +639,16 @@ void LMDB::resolveRegularExpression(const std::string& var, } while ((rc = mdb_cursor_get(cursor, &key, &data, MDB_NEXT)) == 0) { + collectionData.setFromSerialized(reinterpret_cast(data.mv_data), data.mv_size); + + if (collectionData.isExpired()) { + expiredVars.push_back(std::string(reinterpret_cast(key.mv_data), key.mv_size)); + continue; + } + if (!collectionData.hasValue()) { + continue; + } + std::string key_to_insert(reinterpret_cast(key.mv_data), key.mv_size); int ret = Utils::regex_search(key_to_insert, r); if (ret <= 0) { @@ -485,8 +658,7 @@ void LMDB::resolveRegularExpression(const std::string& var, continue; } - std::string value_to_insert(reinterpret_cast(data.mv_data), data.mv_size); - VariableValue *v = new VariableValue(&key_to_insert, &value_to_insert); + VariableValue *v = new VariableValue(&key_to_insert, &collectionData.getValue()); l->insert(l->begin(), v); } @@ -494,6 +666,9 @@ void LMDB::resolveRegularExpression(const std::string& var, end_cursor_open: mdb_txn_abort(txn); end_txn: + for (const auto& expiredVar : expiredVars) { + delIfExpired(expiredVar); + } return; } diff --git a/src/collection/backend/lmdb.h b/src/collection/backend/lmdb.h index 8c5f6480..11b4760a 100644 --- a/src/collection/backend/lmdb.h +++ b/src/collection/backend/lmdb.h @@ -126,6 +126,8 @@ class LMDB : void string2val(const std::string& str, MDB_val *val); void inline lmdb_debug(int rc, const std::string &op, const std::string &scope); + void delIfExpired(const std::string& key); + MDB_env *m_env; MDB_dbi m_dbi; bool isOpen; diff --git a/test/cppcheck_suppressions.txt b/test/cppcheck_suppressions.txt index 41e15b6d..87df1639 100644 --- a/test/cppcheck_suppressions.txt +++ b/test/cppcheck_suppressions.txt @@ -79,6 +79,7 @@ unreadVariable:src/operators/rx.cc unreadVariable:src/operators/rx_global.cc noExplicitConstructor:src/collection/backend/collection_data.h +stlIfStrFind:src/collection/backend/collection_data.cc unusedFunction missingIncludeSystem diff --git a/test/test-cases/regression/action-expirevar.json b/test/test-cases/regression/action-expirevar.json index f063ba1e..5c9d4ddf 100644 --- a/test/test-cases/regression/action-expirevar.json +++ b/test/test-cases/regression/action-expirevar.json @@ -2,7 +2,7 @@ { "enabled":1, "version_min":300000, - "title":"Testing expirevar action (1/x)", + "title":"Testing expirevar action (1/x) - ip, expire later", "expected":{ "debug_log": "Saving msg: mycount1 is 100" }, @@ -27,13 +27,44 @@ "SecRuleEngine On", "SecAction \"initcol:ip='127.0.0.1',id:5000,phase:1\"", "SecRule ARGS \"@rx value\" \"id:'5001',phase:2,setvar:ip.mycount1=100,expirevar:ip.mycount1=60,pass\"", - "SecRule &IP:mycount1 \"@eq 1\" \"id:'5002',phase:2,pass,log,msg:'mycount1 is %{ip.mycount1}\"" + "SecRule &IP:mycount1 \"@eq 1\" \"id:'5002',phase:2,pass,log,msg:'mycount1 is %{ip.mycount1}'\"" ] }, { "enabled":1, "version_min":300000, - "title":"Testing expirevar action (2/x)", + "title":"Testing expirevar action (2/x) - ip, expire immediately", + "expected":{ + "debug_log": "Saving msg: mycount1 is " + }, + "client":{ + "ip":"200.249.12.31", + "port":123 + }, + "request":{ + "headers":{ + "Host":"localhost", + "User-Agent":"curl/7.38.0", + "Accept":"*/*" + }, + "uri":"/?key=value", + "method":"GET" + }, + "server":{ + "ip":"200.249.12.31", + "port":80 + }, + "rules":[ + "SecRuleEngine On", + "SecAction \"initcol:ip='127.0.0.1',id:5010,phase:1\"", + "SecRule ARGS \"@rx value\" \"id:'5011',phase:2,setvar:ip.mycount1=100,expirevar:ip.mycount1=0,pass\"", + "SecRule &IP:mycount1 \"@eq 0\" \"id:'5012',phase:2,pass,log,msg:'mycount1 is %{ip.mycount1}'\"" + ] + }, + { + "enabled":1, + "version_min":300000, + "title":"Testing expirevar action (3/x) session, expire later", "expected":{ "debug_log": "Saving msg: mycount1 is 12" }, @@ -58,7 +89,38 @@ "SecRuleEngine On", "SecRule ARGS \"@rx .\" \"id:5150,phase:2,pass,setsid:sess1234\"", "SecRule ARGS \"@rx value\" \"id:5151,phase:2,pass,setvar:session.mycount1=12,expirevar:session.mycount1=30\"", - "SecRule &SESSION:mycount1 \"@eq 1\" \"id:'5152',phase:2,pass,log,msg:'mycount1 is %{session.mycount1}\"" + "SecRule &SESSION:mycount1 \"@eq 1\" \"id:'5152',phase:2,pass,log,msg:'mycount1 is %{session.mycount1}'\"" + ] + }, + { + "enabled":1, + "version_min":300000, + "title":"Testing expirevar action (4/x) session, expire immediately", + "expected":{ + "debug_log": "Saving msg: mycount1 is" + }, + "client":{ + "ip":"200.249.12.31", + "port":123 + }, + "request":{ + "headers":{ + "Host":"localhost", + "User-Agent":"curl/7.38.0", + "Accept":"*/*" + }, + "uri":"/?key=value", + "method":"GET" + }, + "server":{ + "ip":"200.249.12.31", + "port":80 + }, + "rules":[ + "SecRuleEngine On", + "SecRule ARGS \"@rx .\" \"id:5150,phase:2,pass,setsid:sess1234\"", + "SecRule ARGS \"@rx value\" \"id:5151,phase:2,pass,setvar:session.mycount1=12,expirevar:session.mycount1=0\"", + "SecRule &SESSION:mycount1 \"@eq 0\" \"id:'5152',phase:2,pass,log,msg:'mycount1 is %{session.mycount1}'\"" ] } ]