Browse Source

Proxy support for Keep-Alive requests

yhirose 6 years ago
parent
commit
de844e67ef
3 changed files with 205 additions and 114 deletions
  1. 112 96
      httplib.h
  2. 0 1
      test/test.cc
  3. 93 17
      test/test_proxy.cc

+ 112 - 96
httplib.h

@@ -783,6 +783,9 @@ private:
   bool read_response_line(Stream &strm, Response &res);
   bool write_request(Stream &strm, const Request &req, bool last_connection);
   bool redirect(const Request &req, Response &res);
+  bool connect(socket_t sock, Response &res, bool &error);
+  bool handle_request(Stream &strm, const Request &req, Response &res,
+                      bool last_connection, bool &connection_close);
 
   std::shared_ptr<Response> send_with_content_provider(
       const char *method, const char *path, const Headers &headers,
@@ -2750,9 +2753,7 @@ inline Server::Server()
 #ifndef _WIN32
   signal(SIGPIPE, SIG_IGN);
 #endif
-  new_task_queue = [] {
-    return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT);
-  };
+  new_task_queue = [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); };
 }
 
 inline Server::~Server() {}
@@ -3445,77 +3446,128 @@ inline bool Client::read_response_line(Stream &strm, Response &res) {
 }
 
 inline bool Client::send(const Request &req, Response &res) {
-  if (req.path.empty()) { return false; }
-
   auto sock = create_client_socket();
   if (sock == INVALID_SOCKET) { return false; }
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-  // CONNECT
   if (is_ssl() && !proxy_host_.empty()) {
-    Response res2;
-    if (!detail::process_socket(
-            true, sock, 1, read_timeout_sec_, read_timeout_usec_,
-            [&](Stream &strm, bool /*last_connection*/,
-                bool &connection_close) {
-              Request req2;
-              req2.method = "CONNECT";
-              req2.path = host_and_port_;
-              return process_request(strm, req2, res2, false, connection_close);
-            })) {
-      detail::close_socket(sock);
+    bool error;
+    if (!connect(sock, res, error)) { return error; }
+  }
+#endif
+
+  return process_and_close_socket(
+      sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) {
+        return handle_request(strm, req, res, last_connection,
+                              connection_close);
+      });
+}
+
+inline bool Client::send(const std::vector<Request> &requests,
+                         std::vector<Response> &responses) {
+  size_t i = 0;
+  while (i < requests.size()) {
+    auto sock = create_client_socket();
+    if (sock == INVALID_SOCKET) { return false; }
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+    if (is_ssl() && !proxy_host_.empty()) {
+      Response res;
+      bool error;
+      if (!connect(sock, res, error)) { return false; }
+    }
+#endif
+
+    if (!process_and_close_socket(sock, requests.size() - i,
+                                  [&](Stream &strm, bool last_connection,
+                                      bool &connection_close) -> bool {
+                                    auto &req = requests[i++];
+                                    auto res = Response();
+                                    auto ret = handle_request(strm, req, res,
+                                                              last_connection,
+                                                              connection_close);
+                                    if (ret) {
+                                      responses.emplace_back(std::move(res));
+                                    }
+                                    return ret;
+                                  })) {
       return false;
     }
+  }
 
-    if (res2.status == 407) {
-      if (!proxy_digest_auth_username_.empty() &&
-          !proxy_digest_auth_password_.empty()) {
-        std::map<std::string, std::string> auth;
-        if (parse_www_authenticate(res2, auth, true)) {
-          Response res2;
-          if (!detail::process_socket(
-                  true, sock, 1, read_timeout_sec_, read_timeout_usec_,
-                  [&](Stream &strm, bool /*last_connection*/,
-                      bool &connection_close) {
-                    Request req2;
-                    req2.method = "CONNECT";
-                    req2.path = host_and_port_;
-                    req2.headers.insert(make_digest_authentication_header(
-                        req2, auth, 1, random_string(10),
-                        proxy_digest_auth_username_,
-                        proxy_digest_auth_password_, true));
-                    return process_request(strm, req2, res2, false,
-                                           connection_close);
-                  })) {
-            detail::close_socket(sock);
-            return false;
-          }
+  return true;
+}
+
+inline bool Client::connect(socket_t sock, Response &res, bool &error) {
+  error = true;
+  Response res2;
+
+  if (!detail::process_socket(
+          true, sock, 1, read_timeout_sec_, read_timeout_usec_,
+          [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
+            Request req2;
+            req2.method = "CONNECT";
+            req2.path = host_and_port_;
+            return process_request(strm, req2, res2, false, connection_close);
+          })) {
+    detail::close_socket(sock);
+    error = false;
+    return false;
+  }
+
+  if (res2.status == 407) {
+    if (!proxy_digest_auth_username_.empty() &&
+        !proxy_digest_auth_password_.empty()) {
+      std::map<std::string, std::string> auth;
+      if (parse_www_authenticate(res2, auth, true)) {
+        Response res3;
+        if (!detail::process_socket(
+                true, sock, 1, read_timeout_sec_, read_timeout_usec_,
+                [&](Stream &strm, bool /*last_connection*/,
+                    bool &connection_close) {
+                  Request req3;
+                  req3.method = "CONNECT";
+                  req3.path = host_and_port_;
+                  req3.headers.insert(make_digest_authentication_header(
+                      req3, auth, 1, random_string(10),
+                      proxy_digest_auth_username_, proxy_digest_auth_password_,
+                      true));
+                  return process_request(strm, req3, res3, false,
+                                         connection_close);
+                })) {
+          detail::close_socket(sock);
+          error = false;
+          return false;
         }
-      } else {
-        res = res2;
-        return true;
       }
+    } else {
+      res = res2;
+      return false;
     }
   }
-#endif
 
-  if (!process_and_close_socket(
-          sock, 1,
-          [&](Stream &strm, bool last_connection, bool &connection_close) {
-            if (!is_ssl() && !proxy_host_.empty()) {
-              auto req2 = req;
-              req2.path = "http://" + host_and_port_ + req.path;
-              return process_request(strm, req2, res, last_connection,
-                                     connection_close);
-            }
-            return process_request(strm, req, res, last_connection,
-                                   connection_close);
-          })) {
-    return false;
+  return true;
+}
+
+inline bool Client::handle_request(Stream &strm, const Request &req,
+                                   Response &res, bool last_connection,
+                                   bool &connection_close) {
+  if (req.path.empty()) { return false; }
+
+  bool ret;
+
+  if (!is_ssl() && !proxy_host_.empty()) {
+    auto req2 = req;
+    req2.path = "http://" + host_and_port_ + req.path;
+    ret = process_request(strm, req2, res, last_connection, connection_close);
+  } else {
+    ret = process_request(strm, req, res, last_connection, connection_close);
   }
 
+  if (!ret) { return false; }
+
   if (300 < res.status && res.status < 400 && follow_location_) {
-    return redirect(req, res);
+    ret = redirect(req, res);
   }
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@@ -3537,50 +3589,14 @@ inline bool Client::send(const Request &req, Response &res) {
 
         Response new_res;
 
-        auto ret = send(new_req, new_res);
+        ret = send(new_req, new_res);
         if (ret) { res = new_res; }
-        return ret;
       }
     }
   }
 #endif
 
-  return true;
-}
-
-inline bool Client::send(const std::vector<Request> &requests,
-                         std::vector<Response> &responses) {
-  size_t i = 0;
-  while (i < requests.size()) {
-    auto sock = create_client_socket();
-    if (sock == INVALID_SOCKET) { return false; }
-
-    if (!process_and_close_socket(
-            sock, requests.size() - i,
-            [&](Stream &strm, bool last_connection,
-                bool &connection_close) -> bool {
-              auto &req = requests[i];
-              auto res = Response();
-              i++;
-
-              if (req.path.empty()) { return false; }
-              auto ret = process_request(strm, req, res, last_connection,
-                                         connection_close);
-
-              if (ret && follow_location_ &&
-                  (300 < res.status && res.status < 400)) {
-                ret = redirect(req, res);
-              }
-
-              if (ret) { responses.emplace_back(std::move(res)); }
-
-              return ret;
-            })) {
-      return false;
-    }
-  }
-
-  return true;
+  return ret;
 }
 
 inline bool Client::redirect(const Request &req, Response &res) {

+ 0 - 1
test/test.cc

@@ -514,7 +514,6 @@ TEST(DigestAuthTest, FromHTTPWatch) {
         "/digest-auth/auth/hello/world/MD5",
         "/digest-auth/auth/hello/world/SHA-256",
         "/digest-auth/auth/hello/world/SHA-512",
-        "/digest-auth/auth-init/hello/world/MD5",
         "/digest-auth/auth-int/hello/world/MD5",
     };
 

+ 93 - 17
test/test_proxy.cc

@@ -36,7 +36,7 @@ TEST(ProxyTest, SSLDigest) {
 
 // ----------------------------------------------------------------------------
 
-void RedirectTestHTTPBin(Client& cli, const char *path, bool basic) {
+void RedirectProxyText(Client& cli, const char *path, bool basic) {
   cli.set_proxy("localhost", basic ? 3128 : 3129);
   if (basic) {
     cli.set_proxy_basic_auth("hello", "world");
@@ -52,45 +52,45 @@ void RedirectTestHTTPBin(Client& cli, const char *path, bool basic) {
 
 TEST(RedirectTest, HTTPBinNoSSLBasic) {
   Client cli("httpbin.org");
-  RedirectTestHTTPBin(cli, "/redirect/2", true);
+  RedirectProxyText(cli, "/redirect/2", true);
 }
 
 TEST(RedirectTest, HTTPBinNoSSLDigest) {
   Client cli("httpbin.org");
-  RedirectTestHTTPBin(cli, "/redirect/2", false);
+  RedirectProxyText(cli, "/redirect/2", false);
 }
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 TEST(RedirectTest, HTTPBinSSLBasic) {
   SSLClient cli("httpbin.org");
-  RedirectTestHTTPBin(cli, "/redirect/2", true);
+  RedirectProxyText(cli, "/redirect/2", true);
 }
 
 TEST(RedirectTest, HTTPBinSSLDigest) {
   SSLClient cli("httpbin.org");
-  RedirectTestHTTPBin(cli, "/redirect/2", false);
+  RedirectProxyText(cli, "/redirect/2", false);
 }
 #endif
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 TEST(RedirectTest, YouTubeNoSSLBasic) {
   Client cli("youtube.com");
-  RedirectTestHTTPBin(cli, "/", true);
+  RedirectProxyText(cli, "/", true);
 }
 
 TEST(RedirectTest, YouTubeNoSSLDigest) {
   Client cli("youtube.com");
-  RedirectTestHTTPBin(cli, "/", false);
+  RedirectProxyText(cli, "/", false);
 }
 
 TEST(RedirectTest, YouTubeSSLBasic) {
   SSLClient cli("youtube.com");
-  RedirectTestHTTPBin(cli, "/", true);
+  RedirectProxyText(cli, "/", true);
 }
 
 TEST(RedirectTest, YouTubeSSLDigest) {
   SSLClient cli("youtube.com");
-  RedirectTestHTTPBin(cli, "/", false);
+  RedirectProxyText(cli, "/", false);
 }
 #endif
 
@@ -111,8 +111,7 @@ void BaseAuthTestFromHTTPWatch(Client& cli) {
         cli.Get("/basic-auth/hello/world",
                 {make_basic_authentication_header("hello", "world")});
     ASSERT_TRUE(res != nullptr);
-    EXPECT_EQ(res->body,
-              "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
+    EXPECT_EQ("{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n", res->body);
     EXPECT_EQ(200, res->status);
   }
 
@@ -120,8 +119,7 @@ void BaseAuthTestFromHTTPWatch(Client& cli) {
     cli.set_basic_auth("hello", "world");
     auto res = cli.Get("/basic-auth/hello/world");
     ASSERT_TRUE(res != nullptr);
-    EXPECT_EQ(res->body,
-              "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
+    EXPECT_EQ("{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n", res->body);
     EXPECT_EQ(200, res->status);
   }
 
@@ -150,6 +148,7 @@ TEST(BaseAuthTest, SSL) {
   SSLClient cli("httpbin.org");
   BaseAuthTestFromHTTPWatch(cli);
 }
+#endif
 
 // ----------------------------------------------------------------------------
 
@@ -169,7 +168,6 @@ void DigestAuthTestFromHTTPWatch(Client& cli) {
         "/digest-auth/auth/hello/world/MD5",
         "/digest-auth/auth/hello/world/SHA-256",
         "/digest-auth/auth/hello/world/SHA-512",
-        "/digest-auth/auth-init/hello/world/MD5",
         "/digest-auth/auth-int/hello/world/MD5",
     };
 
@@ -177,8 +175,7 @@ void DigestAuthTestFromHTTPWatch(Client& cli) {
     for (auto path : paths) {
       auto res = cli.Get(path.c_str());
       ASSERT_TRUE(res != nullptr);
-      EXPECT_EQ(res->body,
-                "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
+      EXPECT_EQ("{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n", res->body);
       EXPECT_EQ(200, res->status);
     }
 
@@ -197,7 +194,6 @@ void DigestAuthTestFromHTTPWatch(Client& cli) {
     }
   }
 }
-#endif
 
 TEST(DigestAuthTest, SSL) {
   SSLClient cli("httpbin.org");
@@ -209,3 +205,83 @@ TEST(DigestAuthTest, NoSSL) {
   DigestAuthTestFromHTTPWatch(cli);
 }
 #endif
+
+// ----------------------------------------------------------------------------
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+TEST(KeepAliveText, NoSSLWithDigest) {
+  Client cli("httpbin.org");
+  cli.set_keep_alive_max_count(4);
+  cli.set_follow_location(true);
+  cli.set_digest_auth("hello", "world");
+  cli.set_proxy("localhost", 3129);
+  cli.set_proxy_digest_auth("hello", "world");
+
+  std::vector<Request> requests;
+
+  Get(requests, "/get");
+  Get(requests, "/redirect/2");
+  Get(requests, "/digest-auth/auth/hello/world/MD5");
+
+  std::vector<Response> responses;
+  auto ret = cli.send(requests, responses);
+  ASSERT_TRUE(ret == true);
+  ASSERT_TRUE(requests.size() == responses.size());
+
+  auto i = 0;
+
+  {
+    auto &res = responses[i++];
+    EXPECT_EQ(200, res.status);
+  }
+
+  {
+    auto &res = responses[i++];
+    EXPECT_EQ(200, res.status);
+  }
+
+  {
+    auto &res = responses[i++];
+    EXPECT_EQ("{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n", res.body);
+    EXPECT_EQ(200, res.status);
+  }
+}
+
+TEST(KeepAliveText, SSLWithDigest) {
+  SSLClient cli("httpbin.org");
+  cli.set_keep_alive_max_count(4);
+  cli.set_follow_location(true);
+  cli.set_digest_auth("hello", "world");
+  cli.set_proxy("localhost", 3129);
+  cli.set_proxy_digest_auth("hello", "world");
+
+  std::vector<Request> requests;
+
+  Get(requests, "/get");
+  Get(requests, "/redirect/2");
+  Get(requests, "/digest-auth/auth/hello/world/MD5");
+
+  std::vector<Response> responses;
+  auto ret = cli.send(requests, responses);
+  ASSERT_TRUE(ret == true);
+  ASSERT_TRUE(requests.size() == responses.size());
+
+  auto i = 0;
+
+  {
+    auto &res = responses[i++];
+    EXPECT_EQ(200, res.status);
+  }
+
+  {
+    auto &res = responses[i++];
+    EXPECT_EQ(200, res.status);
+  }
+
+  {
+    auto &res = responses[i++];
+    EXPECT_EQ("{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n", res.body);
+    EXPECT_EQ(200, res.status);
+  }
+}
+#endif