Browse Source

Keep-alive connection support on client (Fix #36)

yhirose 6 years ago
parent
commit
1e82359329
2 changed files with 171 additions and 55 deletions
  1. 135 55
      httplib.h
  2. 36 0
      test/test.cc

+ 135 - 55
httplib.h

@@ -171,6 +171,9 @@ struct Request {
   Ranges ranges;
   Match matches;
 
+  ContentReceiver content_receiver;
+  Progress progress;
+
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
   const SSL *ssl;
 #endif
@@ -195,10 +198,6 @@ struct Response {
   Headers headers;
   std::string body;
 
-  ContentReceiver content_receiver;
-
-  Progress progress;
-
   bool has_header(const char *key) const;
   std::string get_header_value(const char *key, size_t id = 0) const;
   size_t get_header_value_count(const char *key) const;
@@ -456,7 +455,7 @@ private:
                                    Response &res, const std::string &boundary,
                                    const std::string &content_type);
 
-  virtual bool read_and_close_socket(socket_t sock);
+  virtual bool process_and_close_socket(socket_t sock);
 
   std::atomic<bool> is_running_;
   std::atomic<socket_t> svr_sock_;
@@ -533,6 +532,10 @@ public:
 
   bool send(Request &req, Response &res);
 
+  bool send(std::vector<Request> &requests, std::vector<Response>& responses);
+
+  void set_keep_alive_max_count(size_t count);
+
 protected:
   bool process_request(Stream &strm, Request &req, Response &res,
                        bool &connection_close);
@@ -541,17 +544,48 @@ protected:
   const int port_;
   time_t timeout_sec_;
   const std::string host_and_port_;
+  size_t keep_alive_max_count_;
 
 private:
   socket_t create_client_socket() const;
   bool read_response_line(Stream &strm, Response &res);
   void write_request(Stream &strm, Request &req);
 
-  virtual bool read_and_close_socket(socket_t sock, Request &req,
-                                     Response &res);
+  virtual bool process_and_close_socket(
+      socket_t sock, size_t request_count,
+      std::function<bool(Stream &strm, bool last_connection,
+                         bool &connection_close)>
+          callback);
+
   virtual bool is_ssl() const;
 };
 
+inline void Get(std::vector<Request> &requests, const char *path, const Headers &headers) {
+  Request req;
+  req.method = "GET";
+  req.path = path;
+  req.headers = headers;
+  requests.emplace_back(std::move(req));
+}
+
+inline void Get(std::vector<Request> &requests, const char *path) {
+  Get(requests, path, Headers());
+}
+
+inline void Post(std::vector<Request> &requests, const char *path, const Headers &headers, const std::string &body, const char *content_type) {
+  Request req;
+  req.method = "POST";
+  req.path = path;
+  req.headers = headers;
+  req.headers.emplace("Content-Type", content_type);
+  req.body = body;
+  requests.emplace_back(std::move(req));
+}
+
+inline void Post(std::vector<Request> &requests, const char *path, const std::string &body, const char *content_type) {
+  Post(requests, path, Headers(), body, content_type);
+}
+
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 class SSLSocketStream : public Stream {
 public:
@@ -580,7 +614,7 @@ public:
   virtual bool is_valid() const;
 
 private:
-  virtual bool read_and_close_socket(socket_t sock);
+  virtual bool process_and_close_socket(socket_t sock);
 
   SSL_CTX *ctx_;
   std::mutex ctx_mutex_;
@@ -603,8 +637,11 @@ public:
   long get_openssl_verify_result() const;
 
 private:
-  virtual bool read_and_close_socket(socket_t sock, Request &req,
-                                     Response &res);
+  virtual bool process_and_close_socket(
+      socket_t sock, size_t request_count,
+      std::function<bool(Stream &strm, bool last_connection,
+                         bool &connection_close)>
+          callback);
   virtual bool is_ssl() const;
 
   bool verify_host(X509 *server_cert) const;
@@ -928,15 +965,18 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) {
 }
 
 template <typename T>
-inline bool read_and_close_socket(socket_t sock, size_t keep_alive_max_count,
-                                  T callback) {
+inline bool process_and_close_socket(bool is_client_request, socket_t sock,
+                                     size_t keep_alive_max_count, T callback) {
+  assert(keep_alive_max_count > 0);
+
   bool ret = false;
 
-  if (keep_alive_max_count > 0) {
+  if (keep_alive_max_count > 1) {
     auto count = keep_alive_max_count;
     while (count > 0 &&
-           detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
-                               CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) {
+           (is_client_request ||
+            detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
+                                CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
       SocketStream strm(sock);
       auto last_connection = count == 1;
       auto connection_close = false;
@@ -2315,9 +2355,7 @@ inline bool Server::handle_file_request(Request &req, Response &res) {
       auto type = detail::find_content_type(path);
       if (type) { res.set_header("Content-Type", type); }
       res.status = 200;
-      if (file_request_handler_) {
-        file_request_handler_(req, res);
-      }
+      if (file_request_handler_) { file_request_handler_(req, res); }
       return true;
     }
   }
@@ -2398,7 +2436,7 @@ inline bool Server::listen_internal() {
         break;
       }
 
-      task_queue->enqueue([=]() { read_and_close_socket(sock); });
+      task_queue->enqueue([=]() { process_and_close_socket(sock); });
     }
 
     task_queue->shutdown();
@@ -2528,9 +2566,9 @@ Server::process_request(Stream &strm, bool last_connection,
 
 inline bool Server::is_valid() const { return true; }
 
-inline bool Server::read_and_close_socket(socket_t sock) {
-  return detail::read_and_close_socket(
-      sock, keep_alive_max_count_,
+inline bool Server::process_and_close_socket(socket_t sock) {
+  return detail::process_and_close_socket(
+      false, sock, keep_alive_max_count_,
       [this](Stream &strm, bool last_connection, bool &connection_close) {
         return process_request(strm, last_connection, connection_close,
                                nullptr);
@@ -2540,7 +2578,8 @@ inline bool Server::read_and_close_socket(socket_t sock) {
 // HTTP client implementation
 inline Client::Client(const char *host, int port, time_t timeout_sec)
     : host_(host), port_(port), timeout_sec_(timeout_sec),
-      host_and_port_(host_ + ":" + std::to_string(port_)) {}
+      host_and_port_(host_ + ":" + std::to_string(port_)),
+      keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT) {}
 
 inline Client::~Client() {}
 
@@ -2590,7 +2629,37 @@ inline bool Client::send(Request &req, Response &res) {
   auto sock = create_client_socket();
   if (sock == INVALID_SOCKET) { return false; }
 
-  return read_and_close_socket(sock, req, res);
+  return process_and_close_socket(
+      sock, 1,
+      [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
+        return process_request(strm, req, res, connection_close);
+      });
+}
+
+inline bool Client::send(std::vector<Request> &requests, std::vector<Response>& responses) {
+  size_t i = 0;
+  while (i < requests.size()) {
+    auto sock = create_client_socket();
+    if (sock == INVALID_SOCKET) { return false; }
+
+    if (!process_and_close_socket(
+            sock, requests.size() - i,
+            [&](Stream &strm, bool last_connection, bool &connection_close) {
+              auto &req = requests[i];
+              auto res = Response();
+              i++;
+
+              if (req.path.empty()) { return false; }
+              if (last_connection) { req.set_header("Connection", "close"); }
+              auto ret = process_request(strm, req, res, connection_close);
+              if (ret) { responses.emplace_back(std::move(res)); }
+              return ret;
+            })) {
+      return false;
+    }
+  }
+
+  return true;
 }
 
 inline void Client::write_request(Stream &strm, Request &req) {
@@ -2677,10 +2746,10 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
       return true;
     };
 
-    if (res.content_receiver) {
+    if (req.content_receiver) {
       auto offset = std::make_shared<uint64_t>();
       auto length = get_header_value_uint64(res.headers, "Content-Length", 0);
-      auto receiver = res.content_receiver;
+      auto receiver = req.content_receiver;
       out = [offset, length, receiver](const char *buf, size_t n) {
         auto ret = receiver(buf, n, *offset, length);
         (*offset) += n;
@@ -2690,7 +2759,7 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
 
     int dummy_status;
     if (!detail::read_content(strm, res, std::numeric_limits<uint64_t>::max(),
-                              dummy_status, res.progress, out)) {
+                              dummy_status, req.progress, out)) {
       return false;
     }
   }
@@ -2698,13 +2767,13 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
   return true;
 }
 
-inline bool Client::read_and_close_socket(socket_t sock, Request &req,
-                                          Response &res) {
-  return detail::read_and_close_socket(
-      sock, 0,
-      [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
-        return process_request(strm, req, res, connection_close);
-      });
+inline bool Client::process_and_close_socket(
+    socket_t sock, size_t request_count,
+    std::function<bool(Stream &strm, bool last_connection,
+                       bool &connection_close)>
+        callback) {
+  request_count = std::min(request_count, keep_alive_max_count_);
+  return detail::process_and_close_socket(true, sock, request_count, callback);
 }
 
 inline bool Client::is_ssl() const { return false; }
@@ -2720,10 +2789,9 @@ Client::Get(const char *path, const Headers &headers, Progress progress) {
   req.method = "GET";
   req.path = path;
   req.headers = headers;
+  req.progress = progress;
 
   auto res = std::make_shared<Response>();
-  res->progress = progress;
-
   return send(req, *res) ? res : nullptr;
 }
 
@@ -2741,11 +2809,10 @@ inline std::shared_ptr<Response> Client::Get(const char *path,
   req.method = "GET";
   req.path = path;
   req.headers = headers;
+  req.content_receiver = content_receiver;
+  req.progress = progress;
 
   auto res = std::make_shared<Response>();
-  res->content_receiver = content_receiver;
-  res->progress = progress;
-
   return send(req, *res) ? res : nullptr;
 }
 
@@ -2930,6 +2997,10 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
   return send(req, *res) ? res : nullptr;
 }
 
+inline void Client::set_keep_alive_max_count(size_t count) {
+  keep_alive_max_count_ = count;
+}
+
 /*
  * SSL Implementation
  */
@@ -2937,10 +3008,13 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
 namespace detail {
 
 template <typename U, typename V, typename T>
-inline bool
-read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
-                          SSL_CTX *ctx, std::mutex &ctx_mutex,
-                          U SSL_connect_or_accept, V setup, T callback) {
+inline bool process_and_close_socket_ssl(bool is_client_request, socket_t sock,
+                                         size_t keep_alive_max_count,
+                                         SSL_CTX *ctx, std::mutex &ctx_mutex,
+                                         U SSL_connect_or_accept, V setup,
+                                         T callback) {
+  assert(keep_alive_max_count > 0);
+
   SSL *ssl = nullptr;
   {
     std::lock_guard<std::mutex> guard(ctx_mutex);
@@ -2969,11 +3043,12 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
   bool ret = false;
 
   if (SSL_connect_or_accept(ssl) == 1) {
-    if (keep_alive_max_count > 0) {
+    if (keep_alive_max_count > 1) {
       auto count = keep_alive_max_count;
       while (count > 0 &&
-             detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
-                                 CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) {
+             (is_client_request ||
+              detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
+                                  CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
         SSLSocketStream strm(sock, ssl);
         auto last_connection = count == 1;
         auto connection_close = false;
@@ -3123,9 +3198,9 @@ inline SSLServer::~SSLServer() {
 
 inline bool SSLServer::is_valid() const { return ctx_; }
 
-inline bool SSLServer::read_and_close_socket(socket_t sock) {
-  return detail::read_and_close_socket_ssl(
-      sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept,
+inline bool SSLServer::process_and_close_socket(socket_t sock) {
+  return detail::process_and_close_socket_ssl(
+      false, sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept,
       [](SSL * /*ssl*/) { return true; },
       [this](SSL *ssl, Stream &strm, bool last_connection,
              bool &connection_close) {
@@ -3176,12 +3251,17 @@ inline long SSLClient::get_openssl_verify_result() const {
   return verify_result_;
 }
 
-inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req,
-                                             Response &res) {
+inline bool SSLClient::process_and_close_socket(
+    socket_t sock, size_t request_count,
+    std::function<bool(Stream &strm, bool last_connection,
+                       bool &connection_close)>
+        callback) {
+
+  request_count = std::min(request_count, keep_alive_max_count_);
 
   return is_valid() &&
-         detail::read_and_close_socket_ssl(
-             sock, 0, ctx_, ctx_mutex_,
+         detail::process_and_close_socket_ssl(
+             true, sock, request_count, ctx_, ctx_mutex_,
              [&](SSL *ssl) {
                if (ca_cert_file_path_.empty()) {
                  SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
@@ -3217,9 +3297,9 @@ inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req,
                SSL_set_tlsext_host_name(ssl, host_.c_str());
                return true;
              },
-             [&](SSL * /*ssl*/, Stream &strm, bool /*last_connection*/,
+             [&](SSL * /*ssl*/, Stream &strm, bool last_connection,
                  bool &connection_close) {
-               return process_request(strm, req, res, connection_close);
+               return callback(strm, last_connection, connection_close);
              });
 }
 

+ 36 - 0
test/test.cc

@@ -1280,6 +1280,42 @@ TEST_F(ServerTest, NoMultipleHeaders) {
   EXPECT_EQ(200, res->status);
 }
 
+TEST_F(ServerTest, KeepAlive) {
+  cli_.set_keep_alive_max_count(4);
+
+  std::vector<Request> requests;
+  Get(requests, "/hi");
+  Get(requests, "/hi");
+  Get(requests, "/hi");
+  Get(requests, "/not-exist");
+  Post(requests, "/empty", "", "text/plain");
+
+  std::vector<Response> responses;
+  auto ret = cli_.send(requests, responses);
+
+  ASSERT_TRUE(ret == true);
+  ASSERT_TRUE(requests.size() == responses.size());
+
+  for (int i = 0; i < 3; i++) {
+    auto& res = responses[i];
+    EXPECT_EQ(200, res.status);
+    EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
+    EXPECT_EQ("Hello World!", res.body);
+  }
+
+  {
+    auto& res = responses[3];
+    EXPECT_EQ(404, res.status);
+  }
+
+  {
+    auto& res = responses[4];
+    EXPECT_EQ(200, res.status);
+    EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
+    EXPECT_EQ("empty", res.body);
+  }
+}
+
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
 TEST_F(ServerTest, Gzip) {
   Headers headers;