sync code

This commit is contained in:
Ned Wright
2026-01-03 18:59:01 +00:00
parent c1058db57d
commit 2105628f05
188 changed files with 8272 additions and 2723 deletions

View File

@@ -150,6 +150,42 @@ makeDirRecursive(const string &path, mode_t permission)
return true;
}
bool
createFileWithContent(const string &dest, const string &content, bool overide_if_exists, mode_t permission)
{
dbgFlow(D_INFRA_UTILS)
<< "Trying to create file with content. Destination: "
<< dest
<< ", Content size: "
<< content.size()
<< ", Should override: "
<< (overide_if_exists? "true" : "false")
<< ", permission: "
<< to_string(permission);
if (exists(dest) && !overide_if_exists) {
dbgDebug(D_INFRA_UTILS) << "Failed to create file. Error: destination file already exists";
return false;
}
ofstream file(dest, ios::out | ios::trunc);
if (!file.is_open()) {
dbgDebug(D_INFRA_UTILS) << "Failed to create file. Error: could not open destination file";
return false;
}
file << content;
file.close();
if (chmod(dest.c_str(), permission) != 0) {
dbgWarning(D_INFRA_UTILS) << "Failed to set file permissions. Path: " << dest;
// Don't return false here as the file was created successfully
}
dbgTrace(D_INFRA_UTILS) << "Successfully created file with content. Path: " << dest;
return true;
}
bool
copyFile(const string &src, const string &dest, bool overide_if_exists, mode_t permission)
{
@@ -546,6 +582,13 @@ toLower(string str)
return str;
}
bool startsWith(const std::string& str, const std::string& prefix) {
if (prefix.size() > str.size()) {
return false;
}
return std::equal(prefix.begin(), prefix.end(), str.begin());
}
} // namespace Strings

View File

@@ -254,3 +254,226 @@ TEST_F(AgentCoreUtilUT, regexReplaceTest)
EXPECT_EQ(replaced, testCase.expected);
}
}
TEST_F(AgentCoreUtilUT, createFileWithContentTest)
{
// Test basic file creation with content
string test_file_path = cptestFnameInExeDir("test_create_file.txt");
string content = "Hello, World!\nThis is test content.";
unsigned int permissions = S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP | S_IROTH;
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, content, false, permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
// Verify content was written correctly
ifstream file_stream(test_file_path);
ASSERT_TRUE(file_stream.good());
stringstream buffer;
buffer << file_stream.rdbuf();
string read_content = buffer.str();
file_stream.close();
EXPECT_EQ(read_content, content);
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
}
TEST_F(AgentCoreUtilUT, createFileWithContentEmptyContentTest)
{
// Test creating file with empty content
string test_file_path = cptestFnameInExeDir("test_empty_file.txt");
string empty_content = "";
unsigned int permissions = S_IRUSR | S_IWUSR;
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, empty_content, false, permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
// Verify file is empty
ifstream file_stream(test_file_path);
ASSERT_TRUE(file_stream.good());
stringstream buffer;
buffer << file_stream.rdbuf();
string read_content = buffer.str();
file_stream.close();
EXPECT_EQ(read_content, "");
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
}
TEST_F(AgentCoreUtilUT, createFileWithContentLargeContentTest)
{
// Test creating file with large content
string test_file_path = cptestFnameInExeDir("test_large_file.txt");
string large_content;
// Create large content (10KB)
for (int i = 0; i < 1000; ++i) {
large_content += "0123456789";
}
unsigned int permissions = S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH;
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, large_content, false, permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
// Verify content was written correctly
ifstream file_stream(test_file_path);
ASSERT_TRUE(file_stream.good());
stringstream buffer;
buffer << file_stream.rdbuf();
string read_content = buffer.str();
file_stream.close();
EXPECT_EQ(read_content.size(), large_content.size());
EXPECT_EQ(read_content, large_content);
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
}
TEST_F(AgentCoreUtilUT, createFileWithContentOverwriteFalseTest)
{
// Test that overwrite=false prevents overwriting existing file
string test_file_path = cptestFnameInExeDir("test_no_overwrite.txt");
string original_content = "Original content";
string new_content = "New content that should not be written";
unsigned int permissions = S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP;
// Create initial file
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, original_content, false, permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
// Try to overwrite with overwrite=false (should fail)
EXPECT_FALSE(NGEN::Filesystem::createFileWithContent(test_file_path, new_content, false, permissions));
// Verify original content is preserved
ifstream file_stream(test_file_path);
ASSERT_TRUE(file_stream.good());
stringstream buffer;
buffer << file_stream.rdbuf();
string read_content = buffer.str();
file_stream.close();
EXPECT_EQ(read_content, original_content);
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
}
TEST_F(AgentCoreUtilUT, createFileWithContentOverwriteTrueTest)
{
// Test that overwrite=true allows overwriting existing file
string test_file_path = cptestFnameInExeDir("test_with_overwrite.txt");
string original_content = "Original content";
string new_content = "New content that should overwrite";
unsigned int permissions = S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP;
// Create initial file
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, original_content, false, permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
// Overwrite with overwrite=true (should succeed)
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, new_content, true, permissions));
// Verify new content was written
ifstream file_stream(test_file_path);
ASSERT_TRUE(file_stream.good());
stringstream buffer;
buffer << file_stream.rdbuf();
string read_content = buffer.str();
file_stream.close();
EXPECT_EQ(read_content, new_content);
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
}
TEST_F(AgentCoreUtilUT, createFileWithContentDifferentPermissionsTest)
{
// Test creating files with different permission sets
string test_file_path = cptestFnameInExeDir("test_permissions.txt");
string content = "Content for permission test";
// Test read-only permissions
unsigned int read_only_permissions = S_IRUSR | S_IRGRP | S_IROTH;
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, content, false, read_only_permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
// Test read-write permissions
unsigned int read_write_permissions = S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP;
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, content, false, read_write_permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
// Test full permissions
unsigned int full_permissions;
full_permissions = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IWGRP | S_IXGRP | S_IROTH | S_IWOTH | S_IXOTH;
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, content, false, full_permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
}
TEST_F(AgentCoreUtilUT, createFileWithContentSpecialCharactersTest)
{
// Test creating file with special characters in content
string test_file_path = cptestFnameInExeDir("test_special_chars.txt");
string content = "Special chars: !@#$%^&*()[]{}|\\:;\"'<>?,./\n\t\r\n";
content += "Unicode: \u00A9 \u00AE \u2122\n";
content += "Null byte in middle: before\0after\n";
unsigned int permissions = S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP;
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, content, false, permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
// Verify content (note: null byte will terminate string reading early)
ifstream file_stream(test_file_path, ios::binary);
ASSERT_TRUE(file_stream.good());
stringstream buffer;
buffer << file_stream.rdbuf();
string read_content = buffer.str();
file_stream.close();
// Content should match up to the special characters
EXPECT_THAT(read_content, HasSubstr("Special chars: !@#$%^&*()[]{}|\\:;\"'<>?,./"));
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
}
TEST_F(AgentCoreUtilUT, createFileWithContentInvalidPathTest)
{
// Test creating file with invalid/inaccessible path
string invalid_path = "/root/inaccessible/path/test_file.txt";
string content = "This should fail";
unsigned int permissions = S_IRUSR | S_IWUSR;
// This should fail due to invalid/inaccessible path
EXPECT_FALSE(NGEN::Filesystem::createFileWithContent(invalid_path, content, false, permissions));
EXPECT_FALSE(NGEN::Filesystem::exists(invalid_path));
}
TEST_F(AgentCoreUtilUT, createFileWithContentCreateDirectoryTest)
{
// Test creating file in a subdirectory (should create intermediate directories)
string test_dir = cptestFnameInExeDir("test_subdir");
string test_file_path = test_dir + "/nested_file.txt";
string content = "Content in nested directory";
unsigned int permissions = S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP;
// Ensure directory doesn't exist initially
EXPECT_FALSE(NGEN::Filesystem::exists(test_dir));
// Create directory first
EXPECT_TRUE(NGEN::Filesystem::makeDir(test_dir));
// Now create file in the directory
EXPECT_TRUE(NGEN::Filesystem::createFileWithContent(test_file_path, content, false, permissions));
EXPECT_TRUE(NGEN::Filesystem::exists(test_file_path));
// Verify content
ifstream file_stream(test_file_path);
ASSERT_TRUE(file_stream.good());
stringstream buffer;
buffer << file_stream.rdbuf();
string read_content = buffer.str();
file_stream.close();
EXPECT_EQ(read_content, content);
// Cleanup
EXPECT_TRUE(NGEN::Filesystem::deleteFile(test_file_path));
EXPECT_TRUE(NGEN::Filesystem::deleteDirectory(test_dir));
}

View File

@@ -19,6 +19,7 @@
#include <sys/stat.h>
#include <boost/regex.hpp>
#include <boost/algorithm/string.hpp>
#include "i_messaging.h"
#include "config.h"
#include "debug.h"
@@ -64,19 +65,37 @@ AgentDetails::init()
writeAgentDetails();
}
previous_proxy = proxy;
registerConfigLoadCb(
[&]()
{
auto proxy_config = getProfileAgentSetting<string>("agent.config.message.proxy");
if (proxy_config.ok()) {
is_proxy_configured_via_settings = true;
setProxy(*proxy_config);
writeAgentDetails();
} else if (is_proxy_configured_via_settings) {
is_proxy_configured_via_settings = false;
setProxy(string(""));
writeAgentDetails();
auto load_env_proxy = loadProxy();
auto proxy_config = getProxy();
if (proxy != previous_proxy) {
dbgInfo(D_ORCHESTRATOR)
<< "Proxy configuration changed from '"
<< previous_proxy
<< "' to '"
<< proxy
<< "'";
auto messaging = Singleton::Consume<I_Messaging>::by<AgentDetails>();
messaging->clearConnections();
}
if (!proxy_config.ok() || proxy_config.unpack() == "none") {
auto proxy_config = getProfileAgentSetting<string>("agent.config.message.proxy");
if (proxy_config.ok()) {
is_proxy_configured_via_settings = true;
setProxy(*proxy_config);
writeAgentDetails();
} else if (is_proxy_configured_via_settings) {
is_proxy_configured_via_settings = false;
setProxy(string(""));
writeAgentDetails();
}
}
previous_proxy = proxy;
}
);

View File

@@ -5,6 +5,7 @@
#include "mock/mock_encryptor.h"
#include "mock/mock_shell_cmd.h"
#include "mock/mock_messaging.h"
#include "mock/mock_mainloop.h"
#include "cptest.h"
#include "config.h"
@@ -25,6 +26,7 @@ public:
::Environment env;
ConfigComponent conf;
StrictMock<MockMessaging> mock_messaging;
StrictMock<MockEncryptor> mock_encryptor;
StrictMock<MockShellCmd> mock_shell_cmd;
Config::I_Config *config = nullptr;

View File

@@ -179,6 +179,12 @@ AgentDetailsReporter::Impl::addAttr(const string &key, const string &val, bool a
}
}
if (val.empty()) {
deleteAttr(key);
dbgDebug(D_AGENT_DETAILS) << "Attribute " << key << " was empty, deleting";
return true;
}
if (persistant_attributes[key] == val) {
dbgDebug(D_AGENT_DETAILS) << "Attribute " << key << " did not change. Value: " << val;
return true;

View File

@@ -106,7 +106,7 @@ HttpAttachmentConfiguration::save(cereal::JSONOutputArchive &archive) const
"waiting_for_verdict_thread_timeout_msec",
getNumericalValue("waiting_for_verdict_thread_timeout_msec")
),
cereal::make_nvp("nginx_inspection_mode", getNumericalValue("inspection_mode")),
cereal::make_nvp("nginx_inspection_mode", getNumericalValue("nginx_inspection_mode")),
cereal::make_nvp("num_of_nginx_ipc_elements", getNumericalValue("num_of_nginx_ipc_elements")),
cereal::make_nvp("keep_alive_interval_msec", getNumericalValue("keep_alive_interval_msec")),
cereal::make_nvp("min_retries_for_verdict", getNumericalValue("min_retries_for_verdict")),
@@ -114,7 +114,12 @@ HttpAttachmentConfiguration::save(cereal::JSONOutputArchive &archive) const
cereal::make_nvp("hold_verdict_retries", getNumericalValue("hold_verdict_retries")),
cereal::make_nvp("hold_verdict_polling_time", getNumericalValue("hold_verdict_polling_time")),
cereal::make_nvp("body_size_trigger", getNumericalValue("body_size_trigger")),
cereal::make_nvp("remove_server_header", getNumericalValue("remove_server_header"))
cereal::make_nvp("remove_server_header", getNumericalValue("remove_server_header")),
cereal::make_nvp("decompression_pool_size", getNumericalValue("decompression_pool_size")),
cereal::make_nvp("recompression_pool_size", getNumericalValue("recompression_pool_size")),
cereal::make_nvp("is_paired_affinity_enabled", getNumericalValue("is_paired_affinity_enabled")),
cereal::make_nvp("is_async_mode_enabled", getNumericalValue("is_async_mode_enabled")),
cereal::make_nvp("is_brotli_inspection_enabled", getNumericalValue("is_brotli_inspection_enabled"))
);
}
@@ -173,6 +178,21 @@ HttpAttachmentConfiguration::load(cereal::JSONInputArchive &archive)
loadNumericalValue(archive, "hold_verdict_polling_time", 1);
loadNumericalValue(archive, "body_size_trigger", 200000);
loadNumericalValue(archive, "remove_server_header", 0);
loadNumericalValue(archive, "decompression_pool_size", 262144);
loadNumericalValue(archive, "recompression_pool_size", 16384);
loadNumericalValue(archive, "is_paired_affinity_enabled", 0);
loadNumericalValue(archive, "is_brotli_inspection_enabled", 0);
int g_env_async_mode = 1;
char *env_async_mode = getenv("CP_ASYNC_MODE");
if (env_async_mode != NULL) {
if (strcmp(env_async_mode, "true") == 0 || strcmp(env_async_mode, "1") == 0) {
g_env_async_mode = 1;
} else {
g_env_async_mode = 0;
}
}
loadNumericalValue(archive, "is_async_mode_enabled", g_env_async_mode);
}
bool

View File

@@ -64,7 +64,12 @@ TEST_F(HttpAttachmentUtilTest, GetValidAttachmentConfiguration)
"\"req_header_thread_timeout_msec\": 10,\n"
"\"ip_ranges\": " + createIPRangesString(ip_ranges) + ",\n"
"\"static_resources_path\": \"" + static_resources_path + "\",\n"
"\"remove_server_header\": 0"
"\"remove_server_header\": 0,\n"
"\"decompression_pool_size\": 524288,\n"
"\"recompression_pool_size\": 32768,\n"
"\"is_paired_affinity_enabled\": 0,\n"
"\"is_async_mode_enabled\": 0,\n"
"\"is_brotli_inspection_enabled\": 1\n"
"}\n";
ofstream valid_configuration_file(attachment_configuration_file_name);
valid_configuration_file << valid_configuration;
@@ -91,6 +96,11 @@ TEST_F(HttpAttachmentUtilTest, GetValidAttachmentConfiguration)
EXPECT_EQ(conf_data_out.getNumericalValue("waiting_for_verdict_thread_timeout_msec"), 60u);
EXPECT_EQ(conf_data_out.getNumericalValue("nginx_inspection_mode"), 1u);
EXPECT_EQ(conf_data_out.getNumericalValue("remove_server_header"), 0u);
EXPECT_EQ(conf_data_out.getNumericalValue("decompression_pool_size"), 524288u);
EXPECT_EQ(conf_data_out.getNumericalValue("recompression_pool_size"), 32768u);
EXPECT_EQ(conf_data_out.getNumericalValue("is_paired_affinity_enabled"), 0u);
EXPECT_EQ(conf_data_out.getNumericalValue("is_async_mode_enabled"), 0u);
EXPECT_EQ(conf_data_out.getNumericalValue("is_brotli_inspection_enabled"), 1u);
}
TEST_F(HttpAttachmentUtilTest, GetMalformedAttachmentConfiguration)

View File

@@ -22,6 +22,8 @@
#include <strings.h>
#include <string.h>
#include <zlib.h>
#include <brotli/encode.h>
#include <brotli/decode.h>
using namespace std;
@@ -29,6 +31,10 @@ using DebugFunction = void(*)(const char *);
static const int max_debug_level = static_cast<int>(CompressionUtilsDebugLevel::COMPRESSION_DBG_LEVEL_ASSERTION);
static const int max_retries = 3;
static const size_t default_brotli_buffer_size = 16384;
static const size_t brotli_decompression_probe_size = 64;
static void
defaultPrint(const char *debug_message)
{
@@ -104,12 +110,23 @@ static const int zlib_no_flush = Z_NO_FLUSH;
struct CompressionStream
{
CompressionStream() { bzero(&stream, sizeof(z_stream)); }
CompressionStream()
:
br_encoder_state(nullptr),
br_decoder_state(nullptr)
{
bzero(&stream, sizeof(z_stream));
}
~CompressionStream() { fini(); }
tuple<basic_string<unsigned char>, bool>
decompress(const unsigned char *data, uint32_t size)
{
if (state == TYPE::UNINITIALIZED && size > 0 && isBrotli(data, size)) return decompressBrotli(data, size);
if (state == TYPE::DECOMPRESS_BROTLI) return decompressBrotli(data, size);
initInflate();
if (state != TYPE::DECOMPRESS) throw runtime_error("Could not start decompression");
@@ -138,7 +155,7 @@ struct CompressionStream
res.append(work_space.data(), stream.total_out - old_total_out);
} else {
++retries;
if (retries > 3) {
if (retries > max_retries) {
fini();
throw runtime_error("No results from inflate more than three times");
}
@@ -156,6 +173,7 @@ struct CompressionStream
basic_string<unsigned char>
compress(CompressionType type, const unsigned char *data, uint32_t size, int is_last_chunk)
{
if (type == CompressionType::BROTLI) return compressBrotli(data, size, is_last_chunk);
initDeflate(type);
if (state != TYPE::COMPRESS) throw runtime_error("Could not start compression");
@@ -183,7 +201,7 @@ struct CompressionStream
res.append(work_space.data(), stream.total_out - old_total_out);
} else {
++retries;
if (retries > 3) {
if (retries > max_retries) {
fini();
throw runtime_error("No results from deflate more than three times");
}
@@ -201,7 +219,7 @@ private:
void
initInflate()
{
if (state != TYPE::UNINITIALIZAED) return;
if (state != TYPE::UNINITIALIZED) return;
auto init_status = inflateInit2(&stream, default_num_window_bits + 32);
if (init_status != zlib_ok_return_value) {
@@ -216,7 +234,7 @@ private:
void
initDeflate(CompressionType type)
{
if (state != TYPE::UNINITIALIZAED) return;
if (state != TYPE::UNINITIALIZED) return;
int num_history_window_bits;
switch (type) {
@@ -228,6 +246,10 @@ private:
num_history_window_bits = default_num_window_bits;
break;
}
case CompressionType::BROTLI: {
zlibDbgAssertion << "Brotli compression should use compressBrotli()";
return;
}
default: {
zlibDbgAssertion
<< "Invalid compression type value: "
@@ -253,6 +275,190 @@ private:
state = TYPE::COMPRESS;
}
basic_string<unsigned char>
compressBrotli(const unsigned char *data, uint32_t size, int is_last_chunk)
{
if (state == TYPE::UNINITIALIZED) {
br_encoder_state = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr);
if (!br_encoder_state) throw runtime_error("Failed to create Brotli encoder state");
BrotliEncoderSetParameter(br_encoder_state, BROTLI_PARAM_QUALITY, BROTLI_DEFAULT_QUALITY);
BrotliEncoderSetParameter(br_encoder_state, BROTLI_PARAM_LGWIN, BROTLI_DEFAULT_WINDOW);
state = TYPE::COMPRESS_BROTLI;
} else if (state != TYPE::COMPRESS_BROTLI) {
throw runtime_error("Compression stream in inconsistent state for Brotli compression");
}
basic_string<unsigned char> output;
vector<uint8_t> buffer(16384);
int retries = 0;
const uint8_t* next_in = data;
size_t available_in = size;
while (available_in > 0 || is_last_chunk) {
size_t available_out = buffer.size();
uint8_t* next_out = buffer.data();
BrotliEncoderOperation op = is_last_chunk ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS;
auto brotli_success = BrotliEncoderCompressStream(
br_encoder_state,
op,
&available_in,
&next_in,
&available_out,
&next_out,
nullptr
);
if (brotli_success == BROTLI_FALSE) {
fini();
throw runtime_error("Brotli compression error");
}
size_t bytes_written = buffer.size() - available_out;
if (bytes_written > 0) {
output.append(buffer.data(), bytes_written);
retries = 0;
} else {
retries++;
if (retries > max_retries) {
fini();
throw runtime_error("Brotli compression error: Exceeded retry limit.");
}
}
if (BrotliEncoderIsFinished(br_encoder_state)) break;
if (available_in == 0 && !is_last_chunk) break;
}
if (is_last_chunk) fini();
return output;
}
tuple<basic_string<unsigned char>, bool>
decompressBrotli(const unsigned char *data, uint32_t size)
{
if (state != TYPE::DECOMPRESS_BROTLI) {
br_decoder_state = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
if (!br_decoder_state) throw runtime_error("Failed to create Brotli decoder state");
BrotliDecoderSetParameter(br_decoder_state, BROTLI_DECODER_PARAM_LARGE_WINDOW, 1u);
state = TYPE::DECOMPRESS_BROTLI;
}
basic_string<unsigned char> output;
const uint8_t* next_in = data;
size_t available_in = size;
size_t buffer_size = max<size_t>(size * 4, default_brotli_buffer_size);
vector<uint8_t> buffer(buffer_size);
// Use a constant ratio for max buffer size relative to input size
const size_t max_buffer_size = 256 * 1024 * 1024; // 256 MB max buffer size
while (true) {
size_t available_out = buffer.size();
uint8_t* next_out = buffer.data();
BrotliDecoderResult result = BrotliDecoderDecompressStream(
br_decoder_state,
&available_in,
&next_in,
&available_out,
&next_out,
nullptr
);
if (result == BROTLI_DECODER_RESULT_ERROR) {
fini();
auto error_msg = string(BrotliDecoderErrorString(BrotliDecoderGetErrorCode(br_decoder_state)));
throw runtime_error("Brotli decompression error: " + error_msg);
}
// Handle any produced output
size_t bytes_produced = buffer.size() - available_out;
if (bytes_produced > 0) {
output.append(buffer.data(), bytes_produced);
}
if (result == BROTLI_DECODER_RESULT_SUCCESS) {
bool is_finished = BrotliDecoderIsFinished(br_decoder_state);
if (is_finished) fini();
return make_tuple(output, is_finished);
}
if (result == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) {
// Check if we've exceeded the maximum buffer size limit
if (buffer.size() >= max_buffer_size) {
fini();
throw runtime_error("Brotli decompression buffer size limit exceeded - possibly corrupted data");
}
// Resize buffer to accommodate more output
size_t new_size = min(buffer.size() * 2, max_buffer_size);
buffer.resize(new_size);
continue; // Continue with the same input, new buffer
}
// If we reach here, we need more input but have no more to provide
if (available_in == 0) {
// No more input data available, return what we have so far
return make_tuple(output, false);
}
}
return make_tuple(output, false);
}
bool
isBrotli(const unsigned char *data, uint32_t size)
{
if (size < 4) return false;
BrotliDecoderState* test_decoder = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
if (!test_decoder) return false;
const uint8_t* next_in = data;
size_t available_in = min<size_t>(size, brotli_decompression_probe_size);
uint8_t output[brotli_decompression_probe_size];
size_t available_out = sizeof(output);
uint8_t* next_out = output;
BrotliDecoderResult result = BrotliDecoderDecompressStream(
test_decoder,
&available_in,
&next_in,
&available_out,
&next_out,
nullptr
);
bool is_brotli = false;
if (
result != BROTLI_DECODER_RESULT_ERROR &&
(
available_out < sizeof(output) ||
available_in < min<size_t>(size, brotli_decompression_probe_size)
)
) {
is_brotli = true;
}
BrotliDecoderDestroyInstance(test_decoder);
if (is_brotli) {
br_decoder_state = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
BrotliDecoderSetParameter(br_decoder_state, BROTLI_DECODER_PARAM_LARGE_WINDOW, 1u);
state = TYPE::DECOMPRESS_BROTLI;
return true;
}
return false;
}
void
fini()
{
@@ -261,11 +467,21 @@ private:
if (state == TYPE::DECOMPRESS) end_stream_res = inflateEnd(&stream);
if (state == TYPE::COMPRESS) end_stream_res = deflateEnd(&stream);
if (end_stream_res != zlib_ok_return_value) {
if (br_encoder_state) {
BrotliEncoderDestroyInstance(br_encoder_state);
br_encoder_state = nullptr;
}
if (br_decoder_state) {
BrotliDecoderDestroyInstance(br_decoder_state);
br_decoder_state = nullptr;
}
if (end_stream_res != zlib_ok_return_value && end_stream_res != Z_DATA_ERROR) {
zlibDbgError << "Failed to clean state: " << getZlibError(end_stream_res);
}
state = TYPE::UNINITIALIZAED;
state = TYPE::UNINITIALIZED;
}
string
@@ -288,7 +504,16 @@ private:
}
z_stream stream;
enum class TYPE { UNINITIALIZAED, COMPRESS, DECOMPRESS } state = TYPE::UNINITIALIZAED;
enum class TYPE {
UNINITIALIZED,
COMPRESS,
DECOMPRESS,
COMPRESS_BROTLI,
DECOMPRESS_BROTLI
} state = TYPE::UNINITIALIZED;
BrotliEncoderState* br_encoder_state = nullptr;
BrotliDecoderState* br_decoder_state = nullptr;
};
void

View File

@@ -284,6 +284,25 @@ private:
const string test_files_dir_name = "test_files";
};
TEST_F(CompressionUtilsTest, BrotliBufferLimitTest)
{
// Create a large string with highly compressible data that will expand significantly
// This should trigger the buffer size limit if max_ratio is too restrictive
string large_compressible_string = string(10000, 'A') + string(10000, 'B') + "CCCCCC"; // 10KB of 'A' characters
Maybe<string> compressed = compressString(CompressionType::BROTLI, large_compressible_string);
EXPECT_TRUE(compressed.ok());
cout << "Compressed size: " << compressed.unpack().size() << ". compression ratio: "
<< static_cast<double>(large_compressible_string.size()) / compressed.unpack().size() << endl;
// This decompression should fail if max_ratio is too restrictive (like 1)
// because decompressed data (100KB) will be much larger than compressed data
Maybe<string> decompressed = decompressString(compressed.unpack());
ASSERT_TRUE(decompressed.ok());
EXPECT_EQ(large_compressible_string, decompressed.unpack());
}
TEST_F(CompressionUtilsTest, CompressAndDecompressSimpleString)
{
for (auto single_compression_type : compression_types) {
@@ -460,3 +479,177 @@ TEST_F(CompressionUtilsTest, DecompressPlainText)
HasSubstr("error in 'inflate': Invalid or corrupted stream data")
);
}
TEST_F(CompressionUtilsTest, BrotliCompressAndDecompressSimpleString)
{
Maybe<string> compressed_string_maybe = compressString(
CompressionType::BROTLI,
simple_test_string
);
EXPECT_TRUE(compressed_string_maybe.ok());
Maybe<string> decompressed_string_maybe = decompressString(compressed_string_maybe.unpack());
EXPECT_TRUE(decompressed_string_maybe.ok());
EXPECT_EQ(simple_test_string, decompressed_string_maybe.unpack());
}
TEST_F(CompressionUtilsTest, BrotliCompressAndDecompressChunkSizedString)
{
string test_string = readTestFileContents(chunk_sized_string_file_name);
Maybe<string> compressed_string_maybe = compressString(
CompressionType::BROTLI,
test_string
);
EXPECT_TRUE(compressed_string_maybe.ok());
Maybe<string> decompressed_string_maybe = decompressString(compressed_string_maybe.unpack());
EXPECT_TRUE(decompressed_string_maybe.ok());
EXPECT_EQ(test_string, decompressed_string_maybe.unpack());
}
TEST_F(CompressionUtilsTest, BrotliCompressMultipleChunkSizedStringAndDecompress)
{
string test_string = readTestFileContents(multi_chunk_sized_string_file_name);
Maybe<string> chunked_compress_result = chunkedCompressString(CompressionType::BROTLI, test_string);
EXPECT_TRUE(chunked_compress_result.ok());
Maybe<string> chunked_decompress_result = chunkedDecompressString(chunked_compress_result.unpack());
EXPECT_TRUE(chunked_decompress_result.ok());
EXPECT_EQ(chunked_decompress_result.unpack(), test_string);
}
TEST_F(CompressionUtilsTest, BrotliEmptyBuffer)
{
auto compression_stream = initCompressionStream();
stringstream compressed_stream;
Maybe<string> compressed_string = compressString(
CompressionType::BROTLI,
simple_test_string,
false,
compression_stream
);
EXPECT_TRUE(compressed_string.ok());
compressed_stream << compressed_string.unpack();
compressed_string = compressString(
CompressionType::BROTLI,
"",
true,
compression_stream
);
finiCompressionStream(compression_stream);
EXPECT_TRUE(compressed_string.ok());
compressed_stream << compressed_string.unpack();
int is_last_chunk;
auto decompression_stream = initCompressionStream();
Maybe<string> decompressed_string = decompressString(
compressed_stream.str(),
&is_last_chunk,
decompression_stream
);
EXPECT_TRUE(decompressed_string.ok());
EXPECT_EQ(decompressed_string.unpack(), simple_test_string);
finiCompressionStream(decompression_stream);
}
TEST_F(CompressionUtilsTest, BrotliCompressionRatio)
{
// Test if Brotli provides reasonable compression ratio for highly compressible content
string test_string = string(10000, 'A');
Maybe<string> gzip_compressed = compressString(CompressionType::GZIP, test_string);
Maybe<string> brotli_compressed = compressString(CompressionType::BROTLI, test_string);
EXPECT_TRUE(gzip_compressed.ok());
EXPECT_TRUE(brotli_compressed.ok());
// Both should compress well
EXPECT_LT(gzip_compressed.unpack().size(), test_string.size() / 10);
EXPECT_LT(brotli_compressed.unpack().size(), test_string.size() / 10);
}
TEST_F(CompressionUtilsTest, BrotliVariousSizedPayloads)
{
const vector<string> test_strings = {
"", // Empty string
"a", // Single character
"Hello, Brotli compression!", // Short string
string(1024, 'A'), // 1KB of repeating data
readTestFileContents(chunk_sized_string_file_name) // Test file
};
for (const auto& test_string : test_strings) {
Maybe<string> compressed_string_maybe = compressString(
CompressionType::BROTLI,
test_string
);
EXPECT_TRUE(compressed_string_maybe.ok());
Maybe<string> decompressed_string_maybe = decompressString(compressed_string_maybe.unpack());
EXPECT_TRUE(decompressed_string_maybe.ok());
EXPECT_EQ(test_string, decompressed_string_maybe.unpack());
}
}
TEST_F(CompressionUtilsTest, ExceptionHandling_NullDataPointer)
{
auto compression_stream = initCompressionStream();
CompressionResult result = compressData(
compression_stream,
CompressionType::GZIP,
100,
nullptr, // Null data pointer
1
);
EXPECT_EQ(result.ok, 0);
EXPECT_THAT(
capture_debug.str(),
HasSubstr("Compression failed Data pointer is NULL")
);
finiCompressionStream(compression_stream);
}
TEST_F(CompressionUtilsTest, ExceptionHandling_InvalidDecompressionData)
{
auto compression_stream = initCompressionStream();
unsigned char invalid_data[] = "This is not compressed data";
DecompressionResult result = decompressData(
compression_stream,
sizeof(invalid_data),
invalid_data
);
EXPECT_EQ(result.ok, 0);
EXPECT_THAT(
capture_debug.str(),
AnyOf(HasSubstr("Decompression failed"), HasSubstr("error in 'inflate'"))
);
finiCompressionStream(compression_stream);
}
TEST_F(CompressionUtilsTest, ExceptionHandling_ResourceCleanup)
{
// Verify no memory leaks by creating and destroying multiple streams
// with failing operations
unsigned char invalid_data[] = "This is not compressed data";
for (int i = 0; i < 10; i++) {
auto temp_stream = initCompressionStream();
decompressData(temp_stream, 5, invalid_data); // This should fail
finiCompressionStream(temp_stream);
}
// No crashes = success
}

View File

@@ -1,4 +1,4 @@
add_library(config config.cc config_specific.cc config_globals.cc)
add_library(config config.cc config_specific.cc config_globals.cc config_cache_stats.cc)
target_link_libraries(config agent_core_utilities)
link_directories(${BOOST_ROOT}/lib)

View File

@@ -19,6 +19,8 @@
#include <fstream>
#include <iostream>
#include <cctype>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
#include "agent_core_utilities.h"
#include "cereal/archives/json.hpp"
@@ -109,7 +111,7 @@ public:
void init();
const TypeWrapper & getConfiguration(const vector<string> &paths) const override;
PerContextValue getAllConfiguration(const std::vector<std::string> &paths) const;
PerContextValue getAllConfiguration(const vector<string> &paths) const;
const TypeWrapper & getResource(const vector<string> &paths) const override;
const TypeWrapper & getSetting(const vector<string> &paths) const override;
string getProfileAgentSetting(const string &setting_name) const override;
@@ -120,15 +122,22 @@ public:
const string & getFilesystemPathConfig() const override;
const string & getLogFilesPathConfig() const override;
bool
isConfigCacheEnabled() const override
{
return is_cache_enabled && !policy_load_id.empty();
}
const string & getPolicyLoadId() const override { return policy_load_id; }
string getPolicyConfigPath(
const string &name,
ConfigFileType type,
const string &tenant = "",
const string &profile = "") const override;
bool setConfiguration(TypeWrapper &&value, const std::vector<std::string> &paths) override;
bool setResource(TypeWrapper &&value, const std::vector<std::string> &paths) override;
bool setSetting(TypeWrapper &&value, const std::vector<std::string> &paths) override;
bool setConfiguration(TypeWrapper &&value, const vector<string> &paths) override;
bool setResource(TypeWrapper &&value, const vector<string> &paths) override;
bool setSetting(TypeWrapper &&value, const vector<string> &paths) override;
void registerExpectedConfigFile(const string &file_name, ConfigFileType type) override;
void registerExpectedConfiguration(unique_ptr<GenericConfig<true>> &&config) override;
@@ -145,6 +154,15 @@ public:
void registerConfigLoadCb(ConfigCb) override;
void registerConfigAbortCb(ConfigCb) override;
void clearOldTenants() override;
void resetConfigCache() override;
// Cache statistics interface implementation
uint64_t getCacheHits() const override;
uint64_t getCacheMisses() const override;
void resetCacheStats() override;
void enableCacheTracking() override;
void disableCacheTracking() override;
bool isCacheTrackingEnabled() const override;
private:
bool areTenantAndProfileActive(const TenantProfilePair &tenant_profile) const;
@@ -157,7 +175,8 @@ private:
void reloadConfigurationContinuesWrapper(const string &version, uint id);
vector<string> fillMultiTenantConfigFiles(const map<string, set<string>> &tenants);
vector<string> fillMultiTenantExpectedConfigFiles(const map<string, set<string>> &tenants);
map<string, string> getProfileAgentSetting() const;
map<string, string> & getProfileAgentSetting() const;
void resolveVsId() const;
string
@@ -234,7 +253,7 @@ private:
secondary_port_req_md
);
}
if (!service_config_status.ok()) {
dbgWarning(D_CONFIG)
<< "Could not send configuration to orchestrator 7778, error: "
@@ -343,6 +362,12 @@ private:
string default_config_directory_path = "/conf/";
string config_directory_path = "";
string error_to_report = "";
string current_policy_version = "";
mutable size_t cached_last_policy_count = 0;
mutable map<string, string> cached_agent_settings;
size_t policy_load_count = 0;
string policy_load_id = "";
bool is_cache_enabled = false;
TypeWrapper empty;
};
@@ -350,6 +375,7 @@ private:
void
ConfigComponent::Impl::preload()
{
resetConfigCache();
I_Environment *environment = Singleton::Consume<I_Environment>::by<ConfigComponent>();
auto executable = environment->get<string>("Base Executable Name");
if (!executable.ok() || *executable == "") {
@@ -733,6 +759,7 @@ ConfigComponent::Impl::periodicRegistrationRefresh()
{
I_Environment *environment = Singleton::Consume<I_Environment>::by<ConfigComponent>();
I_MainLoop *mainloop = Singleton::Consume<I_MainLoop>::by<ConfigComponent>();
I_Environment *env = Singleton::Consume<I_Environment>::by<ConfigComponent>();
while (true) {
auto env_listening_port = environment->get<int>("Listening Port");
@@ -742,15 +769,17 @@ ConfigComponent::Impl::periodicRegistrationRefresh()
<< "Internal rest server listening port is not yet set."
<< " Setting retry attempt to 500 milliseconds from now";
mainloop->yield(chrono::milliseconds(500));
} else if (!sendOrchestatorConfMsg(env_listening_port.unpack())) {
mainloop->yield(chrono::milliseconds(500));
} else {
} else if (sendOrchestatorConfMsg(env_listening_port.unpack())) {
dbgInfo(D_CONFIG) << "Configuration update registration with orchestrator succeeded.";
env->registerValue<bool>("isRegisteredWithOrchestrator", true);
uint next_iteration_in_sec = getConfigurationWithDefault<uint>(
600,
"Config Component",
"Refresh config update registration time interval"
);
mainloop->yield(chrono::seconds(next_iteration_in_sec));
} else {
mainloop->yield(chrono::milliseconds(500));
}
}
}
@@ -758,6 +787,7 @@ ConfigComponent::Impl::periodicRegistrationRefresh()
bool
ConfigComponent::Impl::loadConfiguration(vector<shared_ptr<JSONInputArchive>> &file_archives, bool is_async)
{
is_cache_enabled = false;
auto mainloop = is_async ? Singleton::Consume<I_MainLoop>::by<ConfigComponent>() : nullptr;
for (auto &cb : configuration_prepare_cbs) {
@@ -828,6 +858,15 @@ ConfigComponent::Impl::commitSuccess()
for (auto &cb : configuration_commit_cbs) {
cb();
}
policy_load_count++;
try {
policy_load_id = to_string(boost::uuids::random_generator()());
} catch (const boost::uuids::entropy_error &e) {
dbgWarning(D_CONFIG) << "Failed to create random id for policy_load_id";
policy_load_id = "";
}
is_cache_enabled = true;
initializeCacheTracking();
return true;
}
@@ -842,6 +881,7 @@ ConfigComponent::Impl::commitFailure(const string &error)
for (auto &cb : configuration_abort_cbs) {
cb();
}
policy_load_id = "";
return false;
}
@@ -933,24 +973,38 @@ ConfigComponent::Impl::reloadConfigurationImpl(const string &version, bool is_as
env->registerValue<bool>("Is Async Config Load", is_async);
bool res = loadConfiguration(archives, is_async);
env->unregisterKey<bool>("Is Async Config Load");
if (res) env->registerValue<string>("Current Policy Version", version);
if (res) {
env->registerValue<string>("Current Policy Version", version);
current_policy_version = version;
dbgTrace(D_CONFIG) << "Successfully loaded configuration. Version: " << version
<< ", Policy Load Count: " << policy_load_count
<< ", Policy Load ID: " << policy_load_id;
}
return res;
}
map<string, string>
map<string, string> &
ConfigComponent::Impl::getProfileAgentSetting() const
{
auto general_sets = getSettingWithDefault(AgentProfileSettings::default_profile_settings, "generalAgentSettings");
if (!is_cache_enabled ||
policy_load_count == 0 ||
cached_last_policy_count != policy_load_count) {
cached_agent_settings.clear();
auto general_sets = getSettingWithDefault(
AgentProfileSettings::default_profile_settings,
"generalAgentSettings");
cached_agent_settings = general_sets.getSettings();
auto settings = general_sets.getSettings();
auto profile_sets = getSettingWithDefault(AgentProfileSettings::default_profile_settings, "agentSettings");
auto profile_settings = profile_sets.getSettings();
for (const auto &profile_setting : profile_settings) {
settings.insert(profile_setting);
auto profile_sets = getSettingWithDefault(AgentProfileSettings::default_profile_settings, "agentSettings");
auto profile_settings = profile_sets.getSettings();
for (const auto &profile_setting : profile_settings) {
cached_agent_settings.insert(profile_setting);
}
cached_last_policy_count = policy_load_count;
}
return settings;
return cached_agent_settings;
}
void
@@ -963,7 +1017,7 @@ ConfigComponent::Impl::reloadConfigurationContinuesWrapper(const string &version
LoadNewConfigurationStatus in_progress(id, service_name, false, false);
auto routine_id = mainloop->addRecurringRoutine(
I_MainLoop::RoutineType::Timer,
std::chrono::seconds(30),
chrono::seconds(30),
[=] () { sendOrchestatorReloadStatusMsg(in_progress); },
"A-Synchronize reload configuraion monitoring"
);
@@ -1006,13 +1060,59 @@ ConfigComponent::Impl::resolveVsId() const
return;
}
void
ConfigComponent::Impl::resetConfigCache()
{
cached_last_policy_count = 0;
policy_load_count = 0;
policy_load_id = "";
is_cache_enabled = false;
cached_agent_settings.clear();
}
uint64_t
ConfigComponent::Impl::getCacheHits() const
{
return CacheStats::getHits();
}
uint64_t
ConfigComponent::Impl::getCacheMisses() const
{
return CacheStats::getMisses();
}
void
ConfigComponent::Impl::resetCacheStats()
{
CacheStats::reset();
}
void
ConfigComponent::Impl::enableCacheTracking()
{
CacheStats::enableTracking();
}
void
ConfigComponent::Impl::disableCacheTracking()
{
CacheStats::disableTracking();
}
bool
ConfigComponent::Impl::isCacheTrackingEnabled() const
{
return CacheStats::isTrackingEnabled();
}
ConfigComponent::ConfigComponent() : Component("ConfigComponent"), pimpl(make_unique<Impl>()) {}
ConfigComponent::~ConfigComponent() {}
void
ConfigComponent::preload()
{
registerExpectedConfiguration<string>("Config Component", "configuration path");
registerExpectedConfigurationWithCache<string>("assetId", "Config Component", "configuration path");
registerExpectedConfiguration<uint>("Config Component", "Refresh config update registration time interval");
registerExpectedConfiguration<bool>("Config Component", "Periodic Registration Refresh");
registerExpectedResource<bool>("Config Component", "Config Load Test");

View File

@@ -211,6 +211,8 @@ CPUManager::init()
}
}
registerConfigLoadCb([this]() { loadCPUConfig(); });
i_mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::Timer,
[this]() { checkCPUStatus(); },
@@ -265,8 +267,6 @@ void
CPUManager::checkCPUStatus()
{
while (true) {
loadCPUConfig();
auto is_orchestrator = Singleton::Consume<I_Environment>::by<CPUManager>()->get<bool>("Is Orchestrator");
if (is_orchestrator.ok() && is_orchestrator.unpack()) {
Maybe<double> current_general_cpu = i_cpu->getCurrentGeneralCPUUsage();

View File

@@ -47,8 +47,6 @@ public:
StrictMock<MockMainLoop> mock_ml;
StrictMock<MockTimeGet> mock_time;
I_Environment *i_env;
private:
ConfigComponent conf;
::Environment env;
};
@@ -261,8 +259,24 @@ TEST_F(CPUTest, noDebugTest)
StrictMock<MockCPU> mock_cpu;
CPUManager cpu;
cpu.preload();
setConfiguration<uint>(0, string("CPU"), string("debug period"));
cpu.init();
// Test loadConfiguration functionality like in compression_ut
string config_json =
"{"
" \"CPU\": {"
" \"debug period\": ["
" {"
" \"value\": 0"
" }"
" ]"
" }"
"}";
istringstream ss(config_json);
Singleton::Consume<Config::I_Config>::from(conf)->loadConfiguration(ss);
auto loaded_debug_period = getConfiguration<uint>("CPU", "debug period");
EXPECT_TRUE(loaded_debug_period.ok());
EXPECT_EQ((int)loaded_debug_period.unpack(), 0);
doFWError();
EXPECT_THAT(debug_output.str(), HasSubstr("!!!] FW error message\n"));

View File

@@ -50,6 +50,7 @@ CurlHttpClient::~CurlHttpClient()
curl_global_cleanup();
}
// LCOV_EXCL_START
void
CurlHttpClient::setProxy(const string& hosts)
{
@@ -68,7 +69,15 @@ CurlHttpClient::authEnabled(bool enabled)
{
auth_enabled = enabled;
}
// LCOV_EXCL_STOP
void
CurlHttpClient::setConfigs(const CurlHttpClientConfig& config)
{
this->config = config;
}
// LCOV_EXCL_START
HTTPResponse
CurlHttpClient::get(const string& url, const map<string, string>& headers)
{
@@ -106,6 +115,7 @@ CurlHttpClient::WriteCallback(void *contents, size_t size, size_t nmemb, string
userp->append(static_cast<char *>(contents), totalSize);
return totalSize;
}
// LCOV_EXCL_STOP
HTTPResponse
CurlHttpClient::perform_request(
@@ -128,6 +138,24 @@ CurlHttpClient::perform_request(
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_body);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, config.timeout_seconds);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, config.connect_timeout_seconds);
if (config.verbose_enabled) {
curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);
}
curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, config.ssl_verify_peer ? 1L : 0L);
curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, config.ssl_verify_host ? 2L : 0L);
if (config.http_version != CURL_HTTP_VERSION_NONE) {
curl_easy_setopt(curl, CURLOPT_HTTP_VERSION, config.http_version);
}
if (!config.user_agent.empty()) {
curl_easy_setopt(curl, CURLOPT_USERAGENT, config.user_agent.c_str());
}
if (!no_proxy_hosts.empty()) {
dbgTrace(D_NGINX_MANAGER) << "Using proxy url: " << no_proxy_hosts;
curl_easy_setopt(curl, CURLOPT_PROXY, no_proxy_hosts.c_str());

View File

@@ -84,6 +84,15 @@ static map<Debug::DebugFlags, string> flags_to_setting_name = {
#undef DEFINE_FLAG
};
// Reverse mapping: flag name string to enum (using the same source as flags_to_setting_name)
static map<string, Debug::DebugFlags> string_to_flag_map = {
{"D_ALL", Debug::DebugFlags::D_ALL},
#define DEFINE_FLAG(flag_name, parent_name) \
{#flag_name, Debug::DebugFlags::flag_name},
#include "debug_flags.h"
#undef DEFINE_FLAG
};
static map<string, shared_ptr<Debug::DebugStream>> preparing_streams;
static FlagsArray global_flags_levels(FlagsArray::Fill(), default_level);
@@ -510,7 +519,7 @@ Debug::~Debug()
void
Debug::preload()
{
registerExpectedConfiguration<DebugConfiguration>("Debug");
registerExpectedConfigurationWithCache<DebugConfiguration>("assetId", "Debug");
registerExpectedConfiguration<string>("Debug I/S", "Fog Debug URI");
registerExpectedConfiguration<string>("Debug I/S", "Debug conf file path");
registerExpectedConfiguration<bool>("Debug I/S", "Enable bulk of debugs");
@@ -750,6 +759,25 @@ Debug::isFlagAtleastLevel(Debug::DebugFlags flag, Debug::DebugLevel level)
return global_flags_levels[flag] <= level;
}
void
Debug::setDebugFlag(Debug::DebugFlags flag, Debug::DebugLevel level)
{
global_flags_levels[flag] = level;
default_config.streams_in_context[0].flag_values[flag] = level;
//if the new level is lower than the current lowest, update it
if (lowest_global_level >= level) {
lowest_global_level = level;
return;
}
// if the new level is higher, recalculate lowest_global_level by scanning all flag levels
lowest_global_level = global_flags_levels[Debug::DebugFlags::D_ALL];
for (const auto &current_level : global_flags_levels) {
if (current_level < lowest_global_level) lowest_global_level = current_level;
}
}
void
Debug::setUnitTestFlag(Debug::DebugFlags flag, Debug::DebugLevel level)
{
@@ -783,6 +811,17 @@ Debug::getExecutableName()
return executable.ok() ? *executable : "";
}
bool
Debug::getDebugFlagFromString(const string &flag_name, DebugFlags &flag)
{
auto flag_it = string_to_flag_map.find(flag_name);
if (flag_it != string_to_flag_map.end()) {
flag = flag_it->second;
return true;
}
return false;
}
void
Debug::addActiveStream(const string &name)
{

View File

@@ -587,6 +587,13 @@ public:
return msg;
}
void
clearDebugMessage()
{
capture_debug.str("");
capture_debug.clear();
}
bool
loadConfiguration(const string &conf_str)
{
@@ -667,6 +674,7 @@ TEST_F(DebugConfigTest, debug_all)
CPTestTempfile debug_file;
loadConfiguration("{\"Output\": \"STDOUT\", \"D_PM\": \"Error\", \"D_ALL\": \"Trace\"}");
clearDebugMessage();
doFWWarning();
EXPECT_EQ(getDebugMessage(),

View File

@@ -15,286 +15,10 @@
#ifndef __NGINX_ATTACHMENT_COMMON_H__
#define __NGINX_ATTACHMENT_COMMON_H__
#include <stddef.h>
#include <stdint.h>
#include <sys/types.h>
#include <assert.h>
// This file has been deprecated. Do not add anything here.
// Any future additions should be added to nano_attachment_common.h
// For any inquiries please contact Daniel Yashin.
#define MAX_NGINX_UID_LEN 32
#define NUM_OF_NGINX_IPC_ELEMENTS 200
#define DEFAULT_KEEP_ALIVE_INTERVAL_MSEC 300000
#define SHARED_MEM_PATH "/dev/shm/"
#define SHARED_REGISTRATION_SIGNAL_PATH SHARED_MEM_PATH "check-point/cp-nano-attachment-registration"
#define SHARED_KEEP_ALIVE_PATH SHARED_MEM_PATH "check-point/cp-nano-attachment-registration-expiration-socket"
#define SHARED_VERDICT_SIGNAL_PATH SHARED_MEM_PATH "check-point/cp-nano-http-transaction-handler"
#define SHARED_ATTACHMENT_CONF_PATH SHARED_MEM_PATH "cp_nano_http_attachment_conf"
#define DEFAULT_STATIC_RESOURCES_PATH SHARED_MEM_PATH "static_resources"
#define INJECT_POS_IRRELEVANT -1
#define CORRUPTED_SESSION_ID 0
#define METRIC_PERIODIC_TIMEOUT 600
extern char shared_verdict_signal_path[];
extern int workers_amount_to_send;
typedef int64_t ngx_http_cp_inject_pos_t;
#ifdef __cplusplus
typedef enum class ngx_http_modification_type
#else
typedef enum ngx_http_modification_type
#endif
{
APPEND,
INJECT,
REPLACE
} ngx_http_modification_type_e;
#ifdef __cplusplus
typedef enum class ngx_http_chunk_type
#else
typedef enum ngx_http_chunk_type
#endif
{
REQUEST_START,
REQUEST_HEADER,
REQUEST_BODY,
REQUEST_END,
RESPONSE_CODE,
RESPONSE_HEADER,
RESPONSE_BODY,
RESPONSE_END,
CONTENT_LENGTH,
METRIC_DATA_FROM_PLUGIN,
HOLD_DATA,
COUNT
} ngx_http_chunk_type_e;
#ifdef __cplusplus
typedef enum class ngx_http_plugin_metric_type
#else
typedef enum ngx_http_plugin_metric_type
#endif
{
TRANSPARENTS_COUNT,
TOTAL_TRANSPARENTS_TIME,
INSPECTION_OPEN_FAILURES_COUNT,
INSPECTION_CLOSE_FAILURES_COUNT,
INSPECTION_SUCCESSES_COUNT,
INJECT_VERDICTS_COUNT,
DROP_VERDICTS_COUNT,
ACCEPT_VERDICTS_COUNT,
IRRELEVANT_VERDICTS_COUNT,
RECONF_VERDICTS_COUNT,
INSPECT_VERDICTS_COUNT,
HOLD_VERDICTS_COUNT,
AVERAGE_OVERALL_PPROCESSING_TIME_UNTIL_VERDICT,
MAX_OVERALL_PPROCESSING_TIME_UNTIL_VERDICT,
MIN_OVERALL_PPROCESSING_TIME_UNTIL_VERDICT,
AVERAGE_REQ_PPROCESSING_TIME_UNTIL_VERDICT,
MAX_REQ_PPROCESSING_TIME_UNTIL_VERDICT,
MIN_REQ_PPROCESSING_TIME_UNTIL_VERDICT,
AVERAGE_RES_PPROCESSING_TIME_UNTIL_VERDICT,
MAX_RES_PPROCESSING_TIME_UNTIL_VERDICT,
MIN_RES_PPROCESSING_TIME_UNTIL_VERDICT,
THREAD_TIMEOUT,
REG_THREAD_TIMEOUT,
REQ_HEADER_THREAD_TIMEOUT,
REQ_BODY_THREAD_TIMEOUT,
AVERAGE_REQ_BODY_SIZE_UPON_TIMEOUT,
MAX_REQ_BODY_SIZE_UPON_TIMEOUT,
MIN_REQ_BODY_SIZE_UPON_TIMEOUT,
RES_HEADER_THREAD_TIMEOUT,
RES_BODY_THREAD_TIMEOUT,
HOLD_THREAD_TIMEOUT,
AVERAGE_RES_BODY_SIZE_UPON_TIMEOUT,
MAX_RES_BODY_SIZE_UPON_TIMEOUT,
MIN_RES_BODY_SIZE_UPON_TIMEOUT,
THREAD_FAILURE,
REQ_PROCCESSING_TIMEOUT,
RES_PROCCESSING_TIMEOUT,
REQ_FAILED_TO_REACH_UPSTREAM,
REQ_FAILED_COMPRESSION_COUNT,
RES_FAILED_COMPRESSION_COUNT,
REQ_FAILED_DECOMPRESSION_COUNT,
RES_FAILED_DECOMPRESSION_COUNT,
REQ_SUCCESSFUL_COMPRESSION_COUNT,
RES_SUCCESSFUL_COMPRESSION_COUNT,
REQ_SUCCESSFUL_DECOMPRESSION_COUNT,
RES_SUCCESSFUL_DECOMPRESSION_COUNT,
CORRUPTED_ZIP_SKIPPED_SESSION_COUNT,
CPU_USAGE,
AVERAGE_VM_MEMORY_USAGE,
AVERAGE_RSS_MEMORY_USAGE,
MAX_VM_MEMORY_USAGE,
MAX_RSS_MEMORY_USAGE,
REQUEST_OVERALL_SIZE_COUNT,
RESPONSE_OVERALL_SIZE_COUNT,
METRIC_TYPES_COUNT
} ngx_http_plugin_metric_type_e;
#ifdef __cplusplus
typedef enum class ngx_http_cp_verdict
#else
typedef enum ngx_http_cp_verdict
#endif
{
TRAFFIC_VERDICT_INSPECT,
TRAFFIC_VERDICT_ACCEPT,
TRAFFIC_VERDICT_DROP,
TRAFFIC_VERDICT_INJECT,
TRAFFIC_VERDICT_IRRELEVANT,
TRAFFIC_VERDICT_RECONF,
TRAFFIC_VERDICT_WAIT,
LIMIT_RESPONSE_HEADERS
} ngx_http_cp_verdict_e;
#ifdef __cplusplus
typedef enum class ngx_http_cp_debug_level
#else
typedef enum ngx_http_cp_debug_level
#endif
{
DBG_LEVEL_TRACE,
DBG_LEVEL_DEBUG,
DBG_LEVEL_INFO,
DBG_LEVEL_WARNING,
DBG_LEVEL_ERROR,
#ifndef __cplusplus
DBG_LEVEL_ASSERT,
#endif
DBG_LEVEL_COUNT
} ngx_http_cp_debug_level_e;
#ifdef __cplusplus
typedef enum class ngx_http_meta_data
#else
typedef enum ngx_http_meta_data
#endif
{
HTTP_PROTOCOL_SIZE,
HTTP_PROTOCOL_DATA,
HTTP_METHOD_SIZE,
HTTP_METHOD_DATA,
HOST_NAME_SIZE,
HOST_NAME_DATA,
LISTENING_ADDR_SIZE,
LISTENING_ADDR_DATA,
LISTENING_PORT,
URI_SIZE,
URI_DATA,
CLIENT_ADDR_SIZE,
CLIENT_ADDR_DATA,
CLIENT_PORT,
PARSED_HOST_SIZE,
PARSED_HOST_DATA,
PARSED_URI_SIZE,
PARSED_URI_DATA,
WAF_TAG_SIZE,
WAF_TAG_DATA,
META_DATA_COUNT
} ngx_http_meta_data_e;
#ifdef __cplusplus
typedef enum class ngx_http_header_data
#else
typedef enum ngx_http_header_data
#endif
{
HEADER_KEY_SIZE,
HEADER_KEY_DATA,
HEADER_VAL_SIZE,
HEADER_VAL_DATA,
HEADER_DATA_COUNT
} ngx_http_header_data_e;
typedef enum ngx_http_inspection_mode
{
NON_BLOCKING_THREAD,
BLOCKING_THREAD,
NO_THREAD,
INSPECTION_MODE_COUNT
} ngx_http_inspection_mode_e;
#ifdef __cplusplus
typedef enum class ngx_web_response_type
#else
typedef enum ngx_web_response_type
#endif
{
CUSTOM_WEB_RESPONSE,
CUSTOM_WEB_BLOCK_PAGE_RESPONSE,
RESPONSE_CODE_ONLY,
REDIRECT_WEB_RESPONSE,
NO_WEB_RESPONSE
} ngx_web_response_type_e;
typedef struct __attribute__((__packed__)) ngx_http_cp_inject_data {
ngx_http_cp_inject_pos_t injection_pos;
ngx_http_modification_type_e mod_type;
uint16_t injection_size;
uint8_t is_header;
uint8_t orig_buff_index;
char data[0];
} ngx_http_cp_inject_data_t;
typedef struct __attribute__((__packed__)) ngx_http_cp_web_response_data {
uint8_t web_repsonse_type;
uint8_t uuid_size;
union {
struct __attribute__((__packed__)) ngx_http_cp_custom_web_response_data {
uint16_t response_code;
uint8_t title_size;
uint8_t body_size;
char data[0];
} custom_response_data;
struct __attribute__((__packed__)) ngx_http_cp_redirect_data {
uint8_t unused_dummy;
uint8_t add_event_id;
uint16_t redirect_location_size;
char redirect_location[0];
} redirect_data;
} response_data;
} ngx_http_cp_web_response_data_t;
static_assert(
sizeof(((ngx_http_cp_web_response_data_t*)0)->response_data.custom_response_data) ==
sizeof(((ngx_http_cp_web_response_data_t*)0)->response_data.redirect_data),
"custom_response_data must be equal to redirect_data in size"
);
typedef union __attribute__((__packed__)) ngx_http_cp_modify_data {
ngx_http_cp_inject_data_t inject_data[0];
ngx_http_cp_web_response_data_t web_response_data[0];
} ngx_http_cp_modify_data_t;
typedef struct __attribute__((__packed__)) ngx_http_cp_reply_from_service {
uint16_t verdict;
uint32_t session_id;
uint8_t modification_count;
ngx_http_cp_modify_data_t modify_data[0];
} ngx_http_cp_reply_from_service_t;
typedef struct __attribute__((__packed__)) ngx_http_cp_request_data {
uint16_t data_type;
uint32_t session_id;
unsigned char data[0];
} ngx_http_cp_request_data_t;
typedef struct __attribute__((__packed__)) ngx_http_cp_metric_data {
uint16_t data_type;
#ifdef __cplusplus
uint64_t data[static_cast<int>(ngx_http_plugin_metric_type::METRIC_TYPES_COUNT)];
#else
uint64_t data[METRIC_TYPES_COUNT];
#endif
} ngx_http_cp_metric_data_t;
#endif // __NGINX_ATTACHMENT_COMMON_H__

View File

@@ -17,7 +17,7 @@
#include <stdio.h>
#include "nginx_attachment_common.h"
#include "nano_attachment_common.h"
#ifdef __cplusplus
extern "C" {
@@ -29,7 +29,7 @@ typedef const char * c_str;
int initAttachmentConfig(c_str conf_file);
ngx_http_inspection_mode_e getInspectionMode();
NanoHttpInspectionMode getInspectionMode();
unsigned int getNumOfNginxIpcElements();
unsigned int getKeepAliveIntervalMsec();
unsigned int getDbgLevel();
@@ -61,11 +61,16 @@ unsigned int getMinRetriesForVerdict();
unsigned int getMaxRetriesForVerdict();
unsigned int getReqBodySizeTrigger();
unsigned int getRemoveResServerHeader();
unsigned int getDecompressionPoolSize();
unsigned int getRecompressionPoolSize();
unsigned int getIsBrotliInspectionEnabled();
unsigned int getWaitingForVerdictThreadTimeout();
int isIPAddress(c_str ip_str);
int isSkipSource(c_str ip_str);
unsigned int isPairedAffinityEnabled();
unsigned int isAsyncModeEnabled();
#ifdef __cplusplus
}

View File

@@ -36,6 +36,13 @@ class I_SignalHandler;
namespace Config { enum class Errors; }
std::ostream & operator<<(std::ostream &, const Config::Errors &);
template <typename Rep, typename Period>
std::ostream& operator<<(std::ostream& os, const std::chrono::duration<Rep, Period>& d)
{
os << d.count();
return os;
}
enum class AlertTeam { CORE, WAAP, SDWAN, IOT };
class AlertInfo
@@ -233,9 +240,11 @@ public:
static void setNewDefaultStdout(std::ostream *new_stream);
static void setUnitTestFlag(DebugFlags flag, DebugLevel level);
static void setDebugFlag(DebugFlags flag, DebugLevel level);
static std::string findDebugFilePrefix(const std::string &file_name);
static std::string getExecutableName();
static bool getDebugFlagFromString(const std::string &flag_name, DebugFlags &flag);
private:
template <typename T, typename... Args>

View File

@@ -32,7 +32,8 @@ class IntelligenceComponentV2
Singleton::Consume<I_MainLoop>,
Singleton::Consume<I_AgentDetails>,
Singleton::Consume<I_RestApi>,
Singleton::Consume<I_TimeGet>
Singleton::Consume<I_TimeGet>,
Singleton::Consume<I_Environment>
{
public:
IntelligenceComponentV2();

View File

@@ -8,6 +8,16 @@
#include "messaging/http_response.h"
#include "i_http_client.h"
struct CurlHttpClientConfig {
int timeout_seconds = 30;
int connect_timeout_seconds = 10;
bool verbose_enabled = false;
bool ssl_verify_peer = true;
bool ssl_verify_host = true;
long http_version = CURL_HTTP_VERSION_NONE;
std::string user_agent = "";
};
class CurlHttpClient : public I_HttpClient
{
public:
@@ -17,6 +27,7 @@ public:
void setProxy(const std::string& hosts) override;
void setBasicAuth(const std::string& username, const std::string& password) override;
void authEnabled(bool enabled) override;
void setConfigs(const CurlHttpClientConfig& config);
HTTPResponse
get(
@@ -70,6 +81,8 @@ private:
bool auth_enabled;
std::string username;
std::string password;
CurlHttpClientConfig config;
};
#endif // __CURL_HTTP_CLIENT_H__

View File

@@ -17,6 +17,8 @@
#include "maybe_res.h"
#include <string>
#include <memory>
static const std::string data1_file_name = "data1.a";
static const std::string data4_file_name = "data4.a";
@@ -29,6 +31,7 @@ static const std::string session_token_file_name = "data3.a";
class I_Encryptor
{
public:
// Base64
virtual std::string base64Encode(const std::string &input) = 0;
virtual std::string base64Decode(const std::string &input) = 0;

View File

@@ -47,6 +47,16 @@ public:
const std::string &routine_name,
bool is_primary = false
) = 0;
virtual RoutineID
addBalancedIntervalRoutine(
RoutineType priority,
std::chrono::microseconds interval,
Routine func,
const std::string &routine_name,
std::chrono::microseconds offset = std::chrono::microseconds(0),
bool is_primary = false
) = 0;
virtual RoutineID
addFileRoutine(

View File

@@ -85,8 +85,8 @@ public:
) = 0;
virtual Maybe<void, HTTPResponse> uploadFile(
const std::string & uri,
const std::string & upload_file_path,
const std::string &uri,
const std::string &upload_file_path,
const MessageCategory category = MessageCategory::GENERIC,
MessageMetadata message_metadata = MessageMetadata()
) = 0;
@@ -100,6 +100,8 @@ public:
virtual bool setFogConnection(MessageCategory category = MessageCategory::GENERIC) = 0;
virtual void clearConnections() = 0;
protected:
virtual ~I_Messaging() {}
};

View File

@@ -53,8 +53,13 @@ public:
const std::string &uri,
const std::function<std::string(const std::string &)> &callback
) = 0;
virtual bool addPostCall(
const std::string &uri,
const std::function<Maybe<std::string>(const std::string &)> &callback
) = 0;
virtual uint16_t getListeningPort() const = 0;
virtual uint16_t getStartingPortRange() const = 0;
protected:
~I_RestApi() {}

View File

@@ -36,6 +36,7 @@ public:
virtual void closeSocket(socketFd &socket) = 0;
virtual bool writeData(socketFd socket, const std::vector<char> &data) = 0;
virtual bool writeDataAsync(socketFd socket, const std::vector<char> &data) = 0;
virtual Maybe<std::vector<char>> receiveData(socketFd socket, uint data_size, bool is_blocking = true) = 0;
virtual bool isDataAvailable(socketFd socket) = 0;

View File

@@ -80,6 +80,15 @@ I_Messaging::sendSyncMessage(
);
if (!response_data.ok()) return response_data.passErr();
if (response_data.unpack().getHTTPStatusCode() != HTTPStatusCode::HTTP_OK) {
return genError(
HTTPResponse(
response_data.unpack().getHTTPStatusCode(),
response_data.unpack().getBody()
)
);
}
auto res_obj = req_obj.loadJson(response_data.unpack().getBody());
if (!res_obj) {
return genError(
@@ -114,6 +123,7 @@ I_Messaging::sendSyncMessageWithoutResponse(
category,
message_metadata
);
if (!response_data.ok()) {
dbgWarning(D_MESSAGING)
<< "Received error from server. Status code: "
@@ -122,6 +132,16 @@ I_Messaging::sendSyncMessageWithoutResponse(
<< response_data.getErr().getBody();
return false;
}
if (response_data.unpack().getHTTPStatusCode() != HTTPStatusCode::HTTP_OK) {
dbgWarning(D_MESSAGING)
<< "Unexpected status code from server. Status code: "
<< int(response_data.unpack().getHTTPStatusCode())
<< ", response body: "
<< response_data.unpack().getBody();
return false;
}
return true;
}

View File

@@ -15,6 +15,11 @@ public:
uint (RoutineType, std::chrono::microseconds, Routine, const std::string &, bool)
);
MOCK_METHOD6(
addBalancedIntervalRoutine,
uint (RoutineType, std::chrono::microseconds, Routine, const std::string &, std::chrono::microseconds, bool)
);
MOCK_METHOD5(
addFileRoutine,
uint (RoutineType, int, Routine, const std::string &, bool)

View File

@@ -54,6 +54,7 @@ public:
MOCK_METHOD4(setFogConnection, bool(const string &, uint16_t, bool, MessageCategory));
MOCK_METHOD0(setFogConnection, bool());
MOCK_METHOD1(setFogConnection, bool(MessageCategory));
MOCK_METHOD0(clearConnections, void());
};
static std::ostream &

View File

@@ -9,7 +9,12 @@ class MockRestApi : public Singleton::Provide<I_RestApi>::From<MockProvider<I_Re
{
public:
MOCK_CONST_METHOD0(getListeningPort, uint16_t());
MOCK_CONST_METHOD0(getStartingPortRange, uint16_t());
MOCK_METHOD2(addGetCall, bool(const std::string &, const std::function<std::string()> &));
MOCK_METHOD2(
addPostCall,
bool(const std::string &, const std::function<Maybe<std::string>(const std::string &)> &)
);
MOCK_METHOD2(
addWildcardGetCall,
bool(const std::string &, const std::function<std::string(const std::string &)> &)

View File

@@ -15,6 +15,7 @@ public:
MOCK_METHOD1(closeSocket, void (socketFd &));
MOCK_METHOD1(isDataAvailable, bool (socketFd));
MOCK_METHOD2(writeData, bool (socketFd, const std::vector<char> &));
MOCK_METHOD2(writeDataAsync, bool (socketFd, const std::vector<char> &));
MOCK_METHOD3(receiveData, Maybe<std::vector<char>> (socketFd, uint, bool is_blocking));
};

View File

@@ -55,7 +55,8 @@ class AgentDetails
Singleton::Consume<I_Encryptor>,
Singleton::Consume<I_ShellCmd>,
Singleton::Consume<I_Environment>,
Singleton::Consume<I_MainLoop>
Singleton::Consume<I_MainLoop>,
Singleton::Consume<I_Messaging>
{
public:
AgentDetails() : Component("AgentDetails") {}
@@ -80,7 +81,10 @@ public:
void setFogDomain(const std::string &_fog_domain) { fog_domain = _fog_domain; }
void setFogPort(const uint16_t _fog_port) { fog_port = _fog_port; }
void setProxy(const std::string &_proxy) { proxy = _proxy; }
void setProxy(const std::string &_proxy) {
previous_proxy = proxy;
proxy = _proxy;
}
void setAgentId(const std::string &_agent_id) { agent_id = _agent_id; }
void setProfileId(const std::string &_profile_id) { profile_id = _profile_id; }
void setTenantId(const std::string &_tenant_id) { tenant_id = _tenant_id; }
@@ -121,6 +125,7 @@ private:
OrchestrationMode orchestration_mode = OrchestrationMode::ONLINE;
std::string server = "Unknown";
bool is_proxy_configured_via_settings = false;
std::string previous_proxy = "";
std::map<ProxyProtocol, ProxyData> proxies;
static const std::map<std::string, I_AgentDetails::MachineType> machineTypes;

View File

@@ -18,10 +18,20 @@
#error "config_impl.h should not be included directly"
#endif // __CONFIG_H__
#include <fstream>
#include <unistd.h>
#include <atomic>
#include <cstdlib>
#include <algorithm>
#include <cctype>
namespace Config
{
class MockConfigProvider : Singleton::Provide<I_Config> {};
class MockConfigProvider
:
public Singleton::Provide<I_Config>,
public Singleton::Consume<I_Environment>
{};
template<typename String>
std::size_t
@@ -62,22 +72,290 @@ getVector(const Strings & ... strs)
return res;
}
// Utility function to create a separated string from a vector
inline std::string
makeSeparatedStr(const std::vector<std::string> &vec, const std::string &separator = ", ")
{
if (vec.empty()) return "";
if (vec.size() == 1) return vec[0];
std::string result = vec[0];
for (size_t i = 1; i < vec.size(); ++i) {
result += separator + vec[i];
}
return result;
}
} // namespace Config
// Efficient service type checking for caching
inline bool isHttpTransactionHandler() {
static bool is_http_transaction_handler = false;
static bool service_checked = false;
if (!service_checked) {
auto i_environment = Singleton::Consume<I_Environment>::by<Config::MockConfigProvider>();
if (i_environment != nullptr) {
auto maybe_service_name = i_environment->get<std::string>("Service Name");
if (maybe_service_name.ok()) {
is_http_transaction_handler = (maybe_service_name.unpack() == "HTTP Transaction Handler");
service_checked = true;
}
}
}
return is_http_transaction_handler;
}
// Context registration for cache-enabled configurations
template <typename ConfigurationType>
struct ContextRegistration {
static std::map<std::vector<std::string>, std::string> path_to_context_map;
static void registerContext(const std::vector<std::string>& paths, const std::string& context_type) {
path_to_context_map[paths] = context_type;
}
static std::string getContext(const std::vector<std::string>& paths) {
auto it = path_to_context_map.find(paths);
return (it != path_to_context_map.end()) ? it->second : "";
}
};
// Static member definition
template <typename ConfigurationType>
std::map<std::vector<std::string>, std::string> ContextRegistration<ConfigurationType>::path_to_context_map;
template <typename ConfigurationType>
struct ConfigCacheKey {
std::vector<std::string> paths;
std::string context_value;
std::string policy_load_id;
bool operator==(const ConfigCacheKey &other) const
{
return paths == other.paths &&
context_value == other.context_value &&
policy_load_id == other.policy_load_id;
}
bool match(
const std::vector<std::string>& other_paths,
const std::string& other_context_value,
const std::string& other_policy_load_id
) const
{
return paths == other_paths &&
context_value == other_context_value &&
policy_load_id == other_policy_load_id;
}
};
template <typename ConfigurationType>
struct ConfigCacheEntry {
ConfigCacheKey<ConfigurationType> key;
Maybe<ConfigurationType, Config::Errors> value;
ConfigCacheEntry()
: key(), value(genError(Config::Errors::MISSING_TAG)) {}
bool isValid() const { return !key.context_value.empty(); }
void invalidate()
{
key.context_value.clear();
value = genError(Config::Errors::MISSING_TAG);
}
};
template <typename ConfigurationType, typename ... Strings>
const Maybe<ConfigurationType, Config::Errors> &
getConfiguration(const Strings & ... strs)
{
auto i_config = Singleton::Consume<Config::I_Config>::from<Config::MockConfigProvider>();
return i_config->getConfiguration(Config::getVector(strs ...)).template getValue<ConfigurationType>();
const auto &paths = Config::getVector(strs ...);
return i_config->getConfiguration(paths).template getValue<ConfigurationType>();
};
// LCOV_EXCL_START - Helper function to isolate static variables from lcov function data mismatch
// Helper function to get cache array - isolates static variables
template <typename ConfigurationType>
ConfigCacheEntry<ConfigurationType>* getCacheArray() {
static ConfigCacheEntry<ConfigurationType> config_cache[3];
return config_cache;
}
// Cache statistics tracking
struct CacheStats {
static std::atomic<uint64_t> hits;
static std::atomic<uint64_t> misses;
static bool tracking_enabled;
static void recordHit() {
if (tracking_enabled) hits.fetch_add(1, std::memory_order_relaxed);
}
static void recordMiss() {
if (tracking_enabled) misses.fetch_add(1, std::memory_order_relaxed);
}
static uint64_t getHits() { return hits.load(std::memory_order_relaxed); }
static uint64_t getMisses() { return misses.load(std::memory_order_relaxed); }
static void reset() {
hits.store(0, std::memory_order_relaxed);
misses.store(0, std::memory_order_relaxed);
}
static void enableTracking() { tracking_enabled = true; }
static void disableTracking() { tracking_enabled = false; }
static bool isTrackingEnabled() { return tracking_enabled; }
};
// Initialize cache tracking from environment variable
inline void initializeCacheTracking() {
const char* enable_tracking = std::getenv("ENABLE_CONFIG_CACHE_TRACKING");
if (enable_tracking != nullptr) {
// Check for various "true" values
std::string tracking_value(enable_tracking);
std::transform(tracking_value.begin(), tracking_value.end(), tracking_value.begin(), ::tolower);
if (tracking_value == "true") {
CacheStats::enableTracking();
CacheStats::reset(); // Start with clean counters when enabling tracking
}
}
}
// LCOV_EXCL_STOP
template <typename ConfigurationType, typename ... Strings>
const Maybe<ConfigurationType, Config::Errors> &
getConfigurationWithCache(const Strings & ... strs)
{
// Step 1: Check if current service is HTTP Transaction Handler
if (!isHttpTransactionHandler()) {
return getConfiguration<ConfigurationType>(strs...);
}
// Step 2: Fast checks - get basic info
auto i_config = Singleton::Consume<Config::I_Config>::from<Config::MockConfigProvider>();
const auto &paths = Config::getVector(strs ...);
size_t idx = paths.size();
// Step 3: Quick validation checks (fastest)
bool idx_valid = (idx >= 1 && idx <= 3); // max_cache_key_size = 3
if (!idx_valid || !i_config->isConfigCacheEnabled()) {
return getConfiguration<ConfigurationType>(strs...);
}
// Step 4: Single map lookup - get context if registered, empty string if not
std::string context_type = ContextRegistration<ConfigurationType>::getContext(paths);
if (context_type.empty()) {
return getConfiguration<ConfigurationType>(strs...);
}
// Step 5: Now we know it's registered - get environment value using the context
std::string context_value;
auto i_environment = Singleton::Consume<I_Environment>::by<Config::MockConfigProvider>();
if (i_environment != nullptr) {
auto maybe_context_value = i_environment->get<std::string>(
(context_type == "triggerId") ? "triggers" : "asset_id");
if (maybe_context_value.ok()) {
context_value = maybe_context_value.unpack();
}
}
// Step 6: Final cache enablement check
if (context_value.empty()) {
return getConfiguration<ConfigurationType>(strs...);
}
// Step 7: Cache operations
auto* config_cache = getCacheArray<ConfigurationType>();
std::string policy_load_id = i_config->getPolicyLoadId();
// Check cache first
ConfigCacheEntry<ConfigurationType> &entry = config_cache[idx - 1];
if (entry.key.match(paths, context_value, policy_load_id)) {
// Cache hit
CacheStats::recordHit();
return entry.value;
}
// Cache miss - get configuration and update cache
CacheStats::recordMiss();
const auto &maybe_val = i_config->getConfiguration(paths).template getValue<ConfigurationType>();
// Update cache
config_cache[idx - 1].key = ConfigCacheKey<ConfigurationType>{paths, context_value, policy_load_id};
config_cache[idx - 1].value = maybe_val;
return maybe_val;
}
template <typename ConfigurationType, typename ... Strings>
const Maybe<ConfigurationType, Config::Errors> &
setConfigurationInCache(const Strings & ... strs)
{
// Step 1: Check if current service is HTTP Transaction Handler
if (!isHttpTransactionHandler()) {
return getConfiguration<ConfigurationType>(strs...);
}
// Step 2: Fast checks - get basic info
auto i_config = Singleton::Consume<Config::I_Config>::from<Config::MockConfigProvider>();
const auto &paths = Config::getVector(strs ...);
size_t idx = paths.size();
// Step 3: Quick validation checks (fastest)
bool idx_valid = (idx >= 1 && idx <= 3); // max_cache_key_size = 3
if (!idx_valid || !i_config->isConfigCacheEnabled()) {
// Early exit - no caching possible, just fetch and return
return getConfiguration<ConfigurationType>(strs...);
}
// Step 4: Single map lookup - get context if registered, empty string if not
std::string context_type = ContextRegistration<ConfigurationType>::getContext(paths);
if (context_type.empty()) {
// Not registered for caching - just fetch and return
return getConfiguration<ConfigurationType>(strs...);
}
// Step 5: Now we know it's registered - get environment value using the context
std::string context_value;
auto i_environment = Singleton::Consume<I_Environment>::by<Config::MockConfigProvider>();
if (i_environment != nullptr) {
auto maybe_context_value = i_environment->get<std::string>(
(context_type == "triggerId") ? "triggers" : "asset_id");
if (maybe_context_value.ok()) {
context_value = maybe_context_value.unpack();
}
}
// Step 6: Final cache enablement check
if (context_value.empty()) {
// No valid context value - just fetch and return
return getConfiguration<ConfigurationType>(strs...);
}
// Step 7: Always fetch configuration and update cache (no cache check first)
auto* config_cache = getCacheArray<ConfigurationType>();
std::string policy_load_id = i_config->getPolicyLoadId();
// Fetch configuration directly - no cache hit check
const auto &maybe_val = i_config->getConfiguration(paths).template getValue<ConfigurationType>();
// Update cache with fresh value
config_cache[idx - 1].key = ConfigCacheKey<ConfigurationType>{paths, context_value, policy_load_id};
config_cache[idx - 1].value = maybe_val;
return maybe_val;
}
template <typename ConfigurationType, typename ... Strings>
const ConfigurationType &
getConfigurationWithDefault(const ConfigurationType &deafult_val, const Strings & ... tags)
{
if (!Singleton::exists<Config::I_Config>()) return deafult_val;
auto &res = getConfiguration<ConfigurationType>(tags ...);
auto &res = getConfigurationWithCache<ConfigurationType>(tags ...);
return res.ok() ? res.unpack() : deafult_val;
}
@@ -235,6 +513,18 @@ registerExpectedConfiguration(const Strings & ... tags)
i_config->registerExpectedConfiguration(std::move(conf));
}
template <typename ConfigurationType, typename ... Strings>
void
registerExpectedConfigurationWithCache(const std::string& context_type, const Strings & ... tags)
{
// Register with the original system using existing function
registerExpectedConfiguration<ConfigurationType>(tags...);
// Register the context mapping
const auto &paths = Config::getVector(tags ...);
ContextRegistration<ConfigurationType>::registerContext(paths, context_type);
}
template <typename ResourceType, typename ... Strings>
void
registerExpectedResource(const Strings & ... tags)
@@ -254,3 +544,4 @@ registerExpectedSetting(const Strings & ... tags)
}
#endif // __CONFIG_IMPL_H__

View File

@@ -106,6 +106,18 @@ public:
virtual void clearOldTenants() = 0;
virtual bool isConfigCacheEnabled() const = 0;
virtual void resetConfigCache() = 0;
virtual const std::string & getPolicyLoadId() const = 0;
// Cache statistics access functions
virtual uint64_t getCacheHits() const = 0;
virtual uint64_t getCacheMisses() const = 0;
virtual void resetCacheStats() = 0;
virtual void enableCacheTracking() = 0;
virtual void disableCacheTracking() = 0;
virtual bool isCacheTrackingEnabled() const = 0;
protected:
virtual ~I_Config() {}
};

View File

@@ -76,15 +76,27 @@ public:
template <typename T, typename ... Attr>
void registerValue(const std::string &name, const T &value, Attr ... attr);
template <typename T, typename ... Attr>
void registerQuickAccessValue(const std::string &name, const T &value, Attr ... attr);
template <typename ... Params>
void registerValue(MetaDataType name, Params ... params);
template <typename ... Params>
void registerQuickAccessValue(MetaDataType name, Params ... params);
template <typename T, typename ... Attr>
void registerFunc(const std::string &name, std::function<T()> &&func, Attr ... attr);
template <typename T, typename ... Attr>
void registerQuickAccessFunc(const std::string &name, std::function<T()> &&func, Attr ... attr);
template <typename T, typename ... Attr>
void registerFunc(const std::string &name, std::function<Return<T>()> &&func, Attr ... attr);
template <typename T, typename ... Attr>
void registerQuickAccessFunc(const std::string &name, std::function<Return<T>()> &&func, Attr ... attr);
template <typename T>
void unregisterKey(const std::string &name);
@@ -105,6 +117,7 @@ public:
private:
std::map<Key, std::unique_ptr<AbstractValue>> values;
std::map<Key, std::unique_ptr<AbstractValue>> quick_access_values; // Common values for all contexts
};
class ScopedContext;

View File

@@ -22,6 +22,7 @@ DEFINE_FLAG(D_INFRA, D_ALL)
DEFINE_FLAG(D_COMPRESSION, D_INFRA)
DEFINE_FLAG(D_SHMEM, D_INFRA)
DEFINE_FLAG(D_CONFIG, D_INFRA)
DEFINE_FLAG(D_CONFIG_CACHE, D_INFRA)
DEFINE_FLAG(D_ENVIRONMENT, D_INFRA)
DEFINE_FLAG(D_INTELLIGENCE, D_INFRA)
DEFINE_FLAG(D_RULEBASE_CONFIG, D_INFRA)
@@ -74,6 +75,7 @@ DEFINE_FLAG(D_COMPONENT, D_ALL)
DEFINE_FLAG(D_WAAP_AUTOMATION, D_WAAP)
DEFINE_FLAG(D_WAAP_REGEX, D_WAAP)
DEFINE_FLAG(D_WAAP_SAMPLE_SCAN, D_WAAP)
DEFINE_FLAG(D_WAAP_HYPERSCAN, D_WAAP)
DEFINE_FLAG(D_WAAP_ASSET_STATE, D_WAAP)
DEFINE_FLAG(D_WAAP_CONFIDENCE_CALCULATOR, D_WAAP)
DEFINE_FLAG(D_WAAP_SERIALIZE, D_WAAP)
@@ -89,6 +91,7 @@ DEFINE_FLAG(D_COMPONENT, D_ALL)
DEFINE_FLAG(D_WAAP_STREAMING_PARSING, D_WAAP)
DEFINE_FLAG(D_WAAP_HEADERS, D_WAAP)
DEFINE_FLAG(D_WAAP_OVERRIDE, D_WAAP)
DEFINE_FLAG(D_WAAP_LEARN, D_WAAP)
DEFINE_FLAG(D_WAAP_SAMPLE_HANDLING, D_WAAP_GLOBAL)
DEFINE_FLAG(D_WAAP_SAMPLE_PREPROCESS, D_WAAP_SAMPLE_HANDLING)
@@ -119,8 +122,11 @@ DEFINE_FLAG(D_COMPONENT, D_ALL)
DEFINE_FLAG(D_IPS, D_COMPONENT)
DEFINE_FLAG(D_FILE_UPLOAD, D_COMPONENT)
DEFINE_FLAG(D_RATE_LIMIT, D_COMPONENT)
DEFINE_FLAG(D_AUTH_ENFORCE, D_COMPONENT)
DEFINE_FLAG(D_ANOMALY_DETECTION, D_COMPONENT)
DEFINE_FLAG(D_ROLLBACK_TESTING, D_COMPONENT)
DEFINE_FLAG(D_NGINX_MANAGER, D_COMPONENT)
DEFINE_FLAG(D_BROWSER_AGENT, D_COMPONENT)
DEFINE_FLAG(D_PARSER, D_COMPONENT)
DEFINE_FLAG(D_WS, D_COMPONENT)
@@ -168,6 +174,7 @@ DEFINE_FLAG(D_COMPONENT, D_ALL)
DEFINE_FLAG(D_NGINX_MESSAGE_READER, D_REVERSE_PROXY)
DEFINE_FLAG(D_ERROR_REPORTER, D_REVERSE_PROXY)
DEFINE_FLAG(D_UPSTREAM_KEEPALIVE, D_REVERSE_PROXY)
DEFINE_FLAG(D_UPSTREAM_HEALTH_CHECKER, D_REVERSE_PROXY)
DEFINE_FLAG(D_FORWARD_PROXY, D_REVERSE_PROXY)
DEFINE_FLAG(D_IDA, D_COMPONENT)
@@ -204,6 +211,7 @@ DEFINE_FLAG(D_COMPONENT, D_ALL)
DEFINE_FLAG(D_HORIZON_TELEMETRY, D_COMPONENT)
DEFINE_FLAG(D_PROMETHEUS, D_COMPONENT)
DEFINE_FLAG(D_AIGUARD, D_COMPONENT)
DEFINE_FLAG(D_ERM, D_COMPONENT)
DEFINE_FLAG(D_FLOW, D_ALL)
DEFINE_FLAG(D_DROP, D_FLOW)

View File

@@ -109,6 +109,14 @@ Context::registerValue(const std::string &name, const T &value, Attr ... attr)
registerFunc(name, std::move(new_func), attr ...);
}
template <typename T, typename ... Attr>
void
Context::registerQuickAccessValue(const std::string &name, const T &value, Attr ... attr)
{
std::function<Return<T>()> new_func = [value] () { return Return<T>(value); };
registerQuickAccessFunc(name, std::move(new_func), attr ...);
}
template <typename ... Params>
void
Context::registerValue(MetaDataType name, Params ... params)
@@ -116,6 +124,13 @@ Context::registerValue(MetaDataType name, Params ... params)
return registerValue(convertToString(name), params ...);
}
template <typename ... Params>
void
Context::registerQuickAccessValue(MetaDataType name, Params ... params)
{
return registerQuickAccessValue(convertToString(name), params ...);
}
template <typename T, typename ... Attr>
void
Context::registerFunc(const std::string &name, std::function<T()> &&func, Attr ... attr)
@@ -124,6 +139,14 @@ Context::registerFunc(const std::string &name, std::function<T()> &&func, Attr .
registerFunc(name, std::move(new_func), attr ...);
}
template <typename T, typename ... Attr>
void
Context::registerQuickAccessFunc(const std::string &name, std::function<T()> &&func, Attr ... attr)
{
std::function<Return<T>()> new_func = [func] () { return Return<T>(func()); };
registerQuickAccessFunc(name, std::move(new_func), attr ...);
}
template <typename T, typename ... Attr>
void
Context::registerFunc(const std::string &name, std::function<Return<T>()> &&func, Attr ... attr)
@@ -133,6 +156,15 @@ Context::registerFunc(const std::string &name, std::function<Return<T>()> &&func
values[key] = std::make_unique<Value<T>>(std::move(func));
}
template <typename T, typename ... Attr>
void
Context::registerQuickAccessFunc(const std::string &name, std::function<Return<T>()> &&func, Attr ... attr)
{
dbgTrace(D_ENVIRONMENT) << "Registering key : " << name;
Key key(name, typeid(T), EnvKeyAttr::ParamAttr(attr ...));
quick_access_values[key] = std::make_unique<Value<T>>(std::move(func));
}
template <typename T>
void
Context::unregisterKey(const std::string &name)
@@ -140,6 +172,7 @@ Context::unregisterKey(const std::string &name)
dbgTrace(D_ENVIRONMENT) << "Unregistering key : " << name;
Key key(name, typeid(T));
values.erase(key);
quick_access_values.erase(key);
}
template <typename T>
@@ -154,8 +187,12 @@ Context::Return<T>
Context::get(const std::string &name) const
{
Key key(name, typeid(T));
auto iter = values.find(key);
if (iter == values.end()) return genError(Error::NO_VALUE);
auto iter = quick_access_values.find(key);
if (iter == quick_access_values.end()) {
// If not found in quick access, search in the main values map
iter = values.find(key);
if (iter == values.end()) return genError(Error::NO_VALUE);
}
Value<T> *val = dynamic_cast<Value<T> *>(iter->second.get());
return val->get();
}

View File

@@ -188,6 +188,11 @@ public:
bool matches(const Invalidation &other) const;
void serialize(cereal::JSONInputArchive &ar);
Invalidation & addHeader(const std::string &key, const std::string &value);
Maybe<std::string> getHeader(const std::string &key) const;
const std::map<std::string, std::string> & getHeaders() const;
bool hasHeader(const std::string &key) const;
private:
bool attr_matches(const std::vector<StrAttributes> &current, const std::vector<StrAttributes> &other) const;
bool attr_matches(const std::vector<IpAttributes> &current, const std::vector<IpAttributes> &other) const;
@@ -200,6 +205,7 @@ private:
Maybe<InvalidationType> invalidation_type;
Maybe<uint> listening_id;
Maybe<std::string> registration_id;
std::map<std::string, std::string> headers;
};
} // namespace Intelligence

View File

@@ -41,6 +41,7 @@ struct PrometheusData
{
try {
ar(cereal::make_nvp("metric_name", name));
ar(cereal::make_nvp("unique_name", unique_name));
ar(cereal::make_nvp("metric_type", type));
ar(cereal::make_nvp("metric_description", description));
ar(cereal::make_nvp("labels", label));
@@ -51,6 +52,7 @@ struct PrometheusData
}
std::string name;
std::string unique_name;
std::string type;
std::string description;
std::string label;

View File

@@ -66,13 +66,15 @@ enum class Tags {
CROWDSEC,
PLAYGROUND,
API_DISCOVERY,
LB_HEALTH_STATUS,
NGINX_PROXY_MANAGER,
WEB_SERVER_APISIX,
DEPLOYMENT_DOCKER,
WEB_SERVER_SWAG,
WEB_SERVER_NGINX_UNIFIED,
AIGUARD,
CENTRAL_NGINX_MANAGER,
BROWSER_AGENT,
COUNT
};
@@ -162,7 +164,9 @@ enum class IssuingEngine {
IDA_SAML_IDN_BLADE_REGISTRATION,
IDA_SAML_IDN_CLIENT_IP_NOTIFY,
HORIZON_TELEMETRY_METRICS,
API_DISCOVERY
API_DISCOVERY,
LB_HEALTH_STATUS,
BROWSER_AGENT
};
} // namespace ReportIS

View File

@@ -180,10 +180,31 @@ public:
/// @brief Performs the REST call using the input stream.
/// @param in The input stream containing the JSON data for the REST call.
/// @param headers The HTTP headers from the request.
/// @return A Maybe object containing the result of the REST call (either the JSON data or an error message).
Maybe<std::string> performRestCall(std::istream &in);
Maybe<std::string> performRestCall(std::istream &in, const std::map<std::string, std::string> &headers);
/// @brief Performs the REST call using the input stream (backwards compatibility overload).
/// @param in The input stream containing the JSON data for the REST call.
/// @return A Maybe object containing the result of the REST call (either the JSON data or an error message).
Maybe<std::string> performRestCall(std::istream &in) {
return performRestCall(in, std::map<std::string, std::string>());
}
/// @brief Indicates whether this handler wants to receive HTTP headers.
/// @return True if the handler wants headers, false otherwise. Default is false.
virtual bool wantsHeaders() const { return false; }
/// @brief Sets the HTTP headers for this handler (used by bulk handlers to propagate headers).
/// @param headers The HTTP headers to set.
void setRequestHeaders(const std::map<std::string, std::string> &headers) {
request_headers = headers;
}
protected:
/// @brief HTTP headers from the current request (only populated if wantsHeaders() returns true).
std::map<std::string, std::string> request_headers;
/// @brief Determines if the direction is for input.
/// @param dir The direction of the communication.
/// @return True if the direction is for input, false otherwise.

View File

@@ -43,6 +43,14 @@ copyFile(
mode_t permission = (S_IWUSR | S_IRUSR)
);
bool
createFileWithContent(
const std::string &dest,
const std::string &content,
bool overide_if_exists,
mode_t permission = (S_IWUSR | S_IRUSR)
);
bool deleteFile(const std::string &path);
std::string convertToHumanReadable(uint64_t size_in_bytes);
std::string getFileName(const std::string &path);
@@ -86,6 +94,7 @@ std::string removeTrailingWhitespaces(std::string str);
std::string removeLeadingWhitespaces(std::string str);
std::string trim(std::string str);
std::string toLower(std::string str);
bool startsWith(const std::string &str, const std::string &prefix);
} // namespace Strings

View File

@@ -17,9 +17,11 @@
#include "cache.h"
#include "config.h"
#include "i_environment.h"
#include "intelligence_invalidation.h"
#include "intelligence_is_v2/intelligence_response.h"
#include "intelligence_request.h"
#include "intell_registration_event.h"
using namespace std;
using namespace chrono;
@@ -37,6 +39,8 @@ static const string queries_uri = "/api/v2/intelligence/assets/queries";
static const string fog_health_uri = "/access-manager/health/live";
static const string intelligence_health_uri = "/show-health";
static const string time_range_invalidation_uri = "/api/v2/intelligence/invalidation/get";
static const uint default_registration_interval_seconds = 720; // 12 minutes
static const uint min_registration_interval_seconds = 30;
class I_InvalidationCallBack
{
@@ -100,7 +104,7 @@ public:
res << "\"name\": \"" << (agent_id.empty() ? details->getAgentId() : agent_id) << "\", ";
auto rest = Singleton::Consume<I_RestApi>::by<IntelligenceComponentV2>();
res << "\"url\": \"http://127.0.0.1:" << rest->getListeningPort() <<"/set-new-invalidation\", ";
res << "\"capabilities\": { \"getBulkCallback\": " << "true" << " }, ";
res << "\"capabilities\": { \"getBulkCallback\": true, \"returnRegistrationTTL\": true }, ";
res << "\"dataMap\": [";
res << stream.str();
res << " ] }";
@@ -200,11 +204,17 @@ private:
class SingleReceivedInvalidation : public ServerRest
{
public:
void
doCall() override
{
Invalidation invalidation(class_name);
for (const auto& header : request_headers) {
dbgTrace(D_INTELLIGENCE) << "Adding header: " << header.first << " = " << header.second;
invalidation.addHeader(header.first, header.second);
}
if (category.isActive()) invalidation.setClassifier(ClassifierType::CATEGORY, category.get());
if (family.isActive()) invalidation.setClassifier(ClassifierType::FAMILY, family.get());
if (group.isActive()) invalidation.setClassifier(ClassifierType::GROUP, group.get());
@@ -268,10 +278,10 @@ private:
C2S_OPTIONAL_PARAM(string, invalidationType);
};
class ReceiveInvalidation : public ServerRest
{
public:
bool wantsHeaders() const override { return true; }
void
doCall() override
@@ -282,6 +292,8 @@ public:
: "error in format, expected bulk invalidations, not single");
for (SingleReceivedInvalidation &r : bulkArray.get()) {
// Copy headers from the bulk request to each individual invalidation
r.setRequestHeaders(request_headers);
r.doCall();
}
return;
@@ -360,7 +372,7 @@ public:
mainloop->addRecurringRoutine(
I_MainLoop::RoutineType::System,
chrono::minutes(12),
chrono::seconds(getRegistrationIntervalSec()),
[this] () { sendRecurringInvalidationRegistration(); },
"Sending intelligence invalidation"
);
@@ -467,6 +479,27 @@ public:
}
private:
uint
getRegistrationIntervalSec() const
{
uint interval_in_seconds = getConfigurationWithDefault(
default_registration_interval_seconds,
"intelligence",
"registration interval seconds"
);
if (interval_in_seconds < min_registration_interval_seconds) {
dbgWarning(D_INTELLIGENCE)
<< "Registration interval is too low, setting to minimum: "
<< min_registration_interval_seconds;
interval_in_seconds = min_registration_interval_seconds;
}
dbgInfo(D_INTELLIGENCE)
<< "Using registration interval: "
<< interval_in_seconds
<< " seconds";
return interval_in_seconds;
}
bool
hasLocalIntelligenceSupport() const
{
@@ -585,6 +618,7 @@ private:
sendIntelligenceRequestImpl(const Invalidation &invalidation, const MessageMetadata &local_req_md) const
{
dbgFlow(D_INTELLIGENCE) << "Sending intelligence invalidation";
auto res = message->sendSyncMessageWithoutResponse(
HTTPMethod::POST,
invalidation_uri,
@@ -634,15 +668,29 @@ private:
) const
{
dbgFlow(D_INTELLIGENCE) << "Sending intelligence invalidation registration";
auto res = message->sendSyncMessageWithoutResponse(
Maybe<string> registration_body = registration.genJson();
if (!registration_body.ok()) {
return genError("Could not generate intelligence invalidation registration body. Error: "
+ registration_body.getErr());
}
auto res = message->sendSyncMessage(
HTTPMethod::POST,
registration_uri,
registration,
registration_body.unpack(),
MessageCategory::INTELLIGENCE,
registration_req_md
);
if (res) return Response();
dbgWarning(D_INTELLIGENCE) << "Could not send intelligence invalidation registration.";
if (res.ok()){
string registration_response = res.unpack().getBody();
dbgInfo(D_INTELLIGENCE)
<< "Intelligence invalidation registration sent successfully";
IntelligenceRegistrationEvent(true, registration_response).notify();
return Response();
}
IntelligenceRegistrationEvent(false).notify();
dbgWarning(D_INTELLIGENCE) << "Could not send intelligence invalidation registration. Error: "
<< res.getErr().toString();
return genError("Could not send intelligence invalidation registration");
}
@@ -719,6 +767,13 @@ private:
auto rest = Singleton::Consume<I_RestApi>::by<IntelligenceComponentV2>();
auto agent = (agent_id.empty() ? details->getAgentId() : agent_id) + ":" + to_string(rest->getListeningPort());
headers["X-Source-Id"] = agent;
auto env = Singleton::Consume<I_Environment>::by<IntelligenceComponentV2>();
auto exec_name = env->get<string>("Base Executable Name");
if (exec_name.ok() && *exec_name != "") {
headers["X-Calling-Service"] = *exec_name;
} else {
dbgTrace(D_INTELLIGENCE) << "getHTTPHeaders: X-Calling-Service NOT added - exec_name not available";
}
return headers;
}
@@ -762,6 +817,7 @@ IntelligenceComponentV2::preload()
{
registerExpectedConfiguration<uint>("intelligence", "maximum request overall time");
registerExpectedConfiguration<uint>("intelligence", "maximum request lap time");
registerExpectedConfiguration<uint>("intelligence", "registration interval seconds");
registerExpectedConfiguration<bool>("intelligence", "support Invalidation");
registerExpectedSetting<string>("intelligence", "local intelligence server ip");
registerExpectedSetting<uint>("intelligence", primary_port_setting);

View File

@@ -344,6 +344,7 @@ public:
ON_CALL(mock_details, getFogDomain()).WillByDefault(Return(Maybe<string>(string("fog_domain.com"))));
ON_CALL(mock_details, getFogPort()).WillByDefault(Return(Maybe<uint16_t>(443)));
env.preload();
conf.preload();
intelligence.preload();
intelligence.init();
@@ -389,6 +390,8 @@ TEST_F(IntelligenceInvalidation, sending_incomplete_invalidation)
TEST_F(IntelligenceInvalidation, sending_public_invalidation)
{
Singleton::Consume<I_Environment>::from(env)->registerValue<string>("Base Executable Name", "idn");
auto invalidation = Invalidation("aaa")
.addMainAttr(main_attr)
.addAttr(attr)
@@ -422,10 +425,16 @@ TEST_F(IntelligenceInvalidation, sending_public_invalidation)
" } ] }";
EXPECT_EQ(invalidation_json, expected_json);
EXPECT_FALSE(md.getConnectionFlags().isSet(MessageConnectionConfig::UNSECURE_CONN));
auto headers = md.getHeaders();
EXPECT_NE(headers.find("X-Calling-Service"), headers.end()) << "X-Calling-Service header should be present";
EXPECT_EQ(headers.at("X-Calling-Service"), "idn");
}
TEST_F(IntelligenceInvalidation, multiple_assets_invalidation)
{
Singleton::Consume<I_Environment>::from(env)->registerValue<string>("Base Executable Name", "orchestration");
auto main_attr_2 = StrAttributes()
.addStringAttr("attr2", "22")
.addStringSetAttr("attr3", {"33", "44"});
@@ -440,11 +449,13 @@ TEST_F(IntelligenceInvalidation, multiple_assets_invalidation)
.setObjectType(Intelligence::ObjectType::ASSET);
string invalidation_json;
MessageMetadata md;
EXPECT_CALL(
messaging_mock,
sendSyncMessage(HTTPMethod::POST, invalidation_uri, _, MessageCategory::INTELLIGENCE, _)
).WillOnce(DoAll(
SaveArg<2>(&invalidation_json),
SaveArg<4>(&md),
Return(HTTPResponse(HTTPStatusCode::HTTP_OK, ""))
));
@@ -461,10 +472,16 @@ TEST_F(IntelligenceInvalidation, multiple_assets_invalidation)
"\"attributes\": [ { \"ipv4Addresses\": [ \"1.1.1.1\" ] } ]"
" } ] }";
EXPECT_EQ(invalidation_json, expected_json);
auto headers = md.getHeaders();
EXPECT_NE(headers.find("X-Calling-Service"), headers.end());
EXPECT_EQ(headers.at("X-Calling-Service"), "orchestration");
}
TEST_F(IntelligenceInvalidation, sending_private_invalidation)
{
Singleton::Consume<I_Environment>::from(env)->registerValue<string>("Base Executable Name", "idn");
auto invalidation = Invalidation("aaa")
.addMainAttr(main_attr)
.addAttr(attr)
@@ -511,6 +528,10 @@ TEST_F(IntelligenceInvalidation, sending_private_invalidation)
" } ] }";
EXPECT_EQ(invalidation_json, expected_json);
EXPECT_TRUE(md.getConnectionFlags().isSet(MessageConnectionConfig::UNSECURE_CONN));
auto headers = md.getHeaders();
EXPECT_NE(headers.find("X-Calling-Service"), headers.end());
EXPECT_EQ(headers.at("X-Calling-Service"), "idn");
}
TEST_F(IntelligenceInvalidation, register_for_invalidation)
@@ -554,7 +575,7 @@ TEST_F(IntelligenceInvalidation, register_for_invalidation)
EXPECT_THAT(body, HasSubstr("\"mainAttributes\": [ { \"attr2\": \"2\" } ]"));
EXPECT_THAT(body, HasSubstr("\"attributes\": [ { \"ipv4Addresses\": [ \"1.1.1.1\" ] } ]"));
EXPECT_TRUE(md.getConnectionFlags().isSet(MessageConnectionConfig::UNSECURE_CONN));
EXPECT_THAT(body, HasSubstr("\"capabilities\": { \"getBulkCallback\": true }"));
EXPECT_THAT(body, HasSubstr("\"capabilities\": { \"getBulkCallback\": true, \"returnRegistrationTTL\": true }"));
}
@@ -599,7 +620,7 @@ TEST_F(IntelligenceInvalidation, register_for_multiple_assets_invalidation)
));
EXPECT_NE(i_intelligence->registerInvalidation(invalidation, callback), 0);
EXPECT_THAT(body, HasSubstr("\"capabilities\": { \"getBulkCallback\": true }"));
EXPECT_THAT(body, HasSubstr("\"capabilities\": { \"getBulkCallback\": true, \"returnRegistrationTTL\": true }"));
EXPECT_THAT(
body,

View File

@@ -274,44 +274,44 @@ Invalidation::serialize(cereal::JSONInputArchive &ar)
try {
ar(cereal::make_nvp("class", class_));
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgWarning(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("category", category));
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("family", family));
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("group", group));
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("order", order));
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("kind", kind));
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("mainAttributes", main_attributes));
ar(cereal::make_nvp("attributes", attributes));
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
@@ -323,21 +323,21 @@ Invalidation::serialize(cereal::JSONInputArchive &ar)
throw std::invalid_argument("Invalid string for ObjectType: " + object_type_);
}
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("sourceId", source_id_));
source_id = source_id_;
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("invalidationRegistrationId", registration_id_));
registration_id = registration_id_;
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
@@ -349,14 +349,14 @@ Invalidation::serialize(cereal::JSONInputArchive &ar)
throw std::invalid_argument("Invalid string for InvalidationType: " + invalidation_type_);
}
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgWarning(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("listeningId", listening_id_));
listening_id = listening_id_;
} catch (const cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
classifiers[ClassifierType::CLASS] = class_;
@@ -381,6 +381,35 @@ Invalidation::addMainAttr(const StrAttributes &attr)
return *this;
}
Invalidation &
Invalidation::addHeader(const string &name, const string &value)
{
headers[name] = value;
return *this;
}
Maybe<string>
Invalidation::getHeader(const string &name) const
{
auto it = headers.find(name);
if (it != headers.end()) {
return it->second;
}
return genError("Header not found: " + name);
}
const map<string, string> &
Invalidation::getHeaders() const
{
return headers;
}
bool
Invalidation::hasHeader(const string &name) const
{
return headers.find(name) != headers.end();
}
Maybe<string>
Invalidation::getRegistrationID() const{
return registration_id;
@@ -660,11 +689,26 @@ IpAttributes::serialize(cereal::JSONInputArchive &ar)
{
try {
ar(cereal::make_nvp("ipv4Addresses", ipv4_addresses));
} catch (cereal::Exception &e) {
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("ipv4AddressesRange", ipv4_address_ranges));
} catch (cereal::Exception &e) {
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("ipv6Addresses", ipv6_addresses));
} catch (cereal::Exception &e) {
dbgTrace(D_INTELLIGENCE) << e.what();
}
try {
ar(cereal::make_nvp("ipv6AddressesRange", ipv6_address_ranges));
} catch (cereal::Exception &e) {
dbgError(D_INTELLIGENCE) << e.what();
dbgTrace(D_INTELLIGENCE) << e.what();
}
}

View File

@@ -13,12 +13,14 @@
#include "mainloop.h"
#include <cstddef>
#include <memory>
#include <system_error>
#include <map>
#include <sstream>
#include <poll.h>
#include <unistd.h>
#include <functional>
#include "config.h"
#include "coroutine.h"
@@ -62,6 +64,16 @@ public:
bool is_primary
) override;
RoutineID
addBalancedIntervalRoutine(
RoutineType priority,
std::chrono::microseconds interval,
Routine func,
const std::string &routine_name,
chrono::microseconds offset,
bool is_primary
) override;
RoutineID
addFileRoutine(
RoutineType priority,
@@ -430,6 +442,48 @@ MainloopComponent::Impl::addRecurringRoutine(
return addOneTimeRoutine(priority, func_wrapper, routine_name, is_primary);
}
I_MainLoop::RoutineID
MainloopComponent::Impl::addBalancedIntervalRoutine(
RoutineType priority,
chrono::microseconds interval,
Routine func,
const string &routine_name,
chrono::microseconds offset,
bool is_primary
)
{
Routine func_wrapper = [this, interval, offset, func, routine_name]()
{
using namespace std::chrono;
I_TimeGet *timer = Singleton::Consume<I_TimeGet>::by<MainloopComponent>();
typedef duration<size_t, ratio<86400>> days;
static const microseconds one_day_in_microseconds = duration_cast<microseconds>(days(1));
while (true) {
microseconds now = timer->getWalltime();
size_t whole_days = now.count() / one_day_in_microseconds.count();
microseconds time_since_midnight = now - microseconds(whole_days * one_day_in_microseconds.count());
// Calculate next aligned execution time from midnight
size_t intervals_from_midnight = time_since_midnight / interval;
microseconds next_aligned_time = (intervals_from_midnight + 1) * interval;
// Calculate wait time until next execution
microseconds target_time = next_aligned_time + offset;
microseconds wait_time = target_time - time_since_midnight;
if (wait_time > interval) {
wait_time -= interval;
}
dbgTrace(D_MAINLOOP) << "Balanced interval routine waiting " << wait_time << " microseconds";
yield(wait_time);
func();
}
};
return addOneTimeRoutine(priority, func_wrapper, routine_name, is_primary);
}
I_MainLoop::RoutineID
MainloopComponent::Impl::addFileRoutine(
RoutineType priority,

View File

@@ -24,7 +24,7 @@ USE_DEBUG_FLAG(D_MAINLOOP);
class EndTest
{
};
typedef std::chrono::duration<size_t, std::ratio<86400>> days;
class MainloopTest : public Test
{
public:
@@ -74,6 +74,45 @@ public:
);
}
void
setupBalancedIntervalTest(
chrono::microseconds start_time,
chrono::microseconds interval,
const string &routine_name,
chrono::microseconds offset = chrono::microseconds(0),
chrono::microseconds time_advance = chrono::hours(24)
)
{
chrono::microseconds time = start_time;
EXPECT_CALL(mock_time, getWalltime()).WillRepeatedly(InvokeWithoutArgs([&]() { return time; }));
EXPECT_CALL(mock_time, getMonotonicTime())
.WillRepeatedly(InvokeWithoutArgs(
[&]()
{
auto old_time = time;
time += time_advance;
return old_time;
}
));
auto callback = [this]() { mainloop->stop(); };
mainloop->addBalancedIntervalRoutine(
I_MainLoop::RoutineType::RealTime, interval, callback, routine_name, offset, true
);
// Run the mainloop
mainloop->run();
}
void
expectWaitTimeInDebug(chrono::microseconds expected_wait_time)
{
string expected_debug =
"Balanced interval routine waiting " + to_string(expected_wait_time.count()) + " microseconds";
EXPECT_THAT(capture_debug.str(), HasSubstr(expected_debug));
}
I_Environment::ActiveContexts active_context;
NiceMock<MockTimeGet> mock_time;
@@ -532,3 +571,100 @@ TEST_F(MainloopTest, check_routine_name)
HasSubstr("Starting execution of corutine. Routine named: check routine name test")
);
}
TEST_F(MainloopTest, balanced_interval_empty_routine_name_hour_start)
{
Debug::setUnitTestFlag(D_MAINLOOP, Debug::DebugLevel::TRACE);
// Start on day 1 at 4:00 (at the start of the hour)
// Interval is 2 hours
// Because routine_name is empty, no shifting will be done so remaining time will be exactly equal to interval
const std::string routine_name = "";
const std::chrono::milliseconds time(days(1) + std::chrono::hours(4));
const std::chrono::minutes interval(std::chrono::hours(2));
setupBalancedIntervalTest(time, interval, routine_name);
expectWaitTimeInDebug(interval);
}
TEST_F(MainloopTest, balanced_interval_empty_routine_name_middle_of_the_hour)
{
Debug::setUnitTestFlag(D_MAINLOOP, Debug::DebugLevel::TRACE);
// Start 1 day and 5:54 (in the middle of the second hour)
// Interval is 2 hours
// Because routine_name is empty, no shifting will be done so remaining time will be 6 minutes
const std::string routine_name = "";
const std::chrono::milliseconds time(days(1) + std::chrono::hours(5) + std::chrono::minutes(54));
const std::chrono::minutes interval(std::chrono::hours(2));
setupBalancedIntervalTest(time, interval, routine_name);
expectWaitTimeInDebug(std::chrono::minutes(6));
}
TEST_F(MainloopTest, balanced_interval_non_empty_routine_name_hour_start)
{
Debug::setUnitTestFlag(D_MAINLOOP, Debug::DebugLevel::TRACE);
// Start on day 1 at 4:00 (at the start of the hour)
// Interval is 2 hours
// The routine_name is chosen so that it results in hashed slot #2 which shifts the remaining time by 2 * 10
// minutes So, the remaining time to wait would be original interval + 20 minutes
const std::string routine_name = "5e9dac5d204a8f35b264a932";
const std::chrono::milliseconds time(days(1) + std::chrono::hours(4));
const std::chrono::minutes interval(std::chrono::hours(2));
const std::chrono::microseconds offset = std::chrono::minutes(2*10);
setupBalancedIntervalTest(time, interval, routine_name, offset);
expectWaitTimeInDebug(offset);
}
TEST_F(MainloopTest, balanced_interval_non_empty_routine_name_hour_start_post_offset)
{
Debug::setUnitTestFlag(D_MAINLOOP, Debug::DebugLevel::TRACE);
// Start on day 1 at 4:40 (at the start of the hour)
// Interval is 2 hours
// The routine_name is chosen so that it results in hashed slot #2 which shifts the remaining time by 2 * 10
// minutes So, the remaining time to wait would be original interval - 20 minutes
const std::string routine_name = "5e9dac5d204a8f35b264a932";
const std::chrono::milliseconds time(days(1) + std::chrono::hours(4) + std::chrono::minutes(40));
const std::chrono::minutes interval(std::chrono::hours(2));
const std::chrono::microseconds offset = std::chrono::minutes(2*10);
setupBalancedIntervalTest(time, interval, routine_name, offset);
expectWaitTimeInDebug(interval - offset);
}
TEST_F(MainloopTest, balanced_interval_non_empty_routine_name_middle_of_the_hour)
{
Debug::setUnitTestFlag(D_MAINLOOP, Debug::DebugLevel::TRACE);
// Start 1 day and 5:54 (in the middle of the second hour)
// Interval is 2 hours
// The routine_name is chosen so that it results in hashed slot #2 which shifts the remaining time by 2 * 10
// minutes So, the remaining time to wait would be 6 minutes + 20 minutes shift
const std::string routine_name = "5e9dac5d204a8f35b264a932";
const std::chrono::milliseconds time(days(1) + std::chrono::hours(5) + std::chrono::minutes(54));
const std::chrono::minutes interval(std::chrono::hours(2));
const std::chrono::microseconds offset = std::chrono::minutes(2*10);
setupBalancedIntervalTest(time, interval, routine_name, offset);
expectWaitTimeInDebug(std::chrono::minutes(6) + offset);
}
TEST_F(MainloopTest, balanced_interval_another_asset_hour_start)
{
Debug::setUnitTestFlag(D_MAINLOOP, Debug::DebugLevel::TRACE);
// Start on day 1 at 4:00 (at the start of the hour)
// Interval is 2 hours
// The routine_name is chosen so that it results in hashed slot #7 which shifts the remaining time by 7 * 10
// minutes So, the remaining time to wait would be 70 minutes
const std::string routine_name = "5e9da89572f6f9af9bebc0da";
const std::chrono::milliseconds time(days(1) + std::chrono::hours(4));
const std::chrono::minutes interval(std::chrono::hours(2));
const std::chrono::microseconds offset = std::chrono::minutes(7*10);
setupBalancedIntervalTest(time, interval, routine_name, offset);
expectWaitTimeInDebug(offset);
}

View File

@@ -39,7 +39,7 @@ using namespace smartBIO;
USE_DEBUG_FLAG(D_CONNECTION);
static const HTTPResponse sending_timeout(HTTPStatusCode::HTTP_UNKNOWN, "Failed to send all data in time");
static const HTTPResponse receving_timeout(HTTPStatusCode::HTTP_UNKNOWN, "Failed to receive all data in time");
static const HTTPResponse receiving_timeout(HTTPStatusCode::HTTP_UNKNOWN, "Failed to receive all data in time");
static const HTTPResponse parsing_error(HTTPStatusCode::HTTP_UNKNOWN, "Failed to parse the HTTP response");
static const HTTPResponse close_error(
HTTPStatusCode::HTTP_UNKNOWN,
@@ -271,18 +271,11 @@ private:
return *details_ssl_dir;
}
// Use detail_resolver to determine platform-specific certificate directory
#if defined(alpine)
string platform = "alpine";
return "/etc/ssl/certs/";
#else
string platform = "linux";
#endif
if (platform == "alpine") {
return "/etc/ssl/certs/";
}
return "/usr/lib/ssl/certs/";
#endif
}
Maybe<void>
@@ -741,20 +734,54 @@ private:
}
}
auto receiving_end_time = i_time->getMonotonicTime() + getConnectionTimeout();
auto base_timeout_config = getProfileAgentSettingWithDefault<uint>(
10,
"agent.config.message.chunk.connection.timeout"
);
auto base_timeout = chrono::seconds(base_timeout_config); // 10 seconds between data chunks
auto global_timeout_config = getProfileAgentSettingWithDefault<uint>(
600,
"agent.config.message.global.connection.timeout"
);
auto global_timeout = chrono::seconds(global_timeout_config); // 600 seconds maximum for entire download
auto receiving_end_time = i_time->getMonotonicTime() + base_timeout;
auto global_end_time = i_time->getMonotonicTime() + global_timeout;
HTTPResponseParser http_parser;
dbgTrace(D_CONNECTION) << "Sent the message, now waiting for response";
dbgTrace(D_CONNECTION)
<< "Sent the message, now waiting for response (global timeout: "
<< global_timeout.count()
<< " seconds)";
while (!http_parser.hasReachedError()) {
// Check global timeout first
if (i_time->getMonotonicTime() > global_end_time) {
should_close_connection = true;
dbgWarning(D_CONNECTION)
<< "Global receive timeout reached after "
<< global_timeout.count() << " seconds";
return genError(receiving_timeout);
}
// Check per-chunk timeout
if (i_time->getMonotonicTime() > receiving_end_time) {
should_close_connection = true;
return genError(receving_timeout);
};
dbgWarning(D_CONNECTION) << "No data received for " << base_timeout.count() << " seconds";
return genError(receiving_timeout);
}
auto receieved = receiveData();
if (!receieved.ok()) {
should_close_connection = true;
return receieved.passErr();
}
// Reset timeout each time we receive data
if (!receieved.unpack().empty()) {
receiving_end_time = i_time->getMonotonicTime() + base_timeout;
}
auto response = http_parser.parseData(*receieved, is_connect);
i_mainloop->yield(receieved.unpack().empty());
if (response.ok()) {
dbgTrace(D_MESSAGING) << printOut(response.unpack().toString());

View File

@@ -48,6 +48,13 @@ public:
return establishNewConnection(metadata, category);
}
void
clearConnections() override
{
dbgTrace(D_CONNECTION) << "Clearing all persistent connections";
persistent_connections.clear();
}
Maybe<Connection>
getPersistentConnection(const string &host_name, uint16_t port, MessageCategory category) override
{

View File

@@ -24,6 +24,7 @@
#include "rest.h"
#include "rest_server.h"
#include "dummy_socket.h"
#include <atomic>
using namespace std;
using namespace testing;
@@ -100,6 +101,11 @@ TEST_F(TestConnectionComp, testSetAndGetConnection)
EXPECT_EQ(get_conn.getConnKey().getHostName(), "127.0.0.1");
EXPECT_EQ(get_conn.getConnKey().getPort(), 8080);
EXPECT_EQ(get_conn.getConnKey().getCategory(), MessageCategory::LOG);
i_conn->clearConnections();
maybe_get_connection = i_conn->getPersistentConnection("127.0.0.1", 8080, MessageCategory::LOG);
ASSERT_FALSE(maybe_get_connection.ok());
}
TEST_F(TestConnectionComp, testEstablishNewConnection)
@@ -279,19 +285,27 @@ TEST_F(TestConnectionComp, testSendRequestWithOneTimeFogConnection)
auto req = HTTPRequest::prepareRequest(conn, HTTPMethod::POST, "/test", conn_metadata.getHeaders(), "test-body");
ASSERT_TRUE(req.ok());
// Ensure we accept+respond exactly once regardless of yield overload order
std::atomic<bool> responded{false};
EXPECT_CALL(mock_mainloop, yield(A<std::chrono::microseconds>()))
.WillOnce(
InvokeWithoutArgs(
[&]() {
cerr << "accepting socket" << endl;
dummy_socket.acceptSocket();
dummy_socket.writeToSocket("HTTP/1.1 200 OK\r\nContent-Length: 7\r\n\r\nmy-test");
}
)
).WillRepeatedly(Return());
.WillRepeatedly(InvokeWithoutArgs([&]() {
if (!responded.exchange(true)) {
cerr << "accepting socket" << endl;
dummy_socket.acceptSocket();
dummy_socket.writeToSocket("HTTP/1.1 200 OK\r\nContent-Length: 7\r\n\r\nmy-test");
}
}));
EXPECT_CALL(mock_mainloop, yield(A<bool>()))
.WillRepeatedly(InvokeWithoutArgs([&]() {
if (!responded.exchange(true)) {
cerr << "accepting socket while receiving" << endl;
dummy_socket.acceptSocket();
dummy_socket.writeToSocket("HTTP/1.1 200 OK\r\nContent-Length: 7\r\n\r\nmy-test");
}
}));
EXPECT_CALL(mock_timer, getMonotonicTime())
.WillRepeatedly(Invoke([]() { static int j = 0; return chrono::microseconds(++j * 10); }));
.WillRepeatedly(Invoke([]() { static int j = 0; return chrono::microseconds(++j * 1000 * 1000); }));
auto maybe_response = i_conn->sendRequest(conn, *req);
if (!maybe_response.ok()) {

View File

@@ -27,6 +27,7 @@ class I_MessagingConnection
{
public:
virtual Maybe<Connection> establishConnection(const MessageMetadata &metadata, MessageCategory category) = 0;
virtual void clearConnections() = 0;
virtual Maybe<Connection> getPersistentConnection(
const std::string &host_name, uint16_t port, MessageCategory category

View File

@@ -71,6 +71,7 @@ public:
bool setFogConnection(const std::string &host, uint16_t port, bool is_secure, MessageCategory category);
bool setFogConnection(MessageCategory category);
void clearConnections();
private:
Maybe<Connection> getConnection(MessageCategory category, const MessageMetadata &message_metadata);
@@ -96,6 +97,7 @@ private:
I_MessageBuffer *i_messaging_buffer;
I_AgentDetails *agent_details;
bool should_buffer_failed_messages;
std::string proxy_addr;
TemporaryCache<std::string, HTTPResponse> fog_get_requests_cache;
};

View File

@@ -29,6 +29,7 @@ public:
MOCK_METHOD3(mockSendRequest, Maybe<HTTPResponse, HTTPResponse>(Connection &, HTTPRequest, bool));
MOCK_METHOD0(clearConnections, void());
MOCK_METHOD3(getPersistentConnection, Maybe<Connection>(const string &, uint16_t, MessageCategory));
MOCK_METHOD1(getFogConnectionByCategory, Maybe<Connection>(MessageCategory));
};

View File

@@ -97,6 +97,12 @@ public:
return messaging_comp.setFogConnection(category);
}
void
clearConnections() override
{
messaging_comp.clearConnections();
}
private:
MessagingComp messaging_comp;
ConnectionComponent connection_comp;
@@ -119,7 +125,7 @@ void
Messaging::preload()
{
registerExpectedConfiguration<int>("message", "Cache timeout");
registerExpectedConfiguration<uint>("message", "Connection timeout");
registerExpectedConfigurationWithCache<uint>("assetId", "message", "Connection timeout");
registerExpectedConfiguration<uint>("message", "Connection handshake timeout");
registerExpectedConfiguration<bool>("message", "Verify SSL pinning");
registerExpectedConfiguration<bool>("message", "Buffer Failed Requests");

View File

@@ -204,12 +204,15 @@ MessagingBufferComponent::Impl::pushNewBufferedMessage(
Maybe<BufferedMessage>
MessagingBufferComponent::Impl::peekMessage()
{
auto move_cmd =
"if [ -s " + buffer_input + " ] && [ ! -s " + buffer_output + " ];"
"then mv " + buffer_input + " " + buffer_output + ";"
"fi";
shell_cmd->getExecOutput(move_cmd);
// Native replacement for shell mv command
struct stat stat_input, stat_output;
bool input_exists = (stat(buffer_input.c_str(), &stat_input) == 0 && stat_input.st_size > 0);
bool output_exists = (stat(buffer_output.c_str(), &stat_output) == 0 && stat_output.st_size > 0);
if (input_exists && !output_exists) {
if (rename(buffer_input.c_str(), buffer_output.c_str()) != 0) {
dbgWarning(D_MESSAGING_BUFFER) << "Failed to move buffer input to output: " << strerror(errno);
}
}
if (!checkExistence(buffer_output)) return genError(buffer_output + " does not exist");

View File

@@ -72,6 +72,11 @@ MessagingComp::init()
auto i_time_get = Singleton::Consume<I_TimeGet>::by<Messaging>();
auto cache_timeout = getConfigurationWithDefault<int>(40, "message", "Cache timeout");
fog_get_requests_cache.startExpiration(chrono::seconds(cache_timeout), i_mainloop, i_time_get);
proxy_addr = getConfigurationWithDefault<string>(
getProfileAgentSettingWithDefault<string>("", "proxy.address"),
"message",
"Proxy Address"
);
should_buffer_failed_messages = getConfigurationWithDefault<bool>(
getProfileAgentSettingWithDefault<bool>(true, "eventBuffer.bufferFailedRequests"),
@@ -125,7 +130,7 @@ MessagingComp::sendMessage(
dbgWarning(D_MESSAGING) << "Failed to get connection. Error: " << maybe_conn.getErr();
return genError<HTTPResponse>(HTTPStatusCode::HTTP_UNKNOWN, maybe_conn.getErr());
}
Connection conn = maybe_conn.unpack();
if (message_metadata.shouldSuspend() && conn.isSuspended()) {
return suspendMessage(body, method, uri, category, message_metadata);
@@ -133,12 +138,11 @@ MessagingComp::sendMessage(
bool is_to_fog = isMessageToFog(message_metadata);
auto metadata = message_metadata;
if (is_to_fog) {
if (method == HTTPMethod::GET && fog_get_requests_cache.doesKeyExists(uri)) {
HTTPResponse res = fog_get_requests_cache.getEntry(uri);
dbgTrace(D_MESSAGING) << "Response returned from Fog cache. res body: " << res.getBody();
return fog_get_requests_cache.getEntry(uri);
}
@@ -197,7 +201,6 @@ MessagingComp::sendSyncMessage(
)
{
Maybe<HTTPResponse, HTTPResponse> is_msg_send = sendMessage(method, uri, body, category, message_metadata);
if (is_msg_send.ok()) return *is_msg_send;
if (should_buffer_failed_messages && message_metadata.shouldBufferMessage()) {
@@ -412,3 +415,10 @@ MessagingComp::suspendMessage(
HTTPStatusCode::HTTP_SUSPEND, "The connection is suspended due to consecutive message sending errors."
);
}
void
MessagingComp::clearConnections()
{
dbgTrace(D_MESSAGING) << "Clearing all connections (called from AgentDetails)";
i_conn->clearConnections();
}

View File

@@ -62,6 +62,7 @@ public:
{
EXPECT_CALL(mock_agent_details, getFogDomain()).WillRepeatedly(Return(string(fog_addr)));
EXPECT_CALL(mock_agent_details, getFogPort()).WillRepeatedly(Return(fog_port));
EXPECT_CALL(mock_agent_details, getProxy()).WillRepeatedly(Return(string("")));
EXPECT_CALL(mock_agent_details, getOpenSSLDir()).WillRepeatedly(Return(string("/usr/lib/ssl/certs/")));
EXPECT_CALL(mock_agent_details, getAccessToken()).WillRepeatedly(Return(string("accesstoken")));
EXPECT_CALL(mock_agent_details, readAgentDetails()).WillRepeatedly(Return(true));
@@ -262,6 +263,28 @@ operator==(const MessageMetadata &one, const MessageMetadata &two)
one.isDualAuth() == two.isDualAuth();
}
TEST_F(TestMessagingComp, testClearConnections)
{
setAgentDetails();
MessageCategory category = MessageCategory::GENERIC;
MessageConnectionKey conn_key(fog_addr, fog_port, category);
MessageMetadata metadata(fog_addr, fog_port, true);
MessageProxySettings proxy_settings("7.7.7.7", "cred", 8080);
metadata.setProxySettings(proxy_settings);
Connection conn(conn_key, metadata);
EXPECT_CALL(mock_messaging_connection, establishConnection(metadata, category)).WillOnce(Return(conn));
EXPECT_TRUE(messaging_comp.setFogConnection(category));
EXPECT_CALL(mock_messaging_connection, clearConnections()).Times(1);
messaging_comp.clearConnections();
EXPECT_CALL(mock_messaging_connection, establishConnection(metadata, category)).WillOnce(Return(conn));
EXPECT_TRUE(messaging_comp.setFogConnection(category));
}
TEST_F(TestMessagingComp, testSetFogConnection)
{
setAgentDetails();

View File

@@ -33,16 +33,35 @@ MetricMetadata::Description operator"" _desc(const char *str, size_t) { return M
static const set<string> default_metrics = {
"watchdogProcessStartupEventsSum",
"reservedNgenA",
"reservedNgenB",
"reservedNgenC"
"reservedNgenD"
"reservedNgenE",
"reservedNgenF",
"reservedNgenG"
"reservedNgenH",
"reservedNgenI",
"reservedNgenJ",
"reservedNgenA_WAAP telemetry",
"reservedNgenB_WAAP telemetry",
"reservedNgenC_WAAP telemetry",
"reservedNgenD_WAAP telemetry",
"reservedNgenE_WAAP telemetry",
"reservedNgenF_WAAP telemetry",
"reservedNgenG_WAAP telemetry",
"reservedNgenH_WAAP telemetry",
"reservedNgenI_WAAP telemetry",
"reservedNgenJ_WAAP telemetry",
"reservedNgenA_WAAP traffic telemetry",
"reservedNgenB_WAAP traffic telemetry",
"reservedNgenC_WAAP traffic telemetry",
"reservedNgenD_WAAP traffic telemetry",
"reservedNgenE_WAAP traffic telemetry",
"reservedNgenF_WAAP traffic telemetry",
"reservedNgenG_WAAP traffic telemetry",
"reservedNgenH_WAAP traffic telemetry",
"reservedNgenI_WAAP traffic telemetry",
"reservedNgenJ_WAAP traffic telemetry",
"reservedNgenA_WAAP attack type telemetry",
"reservedNgenB_WAAP attack type telemetry",
"reservedNgenC_WAAP attack type telemetry",
"reservedNgenD_WAAP attack type telemetry",
"reservedNgenE_WAAP attack type telemetry",
"reservedNgenF_WAAP attack type telemetry",
"reservedNgenG_WAAP attack type telemetry",
"reservedNgenH_WAAP attack type telemetry",
"reservedNgenI_WAAP attack type telemetry",
"numberOfProtectedAssetsSample",
"preventEngineMatchesSample",
"detectEngineMatchesSample",
@@ -115,6 +134,7 @@ MetricCalc::getPrometheusMetrics(const std::string &metric_name, const string &a
PrometheusData res;
res.name = getMetricDotName() != "" ? getMetricDotName() : getMetricName();
res.unique_name = res.name + "_" + metric_name;
res.type = getMetricType() == MetricType::GAUGE ? "gauge" : "counter";
res.description = getMetircDescription();

View File

@@ -575,6 +575,7 @@ TEST_F(MetricTest, getPromeathusMetric)
" \"metrics\": [\n"
" {\n"
" \"metric_name\": \"cpuMax\",\n"
" \"unique_name\": \"cpuMax_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -583,6 +584,7 @@ TEST_F(MetricTest, getPromeathusMetric)
" },\n"
" {\n"
" \"metric_name\": \"cpuMin\",\n"
" \"unique_name\": \"cpuMin_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -591,6 +593,7 @@ TEST_F(MetricTest, getPromeathusMetric)
" },\n"
" {\n"
" \"metric_name\": \"cpuAvg\",\n"
" \"unique_name\": \"cpuAvg_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -599,6 +602,7 @@ TEST_F(MetricTest, getPromeathusMetric)
" },\n"
" {\n"
" \"metric_name\": \"cpuCurrent\",\n"
" \"unique_name\": \"cpuCurrent_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -607,6 +611,7 @@ TEST_F(MetricTest, getPromeathusMetric)
" },\n"
" {\n"
" \"metric_name\": \"cpuCounter\",\n"
" \"unique_name\": \"cpuCounter_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -615,6 +620,7 @@ TEST_F(MetricTest, getPromeathusMetric)
" },\n"
" {\n"
" \"metric_name\": \"cpuTotalCounter\",\n"
" \"unique_name\": \"cpuTotalCounter_CPU usage\",\n"
" \"metric_type\": \"counter\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -665,6 +671,7 @@ TEST_F(MetricTest, getPromeathusMultiMap)
" \"metrics\": [\n"
" {\n"
" \"metric_name\": \"request.total\",\n"
" \"unique_name\": \"GET_Bytes per URL\",\n"
" \"metric_type\": \"counter\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -673,6 +680,7 @@ TEST_F(MetricTest, getPromeathusMultiMap)
" },\n"
" {\n"
" \"metric_name\": \"request.total\",\n"
" \"unique_name\": \"POST_Bytes per URL\",\n"
" \"metric_type\": \"counter\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -681,6 +689,7 @@ TEST_F(MetricTest, getPromeathusMultiMap)
" },\n"
" {\n"
" \"metric_name\": \"request.total\",\n"
" \"unique_name\": \"GET_Bytes per URL\",\n"
" \"metric_type\": \"counter\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -750,6 +759,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" \"metrics\": [\n"
" {\n"
" \"metric_name\": \"request.total\",\n"
" \"unique_name\": \"GET_Bytes per URL\",\n"
" \"metric_type\": \"counter\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -758,6 +768,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" },\n"
" {\n"
" \"metric_name\": \"request.total\",\n"
" \"unique_name\": \"POST_Bytes per URL\",\n"
" \"metric_type\": \"counter\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -766,6 +777,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" },\n"
" {\n"
" \"metric_name\": \"request.total\",\n"
" \"unique_name\": \"GET_Bytes per URL\",\n"
" \"metric_type\": \"counter\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -774,6 +786,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" },\n"
" {\n"
" \"metric_name\": \"cpuMax\",\n"
" \"unique_name\": \"cpuMax_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -782,6 +795,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" },\n"
" {\n"
" \"metric_name\": \"cpuMin\",\n"
" \"unique_name\": \"cpuMin_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -790,6 +804,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" },\n"
" {\n"
" \"metric_name\": \"cpuAvg\",\n"
" \"unique_name\": \"cpuAvg_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -798,6 +813,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" },\n"
" {\n"
" \"metric_name\": \"cpuCurrent\",\n"
" \"unique_name\": \"cpuCurrent_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -806,6 +822,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" },\n"
" {\n"
" \"metric_name\": \"cpuCounter\",\n"
" \"unique_name\": \"cpuCounter_CPU usage\",\n"
" \"metric_type\": \"gauge\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","
@@ -814,6 +831,7 @@ TEST_F(MetricTest, getPromeathusTwoMetrics)
" },\n"
" {\n"
" \"metric_name\": \"cpuTotalCounter\",\n"
" \"unique_name\": \"cpuTotalCounter_CPU usage\",\n"
" \"metric_type\": \"counter\",\n"
" \"metric_description\": \"\",\n"
" \"labels\": \"{agent=\\\"Unknown\\\",assetId=\\\"asset id\\\",id=\\\"87\\\","

View File

@@ -110,13 +110,16 @@ TagAndEnumManagement::convertStringToTag(const string &tag)
{"Horizon Telemetry Metrics", ReportIS::Tags::HORIZON_TELEMETRY_METRICS},
{"Crowdsec", ReportIS::Tags::CROWDSEC},
{"apiDiscoveryCloudMessaging", ReportIS::Tags::API_DISCOVERY},
{"lbHealthStatusEngine", ReportIS::Tags::LB_HEALTH_STATUS},
{"Playground", ReportIS::Tags::PLAYGROUND},
{"Nginx Proxy Manager", ReportIS::Tags::NGINX_PROXY_MANAGER},
{"APISIX Server", ReportIS::Tags::WEB_SERVER_APISIX},
{"Docker Deployment", ReportIS::Tags::DEPLOYMENT_DOCKER},
{"SWAG Server", ReportIS::Tags::WEB_SERVER_SWAG},
{"NGINX Unified Server", ReportIS::Tags::WEB_SERVER_NGINX_UNIFIED},
{"AI Guard", ReportIS::Tags::AIGUARD}
{"AI Guard", ReportIS::Tags::AIGUARD},
{"Central NGINX Manager", ReportIS::Tags::CENTRAL_NGINX_MANAGER},
{"Browser Agent", ReportIS::Tags::BROWSER_AGENT}
};
auto report_is_tag = strings_to_tags.find(tag);
@@ -280,6 +283,8 @@ TagAndEnumManagement::convertToString(const IssuingEngine &issuing_engine)
case IssuingEngine::IDA_SAML_IDN_CLIENT_IP_NOTIFY: return "quantumIPNotifyIdn";
case IssuingEngine::API_DISCOVERY: return "apiDiscoveryCloudMessaging";
case IssuingEngine::HORIZON_TELEMETRY_METRICS: return "horizonTelemetryMetrics";
case IssuingEngine::LB_HEALTH_STATUS: return "lbHealthStatusEngine";
case IssuingEngine::BROWSER_AGENT: return "browserAgentEngine";
}
dbgAssertOpt(false) << alert << "Reached impossible engine value of: " << static_cast<int>(issuing_engine);
@@ -323,12 +328,15 @@ EnumArray<Tags, string> TagAndEnumManagement::tags_translation_arr {
"Crowdsec",
"Playground",
"apiDiscoveryCloudMessaging",
"lbHealthStatusEngine",
"Nginx Proxy Manager",
"APISIX Server",
"Docker Deployment",
"SWAG Server",
"NGINX Unified Server",
"AI Guard"
"AI Guard",
"Central NGINX Manager",
"Browser Agent"
};
EnumArray<AudienceTeam, string> TagAndEnumManagement::audience_team_translation {

View File

@@ -16,6 +16,7 @@
#include <string>
#include <istream>
#include <map>
#include "maybe_res.h"
@@ -23,11 +24,20 @@ class I_RestInvoke
{
public:
virtual Maybe<std::string> getSchema(const std::string &uri) const = 0;
virtual Maybe<std::string> invokeRest(const std::string &uri, std::istream &in) const = 0;
virtual Maybe<std::string> invokeRest(
const std::string &uri,
std::istream &in,
const std::map<std::string, std::string> &headers
) const = 0;
virtual bool isGetCall(const std::string &uri) const = 0;
virtual std::string invokeGet(const std::string &uri) const = 0;
virtual bool isPostCall(const std::string &uri) const = 0;
virtual Maybe<std::string> invokePost(const std::string &uri, const std::string &body) const = 0;
virtual bool shouldCaptureHeaders(const std::string &uri) const = 0;
protected:
~I_RestInvoke() {}
};

View File

@@ -29,8 +29,11 @@ RestHelper::reportError(std::string const &err)
}
Maybe<string>
ServerRest::performRestCall(istream &in)
ServerRest::performRestCall(istream &in, const map<string, string> &headers)
{
if (wantsHeaders()) {
request_headers = headers;
}
try {
try {
int firstChar = in.peek();

View File

@@ -75,17 +75,52 @@ RestConn::parseConn() const
dbgDebug(D_API) << "Call identifier: " << identifier;
uint len = 0;
map<string, string> headers;
bool should_capture_headers = invoke->shouldCaptureHeaders(identifier);
while (true) {
line = readLine();
if (line.size() < 3) break;
os.str(line);
string head, data;
os >> head >> data;
if (compareStringCaseInsensitive(head, "Content-Length:")) {
try {
len = stoi(data, nullptr);
} catch (...) {
if (should_capture_headers) {
size_t colon_pos = line.find(':');
if (colon_pos == string::npos) continue;
string head = line.substr(0, colon_pos);
string data = line.substr(colon_pos + 1);
size_t data_start = data.find_first_not_of(" \t\r\n");
if (data_start != string::npos) {
data = data.substr(data_start);
} else {
data = "";
}
size_t data_end = data.find_last_not_of(" \t\r\n");
if (data_end != string::npos) {
data = data.substr(0, data_end + 1);
}
if (!head.empty()) {
headers[head] = data;
dbgTrace(D_API) << "Captured header: " << head << " = " << data;
}
if (compareStringCaseInsensitive(head, "Content-Length")) {
try {
len = stoi(data, nullptr);
} catch (...) {
}
}
} else {
os.str(line);
string head, data;
os >> head >> data;
if (compareStringCaseInsensitive(head, "Content-Length:")) {
try {
len = stoi(data, nullptr);
} catch (...) {
}
}
}
}
@@ -113,7 +148,19 @@ RestConn::parseConn() const
dbgTrace(D_API) << "Message content: " << body.str();
Maybe<string> res = (method == "POST") ? invoke->invokeRest(identifier, body) : invoke->getSchema(identifier);
if (method == "POST" && invoke->isPostCall(identifier)) {
Maybe<string> result = invoke->invokePost(identifier, body.str());
if (!result.ok()) {
dbgWarning(D_API) << "Failed to invoke POST call: " << result.getErr();
sendResponse("500 Internal Server Error", result.getErr());
return;
}
sendResponse("200 OK", result.unpack());
return;
}
Maybe<string> res = (method == "POST") ?
invoke->invokeRest(identifier, body, headers) : invoke->getSchema(identifier);
if (res.ok()) {
sendResponse("200 OK", res.unpack());

View File

@@ -15,6 +15,7 @@
#define __REST_CONN_H__
#include <string>
#include <map>
#include "i_mainloop.h"
#include "i_rest_invoke.h"

View File

@@ -53,11 +53,16 @@ public:
bool addRestCall(RestAction oper, const string &uri, unique_ptr<RestInit> &&init) override;
bool addGetCall(const string &uri, const function<string()> &cb) override;
bool addWildcardGetCall(const string &uri, const function<string(const string &)> &callback);
bool addPostCall(const string &uri, const function<Maybe<string>(const string &)> &callback) override;
uint16_t getListeningPort() const override { return listening_port; }
uint16_t getStartingPortRange() const override { return starting_port_range; }
Maybe<string> getSchema(const string &uri) const override;
Maybe<string> invokeRest(const string &uri, istream &in) const override;
Maybe<string> invokeRest(const string &uri, istream &in, const map<string, string> &headers) const override;
bool isGetCall(const string &uri) const override;
string invokeGet(const string &uri) const override;
bool isPostCall(const string &uri) const override;
Maybe<string> invokePost(const string &uri, const string &body) const override;
bool shouldCaptureHeaders(const string &uri) const override;
private:
void prepareConfiguration();
@@ -71,7 +76,9 @@ private:
map<string, unique_ptr<RestInit>> rest_calls;
map<string, function<string()>> get_calls;
map<string, function<string(const string &)>> wildcard_get_calls;
map<string, function<Maybe<string>(const string &)>> post_calls;
uint16_t listening_port = 0;
uint16_t starting_port_range = 0;
vector<uint16_t> port_range;
};
@@ -241,6 +248,9 @@ RestServer::Impl::prepareConfiguration()
range_start = 0;
range_end = 1;
}
starting_port_range = range_start.unpack();
dbgInfo(D_API) << "Rest port range start: " << *range_start << ", end: " << *range_end;
// starting_port_range = *range_start;
port_range.resize(*range_end - *range_start);
for (uint16_t i = 0, port = *range_start; i < port_range.size(); i++, port++) {
port_range[i] = port;
@@ -379,6 +389,14 @@ RestServer::Impl::addWildcardGetCall(const string &uri, const function<string(co
return wildcard_get_calls.emplace(uri, callback).second;
}
bool
RestServer::Impl::addPostCall(const string &uri, const function<Maybe<string>(const string&)> &callback)
{
if (rest_calls.find(uri) != rest_calls.end()) return false;
if (get_calls.find(uri) != get_calls.end()) return false;
return post_calls.emplace(uri, callback).second;
}
Maybe<string>
RestServer::Impl::getSchema(const string &uri) const
{
@@ -392,12 +410,12 @@ RestServer::Impl::getSchema(const string &uri) const
}
Maybe<string>
RestServer::Impl::invokeRest(const string &uri, istream &in) const
RestServer::Impl::invokeRest(const string &uri, istream &in, const map<string, string> &headers) const
{
auto iter = rest_calls.find(uri);
if (iter == rest_calls.end()) return genError("No matching REST call was found");
auto instance = iter->second->getRest();
return instance->performRestCall(in);
return instance->performRestCall(in, headers);
}
bool
@@ -412,6 +430,13 @@ RestServer::Impl::isGetCall(const string &uri) const
return false;
}
bool
RestServer::Impl::isPostCall(const string &uri) const
{
if (post_calls.find(uri) != post_calls.end()) return true;
return false;
}
string
RestServer::Impl::invokeGet(const string &uri) const
{
@@ -425,6 +450,23 @@ RestServer::Impl::invokeGet(const string &uri) const
return "";
}
Maybe<string>
RestServer::Impl::invokePost(const string &uri, const string &body) const
{
auto instance = post_calls.find(uri);
if (instance != post_calls.end()) return instance->second(body);
return genError("No matching POST call was found for URI: " + uri);
}
bool
RestServer::Impl::shouldCaptureHeaders(const string &uri) const
{
auto iter = rest_calls.find(uri);
if (iter == rest_calls.end()) return false;
auto instance = iter->second->getRest();
return instance->wantsHeaders();
}
string
RestServer::Impl::changeActionToString(RestAction oper)
{

View File

@@ -414,3 +414,222 @@ TEST_F(RestConfigTest, not_loopback_flow)
"HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\nContent-Length: 0\r\n\r\n"
);
}
TEST_F(RestConfigTest, getStartingPortRange)
{
// Use a configuration with port range instead of primary/alternative ports
string config_json_with_range =
"{\n"
" \"connection\": {\n"
" \"Nano service API Port Range start\": [\n"
" {\n"
" \"value\": 8000\n"
" }\n"
" ],\n"
" \"Nano service API Port Range end\": [\n"
" {\n"
" \"value\": 8010\n"
" }\n"
" ]\n"
" }\n"
"}\n";
istringstream ss(config_json_with_range);
Singleton::Consume<Config::I_Config>::from(config)->loadConfiguration(ss);
rest_server.init();
auto i_rest = Singleton::Consume<I_RestApi>::from(rest_server);
EXPECT_EQ(i_rest->getStartingPortRange(), 8000);
auto mainloop = Singleton::Consume<I_MainLoop>::from(mainloop_comp);
I_MainLoop::Routine stop_routine = [mainloop] () { mainloop->stopAll(); };
mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::RealTime,
stop_routine,
"RestConfigTest-getStartingPortRange stop routine",
false
);
mainloop->run();
}
TEST_F(RestConfigTest, addPostCall)
{
rest_server.init();
time_proxy.init();
mainloop_comp.init();
auto i_rest = Singleton::Consume<I_RestApi>::from(rest_server);
// Test addPostCall
ASSERT_TRUE(i_rest->addPostCall("test-post", [](const string &body) {
return "Received: " + body;
}));
// Test that adding the same POST call twice fails
ASSERT_FALSE(i_rest->addPostCall("test-post", [](const string &) -> Maybe<string> {
return string("Different handler");
}));
// Test that adding POST call with existing GET call fails
ASSERT_TRUE(i_rest->addGetCall("test-get", []() { return "get response"; }));
ASSERT_FALSE(i_rest->addPostCall("test-get", [](const string &) -> Maybe<string> {
return string("post response");
}));
auto mainloop = Singleton::Consume<I_MainLoop>::from(mainloop_comp);
I_MainLoop::Routine stop_routine = [mainloop] () { mainloop->stopAll(); };
mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::RealTime,
stop_routine,
"RestConfigTest-addPostCall stop routine",
false
);
mainloop->run();
}
TEST_F(RestConfigTest, post_call_integration_test)
{
env.preload();
Singleton::Consume<I_Environment>::from(env)->registerValue<string>("Base Executable Name", "tmp_test_file");
config.preload();
config.init();
rest_server.init();
time_proxy.init();
mainloop_comp.init();
auto i_rest = Singleton::Consume<I_RestApi>::from(rest_server);
// Add a POST endpoint that echoes back the request body with prefix
ASSERT_TRUE(i_rest->addPostCall("echo", [](const string &body) -> Maybe<string> {
return string("Echo: ") + body;
}));
int file_descriptor = socket(AF_INET, SOCK_STREAM, 0);
EXPECT_NE(file_descriptor, -1);
auto primary_port = getConfiguration<uint>("connection", "Nano service API Port Alternative");
struct sockaddr_in sa;
sa.sin_family = AF_INET;
sa.sin_port = htons(primary_port.unpack());
sa.sin_addr.s_addr = inet_addr("127.0.0.1");
int socket_enable = 1;
EXPECT_EQ(setsockopt(file_descriptor, SOL_SOCKET, SO_REUSEADDR, &socket_enable, sizeof(int)), 0);
EXPECT_CALL(messaging, sendSyncMessage(_, _, _, _, _))
.WillRepeatedly(Return(HTTPResponse(HTTPStatusCode::HTTP_OK, "")));
auto mainloop = Singleton::Consume<I_MainLoop>::from(mainloop_comp);
I_MainLoop::Routine stop_routine = [&] () {
EXPECT_EQ(connect(file_descriptor, (struct sockaddr*)&sa, sizeof(struct sockaddr)), 0)
<< "file_descriptor Error: " << strerror(errno);
string test_body = "Hello World";
string msg = "POST /echo HTTP/1.1\r\nContent-Length: " +
to_string(test_body.length())
+ "\r\n\r\n" + test_body;
EXPECT_EQ(write(file_descriptor, msg.data(), msg.size()), static_cast<int>(msg.size()));
struct pollfd s_poll;
s_poll.fd = file_descriptor;
s_poll.events = POLLIN;
s_poll.revents = 0;
while(poll(&s_poll, 1, 0) <= 0) {
mainloop->yield(true);
}
mainloop->stopAll();
};
mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::RealTime,
stop_routine,
"RestConfigTest-post_call_integration_test stop routine",
true
);
mainloop->run();
char response[1000];
int bytes_read = read(file_descriptor, response, 1000);
EXPECT_GT(bytes_read, 0);
string response_str(response, bytes_read);
EXPECT_THAT(response_str, HasSubstr("HTTP/1.1 200 OK"));
EXPECT_THAT(response_str, HasSubstr("Echo: Hello World"));
close(file_descriptor);
}
TEST_F(RestConfigTest, post_call_generic_error_test)
{
env.preload();
Singleton::Consume<I_Environment>::from(env)->registerValue<string>("Base Executable Name", "tmp_test_file");
config.preload();
config.init();
rest_server.init();
time_proxy.init();
mainloop_comp.init();
auto i_rest = Singleton::Consume<I_RestApi>::from(rest_server);
// Add a POST endpoint that returns a generic error
ASSERT_TRUE(i_rest->addPostCall("error-test", [](const string &) -> Maybe<string> {
return genError("Test error message");
}));
int file_descriptor = socket(AF_INET, SOCK_STREAM, 0);
EXPECT_NE(file_descriptor, -1);
auto primary_port = getConfiguration<uint>("connection", "Nano service API Port Alternative");
struct sockaddr_in sa;
sa.sin_family = AF_INET;
sa.sin_port = htons(primary_port.unpack());
sa.sin_addr.s_addr = inet_addr("127.0.0.1");
int socket_enable = 1;
EXPECT_EQ(setsockopt(file_descriptor, SOL_SOCKET, SO_REUSEADDR, &socket_enable, sizeof(int)), 0);
EXPECT_CALL(messaging, sendSyncMessage(_, _, _, _, _))
.WillRepeatedly(Return(HTTPResponse(HTTPStatusCode::HTTP_OK, "")));
auto mainloop = Singleton::Consume<I_MainLoop>::from(mainloop_comp);
I_MainLoop::Routine stop_routine = [&] () {
EXPECT_EQ(connect(file_descriptor, (struct sockaddr*)&sa, sizeof(struct sockaddr)), 0)
<< "file_descriptor Error: " << strerror(errno);
string test_body = "Test request body";
string msg = "POST /error-test HTTP/1.1\r\nContent-Length: " +
to_string(test_body.length())
+ "\r\n\r\n" + test_body;
EXPECT_EQ(write(file_descriptor, msg.data(), msg.size()), static_cast<int>(msg.size()));
struct pollfd s_poll;
s_poll.fd = file_descriptor;
s_poll.events = POLLIN;
s_poll.revents = 0;
while(poll(&s_poll, 1, 0) <= 0) {
mainloop->yield(true);
}
mainloop->stopAll();
};
mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::RealTime,
stop_routine,
"RestConfigTest-post_call_generic_error_test stop routine",
true
);
mainloop->run();
char response[1000];
int bytes_read = read(file_descriptor, response, 1000);
EXPECT_GT(bytes_read, 0);
string response_str(response, bytes_read);
EXPECT_THAT(response_str, HasSubstr("HTTP/1.1 500 Internal Server Error"));
EXPECT_THAT(response_str, HasSubstr("Test error message"));
close(file_descriptor);
}

View File

@@ -444,6 +444,7 @@ TEST(RestSchema, server_schema)
EXPECT_CALL(mock_agent_details, getAccessToken()).WillRepeatedly(testing::Return(string("accesstoken")));
EXPECT_CALL(mock_agent_details, getFogDomain()).WillRepeatedly(testing::Return(string("127.0.0.1")));
EXPECT_CALL(mock_agent_details, getFogPort()).WillRepeatedly(testing::Return(9777));
EXPECT_CALL(mock_agent_details, getProxy()).WillRepeatedly(testing::Return(string("")));
string config_json =
"{"

View File

@@ -199,6 +199,11 @@ void
resetIpc(SharedMemoryIPC *ipc, uint16_t num_of_data_segments)
{
writeDebug(TraceLevel, "Reseting IPC queues\n");
if (!ipc || !ipc->rx_queue || !ipc->tx_queue) {
writeDebug(WarningLevel, "resetIpc called with NULL ipc pointer\n");
return;
}
resetRingQueue(ipc->rx_queue, num_of_data_segments);
resetRingQueue(ipc->tx_queue, num_of_data_segments);
}
@@ -208,6 +213,11 @@ destroyIpc(SharedMemoryIPC *shmem, int is_owner)
{
writeDebug(TraceLevel, "Destroying IPC queues\n");
if (!shmem) {
writeDebug(WarningLevel, "Destroying IPC queues called with NULL shmem pointer\n");
return;
}
if (shmem->rx_queue != NULL) {
destroySharedRingQueue(shmem->rx_queue, is_owner, isTowardsOwner(is_owner, 0));
shmem->rx_queue = NULL;
@@ -225,6 +235,10 @@ dumpIpcMemory(SharedMemoryIPC *ipc)
{
writeDebug(WarningLevel, "Ipc memory dump:\n");
writeDebug(WarningLevel, "RX queue:\n");
if (!ipc || !ipc->rx_queue) {
writeDebug(WarningLevel, "RX queue is NULL\n");
return;
}
dumpRingQueueShmem(ipc->rx_queue);
writeDebug(WarningLevel, "TX queue:\n");
dumpRingQueueShmem(ipc->tx_queue);
@@ -234,6 +248,10 @@ int
sendData(SharedMemoryIPC *ipc, const uint16_t data_to_send_size, const char *data_to_send)
{
writeDebug(TraceLevel, "Sending data of size %u\n", data_to_send_size);
if (!ipc || !ipc->tx_queue) {
writeDebug(WarningLevel, "sendData called with NULL ipc pointer\n");
return -1;
}
return pushToQueue(ipc->tx_queue, data_to_send, data_to_send_size);
}
@@ -247,12 +265,22 @@ sendChunkedData(
{
writeDebug(TraceLevel, "Sending %u chunks of data\n", num_of_data_elem);
if (!ipc) {
writeDebug(WarningLevel, "sendChunkedData called with NULL ipc pointer\n");
return -1;
}
return pushBuffersToQueue(ipc->tx_queue, data_elem_to_send, data_to_send_sizes, num_of_data_elem);
}
int
receiveData(SharedMemoryIPC *ipc, uint16_t *received_data_size, const char **received_data)
{
if (!ipc) {
writeDebug(WarningLevel, "receiveData called with NULL ipc pointer\n");
return -1;
}
int res = peekToQueue(ipc->rx_queue, received_data, received_data_size);
writeDebug(TraceLevel, "Received data from queue. Res: %d, data size: %u\n", res, *received_data_size);
return res;
@@ -261,6 +289,10 @@ receiveData(SharedMemoryIPC *ipc, uint16_t *received_data_size, const char **rec
int
popData(SharedMemoryIPC *ipc)
{
if (!ipc) {
writeDebug(WarningLevel, "popData called with NULL ipc pointer\n");
return -1;
}
int res = popFromQueue(ipc->rx_queue);
writeDebug(TraceLevel, "Popped data from queue. Res: %d\n", res);
return res;
@@ -269,6 +301,10 @@ popData(SharedMemoryIPC *ipc)
int
isDataAvailable(SharedMemoryIPC *ipc)
{
if (!ipc) {
writeDebug(WarningLevel, "isDataAvailable called with NULL ipc pointer\n");
return 0;
}
int res = !isQueueEmpty(ipc->rx_queue);
writeDebug(TraceLevel, "Checking if there is data pending to be read. Res: %d\n", res);
return res;
@@ -277,6 +313,11 @@ isDataAvailable(SharedMemoryIPC *ipc)
int
isCorruptedShmem(SharedMemoryIPC *ipc, int is_owner)
{
if (!ipc) {
writeDebug(WarningLevel, "isCorruptedShmem called with NULL ipc pointer\n");
return 1;
}
if (isCorruptedQueue(ipc->rx_queue, isTowardsOwner(is_owner, 0)) ||
isCorruptedQueue(ipc->tx_queue, isTowardsOwner(is_owner, 1))
) {

View File

@@ -54,6 +54,7 @@ public:
is_server_socket = from.is_server_socket;
socket_int = from.socket_int;
from.socket_int = -1;
i_mainloop = Singleton::Consume<I_MainLoop>::by<SocketIS>();
}
virtual ~SocketInternal()
@@ -112,6 +113,115 @@ public:
return true;
}
bool
writeDataAsync(const vector<char> &data)
{
uint32_t bytes_sent = 0;
bool is_first_iter = true;
uint32_t max_retries = 10;
uint32_t retry_count = 0;
while (bytes_sent < data.size() && retry_count < max_retries) {
if (!is_first_iter && !is_blocking) {
dbgTrace(D_SOCKET)
<< "Trying to yield before writing to socket again. Bytes written: "
<< bytes_sent
<< ", Total bytes: "
<< data.size();
if (!i_mainloop) {
i_mainloop = Singleton::Consume<I_MainLoop>::by<SocketIS>();
}
i_mainloop->yield(false);
}
is_first_iter = false;
int res = send(socket_int, data.data() + bytes_sent, data.size() - bytes_sent, MSG_NOSIGNAL);
if (res <= 0) {
int err = errno;
// Check if it's a temporary error that can be retried
if (res == -1 && (err == EAGAIN || err == EWOULDBLOCK)) {
dbgTrace(D_SOCKET)
<< "Send would block (EAGAIN/EWOULDBLOCK), waiting for socket to become writable. "
<< "Bytes sent so far: "
<< bytes_sent;
// Use poll to wait for socket to become writable with 10ms timeout
struct pollfd pfd;
pfd.fd = socket_int;
pfd.events = POLLOUT;
pfd.revents = 0;
int poll_result = poll(&pfd, 1, 10);
if (poll_result > 0 && (pfd.revents & POLLOUT)) {
dbgTrace(D_SOCKET) << "Socket became writable, retrying send";
retry_count++;
continue;
} else if (poll_result == 0) {
dbgWarning(D_SOCKET)
<< "Timeout waiting for socket to become writable after 100ms. "
<< "Bytes sent: " << bytes_sent << "/" << data.size();
retry_count++;
continue;
} else {
dbgWarning(D_SOCKET)
<< "Poll failed while waiting for writable socket. Error: "
<< strerror(errno);
return false;
}
}
if (
res == 0
|| err == EPIPE
|| err == ECONNRESET
|| err == ENOTCONN
|| err == ESHUTDOWN
|| err == EBADF
|| err == EINVAL
) {
dbgWarning(D_SOCKET)
<< "Fatal error sending data. Error: "
<< strerror(err)
<< ", bytes sent: "
<< bytes_sent
<< "/"
<< data.size();
return false;
}
if (err == EINTR) {
dbgTrace(D_SOCKET) << "Send interrupted (EINTR), retrying immediately";
retry_count++;
continue;
}
dbgWarning(D_SOCKET)
<< "Unexpected error sending data. Error: "
<< strerror(err)
<< ", errno: "
<< err
<< ", bytes sent: "
<< bytes_sent
<< "/"
<< data.size();
return false;
}
bytes_sent += res;
retry_count = 0;
}
if (retry_count >= max_retries) {
dbgWarning(D_SOCKET) << "Reached max retries (" << max_retries << ") for socket write";
return false;
}
return true;
}
bool
isDataAvailable()
{
@@ -223,6 +333,7 @@ protected:
bool is_blocking = false;
bool is_server_socket = true;
int socket_int = -1;
I_MainLoop *i_mainloop = nullptr;
private:
Maybe<string>
@@ -718,6 +829,7 @@ public:
void closeSocket(socketFd &socket_fd) override;
bool writeData(socketFd socket_fd, const vector<char> &data) override;
bool writeDataAsync(socketFd socket_fd, const vector<char> &data) override;
Maybe<vector<char>> receiveData(socketFd socket_fd, uint data_size, bool is_blocking = true) override;
bool isDataAvailable(socketFd socket) override;
@@ -820,6 +932,18 @@ SocketIS::Impl::writeData(socketFd socket_fd, const vector<char> &data)
return sock->second->writeData(data);
}
bool
SocketIS::Impl::writeDataAsync(socketFd socket_fd, const vector<char> &data)
{
auto sock = active_sockets.find(socket_fd);
if (sock == active_sockets.end()) {
dbgWarning(D_SOCKET) << "The provided socket file descriptor does not exist. Socket FD: " << socket_fd;
return false;
}
return sock->second->writeDataAsync(data);
}
Maybe<vector<char>>
SocketIS::Impl::receiveData(socketFd socket_fd, uint data_size, bool is_blocking)
{