yhirose 5 years ago
parent
commit
a2e4af54b7
2 changed files with 63 additions and 38 deletions
  1. 21 9
      httplib.h
  2. 42 29
      test/test.cc

+ 21 - 9
httplib.h

@@ -695,6 +695,8 @@ public:
   bool send(const std::vector<Request> &requests,
             std::vector<Response> &responses);
 
+  void stop();
+
   void set_timeout_sec(time_t timeout_sec);
 
   void set_read_timeout(time_t sec, time_t usec);
@@ -727,6 +729,8 @@ protected:
   bool process_request(Stream &strm, const Request &req, Response &res,
                        bool last_connection, bool &connection_close);
 
+  std::atomic<socket_t> sock_;
+
   const std::string host_;
   const int port_;
   const std::string host_and_port_;
@@ -3714,7 +3718,7 @@ inline bool Server::process_and_close_socket(socket_t sock) {
 inline Client::Client(const std::string &host, int port,
                       const std::string &client_cert_path,
                       const std::string &client_key_path)
-    : host_(host), port_(port),
+    : sock_(INVALID_SOCKET), host_(host), port_(port),
       host_and_port_(host_ + ":" + std::to_string(port_)),
       client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
 
@@ -3750,18 +3754,18 @@ inline bool Client::read_response_line(Stream &strm, Response &res) {
 }
 
 inline bool Client::send(const Request &req, Response &res) {
-  auto sock = create_client_socket();
-  if (sock == INVALID_SOCKET) { return false; }
+  sock_ = create_client_socket();
+  if (sock_ == INVALID_SOCKET) { return false; }
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
   if (is_ssl() && !proxy_host_.empty()) {
     bool error;
-    if (!connect(sock, res, error)) { return error; }
+    if (!connect(sock_, res, error)) { return error; }
   }
 #endif
 
   return process_and_close_socket(
-      sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) {
+      sock_, 1, [&](Stream &strm, bool last_connection, bool &connection_close) {
         return handle_request(strm, req, res, last_connection,
                               connection_close);
       });
@@ -3771,18 +3775,18 @@ inline bool Client::send(const std::vector<Request> &requests,
                          std::vector<Response> &responses) {
   size_t i = 0;
   while (i < requests.size()) {
-    auto sock = create_client_socket();
-    if (sock == INVALID_SOCKET) { return false; }
+    sock_ = create_client_socket();
+    if (sock_ == INVALID_SOCKET) { return false; }
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
     if (is_ssl() && !proxy_host_.empty()) {
       Response res;
       bool error;
-      if (!connect(sock, res, error)) { return false; }
+      if (!connect(sock_, res, error)) { return false; }
     }
 #endif
 
-    if (!process_and_close_socket(sock, requests.size() - i,
+    if (!process_and_close_socket(sock_, requests.size() - i,
                                   [&](Stream &strm, bool last_connection,
                                       bool &connection_close) -> bool {
                                     auto &req = requests[i++];
@@ -4446,6 +4450,14 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
   return send(req, *res) ? res : nullptr;
 }
 
+inline void Client::stop() {
+  if (sock_ != INVALID_SOCKET) {
+    std::atomic<socket_t> sock(sock_.exchange(INVALID_SOCKET));
+    detail::shutdown_socket(sock);
+    detail::close_socket(sock);
+  }
+}
+
 inline void Client::set_timeout_sec(time_t timeout_sec) {
   timeout_sec_ = timeout_sec;
 }

+ 42 - 29
test/test.cc

@@ -803,7 +803,8 @@ protected:
                auto remote_addr = req.headers.find("REMOTE_ADDR")->second;
                EXPECT_TRUE(req.has_header("REMOTE_PORT"));
                EXPECT_EQ(req.remote_addr, req.get_header_value("REMOTE_ADDR"));
-               EXPECT_EQ(req.remote_port, std::stoi(req.get_header_value("REMOTE_PORT")));
+               EXPECT_EQ(req.remote_port,
+                         std::stoi(req.get_header_value("REMOTE_PORT")));
                res.set_content(remote_addr.c_str(), "text/plain");
              })
         .Get("/endwith%",
@@ -979,12 +980,12 @@ protected:
                 res.set_content("empty-no-content-type", "text/plain");
               })
         .Put("/empty-no-content-type",
-              [&](const Request &req, Response &res) {
-                EXPECT_EQ(req.body, "");
-                EXPECT_FALSE(req.has_header("Content-Type"));
-                EXPECT_EQ("0", req.get_header_value("Content-Length"));
-                res.set_content("empty-no-content-type", "text/plain");
-              })
+             [&](const Request &req, Response &res) {
+               EXPECT_EQ(req.body, "");
+               EXPECT_FALSE(req.has_header("Content-Type"));
+               EXPECT_EQ("0", req.get_header_value("Content-Length"));
+               res.set_content("empty-no-content-type", "text/plain");
+             })
         .Put("/put",
              [&](const Request &req, Response &res) {
                EXPECT_EQ(req.body, "PUT");
@@ -1746,6 +1747,18 @@ TEST_F(ServerTest, GetStreamedEndless) {
   ASSERT_TRUE(res == nullptr);
 }
 
+TEST_F(ServerTest, ClientStop) {
+  thread t = thread([&]() {
+    auto res =
+        cli_.Get("/streamed-cancel",
+                 [&](const char *, uint64_t) { return true; });
+    ASSERT_TRUE(res == nullptr);
+  });
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+  cli_.stop();
+  t.join();
+}
+
 TEST_F(ServerTest, GetWithRange1) {
   auto res = cli_.Get("/with-range", {{make_range_header({{3, 5}})}});
   ASSERT_TRUE(res != nullptr);
@@ -2323,40 +2336,40 @@ TEST(ServerRequestParsingTest, ReadHeadersRegexComplexity2) {
 TEST(ServerRequestParsingTest, InvalidFirstChunkLengthInRequest) {
   std::string out;
 
-  test_raw_request(
-      "PUT /put_hi HTTP/1.1\r\n"
-      "Content-Type: text/plain\r\n"
-      "Transfer-Encoding: chunked\r\n"
-      "\r\n"
-      "nothex\r\n", &out);
+  test_raw_request("PUT /put_hi HTTP/1.1\r\n"
+                   "Content-Type: text/plain\r\n"
+                   "Transfer-Encoding: chunked\r\n"
+                   "\r\n"
+                   "nothex\r\n",
+                   &out);
   EXPECT_EQ("HTTP/1.1 400 Bad Request", out.substr(0, 24));
 }
 
 TEST(ServerRequestParsingTest, InvalidSecondChunkLengthInRequest) {
   std::string out;
 
-  test_raw_request(
-      "PUT /put_hi HTTP/1.1\r\n"
-      "Content-Type: text/plain\r\n"
-      "Transfer-Encoding: chunked\r\n"
-      "\r\n"
-      "3\r\n"
-      "xyz\r\n"
-      "NaN\r\n", &out);
+  test_raw_request("PUT /put_hi HTTP/1.1\r\n"
+                   "Content-Type: text/plain\r\n"
+                   "Transfer-Encoding: chunked\r\n"
+                   "\r\n"
+                   "3\r\n"
+                   "xyz\r\n"
+                   "NaN\r\n",
+                   &out);
   EXPECT_EQ("HTTP/1.1 400 Bad Request", out.substr(0, 24));
 }
 
 TEST(ServerRequestParsingTest, ChunkLengthTooHighInRequest) {
   std::string out;
 
-  test_raw_request(
-      "PUT /put_hi HTTP/1.1\r\n"
-      "Content-Type: text/plain\r\n"
-      "Transfer-Encoding: chunked\r\n"
-      "\r\n"
-      // Length is too large for 64 bits.
-      "1ffffffffffffffff\r\n"
-      "xyz\r\n", &out);
+  test_raw_request("PUT /put_hi HTTP/1.1\r\n"
+                   "Content-Type: text/plain\r\n"
+                   "Transfer-Encoding: chunked\r\n"
+                   "\r\n"
+                   // Length is too large for 64 bits.
+                   "1ffffffffffffffff\r\n"
+                   "xyz\r\n",
+                   &out);
   EXPECT_EQ("HTTP/1.1 400 Bad Request", out.substr(0, 24));
 }