Browse Source

Refactoring.

yhirose 13 years ago
parent
commit
ffde8b7e4b
3 changed files with 84 additions and 54 deletions
  1. 1 2
      example/server.cc
  2. 63 41
      httplib.h
  3. 20 11
      test/test.cc

+ 1 - 2
example/server.cc

@@ -89,8 +89,7 @@ int main(void)
     svr.set_error_handler([](httplib::Connection& c) {
         char buf[BUFSIZ];
         snprintf(buf, sizeof(buf), "<p>Error Status: <span style='color:red;'>%d</span></p>", c.response.status);
-        c.response.body = buf;
-        c.response.set_header("Content-Type", "text/html");
+        c.response.set_content(buf, "text/html");
     });
 
     svr.set_logger([](const Connection& c) {

+ 63 - 41
httplib.h

@@ -89,25 +89,27 @@ public:
     void post(const char* pattern, Handler handler);
 
     void set_error_handler(Handler handler);
-    void set_logger(std::function<void (const Connection&)> logger);
+    void set_logger(Handler logger);
 
     bool run();
     void stop();
 
 private:
-    void process_request(FILE* fp_read, FILE* fp_write);
+    typedef std::vector<std::pair<std::regex, Handler>> Handlers;
 
+    void process_request(FILE* fp_read, FILE* fp_write);
     bool read_request_line(FILE* fp, Request& req);
     void write_response(FILE* fp, const Response& res);
+    void dispatch_request(Connection& c, Handlers& handlers);
 
     const std::string host_;
     const int         port_;
     socket_t          sock_;
 
-    std::vector<std::pair<std::regex, Handler>>  get_handlers_;
-    std::vector<std::pair<std::string, Handler>> post_handlers_;
-    Handler                                      error_handler_;
-    std::function<void (const Connection&)>      logger_;
+    Handlers get_handlers_;
+    Handlers post_handlers_;
+    Handler  error_handler_;
+    Handler  logger_;
 };
 
 class Client {
@@ -279,14 +281,20 @@ inline int get_header_value_int(const MultiMap& map, const char* key, int def)
     return def;
 }
 
-inline void read_headers(FILE* fp, MultiMap& headers)
+inline bool read_headers(FILE* fp, MultiMap& headers)
 {
     static std::regex re("(.+?): (.+?)\r\n");
 
     const size_t BUFSIZ_HEADER = 2048;
     char buf[BUFSIZ_HEADER];
 
-    while (fgets(buf, BUFSIZ_HEADER, fp) && strcmp(buf, "\r\n")) {
+    for (;;) {
+        if (!fgets(buf, BUFSIZ_HEADER, fp)) {
+            return false;
+        }
+        if (!strcmp(buf, "\r\n")) {
+            break;
+        }
         std::cmatch m;
         if (std::regex_match(buf, m, re)) {
             auto key = std::string(m[1]);
@@ -294,6 +302,21 @@ inline void read_headers(FILE* fp, MultiMap& headers)
             headers.insert(std::make_pair(key, val));
         }
     }
+
+    return true;
+}
+
+template <typename T>
+bool read_content(T& x, FILE* fp)
+{
+    auto len = get_header_value_int(x.headers, "Content-Length", 0);
+    if (len) {
+        x.body.assign(len, 0);
+        if (!fgets(&x.body[0], x.body.size() + 1, fp)) {
+            return false;
+        }
+    }
+    return true;
 }
 
 // HTTP server implementation
@@ -332,7 +355,6 @@ inline void Response::set_content(const std::string& s, const char* content_type
 {
     body = s;
     set_header("Content-Type", content_type);
-    status = 200;
 }
 
 inline Server::Server(const char* host, int port)
@@ -368,7 +390,7 @@ inline void Server::set_error_handler(Handler handler)
     error_handler_ = handler;
 }
 
-inline void Server::set_logger(std::function<void (const Connection&)> logger)
+inline void Server::set_logger(Handler logger)
 {
     logger_ = logger;
 }
@@ -474,48 +496,57 @@ inline void Server::write_response(FILE* fp, const Response& res)
     }
 }
 
+inline void Server::dispatch_request(Connection& c, Handlers& handlers)
+{
+    for (auto it = handlers.begin(); it != handlers.end(); ++it) {
+        const auto& pattern = it->first;
+        const auto& handler = it->second;
+
+        if (std::regex_match(c.request.url, c.request.match, pattern)) {
+            handler(c);
+
+            if (!c.response.status) {
+                c.response.status = 200;
+            }
+            break;
+        }
+    }
+}
+
 inline void Server::process_request(FILE* fp_read, FILE* fp_write)
 {
     Connection c;
 
-    if (!read_request_line(fp_read, c.request)) {
+    if (!read_request_line(fp_read, c.request) ||
+        !read_headers(fp_read, c.request.headers)) {
         return;
     }
-
-    read_headers(fp_read, c.request.headers);
     
     // Routing
     c.response.status = 0;
 
     if (c.request.method == "GET") {
-        for (auto it = get_handlers_.begin(); it != get_handlers_.end(); ++it) {
-            const auto& pattern = it->first;
-            const auto& handler = it->second;
-            
-            if (std::regex_match(c.request.url, c.request.match, pattern)) {
-                handler(c);
-                break;
-            }
-        }
+        dispatch_request(c, get_handlers_);
     } else if (c.request.method == "POST") {
-        // TODO: parse body
+        if (!read_content(c.request, fp_read)) {
+            return;
+        }
+        dispatch_request(c, post_handlers_);
     }
 
     if (!c.response.status) {
         c.response.status = 404;
     }
 
-    if (400 <= c.response.status) {
-        if (error_handler_) {
-            error_handler_(c);
-        }
+    if (400 <= c.response.status && error_handler_) {
+        error_handler_(c);
     }
 
+    write_response(fp_write, c.response);
+
     if (logger_) {
         logger_(c);
     }
-
-    write_response(fp_write, c.response);
 }
 
 // HTTP client implementation
@@ -569,21 +600,12 @@ inline bool Client::get(const char* url, Response& res)
     fprintf(fp_write, "GET %s HTTP/1.0\r\n\r\n", url);
     fflush(fp_write);
 
-    if (!read_response_line(fp_read, res)) {
+    if (!read_response_line(fp_read, res) ||
+        !read_headers(fp_read, res.headers) ||
+        !read_content(res, fp_read)) {
         return false;
     }
 
-    read_headers(fp_read, res.headers);
-
-    // Read content body
-    auto len = get_header_value_int(res.headers, "Content-Length", 0);
-    if (len) {
-        res.body.assign(len, 0);
-        if (!fgets(&res.body[0], res.body.size() + 1, fp_read)) {
-            return false;
-        }
-    }
-
     close_client_socket(sock);
 
     return true;

+ 20 - 11
test/test.cc

@@ -72,8 +72,11 @@ protected:
     }
 
     virtual void SetUp() {
-        svr_.get(url_, [&](httplib::Connection& c) {
-            c.response.set_content(content_, mime_);
+        svr_.get("/hi", [&](httplib::Connection& c) {
+            c.response.set_content("Hello World!", "text/plain");
+        });
+        svr_.get("/", [&](httplib::Connection& c) {
+            c.response.set_redirect("/hi");
         });
         f_ = async([&](){ svr_.run(); });
     }
@@ -83,12 +86,8 @@ protected:
         f_.get();
     }
 
-    const char* host_ = "localhost";
-    int         port_ = 1914;
-    const char* url_ = "/hi";
-    const char* content_ = "Hello World!";
-    const char* mime_ = "text/plain";
-
+    const char*       host_ = "localhost";
+    int               port_ = 1914;
     Server            svr_;
     std::future<void> f_;
 };
@@ -96,17 +95,27 @@ protected:
 TEST_F(ServerTest, GetMethod200)
 {
     Response res;
-    bool ret = Client(host_, port_).get(url_, res);
+    bool ret = Client(host_, port_).get("/hi", res);
     ASSERT_EQ(true, ret);
     ASSERT_EQ(200, res.status);
-    ASSERT_EQ(content_, res.body);
+    ASSERT_EQ("text/plain", res.get_header_value("Content-Type"));
+    ASSERT_EQ("Hello World!", res.body);
+}
+
+TEST_F(ServerTest, GetMethod302)
+{
+    Response res;
+    bool ret = Client(host_, port_).get("/", res);
+    ASSERT_EQ(true, ret);
+    ASSERT_EQ(302, res.status);
+    ASSERT_EQ("/hi", res.get_header_value("Location"));
 }
 
 TEST_F(ServerTest, GetMethod404)
 {
     Response res;
     bool ret = Client(host_, port_).get("/invalid", res);
-    ASSERT_EQ(false, ret);
+    ASSERT_EQ(true, ret);
     ASSERT_EQ(404, res.status);
 }