Browse Source

Add exception handler (#845)

* Add exception handler

* revert content reader changes

* Add test for and fix exception handler

* Fix warning in test

* Readd exception test, improve readme note, don't rethrow errors, remove exception handler response
Nikolas 4 years ago
parent
commit
0542fdb8e4
3 changed files with 90 additions and 23 deletions
  1. 13 0
      README.md
  2. 41 23
      httplib.h
  3. 36 0
      test/test.cc

+ 13 - 0
README.md

@@ -177,6 +177,19 @@ svr.set_error_handler([](const auto& req, auto& res) {
 });
 ```
 
+### Exception handler
+The exception handler gets called if a user routing handler throws an error.
+
+```cpp
+svr.set_exception_handler([](const auto& req, auto& res, std::exception &e) {
+  res.status = 500;
+  auto fmt = "<h1>Error 500</h1><p>%s</p>";
+  char buf[BUFSIZ];
+  snprintf(buf, sizeof(buf), fmt, e.what());
+  res.set_content(buf, "text/html");
+});
+```
+
 ### Pre routing handler
 
 ```cpp

+ 41 - 23
httplib.h

@@ -598,6 +598,9 @@ class Server {
 public:
   using Handler = std::function<void(const Request &, Response &)>;
 
+  using ExceptionHandler =
+      std::function<void(const Request &, Response &, std::exception &e)>;
+
   enum class HandlerResponse {
     Handled,
     Unhandled,
@@ -652,6 +655,7 @@ public:
 
   Server &set_error_handler(HandlerWithResponse handler);
   Server &set_error_handler(Handler handler);
+  Server &set_exception_handler(ExceptionHandler handler);
   Server &set_pre_routing_handler(HandlerWithResponse handler);
   Server &set_post_routing_handler(Handler handler);
 
@@ -762,6 +766,7 @@ private:
   HandlersForContentReader delete_handlers_for_content_reader_;
   Handlers options_handlers_;
   HandlerWithResponse error_handler_;
+  ExceptionHandler exception_handler_;
   HandlerWithResponse pre_routing_handler_;
   Handler post_routing_handler_;
   Logger logger_;
@@ -4281,6 +4286,11 @@ inline Server &Server::set_error_handler(Handler handler) {
   return *this;
 }
 
+inline Server &Server::set_exception_handler(ExceptionHandler handler) {
+  exception_handler_ = std::move(handler);
+  return *this;
+}
+
 inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) {
   pre_routing_handler_ = std::move(handler);
   return *this;
@@ -4785,26 +4795,26 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm) {
 
       if (req.method == "POST") {
         if (dispatch_request_for_content_reader(
-                req, res, std::move(reader),
-                post_handlers_for_content_reader_)) {
+            req, res, std::move(reader),
+            post_handlers_for_content_reader_)) {
           return true;
         }
       } else if (req.method == "PUT") {
         if (dispatch_request_for_content_reader(
-                req, res, std::move(reader),
-                put_handlers_for_content_reader_)) {
+            req, res, std::move(reader),
+            put_handlers_for_content_reader_)) {
           return true;
         }
       } else if (req.method == "PATCH") {
         if (dispatch_request_for_content_reader(
-                req, res, std::move(reader),
-                patch_handlers_for_content_reader_)) {
+            req, res, std::move(reader),
+            patch_handlers_for_content_reader_)) {
           return true;
         }
       } else if (req.method == "DELETE") {
         if (dispatch_request_for_content_reader(
-                req, res, std::move(reader),
-                delete_handlers_for_content_reader_)) {
+            req, res, std::move(reader),
+            delete_handlers_for_content_reader_)) {
           return true;
         }
       }
@@ -4835,22 +4845,14 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm) {
 
 inline bool Server::dispatch_request(Request &req, Response &res,
                                      const Handlers &handlers) {
-  try {
-    for (const auto &x : handlers) {
-      const auto &pattern = x.first;
-      const auto &handler = x.second;
+  for (const auto &x : handlers) {
+    const auto &pattern = x.first;
+    const auto &handler = x.second;
 
-      if (std::regex_match(req.path, req.matches, pattern)) {
-        handler(req, res);
-        return true;
-      }
+    if (std::regex_match(req.path, req.matches, pattern)) {
+      handler(req, res);
+      return true;
     }
-  } catch (const std::exception &ex) {
-    res.status = 500;
-    res.set_header("EXCEPTION_WHAT", ex.what());
-  } catch (...) {
-    res.status = 500;
-    res.set_header("EXCEPTION_WHAT", "UNKNOWN");
   }
   return false;
 }
@@ -5064,7 +5066,23 @@ Server::process_request(Stream &strm, bool close_connection,
   }
 
   // Rounting
-  if (routing(req, res, strm)) {
+  bool routed = false;
+  try {
+    routed = routing(req, res, strm);
+  } catch (std::exception & e) {
+    if (exception_handler_) {
+      exception_handler_(req, res, e);
+      routed = true;
+    } else {
+      res.status = 500;
+      res.set_header("EXCEPTION_WHAT", e.what());
+    }
+  } catch (...) {
+    res.status = 500;
+    res.set_header("EXCEPTION_WHAT", "UNKNOWN");
+  }
+
+  if (routed) {
     if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; }
     return write_response_with_content(strm, close_connection, req, res);
   } else {

+ 36 - 0
test/test.cc

@@ -6,6 +6,7 @@
 #include <chrono>
 #include <future>
 #include <thread>
+#include <stdexcept>
 
 #define SERVER_CERT_FILE "./cert.pem"
 #define SERVER_CERT2_FILE "./cert2.pem"
@@ -978,6 +979,41 @@ TEST(ErrorHandlerTest, ContentLength) {
   ASSERT_FALSE(svr.is_running());
 }
 
+TEST(ExceptionHandlerTest, ContentLength) {
+  Server svr;
+
+  svr.set_exception_handler([](const Request & /*req*/, Response &res, std::exception & /*e*/) {
+    res.status = 500;
+    res.set_content("abcdefghijklmnopqrstuvwxyz",
+                    "text/html"); // <= Content-Length still 13
+  });
+
+  svr.Get("/hi", [](const Request & /*req*/, Response &res) {
+    res.set_content("Hello World!\n", "text/plain");
+    throw std::runtime_error("abc");
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+
+  // Give GET time to get a few messages.
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+
+  {
+    Client cli(HOST, PORT);
+
+    auto res = cli.Get("/hi");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(500, res->status);
+    EXPECT_EQ("text/html", res->get_header_value("Content-Type"));
+    EXPECT_EQ("26", res->get_header_value("Content-Length"));
+    EXPECT_EQ("abcdefghijklmnopqrstuvwxyz", res->body);
+  }
+
+  svr.stop();
+  thread.join();
+  ASSERT_FALSE(svr.is_running());
+}
+
 TEST(NoContentTest, ContentLength) {
   Server svr;