Browse Source

Thread pool support

yhirose 6 years ago
parent
commit
9785cd47f2
1 changed files with 104 additions and 14 deletions
  1. 104 14
      httplib.h

+ 104 - 14
httplib.h

@@ -61,12 +61,14 @@ typedef int socket_t;
 
 
 #include <assert.h>
 #include <assert.h>
 #include <atomic>
 #include <atomic>
+#include <condition_variable>
 #include <fcntl.h>
 #include <fcntl.h>
 #include <fstream>
 #include <fstream>
 #include <functional>
 #include <functional>
 #include <map>
 #include <map>
 #include <memory>
 #include <memory>
 #include <mutex>
 #include <mutex>
+#include <list>
 #include <random>
 #include <random>
 #include <regex>
 #include <regex>
 #include <string>
 #include <string>
@@ -101,6 +103,7 @@ inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) {
 #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192
 #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192
 #define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (std::numeric_limits<size_t>::max)()
 #define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (std::numeric_limits<size_t>::max)()
 #define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u)
 #define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u)
+#define CPPHTTPLIB_THREAD_POOL_COUNT 8
 
 
 namespace httplib {
 namespace httplib {
 
 
@@ -269,16 +272,101 @@ class TaskQueue {
 public:
 public:
   TaskQueue() {}
   TaskQueue() {}
   virtual ~TaskQueue() {}
   virtual ~TaskQueue() {}
-  virtual void enque(std::function<void(void)> fn) = 0;
+  virtual void enqueue(std::function<void(void)> fn) = 0;
   virtual void shutdown() = 0;
   virtual void shutdown() = 0;
 };
 };
 
 
-class ThreadsTaskQueue : public TaskQueue {
+#if CPPHTTPLIB_THREAD_POOL_COUNT > 0
+class ThreadPool : public TaskQueue {
 public:
 public:
-  ThreadsTaskQueue() : running_threads_(0) {}
-  virtual ~ThreadsTaskQueue() {}
+  ThreadPool(size_t n) : shutdown_(false), remaining_(0) {
+    while (n) {
+      auto t = std::make_shared<std::thread>(worker(*this));
+      threads_.push_back(t);
+      n--;
+    }
+  }
+
+  ThreadPool(const ThreadPool &) = delete;
+  virtual ~ThreadPool() {}
+
+  virtual void enqueue(std::function<void()> fn) override {
+    std::unique_lock<std::mutex> lock(mutex_);
+    jobs_.push_back(fn);
+    cond_.notify_one();
+  }
+
+  virtual void shutdown() override {
+    // Handle all remaining jobs...
+    for (;;) {
+      std::unique_lock<std::mutex> lock(mutex_);
+      if (jobs_.empty()) break;
+      cond_.notify_one();
+    }
+
+    // Stop all worker threads...
+    {
+      std::unique_lock<std::mutex> lock(mutex_);
+      shutdown_ = true;
+      remaining_ = threads_.size();
+    }
+
+    for (;;) {
+      std::unique_lock<std::mutex> lock(mutex_);
+      if (!remaining_) break;
+      cond_.notify_all();
+    }
+
+    // Join...
+    for (auto t : threads_) {
+      t->join();
+    }
+  }
+
+private:
+  struct worker {
+    worker(ThreadPool &pool) : pool_(pool) {}
+
+    void operator()() {
+      for (;;) {
+        std::unique_lock<std::mutex> lock(pool_.mutex_);
+
+        pool_.cond_.wait(
+            lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; });
 
 
-  virtual void enque(std::function<void(void)> fn) override {
+        if (pool_.shutdown_) { break; }
+
+        auto fn = pool_.jobs_.front();
+        pool_.jobs_.pop_front();
+
+        assert(true == (bool)fn);
+        fn();
+      }
+
+      std::unique_lock<std::mutex> lock(pool_.mutex_);
+      pool_.remaining_--;
+    }
+
+    ThreadPool &pool_;
+  };
+  friend struct worker;
+
+  std::vector<std::shared_ptr<std::thread>> threads_;
+  std::list<std::function<void()>> jobs_;
+
+  bool shutdown_;
+  size_t remaining_;
+
+  std::condition_variable cond_;
+  std::mutex mutex_;
+};
+#else
+class Threads : public TaskQueue {
+public:
+  Threads() : running_threads_(0) {}
+  virtual ~Threads() {}
+
+  virtual void enqueue(std::function<void(void)> fn) override {
     std::thread([=]() {
     std::thread([=]() {
       {
       {
         std::lock_guard<std::mutex> guard(running_threads_mutex_);
         std::lock_guard<std::mutex> guard(running_threads_mutex_);
@@ -306,6 +394,7 @@ private:
   std::mutex running_threads_mutex_;
   std::mutex running_threads_mutex_;
   int running_threads_;
   int running_threads_;
 };
 };
+#endif
 
 
 class Server {
 class Server {
 public:
 public:
@@ -342,7 +431,7 @@ public:
   bool is_running() const;
   bool is_running() const;
   void stop();
   void stop();
 
 
-  std::function<TaskQueue*(void)> new_task_queue;
+  std::function<TaskQueue *(void)> new_task_queue;
 
 
 protected:
 protected:
   bool process_request(Stream &strm, bool last_connection,
   bool process_request(Stream &strm, bool last_connection,
@@ -1934,6 +2023,13 @@ inline Server::Server()
 #ifndef _WIN32
 #ifndef _WIN32
   signal(SIGPIPE, SIG_IGN);
   signal(SIGPIPE, SIG_IGN);
 #endif
 #endif
+  new_task_queue = [] {
+#if CPPHTTPLIB_THREAD_POOL_COUNT > 0
+    return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT);
+#else
+    return new Threads();
+#endif
+  };
 }
 }
 
 
 inline Server::~Server() {}
 inline Server::~Server() {}
@@ -2235,13 +2331,7 @@ inline bool Server::listen_internal() {
   is_running_ = true;
   is_running_ = true;
 
 
   {
   {
-    std::unique_ptr<TaskQueue> task_queue;
-
-    if (new_task_queue) {
-      task_queue.reset(new_task_queue());
-    } else {
-      task_queue.reset(new ThreadsTaskQueue());
-    }
+    std::unique_ptr<TaskQueue> task_queue(new_task_queue());
 
 
     for (;;) {
     for (;;) {
       if (svr_sock_ == INVALID_SOCKET) {
       if (svr_sock_ == INVALID_SOCKET) {
@@ -2267,7 +2357,7 @@ inline bool Server::listen_internal() {
         break;
         break;
       }
       }
 
 
-      task_queue->enque([=]() { read_and_close_socket(sock); });
+      task_queue->enqueue([=]() { read_and_close_socket(sock); });
     }
     }
 
 
     task_queue->shutdown();
     task_queue->shutdown();