Added support to run unit tests in a multithreaded context

- This is controlled by specifying the 'mtstress' argument when running
  `unit_test`.
- The goal is to detect if the operator/transformation  fails in this
  context.
- In this mode, the test will be executed 5'000 times in 50 threads
  concurrently.
- Allocation & initialization of the operator/transformation is
  performed once in the main thread, while the evaluation is executed in
  the threads.
  - This is consistent with the library's support for multithreading,
    where initialization and loading of rules is expected to run once.
    See issue #3215.
This commit is contained in:
Eduardo Arias 2024-08-09 06:54:35 -07:00
parent 7bdc3c825c
commit ee5f95eb04
5 changed files with 165 additions and 69 deletions

View File

@ -93,13 +93,13 @@ bool ModSecurityTest<T>::load_test_json(const std::string &file) {
template <class T> template <class T>
std::pair<std::string, std::vector<T *>>* void
ModSecurityTest<T>::load_tests(const std::string &path) { ModSecurityTest<T>::load_tests(const std::string &path) {
DIR *dir; DIR *dir;
struct dirent *ent; struct dirent *ent;
struct stat buffer; struct stat buffer;
if ((dir = opendir(path.c_str())) == NULL) { if ((dir = opendir(path.c_str())) == nullptr) {
/* if target is a file, use it as a single test. */ /* if target is a file, use it as a single test. */
if (stat(path.c_str(), &buffer) == 0) { if (stat(path.c_str(), &buffer) == 0) {
if (load_test_json(path) == false) { if (load_test_json(path) == false) {
@ -107,10 +107,10 @@ ModSecurityTest<T>::load_tests(const std::string &path) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
return NULL; return;
} }
while ((ent = readdir(dir)) != NULL) { while ((ent = readdir(dir)) != nullptr) {
std::string filename = ent->d_name; std::string filename = ent->d_name;
std::string json = ".json"; std::string json = ".json";
if (filename.size() < json.size() if (filename.size() < json.size()
@ -123,16 +123,15 @@ ModSecurityTest<T>::load_tests(const std::string &path) {
} }
} }
closedir(dir); closedir(dir);
return NULL;
} }
template <class T> template <class T>
std::pair<std::string, std::vector<T *>>* ModSecurityTest<T>::load_tests() { void ModSecurityTest<T>::load_tests() {
return load_tests(this->target); load_tests(this->target);
} }
template <class T> template <class T>
void ModSecurityTest<T>::cmd_options(int argc, char **argv) { void ModSecurityTest<T>::cmd_options(int argc, char **argv) {
int i = 1; int i = 1;
@ -144,6 +143,10 @@ void ModSecurityTest<T>::cmd_options(int argc, char **argv) {
i++; i++;
m_count_all = true; m_count_all = true;
} }
if (argc > i && strcmp(argv[i], "mtstress") == 0) {
i++;
m_test_multithreaded = true;
}
if (std::getenv("AUTOMAKE_TESTS")) { if (std::getenv("AUTOMAKE_TESTS")) {
m_automake_output = true; m_automake_output = true;
} }

View File

@ -34,12 +34,13 @@ template <class T> class ModSecurityTest :
ModSecurityTest() ModSecurityTest()
: m_test_number(0), : m_test_number(0),
m_automake_output(false), m_automake_output(false),
m_count_all(false) { } m_count_all(false),
m_test_multithreaded(false) { }
std::string header(); std::string header();
void cmd_options(int, char **); void cmd_options(int, char **);
std::pair<std::string, std::vector<T *>>* load_tests(); void load_tests();
std::pair<std::string, std::vector<T *>>* load_tests(const std::string &path); void load_tests(const std::string &path);
bool load_test_json(const std::string &file); bool load_test_json(const std::string &file);
std::string target; std::string target;
@ -48,6 +49,7 @@ template <class T> class ModSecurityTest :
int m_test_number; int m_test_number;
bool m_automake_output; bool m_automake_output;
bool m_count_all; bool m_count_all;
bool m_test_multithreaded;
}; };
} // namespace modsecurity_test } // namespace modsecurity_test

View File

@ -15,7 +15,9 @@
#include <string.h> #include <string.h>
#include <cstring> #include <cstring>
#include <cassert>
#include <thread>
#include <array>
#include <iostream> #include <iostream>
#include <ctime> #include <ctime>
#include <string> #include <string>
@ -38,6 +40,7 @@
using modsecurity_test::UnitTest; using modsecurity_test::UnitTest;
using modsecurity_test::UnitTestResult;
using modsecurity_test::ModSecurityTest; using modsecurity_test::ModSecurityTest;
using modsecurity_test::ModSecurityTestResults; using modsecurity_test::ModSecurityTestResults;
using modsecurity::actions::transformations::Transformation; using modsecurity::actions::transformations::Transformation;
@ -53,64 +56,149 @@ void print_help() {
} }
void perform_unit_test(ModSecurityTest<UnitTest> *test, UnitTest *t, struct OperatorTest {
ModSecurityTestResults<UnitTest>* res) { using ItemType = Operator;
std::string error;
static ItemType* init(const UnitTest &t) {
auto op = Operator::instantiate(t.name, t.param);
assert(op != nullptr);
std::string error;
op->init(t.filename, &error);
return op;
}
static UnitTestResult eval(ItemType &op, const UnitTest &t) {
return {op.evaluate(nullptr, nullptr, t.input, nullptr), {}};
}
static bool check(const UnitTestResult &result, const UnitTest &t) {
return result.ret != t.ret;
}
};
struct TransformationTest {
using ItemType = Transformation;
static ItemType* init(const UnitTest &t) {
auto tfn = Transformation::instantiate("t:" + t.name);
assert(tfn != nullptr);
return tfn;
}
static UnitTestResult eval(ItemType &tfn, const UnitTest &t) {
return {1, tfn.evaluate(t.input, nullptr)};
}
static bool check(const UnitTestResult &result, const UnitTest &t) {
return result.output != t.output;
}
};
template<typename TestType>
UnitTestResult perform_unit_test_once(const UnitTest &t) {
std::unique_ptr<typename TestType::ItemType> item(TestType::init(t));
assert(item.get() != nullptr);
return TestType::eval(*item.get(), t);
}
template<typename TestType>
UnitTestResult perform_unit_test_multithreaded(const UnitTest &t) {
constexpr auto NUM_THREADS = 50;
constexpr auto ITERATIONS = 5'000;
std::array<std::thread, NUM_THREADS> threads;
std::array<UnitTestResult, NUM_THREADS> results;
std::unique_ptr<typename TestType::ItemType> item(TestType::init(t));
assert(item.get() != nullptr);
for (auto i = 0; i != threads.size(); ++i)
{
auto &result = results[i];
threads[i] = std::thread(
[&item, &t, &result]()
{
for (auto j = 0; j != ITERATIONS; ++j)
result = TestType::eval(*item.get(), t);
});
}
UnitTestResult ret;
for (auto i = 0; i != threads.size(); ++i)
{
threads[i].join();
if (TestType::check(results[i], t))
ret = results[i]; // error value, keep iterating to join all threads
else if(i == 0)
ret = results[i]; // initial value
}
return ret; // cppcheck-suppress uninitvar ; false positive, ret assigned at least once in previous loop
}
template<typename TestType>
void perform_unit_test_helper(const ModSecurityTest<UnitTest> &test, UnitTest &t,
ModSecurityTestResults<UnitTest> &res) {
if (!test.m_test_multithreaded)
t.result = perform_unit_test_once<TestType>(t);
else
t.result = perform_unit_test_multithreaded<TestType>(t);
if (TestType::check(t.result, t)) {
res.push_back(&t);
if (test.m_automake_output) {
std::cout << "FAIL ";
}
} else if (test.m_automake_output) {
std::cout << "PASS ";
}
}
void perform_unit_test(const ModSecurityTest<UnitTest> &test, UnitTest &t,
ModSecurityTestResults<UnitTest> &res) {
bool found = true; bool found = true;
if (test->m_automake_output) { if (test.m_automake_output) {
std::cout << ":test-result: "; std::cout << ":test-result: ";
} }
if (t->resource.empty() == false) { if (t.resource.empty() == false) {
found = (std::find(resources.begin(), resources.end(), t->resource) found = std::find(resources.begin(), resources.end(), t.resource)
!= resources.end()); != resources.end();
} }
if (!found) { if (!found) {
t->skipped = true; t.skipped = true;
res->push_back(t); res.push_back(&t);
if (test->m_automake_output) { if (test.m_automake_output) {
std::cout << "SKIP "; std::cout << "SKIP ";
} }
} }
if (t->type == "op") { if (t.type == "op") {
Operator *op = Operator::instantiate(t->name, t->param); perform_unit_test_helper<OperatorTest>(test, t, res);
op->init(t->filename, &error); } else if (t.type == "tfn") {
int ret = op->evaluate(NULL, NULL, t->input, NULL); perform_unit_test_helper<TransformationTest>(test, t, res);
t->obtained = ret;
if (ret != t->ret) {
res->push_back(t);
if (test->m_automake_output) {
std::cout << "FAIL ";
}
} else if (test->m_automake_output) {
std::cout << "PASS ";
}
delete op;
} else if (t->type == "tfn") {
Transformation *tfn = Transformation::instantiate("t:" + t->name);
std::string ret = tfn->evaluate(t->input, NULL);
t->obtained = 1;
t->obtainedOutput = ret;
if (ret != t->output) {
res->push_back(t);
if (test->m_automake_output) {
std::cout << "FAIL ";
}
} else if (test->m_automake_output) {
std::cout << "PASS ";
}
delete tfn;
} else { } else {
std::cerr << "Failed. Test type is unknown: << " << t->type; std::cerr << "Failed. Test type is unknown: << " << t.type;
std::cerr << std::endl; std::cerr << std::endl;
} }
if (test->m_automake_output) { if (test.m_automake_output) {
std::cout << t->name << " " std::cout << t.name << " "
<< modsecurity::utils::string::toHexIfNeeded(t->input) << modsecurity::utils::string::toHexIfNeeded(t.input)
<< std::endl; << std::endl;
} }
} }
@ -151,17 +239,15 @@ int main(int argc, char **argv) {
test.load_tests("test-cases/secrules-language-tests/transformations"); test.load_tests("test-cases/secrules-language-tests/transformations");
} }
for (std::pair<std::string, std::vector<UnitTest *> *> a : test) { for (auto& [filename, tests] : test) {
std::vector<UnitTest *> *tests = a.second;
total += tests->size(); total += tests->size();
for (UnitTest *t : *tests) { for (auto t : *tests) {
ModSecurityTestResults<UnitTest> r; ModSecurityTestResults<UnitTest> r;
if (!test.m_automake_output) { if (!test.m_automake_output) {
std::cout << " " << a.first << "...\t"; std::cout << " " << filename << "...\t";
} }
perform_unit_test(&test, t, &r); perform_unit_test(test, *t, r);
if (!test.m_automake_output) { if (!test.m_automake_output) {
int skp = 0; int skp = 0;
@ -191,7 +277,7 @@ int main(int argc, char **argv) {
std::cout << "Total >> " << total << std::endl; std::cout << "Total >> " << total << std::endl;
} }
for (UnitTest *t : results) { for (const auto t : results) {
std::cout << t->print() << std::endl; std::cout << t->print() << std::endl;
} }
@ -216,8 +302,8 @@ int main(int argc, char **argv) {
} }
for (auto a : test) { for (auto a : test) {
auto *vec = a.second; auto vec = a.second;
for(auto *t : *vec) for(auto t : *vec)
delete t; delete t;
delete vec; delete vec;
} }

View File

@ -102,15 +102,15 @@ std::string UnitTest::print() {
i << " \"param\": \"" << this->param << "\"" << std::endl; i << " \"param\": \"" << this->param << "\"" << std::endl;
i << " \"output\": \"" << this->output << "\"" << std::endl; i << " \"output\": \"" << this->output << "\"" << std::endl;
i << "}" << std::endl; i << "}" << std::endl;
if (this->ret != this->obtained) { if (this->ret != this->result.ret) {
i << "Expecting: \"" << this->ret << "\" - returned: \""; i << "Expecting: \"" << this->ret << "\" - returned: \"";
i << this->obtained << "\"" << std::endl; i << this->result.ret << "\"" << std::endl;
} }
if (this->output != this->obtainedOutput) { if (this->output != this->result.output) {
i << "Expecting: \""; i << "Expecting: \"";
i << modsecurity::utils::string::toHexIfNeeded(this->output); i << modsecurity::utils::string::toHexIfNeeded(this->output);
i << "\" - returned: \""; i << "\" - returned: \"";
i << modsecurity::utils::string::toHexIfNeeded(this->obtainedOutput); i << modsecurity::utils::string::toHexIfNeeded(this->result.output);
i << "\""; i << "\"";
i << std::endl; i << std::endl;
} }

View File

@ -25,6 +25,12 @@
namespace modsecurity_test { namespace modsecurity_test {
class UnitTestResult {
public:
int ret;
std::string output;
};
class UnitTest { class UnitTest {
public: public:
static UnitTest *from_yajl_node(const yajl_val &); static UnitTest *from_yajl_node(const yajl_val &);
@ -39,9 +45,8 @@ class UnitTest {
std::string filename; std::string filename;
std::string output; std::string output;
int ret; int ret;
int obtained;
int skipped; int skipped;
std::string obtainedOutput; UnitTestResult result;
}; };
} // namespace modsecurity_test } // namespace modsecurity_test