Browse Source

Add SSL object on Request

yhirose 6 years ago
parent
commit
1981e0ccad
2 changed files with 41 additions and 13 deletions
  1. 20 8
      httplib.h
  2. 21 5
      test/test.cc

+ 20 - 8
httplib.h

@@ -145,6 +145,10 @@ struct Request {
 
 
   Progress progress;
   Progress progress;
 
 
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+  const SSL *ssl;
+#endif
+
   bool has_header(const char *key) const;
   bool has_header(const char *key) const;
   std::string get_header_value(const char *key, size_t id = 0) const;
   std::string get_header_value(const char *key, size_t id = 0) const;
   size_t get_header_value_count(const char *key) const;
   size_t get_header_value_count(const char *key) const;
@@ -256,7 +260,8 @@ public:
 
 
 protected:
 protected:
   bool process_request(Stream &strm, bool last_connection,
   bool process_request(Stream &strm, bool last_connection,
-                       bool &connection_close);
+                       bool &connection_close,
+                       std::function<void(Request &)> setup_request = nullptr);
 
 
   size_t keep_alive_max_count_;
   size_t keep_alive_max_count_;
   size_t payload_max_length_;
   size_t payload_max_length_;
@@ -1828,8 +1833,10 @@ inline bool Server::dispatch_request(Request &req, Response &res,
   return false;
   return false;
 }
 }
 
 
-inline bool Server::process_request(Stream &strm, bool last_connection,
-                                    bool &connection_close) {
+inline bool
+Server::process_request(Stream &strm, bool last_connection,
+                        bool &connection_close,
+                        std::function<void(Request &)> setup_request) {
   const auto bufsiz = 2048;
   const auto bufsiz = 2048;
   char buf[bufsiz];
   char buf[bufsiz];
 
 
@@ -1899,6 +1906,9 @@ inline bool Server::process_request(Stream &strm, bool last_connection,
     }
     }
   }
   }
 
 
+  // TODO: Add additional request info
+  if (setup_request) { setup_request(req); }
+
   if (routing(req, res)) {
   if (routing(req, res)) {
     if (res.status == -1) { res.status = 200; }
     if (res.status == -1) { res.status = 200; }
   } else {
   } else {
@@ -2293,7 +2303,7 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
         auto last_connection = count == 1;
         auto last_connection = count == 1;
         auto connection_close = false;
         auto connection_close = false;
 
 
-        ret = callback(strm, last_connection, connection_close);
+        ret = callback(ssl, strm, last_connection, connection_close);
         if (!ret || connection_close) { break; }
         if (!ret || connection_close) { break; }
 
 
         count--;
         count--;
@@ -2301,7 +2311,7 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
     } else {
     } else {
       SSLSocketStream strm(sock, ssl);
       SSLSocketStream strm(sock, ssl);
       auto dummy_connection_close = false;
       auto dummy_connection_close = false;
-      ret = callback(strm, true, dummy_connection_close);
+      ret = callback(ssl, strm, true, dummy_connection_close);
     }
     }
   }
   }
 
 
@@ -2406,8 +2416,10 @@ inline bool SSLServer::read_and_close_socket(socket_t sock) {
   return detail::read_and_close_socket_ssl(
   return detail::read_and_close_socket_ssl(
       sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept,
       sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept,
       [](SSL * /*ssl*/) { return true; },
       [](SSL * /*ssl*/) { return true; },
-      [this](Stream &strm, bool last_connection, bool &connection_close) {
-        return process_request(strm, last_connection, connection_close);
+      [this](SSL *ssl, Stream &strm, bool last_connection,
+             bool &connection_close) {
+        return process_request(strm, last_connection, connection_close,
+                               [&](Request &req) { req.ssl = ssl; });
       });
       });
 }
 }
 
 
@@ -2494,7 +2506,7 @@ inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req,
                SSL_set_tlsext_host_name(ssl, host_.c_str());
                SSL_set_tlsext_host_name(ssl, host_.c_str());
                return true;
                return true;
              },
              },
-             [&](Stream &strm, bool /*last_connection*/,
+             [&](SSL * /*ssl*/, Stream &strm, bool /*last_connection*/,
                  bool &connection_close) {
                  bool &connection_close) {
                return process_request(strm, req, res, connection_close);
                return process_request(strm, req, res, connection_close);
              });
              });

+ 21 - 5
test/test.cc

@@ -1384,9 +1384,28 @@ TEST(SSLClientServerTest, ClientCertPresent) {
                 CLIENT_CA_CERT_DIR);
                 CLIENT_CA_CERT_DIR);
   ASSERT_TRUE(svr.is_valid());
   ASSERT_TRUE(svr.is_valid());
 
 
-  svr.Get("/test", [&](const Request &, Response &res) {
+  svr.Get("/test", [&](const Request &req, Response &res) {
     res.set_content("test", "text/plain");
     res.set_content("test", "text/plain");
     svr.stop();
     svr.stop();
+    ASSERT_TRUE(true);
+
+    auto peer_cert = SSL_get_peer_certificate(req.ssl);
+    ASSERT_TRUE(peer_cert != nullptr);
+
+    auto subject_name = X509_get_subject_name(peer_cert);
+    ASSERT_TRUE(subject_name != nullptr);
+
+    std::string common_name;
+    {
+      char name[BUFSIZ];
+      auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName,
+                                                name, sizeof(name));
+      common_name.assign(name, name_len);
+    }
+
+    EXPECT_EQ("Common Name", common_name);
+
+    X509_free(peer_cert);
   });
   });
 
 
   thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
   thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
@@ -1405,10 +1424,7 @@ TEST(SSLClientServerTest, ClientCertMissing) {
                 CLIENT_CA_CERT_DIR);
                 CLIENT_CA_CERT_DIR);
   ASSERT_TRUE(svr.is_valid());
   ASSERT_TRUE(svr.is_valid());
 
 
-  svr.Get("/test", [&](const Request &, Response &res) {
-    res.set_content("test", "text/plain");
-    svr.stop();
-  });
+  svr.Get("/test", [&](const Request &, Response &) { ASSERT_TRUE(false); });
 
 
   thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
   thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });