Browse Source

Added Endpoint structure in Client

yhirose 5 years ago
parent
commit
f80b6bd980
2 changed files with 290 additions and 218 deletions
  1. 279 211
      httplib.h
  2. 11 7
      test/test.cc

+ 279 - 211
httplib.h

@@ -194,7 +194,6 @@ using socket_t = int;
 #include <mutex>
 #include <random>
 #include <regex>
-#include <set>
 #include <string>
 #include <sys/stat.h>
 #include <thread>
@@ -801,7 +800,7 @@ public:
   bool send(const std::vector<Request> &requests,
             std::vector<Response> &responses);
 
-  void stop();
+  virtual void stop();
 
   CPPHTTPLIB_DEPRECATED void set_timeout_sec(time_t timeout_sec);
   void set_connection_timeout(time_t sec, time_t usec = 0);
@@ -832,11 +831,21 @@ public:
   void set_logger(Logger logger);
 
 protected:
+  struct Endpoint {
+    socket_t sock = INVALID_SOCKET;
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+    SSL *ssl = nullptr;
+#endif
+  };
+
+  virtual bool create_and_connect_socket(Endpoint &endpoint);
+  virtual void close_socket(Endpoint &endpoint, bool process_socket_ret);
+
   bool process_request(Stream &strm, const Request &req, Response &res,
                        bool last_connection, bool &connection_close);
 
-  std::set<socket_t> cli_socks_;
-  std::mutex cli_socks_mutex_;
+  std::vector<Endpoint> endpoints_;
+  std::mutex endpoints_mutex_;
 
   const std::string host_;
   const int port_;
@@ -913,14 +922,13 @@ protected:
 
 private:
   socket_t create_client_socket() const;
-  bool create_and_connect_socket(socket_t &sock);
   bool read_response_line(Stream &strm, Response &res);
   bool write_request(Stream &strm, const Request &req, bool last_connection);
   bool redirect(const Request &req, Response &res);
   bool handle_request(Stream &strm, const Request &req, Response &res,
                       bool last_connection, bool &connection_close);
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-  bool connect(socket_t sock, Response &res, bool &error);
+  bool connect_with_proxy(socket_t sock, Response &res, bool &error);
 #endif
 
   std::shared_ptr<Response> send_with_content_provider(
@@ -928,11 +936,11 @@ private:
       const std::string &body, size_t content_length,
       ContentProvider content_provider, const char *content_type);
 
-  virtual bool process_and_close_socket(
-      socket_t sock, size_t request_count,
-      std::function<bool(Stream &strm, bool last_connection,
-                         bool &connection_close)>
-          callback);
+  virtual bool
+  process_socket(Endpoint &endpoint, size_t request_count,
+                 std::function<bool(Stream &strm, bool last_connection,
+                                    bool &connection_close)>
+                     callback);
 
   virtual bool is_ssl() const;
 };
@@ -1018,6 +1026,8 @@ public:
 
   ~SSLClient() override;
 
+  void stop() override;
+
   bool is_valid() const override;
 
   void set_ca_cert_path(const char *ca_cert_file_path,
@@ -1032,13 +1042,17 @@ public:
   SSL_CTX *ssl_context() const;
 
 private:
-  bool process_and_close_socket(
-      socket_t sock, size_t request_count,
-      std::function<bool(Stream &strm, bool last_connection,
-                         bool &connection_close)>
-          callback) override;
+  bool create_and_connect_socket(Endpoint &endpoint) override;
+  void close_socket(Endpoint &endpoint, bool process_socket_ret) override;
+
+  bool process_socket(Endpoint &endpoint, size_t request_count,
+                      std::function<bool(Stream &strm, bool last_connection,
+                                         bool &connection_close)>
+                          callback) override;
   bool is_ssl() const override;
 
+  bool initialize_ssl(Endpoint &endpoint);
+
   bool verify_host(X509 *server_cert) const;
   bool verify_host_with_subject_alt_name(X509 *server_cert) const;
   bool verify_host_with_common_name(X509 *server_cert) const;
@@ -1845,10 +1859,8 @@ private:
 };
 
 template <typename T>
-inline bool process_socket(bool is_client_request, socket_t sock,
-                           size_t keep_alive_max_count, time_t read_timeout_sec,
-                           time_t read_timeout_usec, time_t write_timeout_sec,
-                           time_t write_timeout_usec, T callback) {
+inline bool process_socket_core(bool is_client_request, socket_t sock,
+                                size_t keep_alive_max_count, T callback) {
   assert(keep_alive_max_count > 0);
 
   auto ret = false;
@@ -1859,37 +1871,34 @@ inline bool process_socket(bool is_client_request, socket_t sock,
            (is_client_request ||
             select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
                         CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
-      SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
-                        write_timeout_sec, write_timeout_usec);
       auto last_connection = count == 1;
       auto connection_close = false;
 
-      ret = callback(strm, last_connection, connection_close);
+      ret = callback(last_connection, connection_close);
       if (!ret || connection_close) { break; }
 
       count--;
     }
   } else { // keep_alive_max_count  is 0 or 1
-    SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
-                      write_timeout_sec, write_timeout_usec);
     auto dummy_connection_close = false;
-    ret = callback(strm, true, dummy_connection_close);
+    ret = callback(true, dummy_connection_close);
   }
 
   return ret;
 }
 
 template <typename T>
-inline bool
-process_and_close_socket(bool is_client_request, socket_t sock,
-                         size_t keep_alive_max_count, time_t read_timeout_sec,
-                         time_t read_timeout_usec, time_t write_timeout_sec,
-                         time_t write_timeout_usec, T callback) {
-  auto ret = process_socket(is_client_request, sock, keep_alive_max_count,
-                            read_timeout_sec, read_timeout_usec,
-                            write_timeout_sec, write_timeout_usec, callback);
-  close_socket(sock);
-  return ret;
+inline bool process_socket(bool is_client_request, socket_t sock,
+                           size_t keep_alive_max_count, time_t read_timeout_sec,
+                           time_t read_timeout_usec, time_t write_timeout_sec,
+                           time_t write_timeout_usec, T callback) {
+  return process_socket_core(
+      is_client_request, sock, keep_alive_max_count,
+      [&](bool last_connection, bool connection_close) {
+        SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
+                          write_timeout_sec, write_timeout_usec);
+        return callback(strm, last_connection, connection_close);
+      });
 }
 
 inline int shutdown_socket(socket_t sock) {
@@ -4295,13 +4304,16 @@ Server::process_request(Stream &strm, bool last_connection,
 inline bool Server::is_valid() const { return true; }
 
 inline bool Server::process_and_close_socket(socket_t sock) {
-  return detail::process_and_close_socket(
+  auto ret = detail::process_socket(
       false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
       write_timeout_sec_, write_timeout_usec_,
       [this](Stream &strm, bool last_connection, bool &connection_close) {
         return process_request(strm, last_connection, connection_close,
                                nullptr);
       });
+
+  detail::close_socket(sock);
+  return ret;
 }
 
 // HTTP client implementation
@@ -4333,20 +4345,26 @@ inline socket_t Client::create_client_socket() const {
                                       connection_timeout_usec_, interface_);
 }
 
-inline bool Client::create_and_connect_socket(socket_t &sock) {
-  sock = create_client_socket();
+inline bool Client::create_and_connect_socket(Endpoint &endpoint) {
+  auto sock = create_client_socket();
   if (sock == INVALID_SOCKET) { return false; }
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
   if (is_ssl() && !proxy_host_.empty()) {
     Response res;
     bool error;
-    if (!connect(sock, res, error)) { return error; }
+    if (!connect_with_proxy(sock, res, error)) { return error; }
   }
 #endif
+  endpoint.sock = sock;
   return true;
 }
 
+inline void Client::close_socket(Endpoint &endpoint,
+                                 bool /*process_socket_ret*/) {
+  detail::close_socket(endpoint.sock);
+}
+
 inline bool Client::read_response_line(Stream &strm, Response &res) {
   std::array<char, 2048> buf;
 
@@ -4366,23 +4384,32 @@ inline bool Client::read_response_line(Stream &strm, Response &res) {
 }
 
 inline bool Client::send(const Request &req, Response &res) {
-  socket_t sock = INVALID_SOCKET;
-  if (!create_and_connect_socket(sock)) { return false; }
+  Endpoint endpoint;
+  if (!create_and_connect_socket(endpoint)) { return false; }
 
   {
-    std::lock_guard<std::mutex> guard(cli_socks_mutex_);
-    cli_socks_.insert(sock);
+    std::lock_guard<std::mutex> guard(endpoints_mutex_);
+    endpoints_.push_back(endpoint);
   }
 
-  auto ret = process_and_close_socket(
-      sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) {
+  auto ret = process_socket(
+      endpoint, 1,
+      [&](Stream &strm, bool last_connection, bool &connection_close) {
         return handle_request(strm, req, res, last_connection,
                               connection_close);
       });
 
   {
-    std::lock_guard<std::mutex> guard(cli_socks_mutex_);
-    cli_socks_.erase(sock);
+    std::lock_guard<std::mutex> guard(endpoints_mutex_);
+
+    auto it = std::find_if(
+        endpoints_.begin(), endpoints_.end(),
+        [&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; });
+
+    if (it != endpoints_.end()) {
+      close_socket(endpoint, ret);
+      endpoints_.erase(it);
+    }
   }
 
   return ret;
@@ -4392,29 +4419,41 @@ inline bool Client::send(const std::vector<Request> &requests,
                          std::vector<Response> &responses) {
   size_t i = 0;
   while (i < requests.size()) {
-    socket_t sock = INVALID_SOCKET;
-    if (!create_and_connect_socket(sock)) { return false; }
+    Endpoint endpoint;
+    if (!create_and_connect_socket(endpoint)) { return false; }
 
     {
-      std::lock_guard<std::mutex> guard(cli_socks_mutex_);
-      cli_socks_.insert(sock);
+      std::lock_guard<std::mutex> guard(endpoints_mutex_);
+      endpoints_.push_back(endpoint);
     }
 
-    auto ret = process_and_close_socket(
-        sock, requests.size() - i,
-        [&](Stream &strm, bool last_connection,
-            bool &connection_close) -> bool {
-          auto &req = requests[i++];
-          auto res = Response();
-          auto ret =
-              handle_request(strm, req, res, last_connection, connection_close);
-          if (ret) { responses.emplace_back(std::move(res)); }
-          return ret;
-        });
+    auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_);
+
+    auto ret = process_socket(endpoint, request_count,
+                              [&](Stream &strm, bool last_connection,
+                                  bool &connection_close) -> bool {
+                                auto &req = requests[i++];
+                                auto res = Response();
+                                auto ret = handle_request(strm, req, res,
+                                                          last_connection,
+                                                          connection_close);
+                                if (ret) {
+                                  responses.emplace_back(std::move(res));
+                                }
+                                return ret;
+                              });
 
     {
-      std::lock_guard<std::mutex> guard(cli_socks_mutex_);
-      cli_socks_.erase(sock);
+      std::lock_guard<std::mutex> guard(endpoints_mutex_);
+
+      auto it = std::find_if(
+          endpoints_.begin(), endpoints_.end(),
+          [&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; });
+
+      if (it != endpoints_.end()) {
+        close_socket(endpoint, ret);
+        endpoints_.erase(it);
+      }
     }
 
     if (!ret) { return false; }
@@ -4477,14 +4516,16 @@ inline bool Client::handle_request(Stream &strm, const Request &req,
 }
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-inline bool Client::connect(socket_t sock, Response &res, bool &error) {
+inline bool Client::connect_with_proxy(socket_t sock, Response &res,
+                                       bool &error) {
   error = true;
   Response res2;
 
-  if (!detail::process_socket(
-          true, sock, 1, read_timeout_sec_, read_timeout_usec_,
-          write_timeout_sec_, write_timeout_usec_,
-          [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
+  if (!detail::process_socket_core(
+          true, sock, 1, [&](bool /*last_connection*/, bool &connection_close) {
+            detail::SocketStream strm(sock, read_timeout_sec_,
+                                      read_timeout_usec_, write_timeout_sec_,
+                                      write_timeout_usec_);
             Request req2;
             req2.method = "CONNECT";
             req2.path = host_and_port_;
@@ -4501,11 +4542,12 @@ inline bool Client::connect(socket_t sock, Response &res, bool &error) {
       std::map<std::string, std::string> auth;
       if (parse_www_authenticate(res2, auth, true)) {
         Response res3;
-        if (!detail::process_socket(
-                true, sock, 1, read_timeout_sec_, read_timeout_usec_,
-                write_timeout_sec_, write_timeout_usec_,
-                [&](Stream &strm, bool /*last_connection*/,
-                    bool &connection_close) {
+        if (!detail::process_socket_core(
+                true, sock, 1,
+                [&](bool /*last_connection*/, bool &connection_close) {
+                  detail::SocketStream strm(
+                      sock, read_timeout_sec_, read_timeout_usec_,
+                      write_timeout_sec_, write_timeout_usec_);
                   Request req3;
                   req3.method = "CONNECT";
                   req3.path = host_and_port_;
@@ -4781,14 +4823,13 @@ inline bool Client::process_request(Stream &strm, const Request &req,
   return true;
 }
 
-inline bool Client::process_and_close_socket(
-    socket_t sock, size_t request_count,
-    std::function<bool(Stream &strm, bool last_connection,
-                       bool &connection_close)>
-        callback) {
-  request_count = (std::min)(request_count, keep_alive_max_count_);
-  return detail::process_and_close_socket(
-      true, sock, request_count, read_timeout_sec_, read_timeout_usec_,
+inline bool
+Client::process_socket(Endpoint &endpoint, size_t request_count,
+                       std::function<bool(Stream &strm, bool last_connection,
+                                          bool &connection_close)>
+                           callback) {
+  return detail::process_socket(
+      true, endpoint.sock, request_count, read_timeout_sec_, read_timeout_usec_,
       write_timeout_sec_, write_timeout_usec_, callback);
 }
 
@@ -5085,12 +5126,12 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
 }
 
 inline void Client::stop() {
-  std::lock_guard<std::mutex> guard(cli_socks_mutex_);
-  for (auto &sock : cli_socks_) {
-    detail::shutdown_socket(sock);
-    detail::close_socket(sock);
+  std::lock_guard<std::mutex> guard(endpoints_mutex_);
+  for (auto &endpoint : endpoints_) {
+    detail::shutdown_socket(endpoint.sock);
+    detail::close_socket(endpoint.sock);
   }
-  cli_socks_.clear();
+  endpoints_.clear();
 }
 
 inline void Client::set_timeout_sec(time_t timeout_sec) {
@@ -5164,77 +5205,55 @@ inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); }
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 namespace detail {
 
-template <typename U, typename V, typename T>
-inline bool process_and_close_socket_ssl(
-    bool is_client_request, socket_t sock, size_t keep_alive_max_count,
-    time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec,
-    time_t write_timeout_usec, SSL_CTX *ctx, std::mutex &ctx_mutex,
-    U SSL_connect_or_accept, V setup, T callback) {
-  assert(keep_alive_max_count > 0);
-
+template <typename U, typename V>
+inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex,
+                    U SSL_connect_or_accept, V setup) {
   SSL *ssl = nullptr;
   {
     std::lock_guard<std::mutex> guard(ctx_mutex);
     ssl = SSL_new(ctx);
   }
 
-  if (!ssl) {
-    close_socket(sock);
-    return false;
-  }
-
-  auto bio = BIO_new_socket(static_cast<int>(sock), BIO_NOCLOSE);
-  SSL_set_bio(ssl, bio, bio);
+  if (ssl) {
+    auto bio = BIO_new_socket(static_cast<int>(sock), BIO_NOCLOSE);
+    SSL_set_bio(ssl, bio, bio);
 
-  if (!setup(ssl)) {
-    SSL_shutdown(ssl);
-    {
-      std::lock_guard<std::mutex> guard(ctx_mutex);
-      SSL_free(ssl);
-    }
-
-    close_socket(sock);
-    return false;
-  }
-
-  auto ret = false;
-
-  if (SSL_connect_or_accept(ssl) == 1) {
-    if (keep_alive_max_count > 1) {
-      auto count = keep_alive_max_count;
-      while (count > 0 &&
-             (is_client_request ||
-              select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
-                          CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
-        SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
-                             write_timeout_sec, write_timeout_usec);
-        auto last_connection = count == 1;
-        auto connection_close = false;
-
-        ret = callback(ssl, strm, last_connection, connection_close);
-        if (!ret || connection_close) { break; }
-
-        count--;
+    if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) {
+      SSL_shutdown(ssl);
+      {
+        std::lock_guard<std::mutex> guard(ctx_mutex);
+        SSL_free(ssl);
       }
-    } else {
-      SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
-                           write_timeout_sec, write_timeout_usec);
-      auto dummy_connection_close = false;
-      ret = callback(ssl, strm, true, dummy_connection_close);
+      return nullptr;
     }
   }
 
-  if (ret) {
+  return ssl;
+}
+
+inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl,
+                       bool process_socket_ret) {
+  if (process_socket_ret) {
     SSL_shutdown(ssl); // shutdown only if not already closed by remote
   }
-  {
-    std::lock_guard<std::mutex> guard(ctx_mutex);
-    SSL_free(ssl);
-  }
 
-  close_socket(sock);
+  std::lock_guard<std::mutex> guard(ctx_mutex);
+  SSL_free(ssl);
+}
 
-  return ret;
+template <typename T>
+inline bool
+process_socket_ssl(SSL *ssl, bool is_client_request, socket_t sock,
+                   size_t keep_alive_max_count, time_t read_timeout_sec,
+                   time_t read_timeout_usec, time_t write_timeout_sec,
+                   time_t write_timeout_usec, T callback) {
+  return process_socket_core(
+      is_client_request, sock, keep_alive_max_count,
+      [&](bool last_connection, bool connection_close) {
+        SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
+                             write_timeout_sec, write_timeout_usec);
+        return callback(strm, last_connection, connection_close);
+      });
 }
 
 #if OPENSSL_VERSION_NUMBER < 0x10100000L
@@ -5311,8 +5330,7 @@ inline bool SSLSocketStream::is_writable() const {
 }
 
 inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
-  if (SSL_pending(ssl_) > 0 ||
-      select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) {
+  if (SSL_pending(ssl_) > 0 || is_readable()) {
     return SSL_read(ssl_, ptr, static_cast<int>(size));
   }
   return -1;
@@ -5405,15 +5423,25 @@ inline SSLServer::~SSLServer() {
 inline bool SSLServer::is_valid() const { return ctx_; }
 
 inline bool SSLServer::process_and_close_socket(socket_t sock) {
-  return detail::process_and_close_socket_ssl(
-      false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
-      write_timeout_sec_, write_timeout_usec_, ctx_, ctx_mutex_, SSL_accept,
-      [](SSL * /*ssl*/) { return true; },
-      [this](SSL *ssl, Stream &strm, bool last_connection,
-             bool &connection_close) {
-        return process_request(strm, last_connection, connection_close,
-                               [&](Request &req) { req.ssl = ssl; });
-      });
+  auto ssl = detail::ssl_new(sock, ctx_, ctx_mutex_, SSL_accept,
+                             [](SSL * /*ssl*/) { return true; });
+
+  if (ssl) {
+    auto ret = detail::process_socket_ssl(
+        ssl, false, sock, keep_alive_max_count_, read_timeout_sec_,
+        read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
+        [this, ssl](Stream &strm, bool last_connection,
+                    bool &connection_close) {
+          return process_request(strm, last_connection, connection_close,
+                                 [&](Request &req) { req.ssl = ssl; });
+        });
+
+    detail::ssl_delete(ctx_mutex_, ssl, ret);
+    return ret;
+  }
+
+  detail::close_socket(sock);
+  return false;
 }
 
 // SSL HTTP client implementation
@@ -5466,6 +5494,25 @@ inline SSLClient::~SSLClient() {
   if (ctx_) { SSL_CTX_free(ctx_); }
 }
 
+inline void SSLClient::stop() {
+  auto endpoints = endpoints_;
+  {
+    std::lock_guard<std::mutex> guard(endpoints_mutex_);
+    for (auto &endpoint : endpoints_) {
+      detail::shutdown_socket(endpoint.sock);
+      detail::close_socket(endpoint.sock);
+    }
+    endpoints_.clear();
+  }
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+  for (auto &endpoint : endpoints) {
+    SSL_shutdown(endpoint.ssl);
+    SSL_free(endpoint.ssl);
+  }
+}
+
 inline bool SSLClient::is_valid() const { return ctx_; }
 
 inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path,
@@ -5488,62 +5535,83 @@ inline long SSLClient::get_openssl_verify_result() const {
 
 inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
 
-inline bool SSLClient::process_and_close_socket(
-    socket_t sock, size_t request_count,
-    std::function<bool(Stream &strm, bool last_connection,
-                       bool &connection_close)>
-        callback) {
-
-  request_count = std::min(request_count, keep_alive_max_count_);
-
-  return is_valid() &&
-         detail::process_and_close_socket_ssl(
-             true, sock, request_count, read_timeout_sec_, read_timeout_usec_,
-             write_timeout_sec_, write_timeout_usec_, ctx_, ctx_mutex_,
-             [&](SSL *ssl) {
-               if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) {
-                 SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
-               } else if (!ca_cert_file_path_.empty()) {
-                 if (!SSL_CTX_load_verify_locations(
-                         ctx_, ca_cert_file_path_.c_str(), nullptr)) {
-                   return false;
-                 }
-                 SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr);
-               } else if (ca_cert_store_ != nullptr) {
-                 if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) {
-                   SSL_CTX_set_cert_store(ctx_, ca_cert_store_);
-                 }
-                 SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr);
-               }
-
-               if (SSL_connect(ssl) != 1) { return false; }
-
-               if (server_certificate_verification_) {
-                 verify_result_ = SSL_get_verify_result(ssl);
-
-                 if (verify_result_ != X509_V_OK) { return false; }
-
-                 auto server_cert = SSL_get_peer_certificate(ssl);
-
-                 if (server_cert == nullptr) { return false; }
-
-                 if (!verify_host(server_cert)) {
-                   X509_free(server_cert);
-                   return false;
-                 }
-                 X509_free(server_cert);
-               }
-
-               return true;
-             },
-             [&](SSL *ssl) {
-               SSL_set_tlsext_host_name(ssl, host_.c_str());
-               return true;
-             },
-             [&](SSL * /*ssl*/, Stream &strm, bool last_connection,
-                 bool &connection_close) {
-               return callback(strm, last_connection, connection_close);
-             });
+inline bool SSLClient::create_and_connect_socket(Endpoint &endpoint) {
+  return is_valid() && Client::create_and_connect_socket(endpoint) &&
+         initialize_ssl(endpoint);
+}
+
+inline bool SSLClient::initialize_ssl(Endpoint &endpoint) {
+  auto ssl = detail::ssl_new(
+      endpoint.sock, ctx_, ctx_mutex_,
+      [&](SSL *ssl) {
+        if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) {
+          SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
+        } else if (!ca_cert_file_path_.empty()) {
+          if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(),
+                                             nullptr)) {
+            return false;
+          }
+          SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr);
+        } else if (ca_cert_store_ != nullptr) {
+          if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) {
+            SSL_CTX_set_cert_store(ctx_, ca_cert_store_);
+          }
+          SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr);
+        }
+
+        if (SSL_connect(ssl) != 1) { return false; }
+
+        if (server_certificate_verification_) {
+          verify_result_ = SSL_get_verify_result(ssl);
+
+          if (verify_result_ != X509_V_OK) { return false; }
+
+          auto server_cert = SSL_get_peer_certificate(ssl);
+
+          if (server_cert == nullptr) { return false; }
+
+          if (!verify_host(server_cert)) {
+            X509_free(server_cert);
+            return false;
+          }
+          X509_free(server_cert);
+        }
+
+        return true;
+      },
+      [&](SSL *ssl) {
+        SSL_set_tlsext_host_name(ssl, host_.c_str());
+        return true;
+      });
+
+  if (ssl) {
+    endpoint.ssl = ssl;
+    return true;
+  }
+
+  detail::close_socket(endpoint.sock);
+  return false;
+}
+
+inline void SSLClient::close_socket(Endpoint &endpoint,
+                                    bool process_socket_ret) {
+  assert(endpoint.ssl);
+  detail::ssl_delete(ctx_mutex_, endpoint.ssl, process_socket_ret);
+  detail::close_socket(endpoint.sock);
+}
+
+inline bool
+SSLClient::process_socket(Endpoint &endpoint, size_t request_count,
+                          std::function<bool(Stream &strm, bool last_connection,
+                                             bool &connection_close)>
+                              callback) {
+  assert(endpoint.ssl);
+  return detail::process_socket_ssl(
+      endpoint.ssl, true, endpoint.sock, request_count, read_timeout_sec_,
+      read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
+      [&](Stream &strm, bool last_connection, bool &connection_close) {
+        return callback(strm, last_connection, connection_close);
+      });
 }
 
 inline bool SSLClient::is_ssl() const { return true; }

+ 11 - 7
test/test.cc

@@ -1767,16 +1767,16 @@ TEST_F(ServerTest, GetStreamedEndless) {
 
 TEST_F(ServerTest, ClientStop) {
   std::vector<std::thread> threads;
-  for (auto i = 0; i < 10; i++) {
+  for (auto i = 0; i < 8; i++) {
     threads.emplace_back(thread([&]() {
       auto res = cli_.Get("/streamed-cancel",
                           [&](const char *, uint64_t) { return true; });
       ASSERT_TRUE(res == nullptr);
     }));
   }
-  std::this_thread::sleep_for(std::chrono::seconds(1));
+  std::this_thread::sleep_for(std::chrono::seconds(3));
   cli_.stop();
-  for (auto& t: threads) {
+  for (auto &t : threads) {
     t.join();
   }
 }
@@ -2299,13 +2299,13 @@ TEST_F(ServerTest, MultipartFormDataGzip) {
 // Sends a raw request to a server listening at HOST:PORT.
 static bool send_request(time_t read_timeout_sec, const std::string &req,
                          std::string *resp = nullptr) {
-  auto client_sock = detail::create_client_socket(
-      HOST, PORT, nullptr,
-      /*timeout_sec=*/5, 0, std::string());
+  auto client_sock =
+      detail::create_client_socket(HOST, PORT, nullptr,
+                                   /*timeout_sec=*/5, 0, std::string());
 
   if (client_sock == INVALID_SOCKET) { return false; }
 
-  return detail::process_and_close_socket(
+  auto ret = detail::process_socket(
       true, client_sock, 1, read_timeout_sec, 0, 0, 0,
       [&](Stream &strm, bool /*last_connection*/, bool &
           /*connection_close*/) -> bool {
@@ -2322,6 +2322,10 @@ static bool send_request(time_t read_timeout_sec, const std::string &req,
         }
         return true;
       });
+
+  detail::close_socket(client_sock);
+
+  return ret;
 }
 
 TEST(ServerRequestParsingTest, TrimWhitespaceFromHeaderValues) {