Browse Source

Add optional user defined header writer (#1683)

* Add optional user defined header writer

* Fix errors and add test
PabloMK7 2 years ago
parent
commit
a609330e4c
2 changed files with 74 additions and 2 deletions
  1. 34 2
      httplib.h
  2. 40 0
      test/test.cc

+ 34 - 2
httplib.h

@@ -737,6 +737,8 @@ private:
   std::regex regex_;
   std::regex regex_;
 };
 };
 
 
+ssize_t write_headers(Stream &strm, const Headers &headers);
+
 } // namespace detail
 } // namespace detail
 
 
 class Server {
 class Server {
@@ -800,6 +802,8 @@ public:
   Server &set_socket_options(SocketOptions socket_options);
   Server &set_socket_options(SocketOptions socket_options);
 
 
   Server &set_default_headers(Headers headers);
   Server &set_default_headers(Headers headers);
+  Server &
+  set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
 
 
   Server &set_keep_alive_max_count(size_t count);
   Server &set_keep_alive_max_count(size_t count);
   Server &set_keep_alive_timeout(time_t sec);
   Server &set_keep_alive_timeout(time_t sec);
@@ -934,6 +938,8 @@ private:
   SocketOptions socket_options_ = default_socket_options;
   SocketOptions socket_options_ = default_socket_options;
 
 
   Headers default_headers_;
   Headers default_headers_;
+  std::function<ssize_t(Stream &, Headers &)> header_writer_ =
+      detail::write_headers;
 };
 };
 
 
 enum class Error {
 enum class Error {
@@ -1164,6 +1170,9 @@ public:
 
 
   void set_default_headers(Headers headers);
   void set_default_headers(Headers headers);
 
 
+  void
+  set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
+
   void set_address_family(int family);
   void set_address_family(int family);
   void set_tcp_nodelay(bool on);
   void set_tcp_nodelay(bool on);
   void set_socket_options(SocketOptions socket_options);
   void set_socket_options(SocketOptions socket_options);
@@ -1273,6 +1282,10 @@ protected:
   // Default headers
   // Default headers
   Headers default_headers_;
   Headers default_headers_;
 
 
+  // Header writer
+  std::function<ssize_t(Stream &, Headers &)> header_writer_ =
+      detail::write_headers;
+
   // Settings
   // Settings
   std::string client_cert_path_;
   std::string client_cert_path_;
   std::string client_key_path_;
   std::string client_key_path_;
@@ -1539,6 +1552,9 @@ public:
 
 
   void set_default_headers(Headers headers);
   void set_default_headers(Headers headers);
 
 
+  void
+  set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
+
   void set_address_family(int family);
   void set_address_family(int family);
   void set_tcp_nodelay(bool on);
   void set_tcp_nodelay(bool on);
   void set_socket_options(SocketOptions socket_options);
   void set_socket_options(SocketOptions socket_options);
@@ -5672,6 +5688,12 @@ inline Server &Server::set_default_headers(Headers headers) {
   return *this;
   return *this;
 }
 }
 
 
+inline Server &Server::set_header_writer(
+    std::function<ssize_t(Stream &, Headers &)> const &writer) {
+  header_writer_ = writer;
+  return *this;
+}
+
 inline Server &Server::set_keep_alive_max_count(size_t count) {
 inline Server &Server::set_keep_alive_max_count(size_t count) {
   keep_alive_max_count_ = count;
   keep_alive_max_count_ = count;
   return *this;
   return *this;
@@ -5866,7 +5888,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
       return false;
       return false;
     }
     }
 
 
-    if (!detail::write_headers(bstrm, res.headers)) { return false; }
+    if (!header_writer_(bstrm, res.headers)) { return false; }
 
 
     // Flush buffer
     // Flush buffer
     auto &data = bstrm.get_buffer();
     auto &data = bstrm.get_buffer();
@@ -7105,7 +7127,7 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req,
     const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path;
     const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path;
     bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str());
     bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str());
 
 
-    detail::write_headers(bstrm, req.headers);
+    header_writer_(bstrm, req.headers);
 
 
     // Flush buffer
     // Flush buffer
     auto &data = bstrm.get_buffer();
     auto &data = bstrm.get_buffer();
@@ -7916,6 +7938,11 @@ inline void ClientImpl::set_default_headers(Headers headers) {
   default_headers_ = std::move(headers);
   default_headers_ = std::move(headers);
 }
 }
 
 
+inline void ClientImpl::set_header_writer(
+    std::function<ssize_t(Stream &, Headers &)> const &writer) {
+  header_writer_ = writer;
+}
+
 inline void ClientImpl::set_address_family(int family) {
 inline void ClientImpl::set_address_family(int family) {
   address_family_ = family;
   address_family_ = family;
 }
 }
@@ -9110,6 +9137,11 @@ inline void Client::set_default_headers(Headers headers) {
   cli_->set_default_headers(std::move(headers));
   cli_->set_default_headers(std::move(headers));
 }
 }
 
 
+inline void Client::set_header_writer(
+    std::function<ssize_t(Stream &, Headers &)> const &writer) {
+  cli_->set_header_writer(writer);
+}
+
 inline void Client::set_address_family(int family) {
 inline void Client::set_address_family(int family) {
   cli_->set_address_family(family);
   cli_->set_address_family(family);
 }
 }

+ 40 - 0
test/test.cc

@@ -1592,6 +1592,46 @@ TEST(URLFragmentTest, WithFragment) {
   }
   }
 }
 }
 
 
+TEST(HeaderWriter, SetHeaderWriter) {
+  Server svr;
+
+  svr.set_header_writer([](Stream &strm, Headers &hdrs) {
+    hdrs.emplace("CustomServerHeader", "CustomServerValue");
+    return detail::write_headers(strm, hdrs);
+  });
+  svr.Get("/hi", [](const Request &req, Response &res) {
+    auto it = req.headers.find("CustomClientHeader");
+    EXPECT_TRUE(it != req.headers.end());
+    EXPECT_EQ(it->second, "CustomClientValue");
+    res.set_content("Hello World!\n", "text/plain");
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    thread.join();
+    ASSERT_FALSE(svr.is_running());
+  });
+
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+
+  {
+    Client cli(HOST, PORT);
+    cli.set_header_writer([](Stream &strm, Headers &hdrs) {
+      hdrs.emplace("CustomClientHeader", "CustomClientValue");
+      return detail::write_headers(strm, hdrs);
+    });
+
+    auto res = cli.Get("/hi");
+    EXPECT_TRUE(res);
+    EXPECT_EQ(200, res->status);
+
+    auto it = res->headers.find("CustomServerHeader");
+    EXPECT_TRUE(it != res->headers.end());
+    EXPECT_EQ(it->second, "CustomServerValue");
+  }
+}
+
 class ServerTest : public ::testing::Test {
 class ServerTest : public ::testing::Test {
 protected:
 protected:
   ServerTest()
   ServerTest()