Browse Source

Revert "Accept large data transfer over SSL (#1261)"

This reverts commit 307b729549a5243fde63b46e592d04793f1ec73f.
yhirose 3 years ago
parent
commit
dae318495f
2 changed files with 42 additions and 87 deletions
  1. 42 43
      httplib.h
  2. 0 44
      test/test.cc

+ 42 - 43
httplib.h

@@ -7228,63 +7228,62 @@ inline bool SSLSocketStream::is_writable() const {
 }
 
 inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
-  size_t readbytes = 0;
   if (SSL_pending(ssl_) > 0) {
-    auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
-    if (ret == 1) { return static_cast<ssize_t>(readbytes); }
-    if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; }
-    return -1;
-  }
-  if (!is_readable()) { return -1; }
-
-  auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
-  if (ret == 1) { return static_cast<ssize_t>(readbytes); }
-  auto err = SSL_get_error(ssl_, ret);
-  int n = 1000;
+    return SSL_read(ssl_, ptr, static_cast<int>(size));
+  } else if (is_readable()) {
+    auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
+    if (ret < 0) {
+      auto err = SSL_get_error(ssl_, ret);
+      int n = 1000;
 #ifdef _WIN32
-  while (--n >= 0 &&
-         (err == SSL_ERROR_WANT_READ ||
-          (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) {
+      while (--n >= 0 && (err == SSL_ERROR_WANT_READ ||
+                          (err == SSL_ERROR_SYSCALL &&
+                           WSAGetLastError() == WSAETIMEDOUT))) {
 #else
-  while (--n >= 0 && err == SSL_ERROR_WANT_READ) {
+      while (--n >= 0 && err == SSL_ERROR_WANT_READ) {
 #endif
-    if (SSL_pending(ssl_) > 0) {
-      ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
-      if (ret == 1) { return static_cast<ssize_t>(readbytes); }
-      if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; }
-      return -1;
+        if (SSL_pending(ssl_) > 0) {
+          return SSL_read(ssl_, ptr, static_cast<int>(size));
+        } else if (is_readable()) {
+          std::this_thread::sleep_for(std::chrono::milliseconds(1));
+          ret = SSL_read(ssl_, ptr, static_cast<int>(size));
+          if (ret >= 0) { return ret; }
+          err = SSL_get_error(ssl_, ret);
+        } else {
+          return -1;
+        }
+      }
     }
-    if (!is_readable()) { return -1; }
-    std::this_thread::sleep_for(std::chrono::milliseconds(1));
-    ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
-    if (ret == 1) { return static_cast<ssize_t>(readbytes); }
-    err = SSL_get_error(ssl_, ret);
+    return ret;
   }
-  if (err == SSL_ERROR_ZERO_RETURN) { return 0; }
   return -1;
 }
 
 inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
-  if (!is_writable()) { return -1; }
-  size_t written = 0;
-  auto ret = SSL_write_ex(ssl_, ptr, size, &written);
-  if (ret == 1) { return static_cast<ssize_t>(written); }
-  auto err = SSL_get_error(ssl_, ret);
-  int n = 1000;
+  if (is_writable()) {
+    auto ret = SSL_write(ssl_, ptr, static_cast<int>(size));
+    if (ret < 0) {
+      auto err = SSL_get_error(ssl_, ret);
+      int n = 1000;
 #ifdef _WIN32
-  while (--n >= 0 &&
-         (err == SSL_ERROR_WANT_WRITE ||
-          (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) {
+      while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE ||
+                          (err == SSL_ERROR_SYSCALL &&
+                           WSAGetLastError() == WSAETIMEDOUT))) {
 #else
-  while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
+      while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
 #endif
-    if (!is_writable()) { return -1; }
-    std::this_thread::sleep_for(std::chrono::milliseconds(1));
-    ret = SSL_write_ex(ssl_, ptr, size, &written);
-    if (ret == 1) { return static_cast<ssize_t>(written); }
-    err = SSL_get_error(ssl_, ret);
+        if (is_writable()) {
+          std::this_thread::sleep_for(std::chrono::milliseconds(1));
+          ret = SSL_write(ssl_, ptr, static_cast<int>(size));
+          if (ret >= 0) { return ret; }
+          err = SSL_get_error(ssl_, ret);
+        } else {
+          return -1;
+        }
+      }
+    }
+    return ret;
   }
-  if (err == SSL_ERROR_ZERO_RETURN) { return 0; }
   return -1;
 }
 

+ 0 - 44
test/test.cc

@@ -4660,50 +4660,6 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) {
 
   t.join();
 }
-
-// Disabled due to the out-of-memory problem on GitHub Actions Workflows
-TEST(SSLClientServerTest, DISABLED_LargeDataTransfer) {
-
-  // prepare large data
-  std::random_device seed_gen;
-  std::mt19937 random(seed_gen());
-  constexpr auto large_size_byte = 2147483648UL + 1048576UL; // 2GiB + 1MiB
-  std::vector<std::uint32_t> binary(large_size_byte / sizeof(std::uint32_t));
-  std::generate(binary.begin(), binary.end(), [&random]() { return random(); });
-
-  // server
-  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
-  ASSERT_TRUE(svr.is_valid());
-
-  svr.Post("/binary", [&](const Request &req, Response &res) {
-    EXPECT_EQ(large_size_byte, req.body.size());
-    EXPECT_EQ(0, std::memcmp(binary.data(), req.body.data(), large_size_byte));
-    res.set_content(req.body, "application/octet-stream");
-  });
-
-  auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
-  while (!svr.is_running()) {
-    std::this_thread::sleep_for(std::chrono::milliseconds(1));
-  }
-
-  // client POST
-  SSLClient cli("localhost", PORT);
-  cli.enable_server_certificate_verification(false);
-  cli.set_read_timeout(std::chrono::seconds(100));
-  cli.set_write_timeout(std::chrono::seconds(100));
-  auto res = cli.Post("/binary", reinterpret_cast<char *>(binary.data()),
-                      large_size_byte, "application/octet-stream");
-
-  // compare
-  EXPECT_EQ(200, res->status);
-  EXPECT_EQ(large_size_byte, res->body.size());
-  EXPECT_EQ(0, std::memcmp(binary.data(), res->body.data(), large_size_byte));
-
-  // cleanup
-  svr.stop();
-  listen_thread.join();
-  ASSERT_FALSE(svr.is_running());
-}
 #endif
 
 #ifdef _WIN32