Added support to run unit tests in a multithreaded context

- This is controlled by specifying the 'mtstress' argument when running
  `unit_test`.
- The goal is to detect if the operator/transformation  fails in this
  context.
- In this mode, the test will be executed 5'000 times in 50 threads
  concurrently.
- Allocation & initialization of the operator/transformation is
  performed once in the main thread, while the evaluation is executed in
  the threads.
  - This is consistent with the library's support for multithreading,
    where initialization and loading of rules is expected to run once.
    See issue #3215.
This commit is contained in:
Eduardo Arias
2024-08-09 06:54:35 -07:00
parent 7bdc3c825c
commit ee5f95eb04
5 changed files with 165 additions and 69 deletions

View File

@@ -15,7 +15,9 @@
#include <string.h>
#include <cstring>
#include <cassert>
#include <thread>
#include <array>
#include <iostream>
#include <ctime>
#include <string>
@@ -38,6 +40,7 @@
using modsecurity_test::UnitTest;
using modsecurity_test::UnitTestResult;
using modsecurity_test::ModSecurityTest;
using modsecurity_test::ModSecurityTestResults;
using modsecurity::actions::transformations::Transformation;
@@ -53,64 +56,149 @@ void print_help() {
}
void perform_unit_test(ModSecurityTest<UnitTest> *test, UnitTest *t,
ModSecurityTestResults<UnitTest>* res) {
std::string error;
struct OperatorTest {
using ItemType = Operator;
static ItemType* init(const UnitTest &t) {
auto op = Operator::instantiate(t.name, t.param);
assert(op != nullptr);
std::string error;
op->init(t.filename, &error);
return op;
}
static UnitTestResult eval(ItemType &op, const UnitTest &t) {
return {op.evaluate(nullptr, nullptr, t.input, nullptr), {}};
}
static bool check(const UnitTestResult &result, const UnitTest &t) {
return result.ret != t.ret;
}
};
struct TransformationTest {
using ItemType = Transformation;
static ItemType* init(const UnitTest &t) {
auto tfn = Transformation::instantiate("t:" + t.name);
assert(tfn != nullptr);
return tfn;
}
static UnitTestResult eval(ItemType &tfn, const UnitTest &t) {
return {1, tfn.evaluate(t.input, nullptr)};
}
static bool check(const UnitTestResult &result, const UnitTest &t) {
return result.output != t.output;
}
};
template<typename TestType>
UnitTestResult perform_unit_test_once(const UnitTest &t) {
std::unique_ptr<typename TestType::ItemType> item(TestType::init(t));
assert(item.get() != nullptr);
return TestType::eval(*item.get(), t);
}
template<typename TestType>
UnitTestResult perform_unit_test_multithreaded(const UnitTest &t) {
constexpr auto NUM_THREADS = 50;
constexpr auto ITERATIONS = 5'000;
std::array<std::thread, NUM_THREADS> threads;
std::array<UnitTestResult, NUM_THREADS> results;
std::unique_ptr<typename TestType::ItemType> item(TestType::init(t));
assert(item.get() != nullptr);
for (auto i = 0; i != threads.size(); ++i)
{
auto &result = results[i];
threads[i] = std::thread(
[&item, &t, &result]()
{
for (auto j = 0; j != ITERATIONS; ++j)
result = TestType::eval(*item.get(), t);
});
}
UnitTestResult ret;
for (auto i = 0; i != threads.size(); ++i)
{
threads[i].join();
if (TestType::check(results[i], t))
ret = results[i]; // error value, keep iterating to join all threads
else if(i == 0)
ret = results[i]; // initial value
}
return ret; // cppcheck-suppress uninitvar ; false positive, ret assigned at least once in previous loop
}
template<typename TestType>
void perform_unit_test_helper(const ModSecurityTest<UnitTest> &test, UnitTest &t,
ModSecurityTestResults<UnitTest> &res) {
if (!test.m_test_multithreaded)
t.result = perform_unit_test_once<TestType>(t);
else
t.result = perform_unit_test_multithreaded<TestType>(t);
if (TestType::check(t.result, t)) {
res.push_back(&t);
if (test.m_automake_output) {
std::cout << "FAIL ";
}
} else if (test.m_automake_output) {
std::cout << "PASS ";
}
}
void perform_unit_test(const ModSecurityTest<UnitTest> &test, UnitTest &t,
ModSecurityTestResults<UnitTest> &res) {
bool found = true;
if (test->m_automake_output) {
if (test.m_automake_output) {
std::cout << ":test-result: ";
}
if (t->resource.empty() == false) {
found = (std::find(resources.begin(), resources.end(), t->resource)
!= resources.end());
if (t.resource.empty() == false) {
found = std::find(resources.begin(), resources.end(), t.resource)
!= resources.end();
}
if (!found) {
t->skipped = true;
res->push_back(t);
if (test->m_automake_output) {
t.skipped = true;
res.push_back(&t);
if (test.m_automake_output) {
std::cout << "SKIP ";
}
}
if (t->type == "op") {
Operator *op = Operator::instantiate(t->name, t->param);
op->init(t->filename, &error);
int ret = op->evaluate(NULL, NULL, t->input, NULL);
t->obtained = ret;
if (ret != t->ret) {
res->push_back(t);
if (test->m_automake_output) {
std::cout << "FAIL ";
}
} else if (test->m_automake_output) {
std::cout << "PASS ";
}
delete op;
} else if (t->type == "tfn") {
Transformation *tfn = Transformation::instantiate("t:" + t->name);
std::string ret = tfn->evaluate(t->input, NULL);
t->obtained = 1;
t->obtainedOutput = ret;
if (ret != t->output) {
res->push_back(t);
if (test->m_automake_output) {
std::cout << "FAIL ";
}
} else if (test->m_automake_output) {
std::cout << "PASS ";
}
delete tfn;
if (t.type == "op") {
perform_unit_test_helper<OperatorTest>(test, t, res);
} else if (t.type == "tfn") {
perform_unit_test_helper<TransformationTest>(test, t, res);
} else {
std::cerr << "Failed. Test type is unknown: << " << t->type;
std::cerr << "Failed. Test type is unknown: << " << t.type;
std::cerr << std::endl;
}
if (test->m_automake_output) {
std::cout << t->name << " "
<< modsecurity::utils::string::toHexIfNeeded(t->input)
if (test.m_automake_output) {
std::cout << t.name << " "
<< modsecurity::utils::string::toHexIfNeeded(t.input)
<< std::endl;
}
}
@@ -151,17 +239,15 @@ int main(int argc, char **argv) {
test.load_tests("test-cases/secrules-language-tests/transformations");
}
for (std::pair<std::string, std::vector<UnitTest *> *> a : test) {
std::vector<UnitTest *> *tests = a.second;
for (auto& [filename, tests] : test) {
total += tests->size();
for (UnitTest *t : *tests) {
for (auto t : *tests) {
ModSecurityTestResults<UnitTest> r;
if (!test.m_automake_output) {
std::cout << " " << a.first << "...\t";
std::cout << " " << filename << "...\t";
}
perform_unit_test(&test, t, &r);
perform_unit_test(test, *t, r);
if (!test.m_automake_output) {
int skp = 0;
@@ -191,7 +277,7 @@ int main(int argc, char **argv) {
std::cout << "Total >> " << total << std::endl;
}
for (UnitTest *t : results) {
for (const auto t : results) {
std::cout << t->print() << std::endl;
}
@@ -216,8 +302,8 @@ int main(int argc, char **argv) {
}
for (auto a : test) {
auto *vec = a.second;
for(auto *t : *vec)
auto vec = a.second;
for(auto t : *vec)
delete t;
delete vec;
}