yhirose 5 years ago
parent
commit
7cd25fbd63
4 changed files with 301 additions and 333 deletions
  1. 6 20
      README.md
  2. 217 225
      httplib.h
  3. 56 44
      test/test.cc
  4. 22 44
      test/test_proxy.cc

+ 6 - 20
README.md

@@ -482,29 +482,15 @@ httplib::make_range_header({{0, 0}, {-1, 1}})        // 'Range: bytes=0-0, -1'
 ### Keep-Alive connection
 
 ```cpp
-cli.set_keep_alive_max_count(2); // Default is 5
+httplib::Client cli("localhost", 1234);
 
-std::vector<Request> requests;
-Get(requests, "/get-request1");
-Get(requests, "/get-request2");
-Post(requests, "/post-request1", "text", "text/plain");
-Post(requests, "/post-request2", "text", "text/plain");
+cli.Get("/hello");         // with "Connection: close"
 
-const size_t DATA_CHUNK_SIZE = 4;
-std::string data("abcdefg");
-Post(requests, "/post-request-with-content-provider",
-  data.size(),
-  [&](size_t offset, size_t length, DataSink &sink){
-    sink.write(&data[offset], std::min(length, DATA_CHUNK_SIZE));
-  },
-  "text/plain");
+cli.set_keep_alive(true);
+cli.Get("/world");
 
-std::vector<Response> responses;
-if (cli.send(requests, responses)) {
-  for (const auto& res: responses) {
-    ...
-  }
-}
+cli.set_keep_alive(false);
+cli.Get("/last-request");  // with "Connection: close"
 ```
 
 ### Redirect

+ 217 - 225
httplib.h

@@ -188,6 +188,7 @@ using socket_t = int;
 #include <fcntl.h>
 #include <fstream>
 #include <functional>
+#include <iostream>
 #include <list>
 #include <map>
 #include <memory>
@@ -593,10 +594,11 @@ public:
   std::function<TaskQueue *(void)> new_task_queue;
 
 protected:
-  bool process_request(Stream &strm, bool last_connection,
-                       bool &connection_close,
+  bool process_request(Stream &strm, bool close_connection,
+                       bool &connection_closed,
                        const std::function<void(Request &)> &setup_request);
 
+  std::atomic<socket_t> svr_sock_;
   size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
   time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND;
   time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND;
@@ -624,7 +626,7 @@ private:
                                            HandlersForContentReader &handlers);
 
   bool parse_request_line(const char *s, Request &req);
-  bool write_response(Stream &strm, bool last_connection, const Request &req,
+  bool write_response(Stream &strm, bool close_connection, const Request &req,
                       Response &res);
   bool write_content_with_provider(Stream &strm, const Request &req,
                                    Response &res, const std::string &boundary,
@@ -643,7 +645,6 @@ private:
   virtual bool process_and_close_socket(socket_t sock);
 
   std::atomic<bool> is_running_;
-  std::atomic<socket_t> svr_sock_;
   std::vector<std::pair<std::string, std::string>> base_dirs_;
   std::map<std::string, std::string> file_extension_and_mimetype_map_;
   Handler file_request_handler_;
@@ -797,9 +798,6 @@ public:
 
   bool send(const Request &req, Response &res);
 
-  bool send(const std::vector<Request> &requests,
-            std::vector<Response> &responses);
-
   size_t is_socket_open() const;
 
   void stop();
@@ -809,13 +807,12 @@ public:
   void set_read_timeout(time_t sec, time_t usec = 0);
   void set_write_timeout(time_t sec, time_t usec = 0);
 
-  void set_keep_alive_max_count(size_t count);
-
   void set_basic_auth(const char *username, const char *password);
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
   void set_digest_auth(const char *username, const char *password);
 #endif
 
+  void set_keep_alive(bool on);
   void set_follow_location(bool on);
 
   void set_compress(bool on);
@@ -846,7 +843,7 @@ protected:
   virtual void close_socket(Socket &socket, bool process_socket_ret);
 
   bool process_request(Stream &strm, const Request &req, Response &res,
-                       bool &connection_close);
+                       bool close_connection);
 
   // Socket endoint information
   const std::string host_;
@@ -869,8 +866,6 @@ protected:
   time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND;
   time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND;
 
-  size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
-
   std::string basic_auth_username_;
   std::string basic_auth_password_;
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@@ -878,6 +873,7 @@ protected:
   std::string digest_auth_password_;
 #endif
 
+  bool keep_alive_ = false;
   bool follow_location_ = false;
 
   bool compress_ = false;
@@ -905,13 +901,13 @@ protected:
     read_timeout_usec_ = rhs.read_timeout_usec_;
     write_timeout_sec_ = rhs.write_timeout_sec_;
     write_timeout_usec_ = rhs.write_timeout_usec_;
-    keep_alive_max_count_ = rhs.keep_alive_max_count_;
     basic_auth_username_ = rhs.basic_auth_username_;
     basic_auth_password_ = rhs.basic_auth_password_;
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
     digest_auth_username_ = rhs.digest_auth_username_;
     digest_auth_password_ = rhs.digest_auth_password_;
 #endif
+    keep_alive_ = rhs.keep_alive_;
     follow_location_ = rhs.follow_location_;
     compress_ = rhs.compress_;
     decompress_ = rhs.decompress_;
@@ -930,22 +926,18 @@ protected:
 private:
   socket_t create_client_socket() const;
   bool read_response_line(Stream &strm, Response &res);
-  bool write_request(Stream &strm, const Request &req);
+  bool write_request(Stream &strm, const Request &req, bool close_connection);
   bool redirect(const Request &req, Response &res);
   bool handle_request(Stream &strm, const Request &req, Response &res,
-                      bool &connection_close);
+                      bool close_connection);
 
   std::shared_ptr<Response> send_with_content_provider(
       const char *method, const char *path, const Headers &headers,
       const std::string &body, size_t content_length,
       ContentProvider content_provider, const char *content_type);
 
-  virtual bool
-  process_socket(Socket &socket, size_t request_count,
-                 std::function<bool(Stream &strm, bool last_connection,
-                                    bool &connection_close)>
-                     callback);
-
+  virtual bool process_socket(Socket &socket,
+                              std::function<bool(Stream &strm)> callback);
   virtual bool is_ssl() const;
 };
 
@@ -1045,15 +1037,13 @@ public:
 
 private:
   bool create_and_connect_socket(Socket &socket) override;
-  bool connect_with_proxy(Socket &sock, bool &error);
   void close_socket(Socket &socket, bool process_socket_ret) override;
 
-  bool process_socket(Socket &socket, size_t request_count,
-                      std::function<bool(Stream &strm, bool last_connection,
-                                         bool &connection_close)>
-                          callback) override;
+  bool process_socket(Socket &socket,
+                      std::function<bool(Stream &strm)> callback) override;
   bool is_ssl() const override;
 
+  bool connect_with_proxy(Socket &sock, Response &res, bool &success);
   bool initialize_ssl(Socket &socket);
 
   bool verify_host(X509 *server_cert) const;
@@ -1070,6 +1060,8 @@ private:
   X509_STORE *ca_cert_store_ = nullptr;
   bool server_certificate_verification_ = false;
   long verify_result_ = 0;
+
+  friend class Client;
 };
 #endif
 
@@ -1301,11 +1293,6 @@ public:
 
   bool send(const Request &req, Response &res) { return cli_->send(req, res); }
 
-  bool send(const std::vector<Request> &requests,
-            std::vector<Response> &responses) {
-    return cli_->send(requests, responses);
-  }
-
   bool is_socket_open() { return cli_->is_socket_open(); }
 
   void stop() { cli_->stop(); }
@@ -1320,11 +1307,6 @@ public:
     return *this;
   }
 
-  Client2 &set_keep_alive_max_count(size_t count) {
-    cli_->set_keep_alive_max_count(count);
-    return *this;
-  }
-
   Client2 &set_basic_auth(const char *username, const char *password) {
     cli_->set_basic_auth(username, password);
     return *this;
@@ -1337,6 +1319,11 @@ public:
   }
 #endif
 
+  Client2 &set_keep_alive(bool on) {
+    cli_->set_keep_alive(on);
+    return *this;
+  }
+
   Client2 &set_follow_location(bool on) {
     cli_->set_follow_location(on);
     return *this;
@@ -1863,49 +1850,75 @@ private:
   size_t position = 0;
 };
 
-template <typename T>
-inline bool process_socket_core(bool is_client_request, socket_t sock,
-                                size_t keep_alive_max_count, T callback) {
-  assert(keep_alive_max_count > 0);
-
-  auto ret = false;
-
-  if (keep_alive_max_count > 1) {
-    auto count = keep_alive_max_count;
-    while (count > 0 &&
-           (is_client_request ||
-            select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
-                        CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
-      auto last_connection = count == 1;
-      auto connection_close = false;
-
-      ret = callback(last_connection, connection_close);
-      if (!ret || connection_close) { break; }
-
-      count--;
+inline bool keep_alive(socket_t sock, std::function<bool()> is_shutting_down) {
+  using namespace std::chrono;
+  auto start = steady_clock::now();
+  while (true) {
+    auto val = select_read(sock, 0, 10000);
+    if (is_shutting_down && is_shutting_down()) {
+      return false;
+    } else if (val < 0) {
+      return false;
+    } else if (val == 0) {
+      auto current = steady_clock::now();
+      auto sec = duration_cast<seconds>(current - start);
+      if (sec.count() > CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND) {
+        return false;
+      } else if (sec.count() == CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND) {
+        auto usec = duration_cast<nanoseconds>(current - start);
+        if (usec.count() > CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) {
+          return false;
+        }
+      }
+      std::this_thread::sleep_for(std::chrono::milliseconds(1));
+    } else {
+      return true;
     }
-  } else { // keep_alive_max_count  is 0 or 1
-    auto dummy_connection_close = false;
-    ret = callback(true, dummy_connection_close);
   }
+}
 
+template <typename T, typename U>
+inline bool process_server_socket_core(socket_t sock,
+                                       size_t keep_alive_max_count,
+                                       T is_shutting_down, U callback) {
+  assert(keep_alive_max_count > 0);
+  auto ret = false;
+  auto count = keep_alive_max_count;
+  while (count > 0 && keep_alive(sock, is_shutting_down)) {
+    auto close_connection = count == 1;
+    auto connection_closed = false;
+    ret = callback(close_connection, connection_closed);
+    if (!ret || connection_closed) { break; }
+    count--;
+  }
   return ret;
 }
 
-template <typename T>
-inline bool process_socket(bool is_client_request, socket_t sock,
-                           size_t keep_alive_max_count, time_t read_timeout_sec,
-                           time_t read_timeout_usec, time_t write_timeout_sec,
-                           time_t write_timeout_usec, T callback) {
-  return process_socket_core(
-      is_client_request, sock, keep_alive_max_count,
-      [&](bool last_connection, bool connection_close) {
+template <typename T, typename U>
+inline bool
+process_server_socket(socket_t sock, size_t keep_alive_max_count,
+                      time_t read_timeout_sec, time_t read_timeout_usec,
+                      time_t write_timeout_sec, time_t write_timeout_usec,
+                      T is_shutting_down, U callback) {
+  return process_server_socket_core(
+      sock, keep_alive_max_count, is_shutting_down,
+      [&](bool close_connection, bool connection_closed) {
         SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
                           write_timeout_sec, write_timeout_usec);
-        return callback(strm, last_connection, connection_close);
+        return callback(strm, close_connection, connection_closed);
       });
 }
 
+template <typename T>
+inline bool process_client_socket(socket_t sock, time_t read_timeout_sec,
+                                  time_t read_timeout_usec,
+                                  time_t write_timeout_sec,
+                                  time_t write_timeout_usec, T callback) {
+  SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
+                    write_timeout_sec, write_timeout_usec);
+  return callback(strm);
+}
+
 inline int shutdown_socket(socket_t sock) {
 #ifdef _WIN32
   return shutdown(sock, SD_BOTH);
@@ -2545,7 +2558,6 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
   }
 
   if (!ret) { status = exceed_payload_max_length ? 413 : 400; }
-
   return ret;
 }
 
@@ -2582,8 +2594,9 @@ inline bool write_data(Stream &strm, const char *d, size_t l) {
   return true;
 }
 
+template <typename T>
 inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
-                             size_t offset, size_t length) {
+                             size_t offset, size_t length, T is_shutting_down) {
   size_t begin_offset = offset;
   size_t end_offset = offset + length;
 
@@ -2598,7 +2611,7 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
   };
   data_sink.is_writable = [&](void) { return ok && strm.is_writable(); };
 
-  while (ok && offset < end_offset) {
+  while (ok && offset < end_offset && !is_shutting_down()) {
     if (!content_provider(offset, end_offset - offset, data_sink)) {
       return -1;
     }
@@ -3110,16 +3123,19 @@ get_multipart_ranges_data_length(const Request &req, Response &res,
   return data_length;
 }
 
+template <typename T>
 inline bool write_multipart_ranges_data(Stream &strm, const Request &req,
                                         Response &res,
                                         const std::string &boundary,
-                                        const std::string &content_type) {
+                                        const std::string &content_type,
+                                        T is_shutting_down) {
   return process_multipart_ranges_data(
       req, res, boundary, content_type,
       [&](const std::string &token) { strm.write(token); },
       [&](const char *token) { strm.write(token); },
       [&](size_t offset, size_t length) {
-        return write_content(strm, res.content_provider_, offset, length) >= 0;
+        return write_content(strm, res.content_provider_, offset, length,
+                             is_shutting_down) >= 0;
       });
 }
 
@@ -3576,7 +3592,7 @@ inline const std::string &BufferStream::get_buffer() const { return buffer; }
 } // namespace detail
 
 // HTTP server implementation
-inline Server::Server() : is_running_(false), svr_sock_(INVALID_SOCKET) {
+inline Server::Server() : svr_sock_(INVALID_SOCKET), is_running_(false) {
 #ifndef _WIN32
   signal(SIGPIPE, SIG_IGN);
 #endif
@@ -3758,7 +3774,7 @@ inline bool Server::parse_request_line(const char *s, Request &req) {
   return false;
 }
 
-inline bool Server::write_response(Stream &strm, bool last_connection,
+inline bool Server::write_response(Stream &strm, bool close_connection,
                                    const Request &req, Response &res) {
   assert(res.status != -1);
 
@@ -3773,11 +3789,11 @@ inline bool Server::write_response(Stream &strm, bool last_connection,
   }
 
   // Headers
-  if (last_connection || req.get_header_value("Connection") == "close") {
+  if (close_connection || req.get_header_value("Connection") == "close") {
     res.set_header("Connection", "close");
   }
 
-  if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") {
+  if (!close_connection && req.get_header_value("Connection") == "Keep-Alive") {
     res.set_header("Connection", "Keep-Alive");
   }
 
@@ -3891,10 +3907,14 @@ inline bool
 Server::write_content_with_provider(Stream &strm, const Request &req,
                                     Response &res, const std::string &boundary,
                                     const std::string &content_type) {
+  auto is_shutting_down = [this]() {
+    return this->svr_sock_ == INVALID_SOCKET;
+  };
+
   if (res.content_length_) {
     if (req.ranges.empty()) {
       if (detail::write_content(strm, res.content_provider_, 0,
-                                res.content_length_) < 0) {
+                                res.content_length_, is_shutting_down) < 0) {
         return false;
       }
     } else if (req.ranges.size() == 1) {
@@ -3902,20 +3922,17 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
           detail::get_range_offset_and_length(req, res.content_length_, 0);
       auto offset = offsets.first;
       auto length = offsets.second;
-      if (detail::write_content(strm, res.content_provider_, offset, length) <
-          0) {
+      if (detail::write_content(strm, res.content_provider_, offset, length,
+                                is_shutting_down) < 0) {
         return false;
       }
     } else {
-      if (!detail::write_multipart_ranges_data(strm, req, res, boundary,
-                                               content_type)) {
+      if (!detail::write_multipart_ranges_data(
+              strm, req, res, boundary, content_type, is_shutting_down)) {
         return false;
       }
     }
   } else {
-    auto is_shutting_down = [this]() {
-      return this->svr_sock_ == INVALID_SOCKET;
-    };
     if (detail::write_content_chunked(strm, res.content_provider_,
                                       is_shutting_down) < 0) {
       return false;
@@ -4241,8 +4258,8 @@ inline bool Server::dispatch_request_for_content_reader(
 }
 
 inline bool
-Server::process_request(Stream &strm, bool last_connection,
-                        bool &connection_close,
+Server::process_request(Stream &strm, bool close_connection,
+                        bool &connection_closed,
                         const std::function<void(Request &)> &setup_request) {
   std::array<char, 2048> buf{};
 
@@ -4261,23 +4278,23 @@ Server::process_request(Stream &strm, bool last_connection,
     Headers dummy;
     detail::read_headers(strm, dummy);
     res.status = 414;
-    return write_response(strm, last_connection, req, res);
+    return write_response(strm, close_connection, req, res);
   }
 
   // Request line and headers
   if (!parse_request_line(line_reader.ptr(), req) ||
       !detail::read_headers(strm, req.headers)) {
     res.status = 400;
-    return write_response(strm, last_connection, req, res);
+    return write_response(strm, close_connection, req, res);
   }
 
   if (req.get_header_value("Connection") == "close") {
-    connection_close = true;
+    connection_closed = true;
   }
 
   if (req.version == "HTTP/1.0" &&
       req.get_header_value("Connection") != "Keep-Alive") {
-    connection_close = true;
+    connection_closed = true;
   }
 
   strm.get_remote_ip_and_port(req.remote_addr, req.remote_port);
@@ -4304,7 +4321,7 @@ Server::process_request(Stream &strm, bool last_connection,
       strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status,
                         detail::status_message(status));
       break;
-    default: return write_response(strm, last_connection, req, res);
+    default: return write_response(strm, close_connection, req, res);
     }
   }
 
@@ -4315,20 +4332,23 @@ Server::process_request(Stream &strm, bool last_connection,
     if (res.status == -1) { res.status = 404; }
   }
 
-  return write_response(strm, last_connection, req, res);
+  return write_response(strm, close_connection, req, res);
 }
 
 inline bool Server::is_valid() const { return true; }
 
 inline bool Server::process_and_close_socket(socket_t sock) {
-  auto ret = detail::process_socket(
-      false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
+  auto ret = detail::process_server_socket(
+      sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
       write_timeout_sec_, write_timeout_usec_,
-      [this](Stream &strm, bool last_connection, bool &connection_close) {
-        return process_request(strm, last_connection, connection_close,
+      [this]() { return this->svr_sock_ == INVALID_SOCKET; },
+      [this](Stream &strm, bool close_connection, bool &connection_closed) {
+        return process_request(strm, close_connection, connection_closed,
                                nullptr);
       });
 
+  std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  detail::shutdown_socket(sock);
   detail::close_socket(sock);
   return ret;
 }
@@ -4347,12 +4367,7 @@ inline Client::Client(const std::string &host, int port,
       host_and_port_(host_ + ":" + std::to_string(port_)),
       client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
 
-inline Client::~Client() {
-  assert(socket_.sock == INVALID_SOCKET);
-#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-  assert(socket_.ssl == nullptr);
-#endif
-}
+inline Client::~Client() { stop(); }
 
 inline bool Client::is_valid() const { return true; }
 
@@ -4402,63 +4417,49 @@ inline bool Client::read_response_line(Stream &strm, Response &res) {
 
 inline bool Client::send(const Request &req, Response &res) {
   std::lock_guard<std::recursive_mutex> request_mutex_guard(request_mutex_);
-  auto need_new_socket = !is_socket_open();
 
-  if (need_new_socket) {
+  {
     std::lock_guard<std::mutex> guard(socket_mutex_);
-    if (!create_and_connect_socket(socket_)) { return false; }
-  }
-
-  auto ret = process_socket(
-      socket_, 1,
-      [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
-        return handle_request(strm, req, res, connection_close);
-      });
 
-  if (need_new_socket) {
-    std::lock_guard<std::mutex> guard(socket_mutex_);
-    if (socket_.is_open()) { close_socket(socket_, ret); }
-  }
+    auto is_alive = false;
+    if (socket_.is_open()) {
+      is_alive = detail::select_write(socket_.sock, 0, 0) > 0;
+      if (!is_alive) { close_socket(socket_, false); }
+    }
 
-  return ret;
-}
+    if (!is_alive) {
+      if (!create_and_connect_socket(socket_)) { return false; }
 
-inline bool Client::send(const std::vector<Request> &requests,
-                         std::vector<Response> &responses) {
-  std::lock_guard<std::recursive_mutex> request_mutex_guard(request_mutex_);
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+      // TODO: refactoring
+      if (is_ssl()) {
+        auto &scli = static_cast<SSLClient &>(*this);
+        if (!proxy_host_.empty()) {
+          bool success = false;
+          if (!scli.connect_with_proxy(socket_, res, success)) {
+            return success;
+          }
+        }
 
-  size_t i = 0;
-  while (i < requests.size()) {
-    {
-      std::lock_guard<std::mutex> guard(socket_mutex_);
-      if (!create_and_connect_socket(socket_)) { return false; }
+        if (!scli.initialize_ssl(socket_)) { return false; }
+      }
+#endif
     }
+  }
 
-    auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_);
-
-    auto ret = process_socket(
-        socket_, request_count,
-        [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
-          auto &req = requests[i++];
-          auto res = Response();
-          auto ret = handle_request(strm, req, res, connection_close);
-          if (ret) { responses.emplace_back(std::move(res)); }
-          return ret;
-        });
+  auto close_connection = !keep_alive_;
 
-    {
-      std::lock_guard<std::mutex> guard(socket_mutex_);
-      if (socket_.is_open()) { close_socket(socket_, ret); }
-    }
+  auto ret = process_socket(socket_, [&](Stream &strm) {
+    return handle_request(strm, req, res, close_connection);
+  });
 
-    if (!ret) { return false; }
-  }
+  if (close_connection) { stop(); }
 
-  return true;
+  return ret;
 }
 
 inline bool Client::handle_request(Stream &strm, const Request &req,
-                                   Response &res, bool &connection_close) {
+                                   Response &res, bool close_connection) {
   if (req.path.empty()) { return false; }
 
   bool ret;
@@ -4466,9 +4467,9 @@ inline bool Client::handle_request(Stream &strm, const Request &req,
   if (!is_ssl() && !proxy_host_.empty()) {
     auto req2 = req;
     req2.path = "http://" + host_and_port_ + req.path;
-    ret = process_request(strm, req2, res, connection_close);
+    ret = process_request(strm, req2, res, close_connection);
   } else {
-    ret = process_request(strm, req, res, connection_close);
+    ret = process_request(strm, req, res, close_connection);
   }
 
   if (!ret) { return false; }
@@ -4558,7 +4559,8 @@ inline bool Client::redirect(const Request &req, Response &res) {
   }
 }
 
-inline bool Client::write_request(Stream &strm, const Request &req) {
+inline bool Client::write_request(Stream &strm, const Request &req,
+                                  bool close_connection) {
   detail::BufferStream bstrm;
 
   // Request line
@@ -4568,6 +4570,8 @@ inline bool Client::write_request(Stream &strm, const Request &req) {
 
   // Additonal headers
   Headers headers;
+  if (close_connection) { headers.emplace("Connection", "close"); }
+
   if (!req.has_header("Host")) {
     if (is_ssl()) {
       if (port_ == 443) {
@@ -4710,9 +4714,9 @@ inline std::shared_ptr<Response> Client::send_with_content_provider(
 }
 
 inline bool Client::process_request(Stream &strm, const Request &req,
-                                    Response &res, bool &connection_close) {
+                                    Response &res, bool close_connection) {
   // Send request
-  if (!write_request(strm, req)) { return false; }
+  if (!write_request(strm, req, close_connection)) { return false; }
 
   // Receive response and headers
   if (!read_response_line(strm, res) ||
@@ -4720,11 +4724,6 @@ inline bool Client::process_request(Stream &strm, const Request &req,
     return false;
   }
 
-  if (res.get_header_value("Connection") == "close" ||
-      res.version == "HTTP/1.0") {
-    connection_close = true;
-  }
-
   if (req.response_handler) {
     if (!req.response_handler(res)) { return false; }
   }
@@ -4749,20 +4748,22 @@ inline bool Client::process_request(Stream &strm, const Request &req,
     }
   }
 
+  if (res.get_header_value("Connection") == "close" ||
+      res.version == "HTTP/1.0") {
+    stop();
+  }
+
   // Log
   if (logger_) { logger_(req, res); }
 
   return true;
 }
 
-inline bool
-Client::process_socket(Socket &socket, size_t request_count,
-                       std::function<bool(Stream &strm, bool last_connection,
-                                          bool &connection_close)>
-                           callback) {
-  return detail::process_socket(
-      true, socket.sock, request_count, read_timeout_sec_, read_timeout_usec_,
-      write_timeout_sec_, write_timeout_usec_, callback);
+inline bool Client::process_socket(Socket &socket,
+                                   std::function<bool(Stream &strm)> callback) {
+  return detail::process_client_socket(socket.sock, read_timeout_sec_,
+                                       read_timeout_usec_, write_timeout_sec_,
+                                       write_timeout_usec_, callback);
 }
 
 inline bool Client::is_ssl() const { return false; }
@@ -5066,9 +5067,9 @@ inline void Client::stop() {
   std::lock_guard<std::mutex> guard(socket_mutex_);
   if (socket_.is_open()) {
     detail::shutdown_socket(socket_.sock);
-    std::this_thread::sleep_for(std::chrono::milliseconds(10));
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
     close_socket(socket_, true);
-    std::this_thread::sleep_for(std::chrono::milliseconds(10));
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
   }
 }
 
@@ -5091,10 +5092,6 @@ inline void Client::set_write_timeout(time_t sec, time_t usec) {
   write_timeout_usec_ = usec;
 }
 
-inline void Client::set_keep_alive_max_count(size_t count) {
-  keep_alive_max_count_ = count;
-}
-
 inline void Client::set_basic_auth(const char *username, const char *password) {
   basic_auth_username_ = username;
   basic_auth_password_ = password;
@@ -5108,6 +5105,8 @@ inline void Client::set_digest_auth(const char *username,
 }
 #endif
 
+inline void Client::set_keep_alive(bool on) { keep_alive_ = on; }
+
 inline void Client::set_follow_location(bool on) { follow_location_ = on; }
 
 inline void Client::set_compress(bool on) { compress_ = on; }
@@ -5181,19 +5180,29 @@ inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl,
 
 template <typename T>
 inline bool
-process_socket_ssl(SSL *ssl, bool is_client_request, socket_t sock,
-                   size_t keep_alive_max_count, time_t read_timeout_sec,
-                   time_t read_timeout_usec, time_t write_timeout_sec,
-                   time_t write_timeout_usec, T callback) {
-  return process_socket_core(
-      is_client_request, sock, keep_alive_max_count,
-      [&](bool last_connection, bool connection_close) {
+process_server_socket_ssl(SSL *ssl, socket_t sock, size_t keep_alive_max_count,
+                          time_t read_timeout_sec, time_t read_timeout_usec,
+                          time_t write_timeout_sec, time_t write_timeout_usec,
+                          std::function<bool()> is_shutting_down, T callback) {
+  return process_server_socket_core(
+      sock, keep_alive_max_count, is_shutting_down,
+      [&](bool close_connection, bool connection_closed) {
         SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
                              write_timeout_sec, write_timeout_usec);
-        return callback(strm, last_connection, connection_close);
+        return callback(strm, close_connection, connection_closed);
       });
 }
 
+template <typename T>
+inline bool
+process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec,
+                          time_t read_timeout_usec, time_t write_timeout_sec,
+                          time_t write_timeout_usec, T callback) {
+  SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
+                       write_timeout_sec, write_timeout_usec);
+  return callback(strm);
+}
+
 #if OPENSSL_VERSION_NUMBER < 0x10100000L
 static std::shared_ptr<std::vector<std::mutex>> openSSL_locks_;
 
@@ -5365,12 +5374,13 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) {
                              [](SSL * /*ssl*/) { return true; });
 
   if (ssl) {
-    auto ret = detail::process_socket_ssl(
-        ssl, false, sock, keep_alive_max_count_, read_timeout_sec_,
-        read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
-        [this, ssl](Stream &strm, bool last_connection,
-                    bool &connection_close) {
-          return process_request(strm, last_connection, connection_close,
+    auto ret = detail::process_server_socket_ssl(
+        ssl, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
+        write_timeout_sec_, write_timeout_usec_,
+        [this]() { return this->svr_sock_ == INVALID_SOCKET; },
+        [this, ssl](Stream &strm, bool close_connection,
+                    bool &connection_closed) {
+          return process_request(strm, close_connection, connection_closed,
                                  [&](Request &req) { req.ssl = ssl; });
         });
 
@@ -5455,49 +5465,36 @@ inline long SSLClient::get_openssl_verify_result() const {
 inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
 
 inline bool SSLClient::create_and_connect_socket(Socket &socket) {
-  if (is_valid() && Client::create_and_connect_socket(socket) &&
-      initialize_ssl(socket)) {
-    if (!proxy_host_.empty()) {
-      bool error;
-      if (!connect_with_proxy(socket, error)) { return error; }
-    }
-    return true;
-  }
-  return false;
+  return is_valid() && Client::create_and_connect_socket(socket);
 }
 
-inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) {
-  error = true;
-  Response res;
+inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
+                                          bool &success) {
+  success = true;
+  Response res2;
 
-  if (!detail::process_socket_core(
-          true, socket.sock, 1,
-          [&](bool /*last_connection*/, bool &connection_close) {
-            detail::SocketStream strm(socket.sock, read_timeout_sec_,
-                                      read_timeout_usec_, write_timeout_sec_,
-                                      write_timeout_usec_);
+  if (!detail::process_client_socket(
+          socket.sock, read_timeout_sec_, read_timeout_usec_,
+          write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) {
             Request req2;
             req2.method = "CONNECT";
             req2.path = host_and_port_;
-            return process_request(strm, req2, res, connection_close);
+            return process_request(strm, req2, res2, false);
           })) {
     close_socket(socket, true);
-    error = false;
+    success = false;
     return false;
   }
 
-  if (res.status == 407) {
+  if (res2.status == 407) {
     if (!proxy_digest_auth_username_.empty() &&
         !proxy_digest_auth_password_.empty()) {
       std::map<std::string, std::string> auth;
-      if (parse_www_authenticate(res, auth, true)) {
+      if (parse_www_authenticate(res2, auth, true)) {
         Response res3;
-        if (!detail::process_socket_core(
-                true, socket.sock, 1,
-                [&](bool /*last_connection*/, bool &connection_close) {
-                  detail::SocketStream strm(
-                      socket.sock, read_timeout_sec_, read_timeout_usec_,
-                      write_timeout_sec_, write_timeout_usec_);
+        if (!detail::process_client_socket(
+                socket.sock, read_timeout_sec_, read_timeout_usec_,
+                write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) {
                   Request req3;
                   req3.method = "CONNECT";
                   req3.path = host_and_port_;
@@ -5505,14 +5502,15 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) {
                       req3, auth, 1, random_string(10),
                       proxy_digest_auth_username_, proxy_digest_auth_password_,
                       true));
-                  return process_request(strm, req3, res3, connection_close);
+                  return process_request(strm, req3, res3, false);
                 })) {
           close_socket(socket, true);
-          error = false;
+          success = false;
           return false;
         }
       }
     } else {
+      res = res2;
       return false;
     }
   }
@@ -5583,17 +5581,12 @@ inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) {
 }
 
 inline bool
-SSLClient::process_socket(Socket &socket, size_t request_count,
-                          std::function<bool(Stream &strm, bool last_connection,
-                                             bool &connection_close)>
-                              callback) {
+SSLClient::process_socket(Socket &socket,
+                          std::function<bool(Stream &strm)> callback) {
   assert(socket.ssl);
-  return detail::process_socket_ssl(
-      socket.ssl, true, socket.sock, request_count, read_timeout_sec_,
-      read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
-      [&](Stream &strm, bool last_connection, bool &connection_close) {
-        return callback(strm, last_connection, connection_close);
-      });
+  return detail::process_client_socket_ssl(
+      socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_,
+      write_timeout_sec_, write_timeout_usec_, callback);
 }
 
 inline bool SSLClient::is_ssl() const { return true; }
@@ -5678,7 +5671,6 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const {
   }
 
   GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names);
-
   return ret;
 }
 

+ 56 - 44
test/test.cc

@@ -1136,6 +1136,10 @@ protected:
                 EXPECT_EQ(req.get_param_value("key"), "value");
                 EXPECT_EQ(req.body, "content");
               })
+        .Get("/last-request",
+             [&](const Request & req, Response &/*res*/) {
+               EXPECT_EQ("close", req.get_header_value("Connection"));
+             })
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
         .Get("/gzip",
              [&](const Request & /*req*/, Response &res) {
@@ -2127,42 +2131,48 @@ TEST_F(ServerTest, HTTP2Magic) {
 }
 
 TEST_F(ServerTest, KeepAlive) {
-  cli_.set_keep_alive_max_count(4);
-
-  std::vector<Request> requests;
-  Get(requests, "/hi");
-  Get(requests, "/hi");
-  Get(requests, "/hi");
-  Get(requests, "/not-exist");
-  Post(requests, "/empty", "", "text/plain");
-  Post(
-      requests, "/empty", 0,
-      [&](size_t, size_t, httplib::DataSink &) { return true; }, "text/plain");
-
-  std::vector<Response> responses;
-  auto ret = cli_.send(requests, responses);
-
-  ASSERT_TRUE(ret == true);
-  ASSERT_TRUE(requests.size() == responses.size());
-
-  for (size_t i = 0; i < 3; i++) {
-    auto &res = responses[i];
-    EXPECT_EQ(200, res.status);
-    EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
-    EXPECT_EQ("Hello World!", res.body);
-  }
+  auto res = cli_.Get("/hi");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
+  EXPECT_EQ("Hello World!", res->body);
 
-  {
-    auto &res = responses[3];
-    EXPECT_EQ(404, res.status);
-  }
+  res = cli_.Get("/hi");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
+  EXPECT_EQ("Hello World!", res->body);
 
-  for (size_t i = 4; i < 6; i++) {
-    auto &res = responses[i];
-    EXPECT_EQ(200, res.status);
-    EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
-    EXPECT_EQ("empty", res.body);
-  }
+  res = cli_.Get("/hi");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
+  EXPECT_EQ("Hello World!", res->body);
+
+  res = cli_.Get("/not-exist");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(404, res->status);
+
+  res = cli_.Post("/empty", "", "text/plain");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
+  EXPECT_EQ("empty", res->body);
+  EXPECT_EQ("close", res->get_header_value("Connection"));
+
+  res = cli_.Post(
+      "/empty", 0, [&](size_t, size_t, httplib::DataSink &) { return true; },
+      "text/plain");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
+  EXPECT_EQ("empty", res->body);
+
+  cli_.set_keep_alive(false);
+  res = cli_.Get("/last-request");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ("close", res->get_header_value("Connection"));
 }
 
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
@@ -2310,10 +2320,8 @@ static bool send_request(time_t read_timeout_sec, const std::string &req,
 
   if (client_sock == INVALID_SOCKET) { return false; }
 
-  auto ret = detail::process_socket(
-      true, client_sock, 1, read_timeout_sec, 0, 0, 0,
-      [&](Stream &strm, bool /*last_connection*/, bool &
-          /*connection_close*/) -> bool {
+  auto ret = detail::process_client_socket(
+      client_sock, read_timeout_sec, 0, 0, 0, [&](Stream &strm) {
         if (req.size() !=
             static_cast<size_t>(strm.write(req.data(), req.size()))) {
           return false;
@@ -2515,8 +2523,7 @@ TEST(ServerStopTest, StopServerWithChunkedTransmission) {
   }
 
   Client client(HOST, PORT);
-  const Headers headers = {{"Accept", "text/event-stream"},
-                           {"Connection", "Keep-Alive"}};
+  const Headers headers = {{"Accept", "text/event-stream"}};
 
   auto get_thread = std::thread([&client, &headers]() {
     std::shared_ptr<Response> res = client.Get(
@@ -2742,19 +2749,24 @@ TEST(SSLClientTest, ServerNameIndication) {
   ASSERT_EQ(200, res->status);
 }
 
-TEST(SSLClientTest, ServerCertificateVerification) {
+TEST(SSLClientTest, ServerCertificateVerification1) {
   SSLClient cli("google.com");
-
   auto res = cli.Get("/");
   ASSERT_TRUE(res != nullptr);
   ASSERT_EQ(301, res->status);
+}
 
+TEST(SSLClientTest, ServerCertificateVerification2) {
+  SSLClient cli("google.com");
   cli.enable_server_certificate_verification(true);
-  res = cli.Get("/");
+  auto res = cli.Get("/");
   ASSERT_TRUE(res == nullptr);
+}
 
+TEST(SSLClientTest, ServerCertificateVerification3) {
+  SSLClient cli("google.com");
   cli.set_ca_cert_path(CA_CERT_FILE);
-  res = cli.Get("/");
+  auto res = cli.Get("/");
   ASSERT_TRUE(res != nullptr);
   ASSERT_EQ(301, res->status);
 }

+ 22 - 44
test/test_proxy.cc

@@ -222,66 +222,45 @@ void KeepAliveTest(Client& cli, bool basic) {
 #endif
   }
 
-  cli.set_keep_alive_max_count(4);
   cli.set_follow_location(true);
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
   cli.set_digest_auth("hello", "world");
-
-  std::vector<Request> requests;
-
-  Get(requests, "/get");
-  Get(requests, "/redirect/2");
-
-  std::vector<std::string> paths = {
-      "/digest-auth/auth/hello/world/MD5",
-      "/digest-auth/auth/hello/world/SHA-256",
-      "/digest-auth/auth/hello/world/SHA-512",
-      "/digest-auth/auth-int/hello/world/MD5",
-  };
-
-  for (auto path : paths) {
-    Get(requests, path.c_str());
-  }
+#endif
 
   {
-    int count = 100;
-    while (count--) {
-      Get(requests, "/get");
-    }
+    auto res = cli.Get("/get");
+    EXPECT_EQ(200, res->status);
   }
-
-  std::vector<Response> responses;
-  auto ret = cli.send(requests, responses);
-  ASSERT_TRUE(ret == true);
-  ASSERT_TRUE(requests.size() == responses.size());
-
-  size_t i = 0;
-
   {
-    auto &res = responses[i++];
-    EXPECT_EQ(200, res.status);
+    auto res = cli.Get("/redirect/2");
+    EXPECT_EQ(200, res->status);
   }
 
   {
-    auto &res = responses[i++];
-    EXPECT_EQ(200, res.status);
-  }
+    std::vector<std::string> paths = {
+        "/digest-auth/auth/hello/world/MD5",
+        "/digest-auth/auth/hello/world/SHA-256",
+        "/digest-auth/auth/hello/world/SHA-512",
+        "/digest-auth/auth-int/hello/world/MD5",
+    };
 
+    for (auto path: paths) {
+      auto res = cli.Get(path.c_str());
+      EXPECT_EQ("{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n", res->body);
+      EXPECT_EQ(200, res->status);
+    }
+  }
 
   {
-    int count = static_cast<int>(paths.size());
+    int count = 100;
     while (count--) {
-      auto &res = responses[i++];
-      EXPECT_EQ("{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n", res.body);
-      EXPECT_EQ(200, res.status);
+      auto res = cli.Get("/get");
+      EXPECT_EQ(200, res->status);
     }
   }
-
-  for (; i < responses.size(); i++) {
-    auto &res = responses[i];
-    EXPECT_EQ(200, res.status);
-  }
 }
 
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 TEST(KeepAliveTest, NoSSLWithBasic) {
   Client cli("httpbin.org");
   KeepAliveTest(cli, true);
@@ -292,7 +271,6 @@ TEST(KeepAliveTest, SSLWithBasic) {
   KeepAliveTest(cli, true);
 }
 
-#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 TEST(KeepAliveTest, NoSSLWithDigest) {
   Client cli("httpbin.org");
   KeepAliveTest(cli, false);