Browse Source

Accept large data transfer over SSL (Fix #1261, Close #1312)

yhirose 3 years ago
parent
commit
caa31aafda
2 changed files with 56 additions and 6 deletions
  1. 12 6
      httplib.h
  2. 44 0
      test/test.cc

+ 12 - 6
httplib.h

@@ -193,7 +193,6 @@ using socket_t = int;
 #endif
 #endif
 #endif //_WIN32
 #endif //_WIN32
 
 
-#include <cstring>
 #include <algorithm>
 #include <algorithm>
 #include <array>
 #include <array>
 #include <atomic>
 #include <atomic>
@@ -201,6 +200,7 @@ using socket_t = int;
 #include <cctype>
 #include <cctype>
 #include <climits>
 #include <climits>
 #include <condition_variable>
 #include <condition_variable>
+#include <cstring>
 #include <errno.h>
 #include <errno.h>
 #include <fcntl.h>
 #include <fcntl.h>
 #include <fstream>
 #include <fstream>
@@ -5098,14 +5098,16 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
 
 
     // Flush buffer
     // Flush buffer
     auto &data = bstrm.get_buffer();
     auto &data = bstrm.get_buffer();
-    strm.write(data.data(), data.size());
+    detail::write_data(strm, data.data(), data.size());
   }
   }
 
 
   // Body
   // Body
   auto ret = true;
   auto ret = true;
   if (req.method != "HEAD") {
   if (req.method != "HEAD") {
     if (!res.body.empty()) {
     if (!res.body.empty()) {
-      if (!strm.write(res.body)) { ret = false; }
+      if (!detail::write_data(strm, res.body.data(), res.body.size())) {
+        ret = false;
+      }
     } else if (res.content_provider_) {
     } else if (res.content_provider_) {
       if (write_content_with_provider(strm, req, res, boundary, content_type)) {
       if (write_content_with_provider(strm, req, res, boundary, content_type)) {
         res.content_provider_success_ = true;
         res.content_provider_success_ = true;
@@ -6322,7 +6324,8 @@ inline std::unique_ptr<Response> ClientImpl::send_with_content_provider(
           auto last = offset + data_len == content_length;
           auto last = offset + data_len == content_length;
 
 
           auto ret = compressor.compress(
           auto ret = compressor.compress(
-              data, data_len, last, [&](const char *compressed_data, size_t compressed_data_len) {
+              data, data_len, last,
+              [&](const char *compressed_data, size_t compressed_data_len) {
                 req.body.append(compressed_data, compressed_data_len);
                 req.body.append(compressed_data, compressed_data_len);
                 return true;
                 return true;
               });
               });
@@ -7261,7 +7264,10 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
 
 
 inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
 inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
   if (is_writable()) {
   if (is_writable()) {
-    auto ret = SSL_write(ssl_, ptr, static_cast<int>(size));
+    auto handle_size = static_cast<int>(
+        std::min<size_t>(size, std::numeric_limits<int>::max()));
+
+    auto ret = SSL_write(ssl_, ptr, static_cast<int>(handle_size));
     if (ret < 0) {
     if (ret < 0) {
       auto err = SSL_get_error(ssl_, ret);
       auto err = SSL_get_error(ssl_, ret);
       int n = 1000;
       int n = 1000;
@@ -7274,7 +7280,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
 #endif
 #endif
         if (is_writable()) {
         if (is_writable()) {
           std::this_thread::sleep_for(std::chrono::milliseconds(1));
           std::this_thread::sleep_for(std::chrono::milliseconds(1));
-          ret = SSL_write(ssl_, ptr, static_cast<int>(size));
+          ret = SSL_write(ssl_, ptr, static_cast<int>(handle_size));
           if (ret >= 0) { return ret; }
           if (ret >= 0) { return ret; }
           err = SSL_get_error(ssl_, ret);
           err = SSL_get_error(ssl_, ret);
         } else {
         } else {

+ 44 - 0
test/test.cc

@@ -4660,6 +4660,50 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) {
 
 
   t.join();
   t.join();
 }
 }
+
+// Disabled due to the out-of-memory problem on GitHub Actions Workflows
+TEST(SSLClientServerTest, DISABLED_LargeDataTransfer) {
+
+  // prepare large data
+  std::random_device seed_gen;
+  std::mt19937 random(seed_gen());
+  constexpr auto large_size_byte = 2147483648UL + 1048576UL; // 2GiB + 1MiB
+  std::vector<std::uint32_t> binary(large_size_byte / sizeof(std::uint32_t));
+  std::generate(binary.begin(), binary.end(), [&random]() { return random(); });
+
+  // server
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
+  ASSERT_TRUE(svr.is_valid());
+
+  svr.Post("/binary", [&](const Request &req, Response &res) {
+    EXPECT_EQ(large_size_byte, req.body.size());
+    EXPECT_EQ(0, std::memcmp(binary.data(), req.body.data(), large_size_byte));
+    res.set_content(req.body, "application/octet-stream");
+  });
+
+  auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
+  while (!svr.is_running()) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  }
+
+  // client POST
+  SSLClient cli("localhost", PORT);
+  cli.enable_server_certificate_verification(false);
+  cli.set_read_timeout(std::chrono::seconds(100));
+  cli.set_write_timeout(std::chrono::seconds(100));
+  auto res = cli.Post("/binary", reinterpret_cast<char *>(binary.data()),
+                      large_size_byte, "application/octet-stream");
+
+  // compare
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ(large_size_byte, res->body.size());
+  EXPECT_EQ(0, std::memcmp(binary.data(), res->body.data(), large_size_byte));
+
+  // cleanup
+  svr.stop();
+  listen_thread.join();
+  ASSERT_FALSE(svr.is_running());
+}
 #endif
 #endif
 
 
 #ifdef _WIN32
 #ifdef _WIN32