Add expirevar support for lmdb

This commit is contained in:
Martin Vierula 2023-10-10 10:31:52 -07:00
parent 118e1b3a44
commit 34809d8064
No known key found for this signature in database
GPG Key ID: F2FC4E45883BCBA4
6 changed files with 363 additions and 24 deletions

View File

@ -35,6 +35,105 @@ void CollectionData::setExpiry(int32_t seconds_until_expiry) {
m_hasExpiryTime = true;
}
std::string CollectionData::getSerialized() const {
std::string serialized;
if (hasValue()) {
serialized.reserve(30 + 10 + getValue().size());
} else {
serialized.reserve(16+10);
}
serialized.assign("{");
if (hasExpiry()) {
serialized.append("\"__expire_\":");
uint64_t expiryEpochSeconds = std::chrono::duration_cast<std::chrono::seconds>(m_expiryTime.time_since_epoch()).count();
serialized.append(std::to_string(expiryEpochSeconds));
if (hasValue()) {
serialized.append(",");
}
}
if (hasValue()) {
serialized.append("\"__value_\":\"");
serialized.append(getValue());
serialized.append("\"");
}
serialized.append("}");
return serialized;
}
void CollectionData::setFromSerialized(const char* serializedData, size_t length) {
const static std::string expiryPrefix("\"__expire_\":");
const static std::string valuePrefix("\"__value_\":\"");
m_hasValue = false;
m_hasExpiryTime = false;
std::string serializedString(serializedData, length);
if ((serializedString.find("{") == 0) && (serializedString.substr(serializedString.length()-1) == "}")) {
size_t currentPos = 1;
uint64_t expiryEpochSeconds = 0;
bool invalidSerializedFormat = false;
bool doneParsing = false;
// Extract the expiry time, if it exists
if (serializedString.find(expiryPrefix, currentPos) == currentPos) {
currentPos += expiryPrefix.length();
std::string expiryDigits = serializedString.substr(currentPos, 10);
if (expiryDigits.find_first_not_of("0123456789") == std::string::npos) {
expiryEpochSeconds = strtoll(expiryDigits.c_str(), NULL, 10);
} else {
invalidSerializedFormat = true;
}
currentPos += 10;
}
if ((!invalidSerializedFormat) && (expiryEpochSeconds > 0)) {
if (serializedString.find(",", currentPos) == currentPos) {
currentPos++;
} else if (currentPos == serializedString.length()-1) {
doneParsing = true;
} else {
invalidSerializedFormat = true;
}
}
if ((!invalidSerializedFormat) && (!doneParsing)) {
// Extract the value
if ((serializedString.find(valuePrefix, currentPos) == currentPos)) {
currentPos += valuePrefix.length();
size_t expectedCloseQuotePos = serializedString.length() - 2;
if ((serializedString.substr(expectedCloseQuotePos, 1) == "\"") && (expectedCloseQuotePos >= currentPos)) {
m_value = serializedString.substr(currentPos);
m_value.resize(m_value.length()-2);
m_hasValue = true;
} else {
invalidSerializedFormat = true;
}
} else {
invalidSerializedFormat = true;
}
}
// Set the object's expiry time, if we found one
if ((!invalidSerializedFormat) && (expiryEpochSeconds > 0)) {
std::chrono::seconds expiryDuration(expiryEpochSeconds);
std::chrono::system_clock::time_point expiryTimePoint(expiryDuration);
m_expiryTime = expiryTimePoint;
m_hasExpiryTime = true;
}
if (!invalidSerializedFormat) {
return;
}
}
// this is the residual case; the entire string is a simple value (not JSON-ish encoded)
// the foreseen case here is lmdb content from prior to the serialization support
m_value.assign(serializedData, length);
m_hasValue = true;
return;
}
} // namespace backend
} // namespace collection

View File

@ -16,13 +16,10 @@
#ifdef __cplusplus
#include <string>
#include <memory>
#include <chrono>
#endif
#include "modsecurity/collection/collection.h"
#ifndef SRC_COLLECTION_DATA_H_
#define SRC_COLLECTION_DATA_H_
@ -53,6 +50,9 @@ public:
bool hasExpiry() const { return m_hasExpiryTime;}
bool isExpired() const;
std::string getSerialized() const;
void setFromSerialized(const char* serializedData, size_t length);
private:
bool m_hasValue;
bool m_hasExpiryTime;

View File

@ -15,6 +15,7 @@
#include "src/collection/backend/lmdb.h"
#include "src/collection/backend/collection_data.h"
#include <sys/types.h>
#include <unistd.h>
@ -158,6 +159,7 @@ std::unique_ptr<std::string> LMDB::resolveFirst(const std::string& var) {
MDB_val mdb_value_ret;
std::unique_ptr<std::string> ret = NULL;
MDB_txn *txn = NULL;
CollectionData collectionData;
string2val(var, &mdb_key);
@ -172,17 +174,125 @@ std::unique_ptr<std::string> LMDB::resolveFirst(const std::string& var) {
goto end_get;
}
ret = std::unique_ptr<std::string>(new std::string(
reinterpret_cast<char *>(mdb_value_ret.mv_data),
mdb_value_ret.mv_size));
collectionData.setFromSerialized(reinterpret_cast<char *>(mdb_value_ret.mv_data), mdb_value_ret.mv_size);
if ((!collectionData.isExpired()) && (collectionData.hasValue())) {
ret = std::unique_ptr<std::string>(new std::string(collectionData.getValue()));
}
end_get:
mdb_txn_abort(txn);
end_txn:
// The read-only transaction is complete. Now we can do a delete if the item was expired.
if (collectionData.isExpired()) {
delIfExpired(var);
}
return ret;
}
void LMDB::setExpiry(const std::string &key, int32_t expiry_seconds) {
int rc;
MDB_txn *txn;
MDB_val mdb_key;
MDB_val mdb_value;
MDB_val mdb_value_ret;
CollectionData previous_data;
CollectionData new_data;
std::string serializedData;
string2val(key, &mdb_key);
rc = txn_begin(0, &txn);
lmdb_debug(rc, "txn", "setExpiry");
if (rc != 0) {
goto end_txn;
}
rc = mdb_get(txn, m_dbi, &mdb_key, &mdb_value_ret);
lmdb_debug(rc, "get", "setExpiry");
if (rc == 0) {
previous_data.setFromSerialized(reinterpret_cast<char *>(mdb_value_ret.mv_data), mdb_value_ret.mv_size);
rc = mdb_del(txn, m_dbi, &mdb_key, &mdb_value_ret);
lmdb_debug(rc, "del", "setExpiry");
if (rc != 0) {
goto end_del;
}
}
if (previous_data.hasValue()) {
new_data = previous_data;
};
new_data.setExpiry(expiry_seconds);
serializedData = new_data.getSerialized();
string2val(serializedData, &mdb_value);
rc = mdb_put(txn, m_dbi, &mdb_key, &mdb_value, 0);
lmdb_debug(rc, "put", "setExpiry");
if (rc != 0) {
goto end_put;
}
rc = mdb_txn_commit(txn);
lmdb_debug(rc, "commit", "setExpiry");
if (rc != 0) {
goto end_commit;
}
end_put:
end_del:
if (rc != 0) {
mdb_txn_abort(txn);
}
end_commit:
end_txn:
return;
}
void LMDB::delIfExpired(const std::string& key) {
MDB_txn *txn;
MDB_val mdb_key;
MDB_val mdb_value_ret;
CollectionData collectionData;
int rc = txn_begin(0, &txn);
lmdb_debug(rc, "txn", "del");
if (rc != 0) {
goto end_txn;
}
string2val(key, &mdb_key);
rc = mdb_get(txn, m_dbi, &mdb_key, &mdb_value_ret);
lmdb_debug(rc, "get", "del");
if (rc != 0) {
goto end_get;
}
collectionData.setFromSerialized(reinterpret_cast<char *>(mdb_value_ret.mv_data), mdb_value_ret.mv_size);
if (collectionData.isExpired()) {
rc = mdb_del(txn, m_dbi, &mdb_key, &mdb_value_ret);
lmdb_debug(rc, "del", "del");
if (rc != 0) {
goto end_del;
}
}
rc = mdb_txn_commit(txn);
lmdb_debug(rc, "commit", "del");
if (rc != 0) {
goto end_commit;
}
end_del:
end_get:
if (rc != 0) {
mdb_txn_abort(txn);
}
end_commit:
end_txn:
return;
}
bool LMDB::storeOrUpdateFirst(const std::string &key,
const std::string &value) {
int rc;
@ -190,9 +300,11 @@ bool LMDB::storeOrUpdateFirst(const std::string &key,
MDB_val mdb_key;
MDB_val mdb_value;
MDB_val mdb_value_ret;
CollectionData previous_data;
CollectionData new_data;
std::string serializedData;
string2val(key, &mdb_key);
string2val(value, &mdb_value);
rc = txn_begin(0, &txn);
lmdb_debug(rc, "txn", "storeOrUpdateFirst");
@ -203,6 +315,7 @@ bool LMDB::storeOrUpdateFirst(const std::string &key,
rc = mdb_get(txn, m_dbi, &mdb_key, &mdb_value_ret);
lmdb_debug(rc, "get", "storeOrUpdateFirst");
if (rc == 0) {
previous_data.setFromSerialized(reinterpret_cast<char *>(mdb_value_ret.mv_data), mdb_value_ret.mv_size);
rc = mdb_del(txn, m_dbi, &mdb_key, &mdb_value_ret);
lmdb_debug(rc, "del", "storeOrUpdateFirst");
if (rc != 0) {
@ -210,6 +323,13 @@ bool LMDB::storeOrUpdateFirst(const std::string &key,
}
}
if (previous_data.hasExpiry()) {
new_data = previous_data;
};
new_data.setValue(value);
serializedData = new_data.getSerialized();
string2val(serializedData, &mdb_value);
rc = mdb_put(txn, m_dbi, &mdb_key, &mdb_value, 0);
lmdb_debug(rc, "put", "storeOrUpdateFirst");
if (rc != 0) {
@ -241,6 +361,8 @@ void LMDB::resolveSingleMatch(const std::string& var,
MDB_val mdb_value;
MDB_val mdb_value_ret;
MDB_cursor *cursor;
CollectionData collectionData;
std::list<std::string> expiredVars;
rc = txn_begin(MDB_RDONLY, &txn);
lmdb_debug(rc, "txn", "resolveSingleMatch");
@ -253,14 +375,24 @@ void LMDB::resolveSingleMatch(const std::string& var,
mdb_cursor_open(txn, m_dbi, &cursor);
while ((rc = mdb_cursor_get(cursor, &mdb_key,
&mdb_value_ret, MDB_NEXT_DUP)) == 0) {
std::string a(reinterpret_cast<char *>(mdb_value_ret.mv_data), mdb_value_ret.mv_size);
VariableValue *v = new VariableValue(&var, &a);
collectionData.setFromSerialized(reinterpret_cast<char *>(mdb_value_ret.mv_data), mdb_value_ret.mv_size);
if (collectionData.isExpired()) {
expiredVars.push_back(std::string(reinterpret_cast<char *>(mdb_key.mv_data), mdb_key.mv_size));
continue;
}
if (!collectionData.hasValue()) {
continue;
}
VariableValue *v = new VariableValue(&var, &collectionData.getValue());
l->push_back(v);
}
mdb_cursor_close(cursor);
mdb_txn_abort(txn);
end_txn:
for (const auto& expiredVar : expiredVars) {
delIfExpired(expiredVar);
}
return;
}
@ -309,6 +441,9 @@ bool LMDB::updateFirst(const std::string &key,
MDB_val mdb_key;
MDB_val mdb_value;
MDB_val mdb_value_ret;
CollectionData previous_data;
CollectionData new_data;
std::string serializedData;
rc = txn_begin(0, &txn);
lmdb_debug(rc, "txn", "updateFirst");
@ -317,7 +452,6 @@ bool LMDB::updateFirst(const std::string &key,
}
string2val(key, &mdb_key);
string2val(value, &mdb_value);
rc = mdb_get(txn, m_dbi, &mdb_key, &mdb_value_ret);
lmdb_debug(rc, "get", "updateFirst");
@ -325,12 +459,20 @@ bool LMDB::updateFirst(const std::string &key,
goto end_get;
}
previous_data.setFromSerialized(reinterpret_cast<char *>(mdb_value_ret.mv_data), mdb_value_ret.mv_size);
rc = mdb_del(txn, m_dbi, &mdb_key, &mdb_value_ret);
lmdb_debug(rc, "del", "updateFirst");
if (rc != 0) {
goto end_del;
}
if (previous_data.hasExpiry()) {
new_data = previous_data;
};
new_data.setValue(value);
serializedData = new_data.getSerialized();
string2val(serializedData, &mdb_value);
rc = mdb_put(txn, m_dbi, &mdb_key, &mdb_value, 0);
lmdb_debug(rc, "put", "updateFirst");
if (rc != 0) {
@ -400,10 +542,6 @@ end_txn:
return;
}
void LMDB::setExpiry(const std::string& key, int32_t expiry_seconds) {
// TODO: add implementation
}
void LMDB::resolveMultiMatches(const std::string& var,
std::vector<const VariableValue *> *l,
variables::KeyExclusions &ke) {
@ -413,6 +551,8 @@ void LMDB::resolveMultiMatches(const std::string& var,
MDB_stat mst;
size_t keySize = var.size();
MDB_cursor *cursor;
CollectionData collectionData;
std::list<std::string> expiredVars;
rc = txn_begin(MDB_RDONLY, &txn);
lmdb_debug(rc, "txn", "resolveMultiMatches");
@ -428,18 +568,36 @@ void LMDB::resolveMultiMatches(const std::string& var,
if (keySize == 0) {
while ((rc = mdb_cursor_get(cursor, &key, &data, MDB_NEXT)) == 0) {
collectionData.setFromSerialized(reinterpret_cast<char *>(data.mv_data), data.mv_size);
if (collectionData.isExpired()) {
expiredVars.push_back(std::string(reinterpret_cast<char *>(key.mv_data), key.mv_size));
continue;
}
if (!collectionData.hasValue()) {
continue;
}
std::string key_to_insert(reinterpret_cast<char *>(key.mv_data), key.mv_size);
std::string value_to_insert(reinterpret_cast<char *>(data.mv_data), data.mv_size);
l->insert(l->begin(), new VariableValue(
&m_name, &key_to_insert, &value_to_insert));
&m_name, &key_to_insert, &collectionData.getValue()));
}
} else {
while ((rc = mdb_cursor_get(cursor, &key, &data, MDB_NEXT)) == 0) {
collectionData.setFromSerialized(reinterpret_cast<char *>(data.mv_data), data.mv_size);
if (collectionData.isExpired()) {
expiredVars.push_back(std::string(reinterpret_cast<char *>(key.mv_data), key.mv_size));
continue;
}
if (!collectionData.hasValue()) {
continue;
}
char *a = reinterpret_cast<char *>(key.mv_data);
if (strncmp(var.c_str(), a, keySize) == 0) {
std::string key_to_insert(reinterpret_cast<char *>(key.mv_data), key.mv_size);
std::string value_to_insert(reinterpret_cast<char *>(data.mv_data), data.mv_size);
l->insert(l->begin(), new VariableValue(&m_name, &key_to_insert, &value_to_insert));
l->insert(l->begin(), new VariableValue(&m_name, &key_to_insert, &collectionData.getValue()));
}
}
}
@ -448,6 +606,9 @@ void LMDB::resolveMultiMatches(const std::string& var,
end_cursor_open:
mdb_txn_abort(txn);
end_txn:
for (const auto& expiredVar : expiredVars) {
delIfExpired(expiredVar);
}
return;
}
@ -460,6 +621,8 @@ void LMDB::resolveRegularExpression(const std::string& var,
int rc;
MDB_stat mst;
MDB_cursor *cursor;
CollectionData collectionData;
std::list<std::string> expiredVars;
Utils::Regex r(var, true);
@ -476,6 +639,16 @@ void LMDB::resolveRegularExpression(const std::string& var,
}
while ((rc = mdb_cursor_get(cursor, &key, &data, MDB_NEXT)) == 0) {
collectionData.setFromSerialized(reinterpret_cast<char *>(data.mv_data), data.mv_size);
if (collectionData.isExpired()) {
expiredVars.push_back(std::string(reinterpret_cast<char *>(key.mv_data), key.mv_size));
continue;
}
if (!collectionData.hasValue()) {
continue;
}
std::string key_to_insert(reinterpret_cast<char *>(key.mv_data), key.mv_size);
int ret = Utils::regex_search(key_to_insert, r);
if (ret <= 0) {
@ -485,8 +658,7 @@ void LMDB::resolveRegularExpression(const std::string& var,
continue;
}
std::string value_to_insert(reinterpret_cast<char *>(data.mv_data), data.mv_size);
VariableValue *v = new VariableValue(&key_to_insert, &value_to_insert);
VariableValue *v = new VariableValue(&key_to_insert, &collectionData.getValue());
l->insert(l->begin(), v);
}
@ -494,6 +666,9 @@ void LMDB::resolveRegularExpression(const std::string& var,
end_cursor_open:
mdb_txn_abort(txn);
end_txn:
for (const auto& expiredVar : expiredVars) {
delIfExpired(expiredVar);
}
return;
}

View File

@ -126,6 +126,8 @@ class LMDB :
void string2val(const std::string& str, MDB_val *val);
void inline lmdb_debug(int rc, const std::string &op, const std::string &scope);
void delIfExpired(const std::string& key);
MDB_env *m_env;
MDB_dbi m_dbi;
bool isOpen;

View File

@ -79,6 +79,7 @@ unreadVariable:src/operators/rx.cc
unreadVariable:src/operators/rx_global.cc
noExplicitConstructor:src/collection/backend/collection_data.h
stlIfStrFind:src/collection/backend/collection_data.cc
unusedFunction
missingIncludeSystem

View File

@ -2,7 +2,7 @@
{
"enabled":1,
"version_min":300000,
"title":"Testing expirevar action (1/x)",
"title":"Testing expirevar action (1/x) - ip, expire later",
"expected":{
"debug_log": "Saving msg: mycount1 is 100"
},
@ -27,13 +27,44 @@
"SecRuleEngine On",
"SecAction \"initcol:ip='127.0.0.1',id:5000,phase:1\"",
"SecRule ARGS \"@rx value\" \"id:'5001',phase:2,setvar:ip.mycount1=100,expirevar:ip.mycount1=60,pass\"",
"SecRule &IP:mycount1 \"@eq 1\" \"id:'5002',phase:2,pass,log,msg:'mycount1 is %{ip.mycount1}\""
"SecRule &IP:mycount1 \"@eq 1\" \"id:'5002',phase:2,pass,log,msg:'mycount1 is %{ip.mycount1}'\""
]
},
{
"enabled":1,
"version_min":300000,
"title":"Testing expirevar action (2/x)",
"title":"Testing expirevar action (2/x) - ip, expire immediately",
"expected":{
"debug_log": "Saving msg: mycount1 is "
},
"client":{
"ip":"200.249.12.31",
"port":123
},
"request":{
"headers":{
"Host":"localhost",
"User-Agent":"curl/7.38.0",
"Accept":"*/*"
},
"uri":"/?key=value",
"method":"GET"
},
"server":{
"ip":"200.249.12.31",
"port":80
},
"rules":[
"SecRuleEngine On",
"SecAction \"initcol:ip='127.0.0.1',id:5010,phase:1\"",
"SecRule ARGS \"@rx value\" \"id:'5011',phase:2,setvar:ip.mycount1=100,expirevar:ip.mycount1=0,pass\"",
"SecRule &IP:mycount1 \"@eq 0\" \"id:'5012',phase:2,pass,log,msg:'mycount1 is %{ip.mycount1}'\""
]
},
{
"enabled":1,
"version_min":300000,
"title":"Testing expirevar action (3/x) session, expire later",
"expected":{
"debug_log": "Saving msg: mycount1 is 12"
},
@ -58,7 +89,38 @@
"SecRuleEngine On",
"SecRule ARGS \"@rx .\" \"id:5150,phase:2,pass,setsid:sess1234\"",
"SecRule ARGS \"@rx value\" \"id:5151,phase:2,pass,setvar:session.mycount1=12,expirevar:session.mycount1=30\"",
"SecRule &SESSION:mycount1 \"@eq 1\" \"id:'5152',phase:2,pass,log,msg:'mycount1 is %{session.mycount1}\""
"SecRule &SESSION:mycount1 \"@eq 1\" \"id:'5152',phase:2,pass,log,msg:'mycount1 is %{session.mycount1}'\""
]
},
{
"enabled":1,
"version_min":300000,
"title":"Testing expirevar action (4/x) session, expire immediately",
"expected":{
"debug_log": "Saving msg: mycount1 is"
},
"client":{
"ip":"200.249.12.31",
"port":123
},
"request":{
"headers":{
"Host":"localhost",
"User-Agent":"curl/7.38.0",
"Accept":"*/*"
},
"uri":"/?key=value",
"method":"GET"
},
"server":{
"ip":"200.249.12.31",
"port":80
},
"rules":[
"SecRuleEngine On",
"SecRule ARGS \"@rx .\" \"id:5150,phase:2,pass,setsid:sess1234\"",
"SecRule ARGS \"@rx value\" \"id:5151,phase:2,pass,setvar:session.mycount1=12,expirevar:session.mycount1=0\"",
"SecRule &SESSION:mycount1 \"@eq 0\" \"id:'5152',phase:2,pass,log,msg:'mycount1 is %{session.mycount1}'\""
]
}
]