yhirose 1 year ago
parent
commit
548dfff0ae
3 changed files with 76 additions and 35 deletions
  1. 24 9
      httplib.h
  2. 3 1
      test/Makefile
  3. 49 25
      test/test.cc

+ 24 - 9
httplib.h

@@ -145,11 +145,11 @@ using ssize_t = long;
 #endif // _MSC_VER
 
 #ifndef S_ISREG
-#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG)
+#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG)
 #endif // S_ISREG
 
 #ifndef S_ISDIR
-#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR)
+#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR)
 #endif // S_ISDIR
 
 #ifndef NOMINMAX
@@ -1745,10 +1745,12 @@ public:
 
   explicit SSLClient(const std::string &host, int port,
                      const std::string &client_cert_path,
-                     const std::string &client_key_path);
+                     const std::string &client_key_path,
+                     const std::string &private_key_password = std::string());
 
   explicit SSLClient(const std::string &host, int port, X509 *client_cert,
-                     EVP_PKEY *client_key);
+                     EVP_PKEY *client_key,
+                     const std::string &private_key_password = std::string());
 
   ~SSLClient() override;
 
@@ -2700,8 +2702,8 @@ inline bool mmap::open(const char *path) {
   if (!::GetFileSizeEx(hFile_, &size)) { return false; }
   size_ = static_cast<size_t>(size.QuadPart);
 
-  hMapping_ = ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_,
-                                         NULL);
+  hMapping_ =
+      ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL);
 
   if (hMapping_ == NULL) {
     close();
@@ -8438,7 +8440,6 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path,
 
     SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION);
 
-    // add default password callback before opening encrypted private key
     if (private_key_password != nullptr && (private_key_password[0] != '\0')) {
       SSL_CTX_set_default_passwd_cb_userdata(
           ctx_,
@@ -8544,7 +8545,8 @@ inline SSLClient::SSLClient(const std::string &host, int port)
 
 inline SSLClient::SSLClient(const std::string &host, int port,
                             const std::string &client_cert_path,
-                            const std::string &client_key_path)
+                            const std::string &client_key_path,
+                            const std::string &private_key_password)
     : ClientImpl(host, port, client_cert_path, client_key_path) {
   ctx_ = SSL_CTX_new(TLS_client_method());
 
@@ -8554,6 +8556,12 @@ inline SSLClient::SSLClient(const std::string &host, int port,
                 });
 
   if (!client_cert_path.empty() && !client_key_path.empty()) {
+    if (!private_key_password.empty()) {
+      SSL_CTX_set_default_passwd_cb_userdata(
+          ctx_, reinterpret_cast<void *>(
+                    const_cast<char *>(private_key_password.c_str())));
+    }
+
     if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(),
                                      SSL_FILETYPE_PEM) != 1 ||
         SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(),
@@ -8565,7 +8573,8 @@ inline SSLClient::SSLClient(const std::string &host, int port,
 }
 
 inline SSLClient::SSLClient(const std::string &host, int port,
-                            X509 *client_cert, EVP_PKEY *client_key)
+                            X509 *client_cert, EVP_PKEY *client_key,
+                            const std::string &private_key_password)
     : ClientImpl(host, port) {
   ctx_ = SSL_CTX_new(TLS_client_method());
 
@@ -8575,6 +8584,12 @@ inline SSLClient::SSLClient(const std::string &host, int port,
                 });
 
   if (client_cert != nullptr && client_key != nullptr) {
+    if (!private_key_password.empty()) {
+      SSL_CTX_set_default_passwd_cb_userdata(
+          ctx_, reinterpret_cast<void *>(
+                    const_cast<char *>(private_key_password.c_str())));
+    }
+
     if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 ||
         SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) {
       SSL_CTX_free(ctx_);

+ 3 - 1
test/Makefile

@@ -70,9 +70,11 @@ cert.pem:
 	openssl genrsa 2048 > rootCA.key.pem
 	openssl req -x509 -new -batch -config test.rootCA.conf -key rootCA.key.pem -days 1024 > rootCA.cert.pem
 	openssl genrsa 2048 > client.key.pem
-	openssl req -new -batch -config test.conf -key client.key.pem | openssl x509 -days 370 -req -CA rootCA.cert.pem -CAkey rootCA.key.pem -CAcreateserial > client.cert.pem
+	openssl req -new -batch -config test.conf -key client.key.pem| openssl x509 -days 370 -req -CA rootCA.cert.pem -CAkey rootCA.key.pem -CAcreateserial > client.cert.pem
 	openssl genrsa -passout pass:test123! 2048 > key_encrypted.pem
 	openssl req -new -batch -config test.conf -key key_encrypted.pem | openssl x509 -days 3650 -req -signkey key_encrypted.pem > cert_encrypted.pem
+	openssl genrsa -aes256 -passout pass:test012! 2048 > client_encrypted.key.pem
+	openssl req -new -batch -config test.conf -key client_encrypted.key.pem -passin pass:test012! | openssl x509 -days 370 -req -CA rootCA.cert.pem -CAkey rootCA.key.pem -CAcreateserial > client_encrypted.cert.pem
 	#c_rehash .
 
 clean:

+ 49 - 25
test/test.cc

@@ -20,6 +20,9 @@
 #define CLIENT_CA_CERT_DIR "."
 #define CLIENT_CERT_FILE "./client.cert.pem"
 #define CLIENT_PRIVATE_KEY_FILE "./client.key.pem"
+#define CLIENT_ENCRYPTED_CERT_FILE "./client_encrypted.cert.pem"
+#define CLIENT_ENCRYPTED_PRIVATE_KEY_FILE "./client_encrypted.key.pem"
+#define CLIENT_ENCRYPTED_PRIVATE_KEY_PASS "test012!"
 #define SERVER_ENCRYPTED_CERT_FILE "./cert_encrypted.pem"
 #define SERVER_ENCRYPTED_PRIVATE_KEY_FILE "./key_encrypted.pem"
 #define SERVER_ENCRYPTED_PRIVATE_KEY_PASS "test123!"
@@ -5109,15 +5112,16 @@ TEST(SSLClientTest, SetInterfaceWithINET6) {
 }
 #endif
 
-TEST(SSLClientServerTest, ClientCertPresent) {
+void ClientCertPresent(
+    const std::string &client_cert_file,
+    const std::string &client_private_key_file,
+    const std::string &client_encrypted_private_key_pass = std::string()) {
   SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE,
                 CLIENT_CA_CERT_DIR);
   ASSERT_TRUE(svr.is_valid());
 
   svr.Get("/test", [&](const Request &req, Response &res) {
     res.set_content("test", "text/plain");
-    svr.stop();
-    ASSERT_TRUE(true);
 
     auto peer_cert = SSL_get_peer_certificate(req.ssl);
     ASSERT_TRUE(peer_cert != nullptr);
@@ -5140,13 +5144,15 @@ TEST(SSLClientServerTest, ClientCertPresent) {
 
   thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
   auto se = detail::scope_exit([&] {
+    svr.stop();
     t.join();
     ASSERT_FALSE(svr.is_running());
   });
 
   svr.wait_until_ready();
 
-  SSLClient cli(HOST, PORT, CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE);
+  SSLClient cli(HOST, PORT, client_cert_file, client_private_key_file,
+                client_encrypted_private_key_pass);
   cli.enable_server_certificate_verification(false);
   cli.set_connection_timeout(30);
 
@@ -5155,35 +5161,43 @@ TEST(SSLClientServerTest, ClientCertPresent) {
   ASSERT_EQ(StatusCode::OK_200, res->status);
 }
 
+TEST(SSLClientServerTest, ClientCertPresent) {
+  ClientCertPresent(CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE);
+}
+
+TEST(SSLClientServerTest, ClientEncryptedCertPresent) {
+  ClientCertPresent(CLIENT_ENCRYPTED_CERT_FILE,
+                    CLIENT_ENCRYPTED_PRIVATE_KEY_FILE,
+                    CLIENT_ENCRYPTED_PRIVATE_KEY_PASS);
+}
+
 #if !defined(_WIN32) || defined(OPENSSL_USE_APPLINK)
-TEST(SSLClientServerTest, MemoryClientCertPresent) {
-  X509 *server_cert;
-  EVP_PKEY *server_private_key;
-  X509_STORE *client_ca_cert_store;
-  X509 *client_cert;
-  EVP_PKEY *client_private_key;
-
-  FILE *f = fopen(SERVER_CERT_FILE, "r+");
-  server_cert = PEM_read_X509(f, nullptr, nullptr, nullptr);
+void MemoryClientCertPresent(
+    const std::string &client_cert_file,
+    const std::string &client_private_key_file,
+    const std::string &client_encrypted_private_key_pass = std::string()) {
+  auto f = fopen(SERVER_CERT_FILE, "r+");
+  auto server_cert = PEM_read_X509(f, nullptr, nullptr, nullptr);
   fclose(f);
 
   f = fopen(SERVER_PRIVATE_KEY_FILE, "r+");
-  server_private_key = PEM_read_PrivateKey(f, nullptr, nullptr, nullptr);
+  auto server_private_key = PEM_read_PrivateKey(f, nullptr, nullptr, nullptr);
   fclose(f);
 
   f = fopen(CLIENT_CA_CERT_FILE, "r+");
-  client_cert = PEM_read_X509(f, nullptr, nullptr, nullptr);
-  client_ca_cert_store = X509_STORE_new();
+  auto client_cert = PEM_read_X509(f, nullptr, nullptr, nullptr);
+  auto client_ca_cert_store = X509_STORE_new();
   X509_STORE_add_cert(client_ca_cert_store, client_cert);
   X509_free(client_cert);
   fclose(f);
 
-  f = fopen(CLIENT_CERT_FILE, "r+");
+  f = fopen(client_cert_file.c_str(), "r+");
   client_cert = PEM_read_X509(f, nullptr, nullptr, nullptr);
   fclose(f);
 
-  f = fopen(CLIENT_PRIVATE_KEY_FILE, "r+");
-  client_private_key = PEM_read_PrivateKey(f, nullptr, nullptr, nullptr);
+  f = fopen(client_private_key_file.c_str(), "r+");
+  auto client_private_key = PEM_read_PrivateKey(
+      f, nullptr, nullptr, (void *)client_encrypted_private_key_pass.c_str());
   fclose(f);
 
   SSLServer svr(server_cert, server_private_key, client_ca_cert_store);
@@ -5191,8 +5205,6 @@ TEST(SSLClientServerTest, MemoryClientCertPresent) {
 
   svr.Get("/test", [&](const Request &req, Response &res) {
     res.set_content("test", "text/plain");
-    svr.stop();
-    ASSERT_TRUE(true);
 
     auto peer_cert = SSL_get_peer_certificate(req.ssl);
     ASSERT_TRUE(peer_cert != nullptr);
@@ -5215,13 +5227,15 @@ TEST(SSLClientServerTest, MemoryClientCertPresent) {
 
   thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
   auto se = detail::scope_exit([&] {
+    svr.stop();
     t.join();
     ASSERT_FALSE(svr.is_running());
   });
 
   svr.wait_until_ready();
 
-  SSLClient cli(HOST, PORT, client_cert, client_private_key);
+  SSLClient cli(HOST, PORT, client_cert, client_private_key,
+                client_encrypted_private_key_pass);
   cli.enable_server_certificate_verification(false);
   cli.set_connection_timeout(30);
 
@@ -5234,6 +5248,16 @@ TEST(SSLClientServerTest, MemoryClientCertPresent) {
   X509_free(client_cert);
   EVP_PKEY_free(client_private_key);
 }
+
+TEST(SSLClientServerTest, MemoryClientCertPresent) {
+  MemoryClientCertPresent(CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE);
+}
+
+TEST(SSLClientServerTest, MemoryClientEncryptedCertPresent) {
+  MemoryClientCertPresent(CLIENT_ENCRYPTED_CERT_FILE,
+                          CLIENT_ENCRYPTED_PRIVATE_KEY_FILE,
+                          CLIENT_ENCRYPTED_PRIVATE_KEY_PASS);
+}
 #endif
 
 TEST(SSLClientServerTest, ClientCertMissing) {
@@ -5265,11 +5289,11 @@ TEST(SSLClientServerTest, TrustDirOptional) {
 
   svr.Get("/test", [&](const Request &, Response &res) {
     res.set_content("test", "text/plain");
-    svr.stop();
   });
 
   thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
   auto se = detail::scope_exit([&] {
+    svr.stop();
     t.join();
     ASSERT_FALSE(svr.is_running());
   });
@@ -5361,13 +5385,12 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) {
         nullptr);
     return true;
   };
+
   SSLServer svr(setup_ssl_ctx_callback);
   ASSERT_TRUE(svr.is_valid());
 
   svr.Get("/test", [&](const Request &req, Response &res) {
     res.set_content("test", "text/plain");
-    svr.stop();
-    ASSERT_TRUE(true);
 
     auto peer_cert = SSL_get_peer_certificate(req.ssl);
     ASSERT_TRUE(peer_cert != nullptr);
@@ -5390,6 +5413,7 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) {
 
   thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });
   auto se = detail::scope_exit([&] {
+    svr.stop();
     t.join();
     ASSERT_FALSE(svr.is_running());
   });