Browse Source

Accept large data transfer over SSL (#1261)

* Add large data transfer test

* Replace `SSL_read` and `SSL_write` with `ex` functions

* Reflect review comment

* Fix return value of `SSLSocketStream::read/write`

* Fix return value in the case of `SSL_ERROR_ZERO_RETURN`

* Disable `LargeDataTransfer` test due to OoM in CI
Yoshiki Matsuda 3 years ago
parent
commit
307b729549
2 changed files with 87 additions and 42 deletions
  1. 43 42
      httplib.h
  2. 44 0
      test/test.cc

+ 43 - 42
httplib.h

@@ -7221,62 +7221,63 @@ inline bool SSLSocketStream::is_writable() const {
 }
 }
 
 
 inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
 inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
+  size_t readbytes = 0;
   if (SSL_pending(ssl_) > 0) {
   if (SSL_pending(ssl_) > 0) {
-    return SSL_read(ssl_, ptr, static_cast<int>(size));
-  } else if (is_readable()) {
-    auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
-    if (ret < 0) {
-      auto err = SSL_get_error(ssl_, ret);
-      int n = 1000;
+    auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
+    if (ret == 1) { return static_cast<ssize_t>(readbytes); }
+    if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; }
+    return -1;
+  }
+  if (!is_readable()) { return -1; }
+
+  auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
+  if (ret == 1) { return static_cast<ssize_t>(readbytes); }
+  auto err = SSL_get_error(ssl_, ret);
+  int n = 1000;
 #ifdef _WIN32
 #ifdef _WIN32
-      while (--n >= 0 && (err == SSL_ERROR_WANT_READ ||
-                          (err == SSL_ERROR_SYSCALL &&
-                           WSAGetLastError() == WSAETIMEDOUT))) {
+  while (--n >= 0 &&
+         (err == SSL_ERROR_WANT_READ ||
+          (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) {
 #else
 #else
-      while (--n >= 0 && err == SSL_ERROR_WANT_READ) {
+  while (--n >= 0 && err == SSL_ERROR_WANT_READ) {
 #endif
 #endif
-        if (SSL_pending(ssl_) > 0) {
-          return SSL_read(ssl_, ptr, static_cast<int>(size));
-        } else if (is_readable()) {
-          std::this_thread::sleep_for(std::chrono::milliseconds(1));
-          ret = SSL_read(ssl_, ptr, static_cast<int>(size));
-          if (ret >= 0) { return ret; }
-          err = SSL_get_error(ssl_, ret);
-        } else {
-          return -1;
-        }
-      }
+    if (SSL_pending(ssl_) > 0) {
+      ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
+      if (ret == 1) { return static_cast<ssize_t>(readbytes); }
+      if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; }
+      return -1;
     }
     }
-    return ret;
+    if (!is_readable()) { return -1; }
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+    ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
+    if (ret == 1) { return static_cast<ssize_t>(readbytes); }
+    err = SSL_get_error(ssl_, ret);
   }
   }
+  if (err == SSL_ERROR_ZERO_RETURN) { return 0; }
   return -1;
   return -1;
 }
 }
 
 
 inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
 inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
-  if (is_writable()) {
-    auto ret = SSL_write(ssl_, ptr, static_cast<int>(size));
-    if (ret < 0) {
-      auto err = SSL_get_error(ssl_, ret);
-      int n = 1000;
+  if (!is_writable()) { return -1; }
+  size_t written = 0;
+  auto ret = SSL_write_ex(ssl_, ptr, size, &written);
+  if (ret == 1) { return static_cast<ssize_t>(written); }
+  auto err = SSL_get_error(ssl_, ret);
+  int n = 1000;
 #ifdef _WIN32
 #ifdef _WIN32
-      while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE ||
-                          (err == SSL_ERROR_SYSCALL &&
-                           WSAGetLastError() == WSAETIMEDOUT))) {
+  while (--n >= 0 &&
+         (err == SSL_ERROR_WANT_WRITE ||
+          (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) {
 #else
 #else
-      while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
+  while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
 #endif
 #endif
-        if (is_writable()) {
-          std::this_thread::sleep_for(std::chrono::milliseconds(1));
-          ret = SSL_write(ssl_, ptr, static_cast<int>(size));
-          if (ret >= 0) { return ret; }
-          err = SSL_get_error(ssl_, ret);
-        } else {
-          return -1;
-        }
-      }
-    }
-    return ret;
+    if (!is_writable()) { return -1; }
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+    ret = SSL_write_ex(ssl_, ptr, size, &written);
+    if (ret == 1) { return static_cast<ssize_t>(written); }
+    err = SSL_get_error(ssl_, ret);
   }
   }
+  if (err == SSL_ERROR_ZERO_RETURN) { return 0; }
   return -1;
   return -1;
 }
 }
 
 

+ 44 - 0
test/test.cc

@@ -4660,6 +4660,50 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) {
 
 
   t.join();
   t.join();
 }
 }
+
+// Disabled due to the out-of-memory problem on GitHub Actions Workflows
+TEST(SSLClientServerTest, DISABLED_LargeDataTransfer) {
+
+  // prepare large data
+  std::random_device seed_gen;
+  std::mt19937 random(seed_gen());
+  constexpr auto large_size_byte = 2147483648UL + 1048576UL; // 2GiB + 1MiB
+  std::vector<std::uint32_t> binary(large_size_byte / sizeof(std::uint32_t));
+  std::generate(binary.begin(), binary.end(), [&random]() { return random(); });
+
+  // server
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
+  ASSERT_TRUE(svr.is_valid());
+
+  svr.Post("/binary", [&](const Request &req, Response &res) {
+    EXPECT_EQ(large_size_byte, req.body.size());
+    EXPECT_EQ(0, std::memcmp(binary.data(), req.body.data(), large_size_byte));
+    res.set_content(req.body, "application/octet-stream");
+  });
+
+  auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
+  while (!svr.is_running()) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  }
+
+  // client POST
+  SSLClient cli("localhost", PORT);
+  cli.enable_server_certificate_verification(false);
+  cli.set_read_timeout(std::chrono::seconds(100));
+  cli.set_write_timeout(std::chrono::seconds(100));
+  auto res = cli.Post("/binary", reinterpret_cast<char *>(binary.data()),
+                      large_size_byte, "application/octet-stream");
+
+  // compare
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ(large_size_byte, res->body.size());
+  EXPECT_EQ(0, std::memcmp(binary.data(), res->body.data(), large_size_byte));
+
+  // cleanup
+  svr.stop();
+  listen_thread.join();
+  ASSERT_FALSE(svr.is_running());
+}
 #endif
 #endif
 
 
 #ifdef _WIN32
 #ifdef _WIN32