yhirose 1 year ago
parent
commit
9c91b6f4a6
2 changed files with 64 additions and 9 deletions
  1. 18 9
      httplib.h
  2. 46 0
      test/test.cc

+ 18 - 9
httplib.h

@@ -932,6 +932,7 @@ public:
   bool is_running() const;
   bool is_running() const;
   void wait_until_ready() const;
   void wait_until_ready() const;
   void stop();
   void stop();
+  void decommission();
 
 
   std::function<TaskQueue *(void)> new_task_queue;
   std::function<TaskQueue *(void)> new_task_queue;
 
 
@@ -1006,7 +1007,7 @@ private:
   virtual bool process_and_close_socket(socket_t sock);
   virtual bool process_and_close_socket(socket_t sock);
 
 
   std::atomic<bool> is_running_{false};
   std::atomic<bool> is_running_{false};
-  std::atomic<bool> done_{false};
+  std::atomic<bool> is_decommisioned{false};
 
 
   struct MountPointEntry {
   struct MountPointEntry {
     std::string mount_point;
     std::string mount_point;
@@ -6111,27 +6112,27 @@ inline Server &Server::set_payload_max_length(size_t length) {
 
 
 inline bool Server::bind_to_port(const std::string &host, int port,
 inline bool Server::bind_to_port(const std::string &host, int port,
                                  int socket_flags) {
                                  int socket_flags) {
-  return bind_internal(host, port, socket_flags) >= 0;
+  auto ret = bind_internal(host, port, socket_flags);
+  if (ret == -1) { is_decommisioned = true; }
+  return ret >= 0;
 }
 }
 inline int Server::bind_to_any_port(const std::string &host, int socket_flags) {
 inline int Server::bind_to_any_port(const std::string &host, int socket_flags) {
-  return bind_internal(host, 0, socket_flags);
+  auto ret = bind_internal(host, 0, socket_flags);
+  if (ret == -1) { is_decommisioned = true; }
+  return ret;
 }
 }
 
 
-inline bool Server::listen_after_bind() {
-  auto se = detail::scope_exit([&]() { done_ = true; });
-  return listen_internal();
-}
+inline bool Server::listen_after_bind() { return listen_internal(); }
 
 
 inline bool Server::listen(const std::string &host, int port,
 inline bool Server::listen(const std::string &host, int port,
                            int socket_flags) {
                            int socket_flags) {
-  auto se = detail::scope_exit([&]() { done_ = true; });
   return bind_to_port(host, port, socket_flags) && listen_internal();
   return bind_to_port(host, port, socket_flags) && listen_internal();
 }
 }
 
 
 inline bool Server::is_running() const { return is_running_; }
 inline bool Server::is_running() const { return is_running_; }
 
 
 inline void Server::wait_until_ready() const {
 inline void Server::wait_until_ready() const {
-  while (!is_running() && !done_) {
+  while (!is_running_ && !is_decommisioned) {
     std::this_thread::sleep_for(std::chrono::milliseconds{1});
     std::this_thread::sleep_for(std::chrono::milliseconds{1});
   }
   }
 }
 }
@@ -6143,8 +6144,11 @@ inline void Server::stop() {
     detail::shutdown_socket(sock);
     detail::shutdown_socket(sock);
     detail::close_socket(sock);
     detail::close_socket(sock);
   }
   }
+  is_decommisioned = false;
 }
 }
 
 
+inline void Server::decommission() { is_decommisioned = true; }
+
 inline bool Server::parse_request_line(const char *s, Request &req) const {
 inline bool Server::parse_request_line(const char *s, Request &req) const {
   auto len = strlen(s);
   auto len = strlen(s);
   if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; }
   if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; }
@@ -6499,6 +6503,8 @@ Server::create_server_socket(const std::string &host, int port,
 
 
 inline int Server::bind_internal(const std::string &host, int port,
 inline int Server::bind_internal(const std::string &host, int port,
                                  int socket_flags) {
                                  int socket_flags) {
+  if (is_decommisioned) { return -1; }
+
   if (!is_valid()) { return -1; }
   if (!is_valid()) { return -1; }
 
 
   svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_);
   svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_);
@@ -6524,6 +6530,8 @@ inline int Server::bind_internal(const std::string &host, int port,
 }
 }
 
 
 inline bool Server::listen_internal() {
 inline bool Server::listen_internal() {
+  if (is_decommisioned) { return false; }
+
   auto ret = true;
   auto ret = true;
   is_running_ = true;
   is_running_ = true;
   auto se = detail::scope_exit([&]() { is_running_ = false; });
   auto se = detail::scope_exit([&]() { is_running_ = false; });
@@ -6613,6 +6621,7 @@ inline bool Server::listen_internal() {
     task_queue->shutdown();
     task_queue->shutdown();
   }
   }
 
 
+  is_decommisioned = !ret;
   return ret;
   return ret;
 }
 }
 
 

+ 46 - 0
test/test.cc

@@ -4926,6 +4926,52 @@ TEST(ServerStopTest, ListenFailure) {
   t.join();
   t.join();
 }
 }
 
 
+TEST(ServerStopTest, Decommision) {
+  Server svr;
+
+  svr.Get("/hi", [&](const Request &, Response &res) { res.body = "hi..."; });
+
+  for (int i = 0; i < 4; i++) {
+    auto is_even = !(i % 2);
+
+    std::thread t{[&] {
+      try {
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+        if (is_even) {
+          throw std::runtime_error("Some thing that happens to go wrong.");
+        }
+
+        svr.listen(HOST, PORT);
+      } catch (...) { svr.decommission(); }
+    }};
+
+    svr.wait_until_ready();
+
+    // Server is up
+    {
+      Client cli(HOST, PORT);
+      auto res = cli.Get("/hi");
+      if (is_even) {
+        EXPECT_FALSE(res);
+      } else {
+        EXPECT_TRUE(res);
+        EXPECT_EQ("hi...", res->body);
+      }
+    }
+
+    svr.stop();
+    t.join();
+
+    // Server is down...
+    {
+      Client cli(HOST, PORT);
+      auto res = cli.Get("/hi");
+      EXPECT_FALSE(res);
+    }
+  }
+}
+
 TEST(StreamingTest, NoContentLengthStreaming) {
 TEST(StreamingTest, NoContentLengthStreaming) {
   Server svr;
   Server svr;