Browse Source

Added set_default_headers on Server

yhirose 4 years ago
parent
commit
77a77f6d2d
2 changed files with 72 additions and 29 deletions
  1. 15 0
      httplib.h
  2. 57 29
      test/test.cc

+ 15 - 0
httplib.h

@@ -667,6 +667,8 @@ public:
   Server &set_tcp_nodelay(bool on);
   Server &set_socket_options(SocketOptions socket_options);
 
+  Server &set_default_headers(Headers headers);
+
   Server &set_keep_alive_max_count(size_t count);
   Server &set_keep_alive_timeout(time_t sec);
 
@@ -786,6 +788,8 @@ private:
   int address_family_ = AF_UNSPEC;
   bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY;
   SocketOptions socket_options_ = default_socket_options;
+
+  Headers default_headers_;
 };
 
 enum class Error {
@@ -4427,6 +4431,11 @@ inline Server &Server::set_socket_options(SocketOptions socket_options) {
   return *this;
 }
 
+inline Server &Server::set_default_headers(Headers headers) {
+  default_headers_ = std::move(headers);
+  return *this;
+}
+
 inline Server &Server::set_keep_alive_max_count(size_t count) {
   keep_alive_max_count_ = count;
   return *this;
@@ -5131,6 +5140,12 @@ Server::process_request(Stream &strm, bool close_connection,
 
   res.version = "HTTP/1.1";
 
+  for (const auto &header : default_headers_) {
+    if (res.headers.find(header.first) == res.headers.end()) {
+      res.headers.insert(header);
+    }
+  }
+
 #ifdef _WIN32
   // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL).
 #else

+ 57 - 29
test/test.cc

@@ -5,9 +5,9 @@
 #include <atomic>
 #include <chrono>
 #include <future>
+#include <sstream>
 #include <stdexcept>
 #include <thread>
-#include <sstream>
 
 #define SERVER_CERT_FILE "./cert.pem"
 #define SERVER_CERT2_FILE "./cert2.pem"
@@ -437,26 +437,6 @@ TEST(ChunkedEncodingTest, WithResponseHandlerAndContentReceiver) {
   EXPECT_EQ(out, body);
 }
 
-TEST(DefaultHeadersTest, FromHTTPBin) {
-  Client cli("httpbin.org");
-  cli.set_default_headers({make_range_header({{1, 10}})});
-  cli.set_connection_timeout(5);
-
-  {
-    auto res = cli.Get("/range/32");
-    ASSERT_TRUE(res);
-    EXPECT_EQ("bcdefghijk", res->body);
-    EXPECT_EQ(206, res->status);
-  }
-
-  {
-    auto res = cli.Get("/range/32");
-    ASSERT_TRUE(res);
-    EXPECT_EQ("bcdefghijk", res->body);
-    EXPECT_EQ(206, res->status);
-  }
-}
-
 TEST(RangeTest, FromHTTPBin) {
   auto host = "httpbin.org";
 
@@ -968,7 +948,7 @@ TEST(RedirectFromPageWithContent, Redirect) {
 TEST(PathUrlEncodeTest, PathUrlEncode) {
   Server svr;
 
-  svr.Get("/foo", [](const Request & req, Response &res) {
+  svr.Get("/foo", [](const Request &req, Response &res) {
     auto a = req.params.find("a");
     if (a != req.params.end()) {
       res.set_content((*a).second, "text/plain");
@@ -1420,7 +1400,8 @@ protected:
                      const auto &d = *data;
                      auto out_len =
                          std::min(static_cast<size_t>(length), DATA_CHUNK_SIZE);
-                     auto ret = sink.write(&d[static_cast<size_t>(offset)], out_len);
+                     auto ret =
+                         sink.write(&d[static_cast<size_t>(offset)], out_len);
                      EXPECT_TRUE(ret);
                      return true;
                    },
@@ -3199,12 +3180,11 @@ static bool send_request(time_t read_timeout_sec, const std::string &req,
                          std::string *resp = nullptr) {
   auto error = Error::Success;
 
-  auto client_sock =
-      detail::create_client_socket(HOST, PORT, AF_UNSPEC, false, nullptr,
-                                   /*connection_timeout_sec=*/5, 0,
-                                   /*read_timeout_sec=*/5, 0,
-                                   /*write_timeout_sec=*/5, 0,
-                                   std::string(), error);
+  auto client_sock = detail::create_client_socket(
+      HOST, PORT, AF_UNSPEC, false, nullptr,
+      /*connection_timeout_sec=*/5, 0,
+      /*read_timeout_sec=*/5, 0,
+      /*write_timeout_sec=*/5, 0, std::string(), error);
 
   if (client_sock == INVALID_SOCKET) { return false; }
 
@@ -3684,6 +3664,54 @@ TEST(GetWithParametersTest, GetWithParameters2) {
   ASSERT_FALSE(svr.is_running());
 }
 
+TEST(ClientDefaultHeadersTest, DefaultHeaders) {
+  Client cli("httpbin.org");
+  cli.set_default_headers({make_range_header({{1, 10}})});
+  cli.set_connection_timeout(5);
+
+  {
+    auto res = cli.Get("/range/32");
+    ASSERT_TRUE(res);
+    EXPECT_EQ("bcdefghijk", res->body);
+    EXPECT_EQ(206, res->status);
+  }
+
+  {
+    auto res = cli.Get("/range/32");
+    ASSERT_TRUE(res);
+    EXPECT_EQ("bcdefghijk", res->body);
+    EXPECT_EQ(206, res->status);
+  }
+}
+
+TEST(ServerDefaultHeadersTest, DefaultHeaders) {
+  Server svr;
+  svr.set_default_headers({{"Hello", "World"}});
+
+  svr.Get("/", [&](const Request & /*req*/, Response &res) {
+    res.set_content("ok", "text/plain");
+  });
+
+  auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
+  while (!svr.is_running()) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  }
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+
+  Client cli("localhost", PORT);
+
+  auto res = cli.Get("/");
+
+  ASSERT_TRUE(res);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ("ok", res->body);
+  EXPECT_EQ("World", res->get_header_value("Hello"));
+
+  svr.stop();
+  listen_thread.join();
+  ASSERT_FALSE(svr.is_running());
+}
+
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 TEST(KeepAliveTest, ReadTimeoutSSL) {
   SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);