Browse Source

Added middleware support (#816)

yhirose 5 years ago
parent
commit
f008fe4539
2 changed files with 105 additions and 9 deletions
  1. 34 9
      httplib.h
  2. 71 0
      test/test.cc

+ 34 - 9
httplib.h

@@ -597,6 +597,7 @@ inline void default_socket_options(socket_t sock) {
 class Server {
 class Server {
 public:
 public:
   using Handler = std::function<void(const Request &, Response &)>;
   using Handler = std::function<void(const Request &, Response &)>;
+  using HandlerWithReturn = std::function<bool(const Request &, Response &)>;
   using HandlerWithContentReader = std::function<void(
   using HandlerWithContentReader = std::function<void(
       const Request &, Response &, const ContentReader &content_reader)>;
       const Request &, Response &, const ContentReader &content_reader)>;
   using Expect100ContinueHandler =
   using Expect100ContinueHandler =
@@ -627,7 +628,11 @@ public:
                                                const char *mime);
                                                const char *mime);
   void set_file_request_handler(Handler handler);
   void set_file_request_handler(Handler handler);
 
 
+  void set_error_handler(HandlerWithReturn handler);
   void set_error_handler(Handler handler);
   void set_error_handler(Handler handler);
+  void set_pre_routing_handler(HandlerWithReturn handler);
+  void set_post_routing_handler(Handler handler);
+
   void set_expect_100_continue_handler(Expect100ContinueHandler handler);
   void set_expect_100_continue_handler(Expect100ContinueHandler handler);
   void set_logger(Logger logger);
   void set_logger(Logger logger);
 
 
@@ -734,7 +739,9 @@ private:
   Handlers delete_handlers_;
   Handlers delete_handlers_;
   HandlersForContentReader delete_handlers_for_content_reader_;
   HandlersForContentReader delete_handlers_for_content_reader_;
   Handlers options_handlers_;
   Handlers options_handlers_;
-  Handler error_handler_;
+  HandlerWithReturn error_handler_;
+  HandlerWithReturn pre_routing_handler_;
+  Handler post_routing_handler_;
   Logger logger_;
   Logger logger_;
   Expect100ContinueHandler expect_100_continue_handler_;
   Expect100ContinueHandler expect_100_continue_handler_;
 
 
@@ -4160,14 +4167,23 @@ inline void Server::set_file_request_handler(Handler handler) {
   file_request_handler_ = std::move(handler);
   file_request_handler_ = std::move(handler);
 }
 }
 
 
-inline void Server::set_error_handler(Handler handler) {
+inline void Server::set_error_handler(HandlerWithReturn handler) {
   error_handler_ = std::move(handler);
   error_handler_ = std::move(handler);
 }
 }
 
 
-inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
+inline void Server::set_error_handler(Handler handler) {
+  error_handler_ = [handler](const Request &req, Response &res) {
+    handler(req, res);
+    return true;
+  };
+}
 
 
-inline void Server::set_socket_options(SocketOptions socket_options) {
-  socket_options_ = std::move(socket_options);
+inline void Server::set_pre_routing_handler(HandlerWithReturn handler) {
+  pre_routing_handler_ = std::move(handler);
+}
+
+inline void Server::set_post_routing_handler(Handler handler) {
+  post_routing_handler_ = std::move(handler);
 }
 }
 
 
 inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); }
 inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); }
@@ -4177,6 +4193,12 @@ Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) {
   expect_100_continue_handler_ = std::move(handler);
   expect_100_continue_handler_ = std::move(handler);
 }
 }
 
 
+inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
+
+inline void Server::set_socket_options(SocketOptions socket_options) {
+  socket_options_ = std::move(socket_options);
+}
+
 inline void Server::set_keep_alive_max_count(size_t count) {
 inline void Server::set_keep_alive_max_count(size_t count) {
   keep_alive_max_count_ = count;
   keep_alive_max_count_ = count;
 }
 }
@@ -4268,8 +4290,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
                                         bool need_apply_ranges) {
                                         bool need_apply_ranges) {
   assert(res.status != -1);
   assert(res.status != -1);
 
 
-  if (400 <= res.status && error_handler_) {
-    error_handler_(req, res);
+  if (400 <= res.status && error_handler_ && error_handler_(req, res)) {
     need_apply_ranges = true;
     need_apply_ranges = true;
   }
   }
 
 
@@ -4277,7 +4298,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
   std::string boundary;
   std::string boundary;
   if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
   if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
 
 
-  // Preapre additional headers
+  // Prepare additional headers
   if (close_connection || req.get_header_value("Connection") == "close") {
   if (close_connection || req.get_header_value("Connection") == "close") {
     res.set_header("Connection", "close");
     res.set_header("Connection", "close");
   } else {
   } else {
@@ -4301,6 +4322,8 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
     res.set_header("Accept-Ranges", "bytes");
     res.set_header("Accept-Ranges", "bytes");
   }
   }
 
 
+  if (post_routing_handler_) { post_routing_handler_(req, res); }
+
   // Response line and headers
   // Response line and headers
   {
   {
     detail::BufferStream bstrm;
     detail::BufferStream bstrm;
@@ -4604,6 +4627,8 @@ inline bool Server::listen_internal() {
 }
 }
 
 
 inline bool Server::routing(Request &req, Response &res, Stream &strm) {
 inline bool Server::routing(Request &req, Response &res, Stream &strm) {
+  if (pre_routing_handler_ && pre_routing_handler_(req, res)) { return true; }
+
   // File handler
   // File handler
   bool is_head_request = req.method == "HEAD";
   bool is_head_request = req.method == "HEAD";
   if ((req.method == "GET" || is_head_request) &&
   if ((req.method == "GET" || is_head_request) &&
@@ -5302,7 +5327,7 @@ inline bool ClientImpl::write_content_with_provider(Stream &strm,
 
 
 inline bool ClientImpl::write_request(Stream &strm, const Request &req,
 inline bool ClientImpl::write_request(Stream &strm, const Request &req,
                                       bool close_connection, Error &error) {
                                       bool close_connection, Error &error) {
-  // Prepare additonal headers
+  // Prepare additional headers
   Headers headers;
   Headers headers;
   if (close_connection) { headers.emplace("Connection", "close"); }
   if (close_connection) { headers.emplace("Connection", "close"); }
 
 

+ 71 - 0
test/test.cc

@@ -953,6 +953,77 @@ TEST(ErrorHandlerTest, ContentLength) {
   ASSERT_FALSE(svr.is_running());
   ASSERT_FALSE(svr.is_running());
 }
 }
 
 
+TEST(RoutingHandlerTest, PreRoutingHandler) {
+  Server svr;
+
+  svr.set_pre_routing_handler([](const Request &req, Response &res) {
+    if (req.path == "/routing_handler") {
+      res.set_header("PRE_ROUTING", "on");
+      res.set_content("Routing Handler", "text/plain");
+      return true;
+    }
+    return false;
+  });
+
+  svr.set_error_handler([](const Request & /*req*/, Response &res) {
+    res.set_content("Error", "text/html");
+  });
+
+  svr.set_post_routing_handler([](const Request &req, Response &res) {
+    if (req.path == "/routing_handler") {
+      res.set_header("POST_ROUTING", "on");
+    }
+  });
+
+  svr.Get("/hi", [](const Request & /*req*/, Response &res) {
+    res.set_content("Hello World!\n", "text/plain");
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+
+  // Give GET time to get a few messages.
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+
+  {
+    Client cli(HOST, PORT);
+
+    auto res = cli.Get("/routing_handler");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(200, res->status);
+    EXPECT_EQ("Routing Handler", res->body);
+    EXPECT_EQ(1, res->get_header_value_count("PRE_ROUTING"));
+    EXPECT_EQ("on", res->get_header_value("PRE_ROUTING"));
+    EXPECT_EQ(1, res->get_header_value_count("POST_ROUTING"));
+    EXPECT_EQ("on", res->get_header_value("POST_ROUTING"));
+  }
+
+  {
+    Client cli(HOST, PORT);
+
+    auto res = cli.Get("/hi");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(200, res->status);
+    EXPECT_EQ("Hello World!\n", res->body);
+    EXPECT_EQ(0, res->get_header_value_count("PRE_ROUTING"));
+    EXPECT_EQ(0, res->get_header_value_count("POST_ROUTING"));
+  }
+
+  {
+    Client cli(HOST, PORT);
+
+    auto res = cli.Get("/aaa");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(404, res->status);
+    EXPECT_EQ("Error", res->body);
+    EXPECT_EQ(0, res->get_header_value_count("PRE_ROUTING"));
+    EXPECT_EQ(0, res->get_header_value_count("POST_ROUTING"));
+  }
+
+  svr.stop();
+  thread.join();
+  ASSERT_FALSE(svr.is_running());
+}
+
 TEST(InvalidFormatTest, StatusCode) {
 TEST(InvalidFormatTest, StatusCode) {
   Server svr;
   Server svr;