yhirose 5 years ago
parent
commit
f086bf5310
2 changed files with 69 additions and 34 deletions
  1. 12 15
      httplib.h
  2. 57 19
      test/test.cc

+ 12 - 15
httplib.h

@@ -692,7 +692,7 @@ private:
                                    const Request &req, Response &res);
   bool write_response_core(Stream &strm, bool close_connection,
                            const Request &req, Response &res,
-                           std::string &content_type, std::string &boundary);
+                           bool need_apply_ranges);
   bool write_content_with_provider(Stream &strm, const Request &req,
                                    Response &res, const std::string &boundary,
                                    const std::string &content_type);
@@ -3769,7 +3769,8 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
   }
   return recv(sock_, ptr, static_cast<int>(size), CPPHTTPLIB_RECV_FLAGS);
 #else
-  return handle_EINTR([&]() { return recv(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); });
+  return handle_EINTR(
+      [&]() { return recv(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); });
 #endif
 }
 
@@ -3782,7 +3783,8 @@ inline ssize_t SocketStream::write(const char *ptr, size_t size) {
   }
   return send(sock_, ptr, static_cast<int>(size), CPPHTTPLIB_SEND_FLAGS);
 #else
-  return handle_EINTR([&]() { return send(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); });
+  return handle_EINTR(
+      [&]() { return send(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); });
 #endif
 }
 
@@ -4022,32 +4024,27 @@ inline bool Server::parse_request_line(const char *s, Request &req) {
 
 inline bool Server::write_response(Stream &strm, bool close_connection,
                                    const Request &req, Response &res) {
-  std::string content_type;
-  std::string boundary;
-  return write_response_core(strm, close_connection, req, res, content_type,
-                             boundary);
+  return write_response_core(strm, close_connection, req, res, false);
 }
 
 inline bool Server::write_response_with_content(Stream &strm,
                                                 bool close_connection,
                                                 const Request &req,
                                                 Response &res) {
-  std::string content_type;
-  std::string boundary;
-  apply_ranges(req, res, content_type, boundary);
-
-  return write_response_core(strm, close_connection, req, res, content_type,
-                             boundary);
+  return write_response_core(strm, close_connection, req, res, true);
 }
 
 inline bool Server::write_response_core(Stream &strm, bool close_connection,
                                         const Request &req, Response &res,
-                                        std::string &content_type,
-                                        std::string &boundary) {
+                                        bool need_apply_ranges) {
   assert(res.status != -1);
 
   if (400 <= res.status && error_handler_) { error_handler_(req, res); }
 
+  std::string content_type;
+  std::string boundary;
+  if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
+
   // Headers
   if (close_connection || req.get_header_value("Connection") == "close") {
     res.set_header("Connection", "close");

+ 57 - 19
test/test.cc

@@ -841,7 +841,7 @@ TEST(UrlWithSpace, Redirect) {
 }
 #endif
 
-TEST(Server, BindDualStack) {
+TEST(BindServerTest, BindDualStack) {
   Server svr;
 
   svr.Get("/1", [&](const Request & /*req*/, Response &res) {
@@ -874,7 +874,7 @@ TEST(Server, BindDualStack) {
   ASSERT_FALSE(svr.is_running());
 }
 
-TEST(Server, BindAndListenSeparately) {
+TEST(BindServerTest, BindAndListenSeparately) {
   Server svr;
   int port = svr.bind_to_any_port("0.0.0.0");
   ASSERT_TRUE(svr.is_valid());
@@ -883,7 +883,7 @@ TEST(Server, BindAndListenSeparately) {
 }
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-TEST(SSLServer, BindAndListenSeparately) {
+TEST(BindServerTest, BindAndListenSeparatelySSL) {
   SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE,
                 CLIENT_CA_CERT_DIR);
   int port = svr.bind_to_any_port("0.0.0.0");
@@ -893,6 +893,41 @@ TEST(SSLServer, BindAndListenSeparately) {
 }
 #endif
 
+TEST(ErrorHandlerTest, ContentLength) {
+  Server svr;
+
+  svr.set_error_handler([](const Request & /*req*/, Response &res) {
+    res.status = 200;
+    res.set_content("abcdefghijklmnopqrstuvwxyz",
+                    "text/html"); // <= Content-Length still 13
+  });
+
+  svr.Get("/hi", [](const Request & /*req*/, Response &res) {
+    res.set_content("Hello World!\n", "text/plain");
+    res.status = 524;
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+
+  // Give GET time to get a few messages.
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+
+  {
+    Client cli(HOST, PORT);
+
+    auto res = cli.Get("/hi");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(200, res->status);
+    EXPECT_EQ("text/html", res->get_header_value("Content-Type"));
+    EXPECT_EQ("26", res->get_header_value("Content-Length"));
+    EXPECT_EQ("abcdefghijklmnopqrstuvwxyz", res->body);
+  }
+
+  svr.stop();
+  thread.join();
+  ASSERT_FALSE(svr.is_running());
+}
+
 class ServerTest : public ::testing::Test {
 protected:
   ServerTest()
@@ -3473,24 +3508,27 @@ TEST(SSLClientServerTest, TrustDirOptional) {
 
 TEST(SSLClientServerTest, SSLConnectTimeout) {
   class NoListenSSLServer : public SSLServer {
-    public:
-      NoListenSSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path,
-                        const char *client_ca_cert_dir_path = nullptr)
-        : SSLServer(cert_path, private_key_path, client_ca_cert_file_path, client_ca_cert_dir_path)
-        , stop_(false)
-      {}
-
-      bool stop_;
-    private:
-      bool process_and_close_socket(socket_t /*sock*/) override {
-        // Don't create SSL context
-        while (!stop_) {
-           std::this_thread::sleep_for(std::chrono::milliseconds(100));
-        }
-        return true;
+  public:
+    NoListenSSLServer(const char *cert_path, const char *private_key_path,
+                      const char *client_ca_cert_file_path,
+                      const char *client_ca_cert_dir_path = nullptr)
+        : SSLServer(cert_path, private_key_path, client_ca_cert_file_path,
+                    client_ca_cert_dir_path),
+          stop_(false) {}
+
+    bool stop_;
+
+  private:
+    bool process_and_close_socket(socket_t /*sock*/) override {
+      // Don't create SSL context
+      while (!stop_) {
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
       }
+      return true;
+    }
   };
-  NoListenSSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE);
+  NoListenSSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE,
+                        CLIENT_CA_CERT_FILE);
   ASSERT_TRUE(svr.is_valid());
 
   svr.Get("/test", [&](const Request &, Response &res) {