Răsfoiți Sursa

Added set_default_headers (Fix #600)

yhirose 5 ani în urmă
părinte
comite
04002d57bd
2 a modificat fișierele cu 93 adăugiri și 16 ștergeri
  1. 68 7
      httplib.h
  2. 25 9
      test/test.cc

+ 68 - 7
httplib.h

@@ -702,9 +702,16 @@ public:
   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
                                 ContentReceiver content_receiver,
                                 Progress progress);
+  std::shared_ptr<Response> Get(const char *path,
+                                ResponseHandler response_handler,
+                                ContentReceiver content_receiver);
   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
                                 ResponseHandler response_handler,
                                 ContentReceiver content_receiver);
+  std::shared_ptr<Response> Get(const char *path,
+                                ResponseHandler response_handler,
+                                ContentReceiver content_receiver,
+                                Progress progress);
   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
                                 ResponseHandler response_handler,
                                 ContentReceiver content_receiver,
@@ -781,6 +788,8 @@ public:
 
   void stop();
 
+  void set_default_headers(Headers headers);
+
   void set_tcp_nodelay(bool on);
   void set_socket_options(SocketOptions socket_options);
 
@@ -838,6 +847,9 @@ protected:
   mutable std::mutex socket_mutex_;
   std::recursive_mutex request_mutex_;
 
+  // Default headers
+  Headers default_headers_;
+
   // Settings
   std::string client_cert_path_;
   std::string client_key_path_;
@@ -967,6 +979,9 @@ public:
   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
                                 ContentReceiver content_receiver,
                                 Progress progress);
+  std::shared_ptr<Response> Get(const char *path,
+                                ResponseHandler response_handler,
+                                ContentReceiver content_receiver);
   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
                                 ResponseHandler response_handler,
                                 ContentReceiver content_receiver);
@@ -974,6 +989,10 @@ public:
                                 ResponseHandler response_handler,
                                 ContentReceiver content_receiver,
                                 Progress progress);
+  std::shared_ptr<Response> Get(const char *path,
+                                ResponseHandler response_handler,
+                                ContentReceiver content_receiver,
+                                Progress progress);
 
   std::shared_ptr<Response> Head(const char *path);
   std::shared_ptr<Response> Head(const char *path, const Headers &headers);
@@ -1044,6 +1063,8 @@ public:
 
   void stop();
 
+  void set_default_headers(Headers headers);
+
   void set_tcp_nodelay(bool on);
   void set_socket_options(SocketOptions socket_options);
 
@@ -3285,7 +3306,7 @@ make_basic_authentication_header(const std::string &username,
 
 inline std::pair<std::string, std::string>
 make_bearer_token_authentication_header(const std::string &token,
-                                 bool is_proxy = false) {
+                                        bool is_proxy = false) {
   auto field = "Bearer " + token;
   auto key = is_proxy ? "Proxy-Authorization" : "Authorization";
   return std::make_pair(key, field);
@@ -4788,7 +4809,8 @@ inline std::shared_ptr<Response> ClientImpl::send_with_content_provider(
     ContentProvider content_provider, const char *content_type) {
   Request req;
   req.method = method;
-  req.headers = headers;
+  req.headers = default_headers_;
+  req.headers.insert(headers.begin(), headers.end());
   req.path = path;
 
   if (content_type) { req.headers.emplace("Content-Type", content_type); }
@@ -4928,7 +4950,8 @@ ClientImpl::Get(const char *path, const Headers &headers, Progress progress) {
   Request req;
   req.method = "GET";
   req.path = path;
-  req.headers = headers;
+  req.headers = default_headers_;
+  req.headers.insert(headers.begin(), headers.end());
   req.progress = std::move(progress);
 
   auto res = std::make_shared<Response>();
@@ -4960,6 +4983,13 @@ ClientImpl::Get(const char *path, const Headers &headers,
              std::move(progress));
 }
 
+inline std::shared_ptr<Response>
+ClientImpl::Get(const char *path, ResponseHandler response_handler,
+                ContentReceiver content_receiver) {
+  return Get(path, Headers(), std::move(response_handler), content_receiver,
+             Progress());
+}
+
 inline std::shared_ptr<Response>
 ClientImpl::Get(const char *path, const Headers &headers,
                 ResponseHandler response_handler,
@@ -4968,6 +4998,13 @@ ClientImpl::Get(const char *path, const Headers &headers,
              Progress());
 }
 
+inline std::shared_ptr<Response>
+ClientImpl::Get(const char *path, ResponseHandler response_handler,
+                ContentReceiver content_receiver, Progress progress) {
+  return Get(path, Headers(), std::move(response_handler), content_receiver,
+             progress);
+}
+
 inline std::shared_ptr<Response>
 ClientImpl::Get(const char *path, const Headers &headers,
                 ResponseHandler response_handler,
@@ -4975,7 +5012,8 @@ ClientImpl::Get(const char *path, const Headers &headers,
   Request req;
   req.method = "GET";
   req.path = path;
-  req.headers = headers;
+  req.headers = default_headers_;
+  req.headers.insert(headers.begin(), headers.end());
   req.response_handler = std::move(response_handler);
   req.content_receiver = std::move(content_receiver);
   req.progress = std::move(progress);
@@ -4992,7 +5030,8 @@ inline std::shared_ptr<Response> ClientImpl::Head(const char *path,
                                                   const Headers &headers) {
   Request req;
   req.method = "HEAD";
-  req.headers = headers;
+  req.headers = default_headers_;
+  req.headers.insert(headers.begin(), headers.end());
   req.path = path;
 
   auto res = std::make_shared<Response>();
@@ -5171,7 +5210,8 @@ inline std::shared_ptr<Response> ClientImpl::Delete(const char *path,
                                                     const char *content_type) {
   Request req;
   req.method = "DELETE";
-  req.headers = headers;
+  req.headers = default_headers_;
+  req.headers.insert(headers.begin(), headers.end());
   req.path = path;
 
   if (content_type) { req.headers.emplace("Content-Type", content_type); }
@@ -5190,8 +5230,9 @@ inline std::shared_ptr<Response> ClientImpl::Options(const char *path,
                                                      const Headers &headers) {
   Request req;
   req.method = "OPTIONS";
+  req.headers = default_headers_;
+  req.headers.insert(headers.begin(), headers.end());
   req.path = path;
-  req.headers = headers;
 
   auto res = std::make_shared<Response>();
 
@@ -5250,6 +5291,10 @@ inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; }
 
 inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; }
 
+inline void ClientImpl::set_default_headers(Headers headers) {
+  default_headers_ = std::move(headers);
+}
+
 inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
 
 inline void ClientImpl::set_socket_options(SocketOptions socket_options) {
@@ -6001,12 +6046,24 @@ inline std::shared_ptr<Response> Client::Get(const char *path,
                                              Progress progress) {
   return cli_->Get(path, headers, content_receiver, progress);
 }
+inline std::shared_ptr<Response> Client::Get(const char *path,
+                                             ResponseHandler response_handler,
+                                             ContentReceiver content_receiver) {
+  return cli_->Get(path, Headers(), response_handler, content_receiver);
+}
 inline std::shared_ptr<Response> Client::Get(const char *path,
                                              const Headers &headers,
                                              ResponseHandler response_handler,
                                              ContentReceiver content_receiver) {
   return cli_->Get(path, headers, response_handler, content_receiver);
 }
+inline std::shared_ptr<Response> Client::Get(const char *path,
+                                             ResponseHandler response_handler,
+                                             ContentReceiver content_receiver,
+                                             Progress progress) {
+  return cli_->Get(path, Headers(), response_handler, content_receiver,
+                   progress);
+}
 inline std::shared_ptr<Response> Client::Get(const char *path,
                                              const Headers &headers,
                                              ResponseHandler response_handler,
@@ -6157,6 +6214,10 @@ inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); }
 
 inline void Client::stop() { cli_->stop(); }
 
+inline void Client::set_default_headers(Headers headers) {
+  cli_->set_default_headers(std::move(headers));
+}
+
 inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); }
 inline void Client::set_socket_options(SocketOptions socket_options) {
   cli_->set_socket_options(socket_options);

+ 25 - 9
test/test.cc

@@ -354,7 +354,7 @@ TEST(ChunkedEncodingTest, WithResponseHandlerAndContentReceiver) {
 
   std::string body;
   auto res = cli.Get(
-      "/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137", Headers(),
+      "/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137",
       [&](const Response &response) {
         EXPECT_EQ(200, response.status);
         return true;
@@ -372,6 +372,26 @@ TEST(ChunkedEncodingTest, WithResponseHandlerAndContentReceiver) {
   EXPECT_EQ(out, body);
 }
 
+TEST(DefaultHeadersTest, FromHTTPBin) {
+  Client cli("httpbin.org");
+  cli.set_default_headers({make_range_header({{1, 10}})});
+  cli.set_connection_timeout(5);
+
+  {
+    auto res = cli.Get("/range/32");
+    ASSERT_TRUE(res != nullptr);
+    EXPECT_EQ("bcdefghijk", res->body);
+    EXPECT_EQ(206, res->status);
+  }
+
+  {
+    auto res = cli.Get("/range/32");
+    ASSERT_TRUE(res != nullptr);
+    EXPECT_EQ("bcdefghijk", res->body);
+    EXPECT_EQ(206, res->status);
+  }
+}
+
 TEST(RangeTest, FromHTTPBin) {
   auto host = "httpbin.org";
 
@@ -385,8 +405,7 @@ TEST(RangeTest, FromHTTPBin) {
   cli.set_connection_timeout(5);
 
   {
-    Headers headers;
-    auto res = cli.Get("/range/32", headers);
+    auto res = cli.Get("/range/32");
     ASSERT_TRUE(res != nullptr);
     EXPECT_EQ("abcdefghijklmnopqrstuvwxyzabcdef", res->body);
     EXPECT_EQ(200, res->status);
@@ -541,8 +560,7 @@ TEST(CancelTest, WithCancelLargePayload) {
   cli.set_connection_timeout(5);
 
   uint32_t count = 0;
-  Headers headers;
-  auto res = cli.Get("/range/65536", headers,
+  auto res = cli.Get("/range/65536",
                      [&count](uint64_t, uint64_t) { return (count++ == 0); });
   ASSERT_TRUE(res == nullptr);
 }
@@ -2319,8 +2337,7 @@ TEST_F(ServerTest, Gzip) {
 }
 
 TEST_F(ServerTest, GzipWithoutAcceptEncoding) {
-  Headers headers;
-  auto res = cli_.Get("/compress", headers);
+  auto res = cli_.Get("/compress");
 
   ASSERT_TRUE(res != nullptr);
   EXPECT_TRUE(res->get_header_value("Content-Encoding").empty());
@@ -2369,9 +2386,8 @@ TEST_F(ServerTest, GzipWithoutDecompressing) {
 }
 
 TEST_F(ServerTest, GzipWithContentReceiverWithoutAcceptEncoding) {
-  Headers headers;
   std::string body;
-  auto res = cli_.Get("/compress", headers,
+  auto res = cli_.Get("/compress",
                       [&](const char *data, uint64_t data_length) {
                         EXPECT_EQ(data_length, 100);
                         body.append(data, data_length);