Browse Source

Merge pull request #162 from PixlRainbow/master

client certificates support
yhirose 6 years ago
parent
commit
92f08b54c4
4 changed files with 137 additions and 7 deletions
  1. 45 6
      httplib.h
  2. 6 1
      test/Makefile
  3. 68 0
      test/test.cc
  4. 18 0
      test/test.rootCA.conf

+ 45 - 6
httplib.h

@@ -376,7 +376,7 @@ private:
 
 class SSLServer : public Server {
 public:
-  SSLServer(const char *cert_path, const char *private_key_path);
+  SSLServer(const char *cert_path, const char *private_key_path, const char *client_CA_cert_path, const char *trusted_cert_path);
 
   virtual ~SSLServer();
 
@@ -387,11 +387,14 @@ private:
 
   SSL_CTX *ctx_;
   std::mutex ctx_mutex_;
+  const char *client_CA_cert_path_;
+  const char *trusted_cert_path_;
 };
 
 class SSLClient : public Client {
 public:
-  SSLClient(const char *host, int port = 443, time_t timeout_sec = 300);
+  SSLClient(const char *host, int port = 443, time_t timeout_sec = 300,
+            const char *client_cert_path = nullptr, const char *client_key_path = nullptr);
 
   virtual ~SSLClient();
 
@@ -2234,7 +2237,9 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
                           // TODO: OpenSSL 1.0.2 occasionally crashes...
                           // The upcoming 1.1.0 is going to be thread safe.
                           SSL_CTX *ctx, std::mutex &ctx_mutex,
-                          U SSL_connect_or_accept, V setup, T callback) {
+                          U SSL_connect_or_accept, V setup, T callback,
+                          const char* client_CA_cert_path = nullptr,
+                          const char* trusted_cert_path = nullptr) {
   SSL *ssl = nullptr;
   {
     std::lock_guard<std::mutex> guard(ctx_mutex);
@@ -2260,9 +2265,24 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
     return false;
   }
 
+  if(client_CA_cert_path){
+    STACK_OF(X509_NAME)* list;
+    //list of client CAs to request from client
+    list = SSL_load_client_CA_file(client_CA_cert_path);
+    SSL_set_client_CA_list(ssl, list);
+    //certificate chain to verify received client certificate against
+    //please run c_rehash in the cert folder first
+    SSL_CTX_load_verify_locations(ctx,client_CA_cert_path,trusted_cert_path);
+  }
+
   bool ret = false;
 
   if (SSL_connect_or_accept(ssl) == 1) {
+    /*
+    auto client_cert = SSL_get_peer_certificate(ssl);
+    if(client_cert)
+      printf("Connected client: %s\n", client_cert->name);
+    */
     if (keep_alive_max_count > 0) {
       auto count = keep_alive_max_count;
       while (count > 0 &&
@@ -2338,7 +2358,11 @@ inline std::string SSLSocketStream::get_remote_addr() const {
 
 // SSL HTTP server implementation
 inline SSLServer::SSLServer(const char *cert_path,
-                            const char *private_key_path) {
+                            const char *private_key_path,
+                            const char *client_CA_cert_path = nullptr,
+                            const char *trusted_cert_path = nullptr)
+  : client_CA_cert_path_(client_CA_cert_path),
+    trusted_cert_path_(trusted_cert_path){
   ctx_ = SSL_CTX_new(SSLv23_server_method());
 
   if (ctx_) {
@@ -2356,6 +2380,11 @@ inline SSLServer::SSLServer(const char *cert_path,
             1) {
       SSL_CTX_free(ctx_);
       ctx_ = nullptr;
+    } else if(client_CA_cert_path_) {
+      SSL_CTX_set_verify(ctx_,
+        SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, //SSL_VERIFY_CLIENT_ONCE,
+        nullptr
+      );
     }
   }
 }
@@ -2372,11 +2401,14 @@ inline bool SSLServer::read_and_close_socket(socket_t sock) {
       [](SSL * /*ssl*/) { return true; },
       [this](Stream &strm, bool last_connection, bool &connection_close) {
         return process_request(strm, last_connection, connection_close);
-      });
+      },
+      client_CA_cert_path_,
+      trusted_cert_path_);
 }
 
 // SSL HTTP client implementation
-inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec)
+inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec,
+                            const char *client_cert_path, const char *client_key_path)
     : Client(host, port, timeout_sec) {
   ctx_ = SSL_CTX_new(SSLv23_client_method());
 
@@ -2384,6 +2416,13 @@ inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec)
                 [&](const char *b, const char *e) {
                   host_components_.emplace_back(std::string(b, e));
                 });
+  if(client_cert_path && client_key_path) {
+    if (SSL_CTX_use_certificate_file(ctx_, client_cert_path, SSL_FILETYPE_PEM) != 1
+      ||SSL_CTX_use_PrivateKey_file(ctx_, client_key_path, SSL_FILETYPE_PEM) != 1) {
+      SSL_CTX_free(ctx_);
+      ctx_ = nullptr;
+    }
+  }
 }
 
 inline SSLClient::~SSLClient() {

+ 6 - 1
test/Makefile

@@ -15,6 +15,11 @@ test : test.cc ../httplib.h Makefile cert.pem
 cert.pem:
 	openssl genrsa 2048 > key.pem
 	openssl req -new -batch -config test.conf -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem
+	openssl genrsa 2048 > rootCA.key.pem
+	openssl req -x509 -new -batch -config test.rootCA.conf -key rootCA.key.pem -days 1024 > rootCA.cert.pem
+	openssl genrsa 2048 > client.key.pem
+	openssl req -new -batch -config test.conf -key client.key.pem | openssl x509 -days 370 -req -CA rootCA.cert.pem -CAkey rootCA.key.pem -CAcreateserial > client.cert.pem
+	#c_rehash .
 
 clean:
-	rm -f test *.pem
+	rm -f test *.pem *.0 *.1 *.srl

+ 68 - 0
test/test.cc

@@ -5,6 +5,10 @@
 #define SERVER_CERT_FILE "./cert.pem"
 #define SERVER_PRIVATE_KEY_FILE "./key.pem"
 #define CA_CERT_FILE "./ca-bundle.crt"
+#define CLIENT_CA_CERT_FILE "./rootCA.cert.pem"
+#define CLIENT_CERT_FILE "./client.cert.pem"
+#define CLIENT_PRIVATE_KEY_FILE "./client.key.pem"
+#define TRUST_CERT_DIR "."
 
 #ifdef _WIN32
 #include <process.h>
@@ -1374,6 +1378,70 @@ TEST(SSLClientTest, WildcardHostNameMatch) {
   ASSERT_TRUE(res != nullptr);
   ASSERT_EQ(200, res->status);
 }
+
+TEST(SSLClientServerTest, ClientCertPresent) {
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE, TRUST_CERT_DIR);
+  ASSERT_TRUE(svr.is_valid());
+
+  svr.Get("/test", [&](const Request &, Response &res){
+      res.set_content("test", "text/plain");
+      svr.stop();
+  });
+
+  thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
+
+  httplib::SSLClient cli(HOST, PORT, 30, CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE);
+  auto res = cli.Get("/test");
+  ASSERT_TRUE(res != nullptr);
+  ASSERT_EQ(200, res->status);
+
+  t.join();
+}
+
+TEST(SSLClientServerTest, ClientCertMissing) {
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE, TRUST_CERT_DIR);
+  ASSERT_TRUE(svr.is_valid());
+
+  svr.Get("/test", [&](const Request &, Response &res){
+      res.set_content("test", "text/plain");
+      svr.stop();
+  });
+
+  thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
+
+  httplib::SSLClient cli(HOST, PORT, 30);
+  auto res = cli.Get("/test");
+  ASSERT_TRUE(res == nullptr);
+
+  svr.stop();
+
+  t.join();
+}
+
+TEST(SSLClientServerTest, TrustDirOptional) {
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE);
+  ASSERT_TRUE(svr.is_valid());
+
+  svr.Get("/test", [&](const Request &, Response &res){
+      res.set_content("test", "text/plain");
+      svr.stop();
+  });
+
+  thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
+
+  httplib::SSLClient cli(HOST, PORT, 30, CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE);
+  auto res = cli.Get("/test");
+  ASSERT_TRUE(res != nullptr);
+  ASSERT_EQ(200, res->status);
+
+  t.join();
+}
+
+/* Cannot test this case as there is no external access to SSL object to check SSL_get_peer_certificate() == NULL
+TEST(SSLClientServerTest, ClientCAPathRequired) {
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, nullptr, TRUST_CERT_DIR);
+}
+*/
 #endif
 
 #ifdef _WIN32

+ 18 - 0
test/test.rootCA.conf

@@ -0,0 +1,18 @@
+[req]
+default_bits           = 2048
+distinguished_name     = req_distinguished_name
+attributes             = req_attributes
+prompt                 = no
+output_password        = mypass
+
+[req_distinguished_name]
+C                      = US
+ST                     = Test State or Province
+L                      = Test Locality
+O                      = Organization Name
+OU                     = Organizational Unit Name
+CN                     = Root CA Name
+emailAddress           = [email protected]
+
+[req_attributes]
+challengePassword              = 1234