summaryrefslogtreecommitdiff
path: root/src/server/connection.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/server/connection.cc')
-rw-r--r--src/server/connection.cc258
1 files changed, 258 insertions, 0 deletions
diff --git a/src/server/connection.cc b/src/server/connection.cc
new file mode 100644
index 0000000..c662d6d
--- /dev/null
+++ b/src/server/connection.cc
@@ -0,0 +1,258 @@
+// SPDX-License-Identifier: BSD-3-Clause
+
+#include <stdexcept>
+#include <unistd.h>
+#include <ios>
+#include <arpa/inet.h>
+#include <cerrno>
+
+#include "connection.h"
+#include "../p18/commands.h"
+#include "../p18/response.h"
+#include "../logging.h"
+#include "../common.h"
+#include "hexdump/hexdump.h"
+#include "signal.h"
+
+#define CHECK_ARGUMENTS_LENGTH(__size__) \
+ if (arguments.size() != (__size__)) { \
+ std::ostringstream error; \
+ error << "invalid arguments count: expected " << (__size__) << ", got " << arguments.size(); \
+ throw std::invalid_argument(error.str()); \
+ }
+
+#define CHECK_ARGUMENTS_MIN_LENGTH(__size__) \
+ if (arguments.size() < (__size__)) { \
+ std::ostringstream error; \
+ error << "invalid arguments count: expected " << (__size__) << ", got " << arguments.size(); \
+ throw std::invalid_argument(error.str()); \
+ }
+
+
+namespace server {
+
+Connection::Connection(int sock, struct sockaddr_in addr, Server* server)
+ : sock_(sock), addr_(addr), server_(server)
+{
+ if (server_->verbose())
+ mylog << "new connection from " << ipv4();
+
+ thread_ = std::thread(&Connection::run, this);
+ thread_.detach();
+}
+
+Connection::~Connection() {
+ if (server_->verbose())
+ mylog << "closing socket..";
+
+ if (close(sock_) == -1)
+ myerr << ipv4() << ": close: " << strerror(errno);
+
+ server_->removeConnection(this);
+}
+
+void Connection::run() {
+ static int bufSize = 2048;
+ char buf[bufSize];
+
+ while (true) {
+ long rcvd = readLoop(buf, bufSize - 1);
+ if (rcvd == -1) {
+ if (errno != EINTR && server_->verbose())
+ myerr << ipv4() << ": recv: " << std::string(strerror(errno));
+ break;
+ }
+ if (rcvd == 0)
+ break;
+
+ buf[rcvd] = '\0';
+ if (*buf == '\4')
+ break;
+
+ Response resp = processRequest(buf);
+ if (!sendResponse(resp))
+ break;
+ }
+
+ delete this;
+}
+
+int Connection::readLoop(char* buf, size_t bufSize) const {
+ char* bufptr = buf;
+ int left = static_cast<int>(bufSize);
+ int readed = 0;
+
+ while (left > 0) {
+ size_t rcvd = recv(sock_, bufptr, left, 0);
+ if (rcvd == -1)
+ return -1;
+ if (rcvd == 0)
+ break;
+
+ readed += static_cast<int>(rcvd);
+ if (*bufptr == '\4')
+ break;
+
+ left -= static_cast<int>(rcvd);
+ bufptr += rcvd;
+
+ bufptr[rcvd] = '\0';
+ char* ptr = strstr(buf, "\r\n");
+ if (ptr)
+ break;
+ }
+
+ return readed;
+}
+
+bool Connection::writeLoop(const char* buf, size_t bufSize) const {
+ const char* bufptr = buf;
+ int left = static_cast<int>(bufSize);
+
+ while (left > 0) {
+ size_t bytesSent = send(sock_, bufptr, left, 0);
+ if (bytesSent == -1) {
+ if (errno != EINTR && server_->verbose())
+ myerr << ipv4() << ": send: " << std::string(strerror(errno));
+ return false;
+ }
+
+ left -= static_cast<int>(bytesSent);
+ bufptr += bytesSent;
+ }
+
+ return true;
+}
+
+bool Connection::sendResponse(Response& resp) const {
+ std::ostringstream sbuf;
+ sbuf << resp;
+
+ std::string s = sbuf.str();
+ const char* buf = s.c_str();
+ size_t bufSize = s.size();
+
+ return writeLoop(buf, bufSize);
+}
+
+std::string Connection::ipv4() const {
+ char ip[INET_ADDRSTRLEN] = {0};
+ const char* result = inet_ntop(AF_INET, (const void*)&addr_.sin_addr, ip, sizeof(ip));
+ if (result == nullptr)
+ return "?";
+
+ std::ostringstream buf;
+ buf << ip << ":" << htons(addr_.sin_port);
+ return buf.str();
+}
+
+Response Connection::processRequest(char* buf) {
+ std::stringstream sbuf;
+ int n = 0;
+ std::vector<std::string> arguments;
+ RequestType type;
+
+ Response resp;
+ resp.type = ResponseType::OK;
+
+ try {
+ char* last = nullptr;
+ const char* delim = " ";
+ for (char* token = strtok_r(buf, delim, &last);
+ token != nullptr;
+ token = strtok_r(nullptr, delim, &last)) {
+
+ char* ptr = strstr(token, "\r\n");
+ if (ptr)
+ *ptr = '\0';
+
+ if (!n++) {
+ std::string s = std::string(token);
+
+ if (s == "format")
+ type = RequestType::Format;
+
+ else if (s == "v")
+ type = RequestType::Version;
+
+ else if (s == "exec")
+ type = RequestType::Execute;
+
+ else if (s == "raw")
+ type = RequestType::Raw;
+
+ else
+ throw std::invalid_argument("invalid token: " + s);
+
+ } else if (strlen(token) > 0)
+ arguments.emplace_back(token);
+ }
+
+ switch (type) {
+ case RequestType::Version: {
+ CHECK_ARGUMENTS_LENGTH(1)
+ auto v = static_cast<unsigned>(std::stoul(arguments[0]));
+ if (v != 1)
+ throw std::invalid_argument("invalid protocol version");
+ options_.version = v;
+ break;
+ }
+
+ case RequestType::Format:
+ CHECK_ARGUMENTS_LENGTH(1)
+ options_.format = format_from_string(arguments[0]);
+ break;
+
+ case RequestType::Execute: {
+ CHECK_ARGUMENTS_MIN_LENGTH(1)
+
+ std::string& command = arguments[0];
+ auto commandArguments = std::vector<std::string>();
+
+ auto argumentsSlice = std::vector<std::string>(arguments.begin()+1, arguments.end());
+
+ p18::CommandInput input{&argumentsSlice};
+ p18::CommandType commandType = p18::validate_input(command, commandArguments, (void*)&input);
+
+ auto response = server_->executeCommand(commandType, commandArguments);
+ resp.buf << *(response->format(options_.format).get());
+
+ break;
+ }
+
+ case RequestType::Raw: {
+ throw std::runtime_error("not implemented");
+// CHECK_ARGUMENTS_LENGTH(1)
+// std::string& raw = arguments[0];
+//
+// resp.type = ResponseType::Error;
+// resp.buf << "not implemented";
+ break;
+ }
+ }
+ }
+ // we except std::invalid_argument and std::runtime_error
+ catch (std::exception& e) {
+ resp.type = ResponseType::Error;
+
+ auto err = p18::response_type::ErrorResponse(e.what());
+ resp.buf << *(err.format(options_.format));
+ }
+
+ return resp;
+}
+
+std::ostream& operator<<(std::ostream& os, Response& resp) {
+ os << (resp.type == ResponseType::OK ? "ok" : "err");
+
+ resp.buf.seekp(0, std::ios::end);
+ size_t size = resp.buf.tellp();
+ if (size) {
+ resp.buf.seekp(0);
+ os << "\r\n" << resp.buf.str();
+ }
+
+ return os << "\r\n\r\n";
+}
+
+} \ No newline at end of file