2
0
yhirose 6 жил өмнө
parent
commit
2f72845008
2 өөрчлөгдсөн 48 нэмэгдсэн , 25 устгасан
  1. 8 25
      httplib.h
  2. 40 0
      test/test.cc

+ 8 - 25
httplib.h

@@ -1897,24 +1897,20 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
     };
     };
     data_sink.done = [&](void) { written_length = -1; };
     data_sink.done = [&](void) { written_length = -1; };
 
 
-    content_provider(offset, end_offset - offset,
-                     // [&](const char *d, size_t l) {
-                     //   offset += l;
-                     //   written_length = strm.write(d, l);
-                     // },
-                     // [&](void) { written_length = -1; }
-                     data_sink);
+    content_provider(offset, end_offset - offset, data_sink);
     if (written_length < 0) { return written_length; }
     if (written_length < 0) { return written_length; }
   }
   }
   return static_cast<ssize_t>(offset - begin_offset);
   return static_cast<ssize_t>(offset - begin_offset);
 }
 }
 
 
+template <typename T>
 inline ssize_t write_content_chunked(Stream &strm,
 inline ssize_t write_content_chunked(Stream &strm,
-                                     ContentProvider content_provider) {
+                                     ContentProvider content_provider,
+                                     T is_shutting_down) {
   size_t offset = 0;
   size_t offset = 0;
   auto data_available = true;
   auto data_available = true;
   ssize_t total_written_length = 0;
   ssize_t total_written_length = 0;
-  while (data_available) {
+  while (data_available && !is_shutting_down()) {
     ssize_t written_length = 0;
     ssize_t written_length = 0;
 
 
     DataSink data_sink;
     DataSink data_sink;
@@ -1931,21 +1927,7 @@ inline ssize_t write_content_chunked(Stream &strm,
       written_length = strm.write("0\r\n\r\n");
       written_length = strm.write("0\r\n\r\n");
     };
     };
 
 
-    content_provider(
-        offset, 0,
-        // [&](const char *d, size_t l) {
-        //   data_available = l > 0;
-        //   offset += l;
-        //
-        //   // Emit chunked response header and footer for each chunk
-        //   auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) +
-        //   "\r\n"; written_length = strm.write(chunk);
-        // },
-        // [&](void) {
-        //   data_available = false;
-        //   written_length = strm.write("0\r\n\r\n");
-        // }
-        data_sink);
+    content_provider(offset, 0, data_sink);
 
 
     if (written_length < 0) { return written_length; }
     if (written_length < 0) { return written_length; }
     total_written_length += written_length;
     total_written_length += written_length;
@@ -3088,7 +3070,8 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
       }
       }
     }
     }
   } else {
   } else {
-    if (detail::write_content_chunked(strm, res.content_provider) < 0) {
+    auto is_shutting_down = [this]() { return this->svr_sock_ == INVALID_SOCKET; };
+    if (detail::write_content_chunked(strm, res.content_provider, is_shutting_down) < 0) {
       return false;
       return false;
     }
     }
   }
   }

+ 40 - 0
test/test.cc

@@ -1967,6 +1967,46 @@ TEST(ServerRequestParsingTest, ReadHeadersRegexComplexity) {
   EXPECT_TRUE(listen_thread_ok);
   EXPECT_TRUE(listen_thread_ok);
 }
 }
 
 
+TEST(ServerStopTest, StopServerWithChunkedTransmission) {
+  Server svr;
+
+  svr.Get("/events", [](const Request &req, Response &res) {
+    res.set_header("Content-Type", "text/event-stream");
+    res.set_header("Cache-Control", "no-cache");
+    res.set_chunked_content_provider([](size_t offset, const DataSink &sink) {
+      char buffer[27];
+      int size = sprintf(buffer, "data:%ld\n\n", offset);
+      sink.write(buffer, size);
+      std::this_thread::sleep_for(std::chrono::seconds(1));
+    });
+  });
+
+  auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
+  while (!svr.is_running()) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  }
+
+  Client client(HOST, PORT);
+  const Headers headers = {{"Accept", "text/event-stream"},
+                           {"Connection", "Keep-Alive"}};
+
+  auto get_thread = std::thread([&client, &headers]() {
+    std::shared_ptr<Response> res =
+        client.Get("/events", headers,
+                   [](const char *data, size_t len) -> bool { return true; });
+  });
+
+  // Give GET time to get a few messages.
+  std::this_thread::sleep_for(std::chrono::seconds(2));
+
+  svr.stop();
+
+  listen_thread.join();
+  get_thread.join();
+
+  ASSERT_FALSE(svr.is_running());
+}
+
 class ServerTestWithAI_PASSIVE : public ::testing::Test {
 class ServerTestWithAI_PASSIVE : public ::testing::Test {
 protected:
 protected:
   ServerTestWithAI_PASSIVE()
   ServerTestWithAI_PASSIVE()