Browse Source

Allow setting CA certificate during TLS connection (Fixed #1007)

melpon 1 year ago
parent
commit
2a429d7666

+ 1 - 0
include/rtc/websocket.hpp

@@ -39,6 +39,7 @@ public:
 		optional<std::chrono::milliseconds> connectionTimeout; // zero to disable
 		optional<std::chrono::milliseconds> connectionTimeout; // zero to disable
 		optional<std::chrono::milliseconds> pingInterval;      // zero to disable
 		optional<std::chrono::milliseconds> pingInterval;      // zero to disable
 		optional<int> maxOutstandingPings;
 		optional<int> maxOutstandingPings;
+		optional<string> caCertificatePemFile;
 	};
 	};
 
 
 	WebSocket();
 	WebSocket();

+ 27 - 2
src/impl/verifiedtlstransport.cpp

@@ -13,9 +13,11 @@
 
 
 namespace rtc::impl {
 namespace rtc::impl {
 
 
+static const string PemBeginCertificateTag = "-----BEGIN CERTIFICATE-----";
+
 VerifiedTlsTransport::VerifiedTlsTransport(
 VerifiedTlsTransport::VerifiedTlsTransport(
     variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower, string host,
     variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower, string host,
-    certificate_ptr certificate, state_callback callback)
+    certificate_ptr certificate, state_callback callback, [[maybe_unused]] optional<string> cacert)
     : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) {
     : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) {
 
 
 	PLOG_DEBUG << "Setting up TLS certificate verification";
 	PLOG_DEBUG << "Setting up TLS certificate verification";
@@ -24,13 +26,36 @@ VerifiedTlsTransport::VerifiedTlsTransport(
 	gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0);
 	gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0);
 #elif USE_MBEDTLS
 #elif USE_MBEDTLS
 	mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED);
 	mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED);
+	mbedtls_x509_crt_init(&mCaCert);
+	try {
+		if (cacert) {
+			if (cacert->find(PemBeginCertificateTag) == string::npos) {
+				// *cacert is a file path
+				mbedtls::check(mbedtls_x509_crt_parse_file(&mCaCert, cacert->c_str()));
+			} else {
+				// *cacert is a PEM content
+				mbedtls::check(mbedtls_x509_crt_parse(
+				    &mCaCert, reinterpret_cast<const unsigned char *>(cacert->c_str()),
+				    cacert->size()));
+			}
+			mbedtls_ssl_conf_ca_chain(&mConf, &mCaCert, NULL);
+		}
+	} catch (...) {
+		mbedtls_x509_crt_free(&mCaCert);
+		throw;
+	}
 #else
 #else
 	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);
 	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);
 	SSL_set_verify_depth(mSsl, 4);
 	SSL_set_verify_depth(mSsl, 4);
 #endif
 #endif
 }
 }
 
 
-VerifiedTlsTransport::~VerifiedTlsTransport() { stop(); }
+VerifiedTlsTransport::~VerifiedTlsTransport() {
+	stop();
+#if USE_MBEDTLS
+	mbedtls_x509_crt_free(&mCaCert);
+#endif
+}
 
 
 } // namespace rtc::impl
 } // namespace rtc::impl
 
 

+ 7 - 1
src/impl/verifiedtlstransport.hpp

@@ -18,8 +18,14 @@ namespace rtc::impl {
 class VerifiedTlsTransport final : public TlsTransport {
 class VerifiedTlsTransport final : public TlsTransport {
 public:
 public:
 	VerifiedTlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower,
 	VerifiedTlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower,
-	                     string host, certificate_ptr certificate, state_callback callback);
+	                     string host, certificate_ptr certificate, state_callback callback,
+	                     optional<string> cacert);
 	~VerifiedTlsTransport();
 	~VerifiedTlsTransport();
+
+private:
+#if USE_MBEDTLS
+	mbedtls_x509_crt mCaCert;
+#endif
 };
 };
 
 
 } // namespace rtc::impl
 } // namespace rtc::impl

+ 2 - 1
src/impl/websocket.cpp

@@ -358,7 +358,8 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 		shared_ptr<TlsTransport> transport;
 		shared_ptr<TlsTransport> transport;
 		if (verify)
 		if (verify)
 			transport = std::make_shared<VerifiedTlsTransport>(lower, mHostname.value(),
 			transport = std::make_shared<VerifiedTlsTransport>(lower, mHostname.value(),
-			                                                   mCertificate, stateChangeCallback);
+			                                                   mCertificate, stateChangeCallback,
+			                                                   config.caCertificatePemFile);
 		else
 		else
 			transport =
 			transport =
 			    std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);
 			    std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);