sync code

This commit is contained in:
Ned Wright
2026-01-03 18:59:01 +00:00
parent c1058db57d
commit 2105628f05
188 changed files with 8272 additions and 2723 deletions

View File

@@ -16,6 +16,7 @@
#include <string>
#include <istream>
#include <map>
#include "maybe_res.h"
@@ -23,11 +24,20 @@ class I_RestInvoke
{
public:
virtual Maybe<std::string> getSchema(const std::string &uri) const = 0;
virtual Maybe<std::string> invokeRest(const std::string &uri, std::istream &in) const = 0;
virtual Maybe<std::string> invokeRest(
const std::string &uri,
std::istream &in,
const std::map<std::string, std::string> &headers
) const = 0;
virtual bool isGetCall(const std::string &uri) const = 0;
virtual std::string invokeGet(const std::string &uri) const = 0;
virtual bool isPostCall(const std::string &uri) const = 0;
virtual Maybe<std::string> invokePost(const std::string &uri, const std::string &body) const = 0;
virtual bool shouldCaptureHeaders(const std::string &uri) const = 0;
protected:
~I_RestInvoke() {}
};

View File

@@ -29,8 +29,11 @@ RestHelper::reportError(std::string const &err)
}
Maybe<string>
ServerRest::performRestCall(istream &in)
ServerRest::performRestCall(istream &in, const map<string, string> &headers)
{
if (wantsHeaders()) {
request_headers = headers;
}
try {
try {
int firstChar = in.peek();

View File

@@ -75,17 +75,52 @@ RestConn::parseConn() const
dbgDebug(D_API) << "Call identifier: " << identifier;
uint len = 0;
map<string, string> headers;
bool should_capture_headers = invoke->shouldCaptureHeaders(identifier);
while (true) {
line = readLine();
if (line.size() < 3) break;
os.str(line);
string head, data;
os >> head >> data;
if (compareStringCaseInsensitive(head, "Content-Length:")) {
try {
len = stoi(data, nullptr);
} catch (...) {
if (should_capture_headers) {
size_t colon_pos = line.find(':');
if (colon_pos == string::npos) continue;
string head = line.substr(0, colon_pos);
string data = line.substr(colon_pos + 1);
size_t data_start = data.find_first_not_of(" \t\r\n");
if (data_start != string::npos) {
data = data.substr(data_start);
} else {
data = "";
}
size_t data_end = data.find_last_not_of(" \t\r\n");
if (data_end != string::npos) {
data = data.substr(0, data_end + 1);
}
if (!head.empty()) {
headers[head] = data;
dbgTrace(D_API) << "Captured header: " << head << " = " << data;
}
if (compareStringCaseInsensitive(head, "Content-Length")) {
try {
len = stoi(data, nullptr);
} catch (...) {
}
}
} else {
os.str(line);
string head, data;
os >> head >> data;
if (compareStringCaseInsensitive(head, "Content-Length:")) {
try {
len = stoi(data, nullptr);
} catch (...) {
}
}
}
}
@@ -113,7 +148,19 @@ RestConn::parseConn() const
dbgTrace(D_API) << "Message content: " << body.str();
Maybe<string> res = (method == "POST") ? invoke->invokeRest(identifier, body) : invoke->getSchema(identifier);
if (method == "POST" && invoke->isPostCall(identifier)) {
Maybe<string> result = invoke->invokePost(identifier, body.str());
if (!result.ok()) {
dbgWarning(D_API) << "Failed to invoke POST call: " << result.getErr();
sendResponse("500 Internal Server Error", result.getErr());
return;
}
sendResponse("200 OK", result.unpack());
return;
}
Maybe<string> res = (method == "POST") ?
invoke->invokeRest(identifier, body, headers) : invoke->getSchema(identifier);
if (res.ok()) {
sendResponse("200 OK", res.unpack());

View File

@@ -15,6 +15,7 @@
#define __REST_CONN_H__
#include <string>
#include <map>
#include "i_mainloop.h"
#include "i_rest_invoke.h"

View File

@@ -53,11 +53,16 @@ public:
bool addRestCall(RestAction oper, const string &uri, unique_ptr<RestInit> &&init) override;
bool addGetCall(const string &uri, const function<string()> &cb) override;
bool addWildcardGetCall(const string &uri, const function<string(const string &)> &callback);
bool addPostCall(const string &uri, const function<Maybe<string>(const string &)> &callback) override;
uint16_t getListeningPort() const override { return listening_port; }
uint16_t getStartingPortRange() const override { return starting_port_range; }
Maybe<string> getSchema(const string &uri) const override;
Maybe<string> invokeRest(const string &uri, istream &in) const override;
Maybe<string> invokeRest(const string &uri, istream &in, const map<string, string> &headers) const override;
bool isGetCall(const string &uri) const override;
string invokeGet(const string &uri) const override;
bool isPostCall(const string &uri) const override;
Maybe<string> invokePost(const string &uri, const string &body) const override;
bool shouldCaptureHeaders(const string &uri) const override;
private:
void prepareConfiguration();
@@ -71,7 +76,9 @@ private:
map<string, unique_ptr<RestInit>> rest_calls;
map<string, function<string()>> get_calls;
map<string, function<string(const string &)>> wildcard_get_calls;
map<string, function<Maybe<string>(const string &)>> post_calls;
uint16_t listening_port = 0;
uint16_t starting_port_range = 0;
vector<uint16_t> port_range;
};
@@ -241,6 +248,9 @@ RestServer::Impl::prepareConfiguration()
range_start = 0;
range_end = 1;
}
starting_port_range = range_start.unpack();
dbgInfo(D_API) << "Rest port range start: " << *range_start << ", end: " << *range_end;
// starting_port_range = *range_start;
port_range.resize(*range_end - *range_start);
for (uint16_t i = 0, port = *range_start; i < port_range.size(); i++, port++) {
port_range[i] = port;
@@ -379,6 +389,14 @@ RestServer::Impl::addWildcardGetCall(const string &uri, const function<string(co
return wildcard_get_calls.emplace(uri, callback).second;
}
bool
RestServer::Impl::addPostCall(const string &uri, const function<Maybe<string>(const string&)> &callback)
{
if (rest_calls.find(uri) != rest_calls.end()) return false;
if (get_calls.find(uri) != get_calls.end()) return false;
return post_calls.emplace(uri, callback).second;
}
Maybe<string>
RestServer::Impl::getSchema(const string &uri) const
{
@@ -392,12 +410,12 @@ RestServer::Impl::getSchema(const string &uri) const
}
Maybe<string>
RestServer::Impl::invokeRest(const string &uri, istream &in) const
RestServer::Impl::invokeRest(const string &uri, istream &in, const map<string, string> &headers) const
{
auto iter = rest_calls.find(uri);
if (iter == rest_calls.end()) return genError("No matching REST call was found");
auto instance = iter->second->getRest();
return instance->performRestCall(in);
return instance->performRestCall(in, headers);
}
bool
@@ -412,6 +430,13 @@ RestServer::Impl::isGetCall(const string &uri) const
return false;
}
bool
RestServer::Impl::isPostCall(const string &uri) const
{
if (post_calls.find(uri) != post_calls.end()) return true;
return false;
}
string
RestServer::Impl::invokeGet(const string &uri) const
{
@@ -425,6 +450,23 @@ RestServer::Impl::invokeGet(const string &uri) const
return "";
}
Maybe<string>
RestServer::Impl::invokePost(const string &uri, const string &body) const
{
auto instance = post_calls.find(uri);
if (instance != post_calls.end()) return instance->second(body);
return genError("No matching POST call was found for URI: " + uri);
}
bool
RestServer::Impl::shouldCaptureHeaders(const string &uri) const
{
auto iter = rest_calls.find(uri);
if (iter == rest_calls.end()) return false;
auto instance = iter->second->getRest();
return instance->wantsHeaders();
}
string
RestServer::Impl::changeActionToString(RestAction oper)
{

View File

@@ -414,3 +414,222 @@ TEST_F(RestConfigTest, not_loopback_flow)
"HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\nContent-Length: 0\r\n\r\n"
);
}
TEST_F(RestConfigTest, getStartingPortRange)
{
// Use a configuration with port range instead of primary/alternative ports
string config_json_with_range =
"{\n"
" \"connection\": {\n"
" \"Nano service API Port Range start\": [\n"
" {\n"
" \"value\": 8000\n"
" }\n"
" ],\n"
" \"Nano service API Port Range end\": [\n"
" {\n"
" \"value\": 8010\n"
" }\n"
" ]\n"
" }\n"
"}\n";
istringstream ss(config_json_with_range);
Singleton::Consume<Config::I_Config>::from(config)->loadConfiguration(ss);
rest_server.init();
auto i_rest = Singleton::Consume<I_RestApi>::from(rest_server);
EXPECT_EQ(i_rest->getStartingPortRange(), 8000);
auto mainloop = Singleton::Consume<I_MainLoop>::from(mainloop_comp);
I_MainLoop::Routine stop_routine = [mainloop] () { mainloop->stopAll(); };
mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::RealTime,
stop_routine,
"RestConfigTest-getStartingPortRange stop routine",
false
);
mainloop->run();
}
TEST_F(RestConfigTest, addPostCall)
{
rest_server.init();
time_proxy.init();
mainloop_comp.init();
auto i_rest = Singleton::Consume<I_RestApi>::from(rest_server);
// Test addPostCall
ASSERT_TRUE(i_rest->addPostCall("test-post", [](const string &body) {
return "Received: " + body;
}));
// Test that adding the same POST call twice fails
ASSERT_FALSE(i_rest->addPostCall("test-post", [](const string &) -> Maybe<string> {
return string("Different handler");
}));
// Test that adding POST call with existing GET call fails
ASSERT_TRUE(i_rest->addGetCall("test-get", []() { return "get response"; }));
ASSERT_FALSE(i_rest->addPostCall("test-get", [](const string &) -> Maybe<string> {
return string("post response");
}));
auto mainloop = Singleton::Consume<I_MainLoop>::from(mainloop_comp);
I_MainLoop::Routine stop_routine = [mainloop] () { mainloop->stopAll(); };
mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::RealTime,
stop_routine,
"RestConfigTest-addPostCall stop routine",
false
);
mainloop->run();
}
TEST_F(RestConfigTest, post_call_integration_test)
{
env.preload();
Singleton::Consume<I_Environment>::from(env)->registerValue<string>("Base Executable Name", "tmp_test_file");
config.preload();
config.init();
rest_server.init();
time_proxy.init();
mainloop_comp.init();
auto i_rest = Singleton::Consume<I_RestApi>::from(rest_server);
// Add a POST endpoint that echoes back the request body with prefix
ASSERT_TRUE(i_rest->addPostCall("echo", [](const string &body) -> Maybe<string> {
return string("Echo: ") + body;
}));
int file_descriptor = socket(AF_INET, SOCK_STREAM, 0);
EXPECT_NE(file_descriptor, -1);
auto primary_port = getConfiguration<uint>("connection", "Nano service API Port Alternative");
struct sockaddr_in sa;
sa.sin_family = AF_INET;
sa.sin_port = htons(primary_port.unpack());
sa.sin_addr.s_addr = inet_addr("127.0.0.1");
int socket_enable = 1;
EXPECT_EQ(setsockopt(file_descriptor, SOL_SOCKET, SO_REUSEADDR, &socket_enable, sizeof(int)), 0);
EXPECT_CALL(messaging, sendSyncMessage(_, _, _, _, _))
.WillRepeatedly(Return(HTTPResponse(HTTPStatusCode::HTTP_OK, "")));
auto mainloop = Singleton::Consume<I_MainLoop>::from(mainloop_comp);
I_MainLoop::Routine stop_routine = [&] () {
EXPECT_EQ(connect(file_descriptor, (struct sockaddr*)&sa, sizeof(struct sockaddr)), 0)
<< "file_descriptor Error: " << strerror(errno);
string test_body = "Hello World";
string msg = "POST /echo HTTP/1.1\r\nContent-Length: " +
to_string(test_body.length())
+ "\r\n\r\n" + test_body;
EXPECT_EQ(write(file_descriptor, msg.data(), msg.size()), static_cast<int>(msg.size()));
struct pollfd s_poll;
s_poll.fd = file_descriptor;
s_poll.events = POLLIN;
s_poll.revents = 0;
while(poll(&s_poll, 1, 0) <= 0) {
mainloop->yield(true);
}
mainloop->stopAll();
};
mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::RealTime,
stop_routine,
"RestConfigTest-post_call_integration_test stop routine",
true
);
mainloop->run();
char response[1000];
int bytes_read = read(file_descriptor, response, 1000);
EXPECT_GT(bytes_read, 0);
string response_str(response, bytes_read);
EXPECT_THAT(response_str, HasSubstr("HTTP/1.1 200 OK"));
EXPECT_THAT(response_str, HasSubstr("Echo: Hello World"));
close(file_descriptor);
}
TEST_F(RestConfigTest, post_call_generic_error_test)
{
env.preload();
Singleton::Consume<I_Environment>::from(env)->registerValue<string>("Base Executable Name", "tmp_test_file");
config.preload();
config.init();
rest_server.init();
time_proxy.init();
mainloop_comp.init();
auto i_rest = Singleton::Consume<I_RestApi>::from(rest_server);
// Add a POST endpoint that returns a generic error
ASSERT_TRUE(i_rest->addPostCall("error-test", [](const string &) -> Maybe<string> {
return genError("Test error message");
}));
int file_descriptor = socket(AF_INET, SOCK_STREAM, 0);
EXPECT_NE(file_descriptor, -1);
auto primary_port = getConfiguration<uint>("connection", "Nano service API Port Alternative");
struct sockaddr_in sa;
sa.sin_family = AF_INET;
sa.sin_port = htons(primary_port.unpack());
sa.sin_addr.s_addr = inet_addr("127.0.0.1");
int socket_enable = 1;
EXPECT_EQ(setsockopt(file_descriptor, SOL_SOCKET, SO_REUSEADDR, &socket_enable, sizeof(int)), 0);
EXPECT_CALL(messaging, sendSyncMessage(_, _, _, _, _))
.WillRepeatedly(Return(HTTPResponse(HTTPStatusCode::HTTP_OK, "")));
auto mainloop = Singleton::Consume<I_MainLoop>::from(mainloop_comp);
I_MainLoop::Routine stop_routine = [&] () {
EXPECT_EQ(connect(file_descriptor, (struct sockaddr*)&sa, sizeof(struct sockaddr)), 0)
<< "file_descriptor Error: " << strerror(errno);
string test_body = "Test request body";
string msg = "POST /error-test HTTP/1.1\r\nContent-Length: " +
to_string(test_body.length())
+ "\r\n\r\n" + test_body;
EXPECT_EQ(write(file_descriptor, msg.data(), msg.size()), static_cast<int>(msg.size()));
struct pollfd s_poll;
s_poll.fd = file_descriptor;
s_poll.events = POLLIN;
s_poll.revents = 0;
while(poll(&s_poll, 1, 0) <= 0) {
mainloop->yield(true);
}
mainloop->stopAll();
};
mainloop->addOneTimeRoutine(
I_MainLoop::RoutineType::RealTime,
stop_routine,
"RestConfigTest-post_call_generic_error_test stop routine",
true
);
mainloop->run();
char response[1000];
int bytes_read = read(file_descriptor, response, 1000);
EXPECT_GT(bytes_read, 0);
string response_str(response, bytes_read);
EXPECT_THAT(response_str, HasSubstr("HTTP/1.1 500 Internal Server Error"));
EXPECT_THAT(response_str, HasSubstr("Test error message"));
close(file_descriptor);
}

View File

@@ -444,6 +444,7 @@ TEST(RestSchema, server_schema)
EXPECT_CALL(mock_agent_details, getAccessToken()).WillRepeatedly(testing::Return(string("accesstoken")));
EXPECT_CALL(mock_agent_details, getFogDomain()).WillRepeatedly(testing::Return(string("127.0.0.1")));
EXPECT_CALL(mock_agent_details, getFogPort()).WillRepeatedly(testing::Return(9777));
EXPECT_CALL(mock_agent_details, getProxy()).WillRepeatedly(testing::Return(string("")));
string config_json =
"{"