Browse Source

Refactoring to make it ready for KeepAlive connection on Client

yhirose 5 years ago
parent
commit
e022b8b80b
2 changed files with 169 additions and 195 deletions
  1. 161 192
      httplib.h
  2. 8 3
      test/test.cc

+ 161 - 192
httplib.h

@@ -800,7 +800,9 @@ public:
   bool send(const std::vector<Request> &requests,
             std::vector<Response> &responses);
 
-  virtual void stop();
+  size_t is_socket_open() const;
+
+  void stop();
 
   CPPHTTPLIB_DEPRECATED void set_timeout_sec(time_t timeout_sec);
   void set_connection_timeout(time_t sec, time_t usec = 0);
@@ -831,26 +833,31 @@ public:
   void set_logger(Logger logger);
 
 protected:
-  struct Endpoint {
+  struct Socket {
     socket_t sock = INVALID_SOCKET;
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
     SSL *ssl = nullptr;
 #endif
+
+    bool is_open() const { return sock != INVALID_SOCKET; }
   };
 
-  virtual bool create_and_connect_socket(Endpoint &endpoint);
-  virtual void close_socket(Endpoint &endpoint, bool process_socket_ret);
+  virtual bool create_and_connect_socket(Socket &socket);
+  virtual void close_socket(Socket &socket, bool process_socket_ret);
 
   bool process_request(Stream &strm, const Request &req, Response &res,
-                       bool last_connection, bool &connection_close);
-
-  std::vector<Endpoint> endpoints_;
-  std::mutex endpoints_mutex_;
+                       bool &connection_close);
 
+  // Socket endoint information
   const std::string host_;
   const int port_;
   const std::string host_and_port_;
 
+  // Current open socket
+  Socket socket_;
+  mutable std::mutex socket_mutex_;
+  std::recursive_mutex request_mutex_;
+
   // Settings
   std::string client_cert_path_;
   std::string client_key_path_;
@@ -923,13 +930,10 @@ protected:
 private:
   socket_t create_client_socket() const;
   bool read_response_line(Stream &strm, Response &res);
-  bool write_request(Stream &strm, const Request &req, bool last_connection);
+  bool write_request(Stream &strm, const Request &req);
   bool redirect(const Request &req, Response &res);
   bool handle_request(Stream &strm, const Request &req, Response &res,
-                      bool last_connection, bool &connection_close);
-#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-  bool connect_with_proxy(socket_t sock, Response &res, bool &error);
-#endif
+                      bool &connection_close);
 
   std::shared_ptr<Response> send_with_content_provider(
       const char *method, const char *path, const Headers &headers,
@@ -937,7 +941,7 @@ private:
       ContentProvider content_provider, const char *content_type);
 
   virtual bool
-  process_socket(Endpoint &endpoint, size_t request_count,
+  process_socket(Socket &socket, size_t request_count,
                  std::function<bool(Stream &strm, bool last_connection,
                                     bool &connection_close)>
                      callback);
@@ -1026,8 +1030,6 @@ public:
 
   ~SSLClient() override;
 
-  void stop() override;
-
   bool is_valid() const override;
 
   void set_ca_cert_path(const char *ca_cert_file_path,
@@ -1042,16 +1044,17 @@ public:
   SSL_CTX *ssl_context() const;
 
 private:
-  bool create_and_connect_socket(Endpoint &endpoint) override;
-  void close_socket(Endpoint &endpoint, bool process_socket_ret) override;
+  bool create_and_connect_socket(Socket &socket) override;
+  bool connect_with_proxy(Socket &sock, bool &error);
+  void close_socket(Socket &socket, bool process_socket_ret) override;
 
-  bool process_socket(Endpoint &endpoint, size_t request_count,
+  bool process_socket(Socket &socket, size_t request_count,
                       std::function<bool(Stream &strm, bool last_connection,
                                          bool &connection_close)>
                           callback) override;
   bool is_ssl() const override;
 
-  bool initialize_ssl(Endpoint &endpoint);
+  bool initialize_ssl(Socket &socket);
 
   bool verify_host(X509 *server_cert) const;
   bool verify_host_with_subject_alt_name(X509 *server_cert) const;
@@ -1303,6 +1306,8 @@ public:
     return cli_->send(requests, responses);
   }
 
+  bool is_socket_open() { return cli_->is_socket_open(); }
+
   void stop() { cli_->stop(); }
 
   Client2 &set_connection_timeout(time_t sec, time_t usec) {
@@ -4330,7 +4335,12 @@ inline Client::Client(const std::string &host, int port,
       host_and_port_(host_ + ":" + std::to_string(port_)),
       client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
 
-inline Client::~Client() {}
+inline Client::~Client() {
+  assert(socket_.sock == INVALID_SOCKET);
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+  assert(socket_.ssl == nullptr);
+#endif
+}
 
 inline bool Client::is_valid() const { return true; }
 
@@ -4345,24 +4355,19 @@ inline socket_t Client::create_client_socket() const {
                                       connection_timeout_usec_, interface_);
 }
 
-inline bool Client::create_and_connect_socket(Endpoint &endpoint) {
+inline bool Client::create_and_connect_socket(Socket &socket) {
   auto sock = create_client_socket();
   if (sock == INVALID_SOCKET) { return false; }
-
-#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-  if (is_ssl() && !proxy_host_.empty()) {
-    Response res;
-    bool error;
-    if (!connect_with_proxy(sock, res, error)) { return error; }
-  }
-#endif
-  endpoint.sock = sock;
+  socket.sock = sock;
   return true;
 }
 
-inline void Client::close_socket(Endpoint &endpoint,
-                                 bool /*process_socket_ret*/) {
-  detail::close_socket(endpoint.sock);
+inline void Client::close_socket(Socket &socket, bool /*process_socket_ret*/) {
+  detail::close_socket(socket.sock);
+  socket_.sock = INVALID_SOCKET;
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+  socket_.ssl = nullptr;
+#endif
 }
 
 inline bool Client::read_response_line(Stream &strm, Response &res) {
@@ -4384,32 +4389,23 @@ inline bool Client::read_response_line(Stream &strm, Response &res) {
 }
 
 inline bool Client::send(const Request &req, Response &res) {
-  Endpoint endpoint;
-  if (!create_and_connect_socket(endpoint)) { return false; }
+  std::lock_guard<std::recursive_mutex> guard(request_mutex_);
+  auto need_new_socket = !is_socket_open();
 
-  {
-    std::lock_guard<std::mutex> guard(endpoints_mutex_);
-    endpoints_.push_back(endpoint);
+  if (need_new_socket) {
+    std::lock_guard<std::mutex> guard(socket_mutex_);
+    if (!create_and_connect_socket(socket_)) { return false; }
   }
 
   auto ret = process_socket(
-      endpoint, 1,
-      [&](Stream &strm, bool last_connection, bool &connection_close) {
-        return handle_request(strm, req, res, last_connection,
-                              connection_close);
+      socket_, 1,
+      [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
+        return handle_request(strm, req, res, connection_close);
       });
 
-  {
-    std::lock_guard<std::mutex> guard(endpoints_mutex_);
-
-    auto it = std::find_if(
-        endpoints_.begin(), endpoints_.end(),
-        [&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; });
-
-    if (it != endpoints_.end()) {
-      close_socket(endpoint, ret);
-      endpoints_.erase(it);
-    }
+  if (need_new_socket) {
+    std::lock_guard<std::mutex> guard(socket_mutex_);
+    if (socket_.is_open()) { close_socket(socket_, ret); }
   }
 
   return ret;
@@ -4417,43 +4413,30 @@ inline bool Client::send(const Request &req, Response &res) {
 
 inline bool Client::send(const std::vector<Request> &requests,
                          std::vector<Response> &responses) {
+  std::lock_guard<std::recursive_mutex> guard(request_mutex_);
+
   size_t i = 0;
   while (i < requests.size()) {
-    Endpoint endpoint;
-    if (!create_and_connect_socket(endpoint)) { return false; }
-
     {
-      std::lock_guard<std::mutex> guard(endpoints_mutex_);
-      endpoints_.push_back(endpoint);
+      std::lock_guard<std::mutex> guard(socket_mutex_);
+      if (!create_and_connect_socket(socket_)) { return false; }
     }
 
     auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_);
 
-    auto ret = process_socket(endpoint, request_count,
-                              [&](Stream &strm, bool last_connection,
-                                  bool &connection_close) -> bool {
-                                auto &req = requests[i++];
-                                auto res = Response();
-                                auto ret = handle_request(strm, req, res,
-                                                          last_connection,
-                                                          connection_close);
-                                if (ret) {
-                                  responses.emplace_back(std::move(res));
-                                }
-                                return ret;
-                              });
+    auto ret = process_socket(
+        socket_, request_count,
+        [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
+          auto &req = requests[i++];
+          auto res = Response();
+          auto ret = handle_request(strm, req, res, connection_close);
+          if (ret) { responses.emplace_back(std::move(res)); }
+          return ret;
+        });
 
     {
-      std::lock_guard<std::mutex> guard(endpoints_mutex_);
-
-      auto it = std::find_if(
-          endpoints_.begin(), endpoints_.end(),
-          [&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; });
-
-      if (it != endpoints_.end()) {
-        close_socket(endpoint, ret);
-        endpoints_.erase(it);
-      }
+      std::lock_guard<std::mutex> guard(socket_mutex_);
+      if (socket_.is_open()) { close_socket(socket_, ret); }
     }
 
     if (!ret) { return false; }
@@ -4463,8 +4446,7 @@ inline bool Client::send(const std::vector<Request> &requests,
 }
 
 inline bool Client::handle_request(Stream &strm, const Request &req,
-                                   Response &res, bool last_connection,
-                                   bool &connection_close) {
+                                   Response &res, bool &connection_close) {
   if (req.path.empty()) { return false; }
 
   bool ret;
@@ -4472,9 +4454,9 @@ inline bool Client::handle_request(Stream &strm, const Request &req,
   if (!is_ssl() && !proxy_host_.empty()) {
     auto req2 = req;
     req2.path = "http://" + host_and_port_ + req.path;
-    ret = process_request(strm, req2, res, last_connection, connection_close);
+    ret = process_request(strm, req2, res, connection_close);
   } else {
-    ret = process_request(strm, req, res, last_connection, connection_close);
+    ret = process_request(strm, req, res, connection_close);
   }
 
   if (!ret) { return false; }
@@ -4515,64 +4497,6 @@ inline bool Client::handle_request(Stream &strm, const Request &req,
   return ret;
 }
 
-#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-inline bool Client::connect_with_proxy(socket_t sock, Response &res,
-                                       bool &error) {
-  error = true;
-  Response res2;
-
-  if (!detail::process_socket_core(
-          true, sock, 1, [&](bool /*last_connection*/, bool &connection_close) {
-            detail::SocketStream strm(sock, read_timeout_sec_,
-                                      read_timeout_usec_, write_timeout_sec_,
-                                      write_timeout_usec_);
-            Request req2;
-            req2.method = "CONNECT";
-            req2.path = host_and_port_;
-            return process_request(strm, req2, res2, false, connection_close);
-          })) {
-    detail::close_socket(sock);
-    error = false;
-    return false;
-  }
-
-  if (res2.status == 407) {
-    if (!proxy_digest_auth_username_.empty() &&
-        !proxy_digest_auth_password_.empty()) {
-      std::map<std::string, std::string> auth;
-      if (parse_www_authenticate(res2, auth, true)) {
-        Response res3;
-        if (!detail::process_socket_core(
-                true, sock, 1,
-                [&](bool /*last_connection*/, bool &connection_close) {
-                  detail::SocketStream strm(
-                      sock, read_timeout_sec_, read_timeout_usec_,
-                      write_timeout_sec_, write_timeout_usec_);
-                  Request req3;
-                  req3.method = "CONNECT";
-                  req3.path = host_and_port_;
-                  req3.headers.insert(make_digest_authentication_header(
-                      req3, auth, 1, random_string(10),
-                      proxy_digest_auth_username_, proxy_digest_auth_password_,
-                      true));
-                  return process_request(strm, req3, res3, false,
-                                         connection_close);
-                })) {
-          detail::close_socket(sock);
-          error = false;
-          return false;
-        }
-      }
-    } else {
-      res = res2;
-      return false;
-    }
-  }
-
-  return true;
-}
-#endif
-
 inline bool Client::redirect(const Request &req, Response &res) {
   if (req.redirect_count == 0) { return false; }
 
@@ -4622,8 +4546,7 @@ inline bool Client::redirect(const Request &req, Response &res) {
   }
 }
 
-inline bool Client::write_request(Stream &strm, const Request &req,
-                                  bool last_connection) {
+inline bool Client::write_request(Stream &strm, const Request &req) {
   detail::BufferStream bstrm;
 
   // Request line
@@ -4633,8 +4556,6 @@ inline bool Client::write_request(Stream &strm, const Request &req,
 
   // Additonal headers
   Headers headers;
-  if (last_connection) { headers.emplace("Connection", "close"); }
-
   if (!req.has_header("Host")) {
     if (is_ssl()) {
       if (port_ == 443) {
@@ -4777,10 +4698,9 @@ inline std::shared_ptr<Response> Client::send_with_content_provider(
 }
 
 inline bool Client::process_request(Stream &strm, const Request &req,
-                                    Response &res, bool last_connection,
-                                    bool &connection_close) {
+                                    Response &res, bool &connection_close) {
   // Send request
-  if (!write_request(strm, req, last_connection)) { return false; }
+  if (!write_request(strm, req)) { return false; }
 
   // Receive response and headers
   if (!read_response_line(strm, res) ||
@@ -4824,12 +4744,12 @@ inline bool Client::process_request(Stream &strm, const Request &req,
 }
 
 inline bool
-Client::process_socket(Endpoint &endpoint, size_t request_count,
+Client::process_socket(Socket &socket, size_t request_count,
                        std::function<bool(Stream &strm, bool last_connection,
                                           bool &connection_close)>
                            callback) {
   return detail::process_socket(
-      true, endpoint.sock, request_count, read_timeout_sec_, read_timeout_usec_,
+      true, socket.sock, request_count, read_timeout_sec_, read_timeout_usec_,
       write_timeout_sec_, write_timeout_usec_, callback);
 }
 
@@ -5125,13 +5045,17 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
   return send(req, *res) ? res : nullptr;
 }
 
+inline size_t Client::is_socket_open() const {
+  std::lock_guard<std::mutex> guard(socket_mutex_);
+  return socket_.is_open();
+}
+
 inline void Client::stop() {
-  std::lock_guard<std::mutex> guard(endpoints_mutex_);
-  for (auto &endpoint : endpoints_) {
-    detail::shutdown_socket(endpoint.sock);
-    detail::close_socket(endpoint.sock);
+  std::lock_guard<std::mutex> guard(socket_mutex_);
+  if (socket_.is_open()) {
+    detail::shutdown_socket(socket_.sock);
+    close_socket(socket_, true);
   }
-  endpoints_.clear();
 }
 
 inline void Client::set_timeout_sec(time_t timeout_sec) {
@@ -5494,25 +5418,6 @@ inline SSLClient::~SSLClient() {
   if (ctx_) { SSL_CTX_free(ctx_); }
 }
 
-inline void SSLClient::stop() {
-  auto endpoints = endpoints_;
-  {
-    std::lock_guard<std::mutex> guard(endpoints_mutex_);
-    for (auto &endpoint : endpoints_) {
-      detail::shutdown_socket(endpoint.sock);
-      detail::close_socket(endpoint.sock);
-    }
-    endpoints_.clear();
-  }
-
-  std::this_thread::sleep_for(std::chrono::milliseconds(100));
-
-  for (auto &endpoint : endpoints) {
-    SSL_shutdown(endpoint.ssl);
-    SSL_free(endpoint.ssl);
-  }
-}
-
 inline bool SSLClient::is_valid() const { return ctx_; }
 
 inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path,
@@ -5535,14 +5440,75 @@ inline long SSLClient::get_openssl_verify_result() const {
 
 inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
 
-inline bool SSLClient::create_and_connect_socket(Endpoint &endpoint) {
-  return is_valid() && Client::create_and_connect_socket(endpoint) &&
-         initialize_ssl(endpoint);
+inline bool SSLClient::create_and_connect_socket(Socket &socket) {
+  if (is_valid() && Client::create_and_connect_socket(socket) &&
+      initialize_ssl(socket)) {
+    if (!proxy_host_.empty()) {
+      bool error;
+      if (!connect_with_proxy(socket, error)) { return error; }
+    }
+    return true;
+  }
+  return false;
+}
+
+inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) {
+  error = true;
+  Response res;
+
+  if (!detail::process_socket_core(
+          true, socket.sock, 1,
+          [&](bool /*last_connection*/, bool &connection_close) {
+            detail::SocketStream strm(socket.sock, read_timeout_sec_,
+                                      read_timeout_usec_, write_timeout_sec_,
+                                      write_timeout_usec_);
+            Request req2;
+            req2.method = "CONNECT";
+            req2.path = host_and_port_;
+            return process_request(strm, req2, res, connection_close);
+          })) {
+    close_socket(socket, true);
+    error = false;
+    return false;
+  }
+
+  if (res.status == 407) {
+    if (!proxy_digest_auth_username_.empty() &&
+        !proxy_digest_auth_password_.empty()) {
+      std::map<std::string, std::string> auth;
+      if (parse_www_authenticate(res, auth, true)) {
+        Response res3;
+        if (!detail::process_socket_core(
+                true, socket.sock, 1,
+                [&](bool /*last_connection*/, bool &connection_close) {
+                  detail::SocketStream strm(
+                      socket.sock, read_timeout_sec_, read_timeout_usec_,
+                      write_timeout_sec_, write_timeout_usec_);
+                  Request req3;
+                  req3.method = "CONNECT";
+                  req3.path = host_and_port_;
+                  req3.headers.insert(make_digest_authentication_header(
+                      req3, auth, 1, random_string(10),
+                      proxy_digest_auth_username_, proxy_digest_auth_password_,
+                      true));
+                  return process_request(strm, req3, res3, connection_close);
+                })) {
+          close_socket(socket, true);
+          error = false;
+          return false;
+        }
+      }
+    } else {
+      return false;
+    }
+  }
+
+  return true;
 }
 
-inline bool SSLClient::initialize_ssl(Endpoint &endpoint) {
+inline bool SSLClient::initialize_ssl(Socket &socket) {
   auto ssl = detail::ssl_new(
-      endpoint.sock, ctx_, ctx_mutex_,
+      socket.sock, ctx_, ctx_mutex_,
       [&](SSL *ssl) {
         if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) {
           SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
@@ -5585,29 +5551,32 @@ inline bool SSLClient::initialize_ssl(Endpoint &endpoint) {
       });
 
   if (ssl) {
-    endpoint.ssl = ssl;
+    socket.ssl = ssl;
     return true;
   }
 
-  detail::close_socket(endpoint.sock);
+  close_socket(socket, false);
   return false;
 }
 
-inline void SSLClient::close_socket(Endpoint &endpoint,
-                                    bool process_socket_ret) {
-  assert(endpoint.ssl);
-  detail::ssl_delete(ctx_mutex_, endpoint.ssl, process_socket_ret);
-  detail::close_socket(endpoint.sock);
+inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) {
+  detail::close_socket(socket.sock);
+  socket_.sock = INVALID_SOCKET;
+  std::this_thread::sleep_for(std::chrono::milliseconds(10));
+  if (socket.ssl) {
+    detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret);
+    socket_.ssl = nullptr;
+  }
 }
 
 inline bool
-SSLClient::process_socket(Endpoint &endpoint, size_t request_count,
+SSLClient::process_socket(Socket &socket, size_t request_count,
                           std::function<bool(Stream &strm, bool last_connection,
                                              bool &connection_close)>
                               callback) {
-  assert(endpoint.ssl);
+  assert(socket.ssl);
   return detail::process_socket_ssl(
-      endpoint.ssl, true, endpoint.sock, request_count, read_timeout_sec_,
+      socket.ssl, true, socket.sock, request_count, read_timeout_sec_,
       read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
       [&](Stream &strm, bool last_connection, bool &connection_close) {
         return callback(strm, last_connection, connection_close);

+ 8 - 3
test/test.cc

@@ -1767,15 +1767,20 @@ TEST_F(ServerTest, GetStreamedEndless) {
 
 TEST_F(ServerTest, ClientStop) {
   std::vector<std::thread> threads;
-  for (auto i = 0; i < 3; i++) {
+  for (auto i = 0; i < 100; i++) {
     threads.emplace_back(thread([&]() {
       auto res = cli_.Get("/streamed-cancel",
                           [&](const char *, uint64_t) { return true; });
       ASSERT_TRUE(res == nullptr);
     }));
   }
-  std::this_thread::sleep_for(std::chrono::seconds(3));
-  cli_.stop();
+
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+
+  while (cli_.is_socket_open()) {
+    cli_.stop();
+    std::this_thread::sleep_for(std::chrono::milliseconds(10));
+  }
   for (auto &t : threads) {
     t.join();
   }