Browse Source

Made certificate generation async to reduce init delay

Paul-Louis Ageneau 5 years ago
parent
commit
5e59186757

+ 5 - 1
include/rtc/peerconnection.hpp

@@ -31,6 +31,7 @@
 
 #include <atomic>
 #include <functional>
+#include <future>
 #include <list>
 #include <mutex>
 #include <shared_mutex>
@@ -44,6 +45,9 @@ class IceTransport;
 class DtlsTransport;
 class SctpTransport;
 
+using certificate_ptr = std::shared_ptr<Certificate>;
+using future_certificate_ptr = std::shared_future<certificate_ptr>;
+
 class PeerConnection : public std::enable_shared_from_this<PeerConnection> {
 public:
 	enum class State : int {
@@ -126,7 +130,7 @@ private:
 	void resetCallbacks();
 
 	const Configuration mConfig;
-	const std::shared_ptr<Certificate> mCertificate;
+	const future_certificate_ptr mCertificate;
 
 	std::optional<Description> mLocalDescription, mRemoteDescription;
 	mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;

+ 30 - 21
src/certificate.cpp

@@ -141,14 +141,9 @@ string make_fingerprint(gnutls_x509_crt_t crt) {
 	return oss.str();
 }
 
-shared_ptr<Certificate> make_certificate(const string &commonName) {
-	static std::unordered_map<string, shared_ptr<Certificate>> cache;
-	static std::mutex cacheMutex;
-
-	std::lock_guard lock(cacheMutex);
-	if (auto it = cache.find(commonName); it != cache.end())
-		return it->second;
+namespace {
 
+certificate_ptr make_certificate_impl(string commonName) {
 	std::unique_ptr<gnutls_x509_crt_t, decltype(&delete_crt)> crt(create_crt(), delete_crt);
 	std::unique_ptr<gnutls_x509_privkey_t, decltype(&delete_privkey)> privkey(create_privkey(),
 	                                                                          delete_privkey);
@@ -174,11 +169,11 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
 	check_gnutls(gnutls_x509_crt_sign2(*crt, *crt, *privkey, GNUTLS_DIG_SHA256, 0),
 	             "Unable to auto-sign certificate");
 
-	auto certificate = std::make_shared<Certificate>(*crt, *privkey);
-	cache.emplace(std::make_pair(commonName, certificate));
-	return certificate;
+	return std::make_shared<Certificate>(*crt, *privkey);
 }
 
+} // namespace
+
 } // namespace rtc
 
 #else
@@ -236,15 +231,9 @@ string make_fingerprint(X509 *x509) {
 	return oss.str();
 }
 
+namespace {
 
-shared_ptr<Certificate> make_certificate(const string &commonName) {
-	static std::unordered_map<string, shared_ptr<Certificate>> cache;
-	static std::mutex cacheMutex;
-
-	std::lock_guard lock(cacheMutex);
-	if (auto it = cache.find(commonName); it != cache.end())
-		return it->second;
-
+certificate_ptr make_certificate_impl(string commonName) {
 	shared_ptr<X509> x509(X509_new(), X509_free);
 	shared_ptr<EVP_PKEY> pkey(EVP_PKEY_new(), EVP_PKEY_free);
 
@@ -281,12 +270,32 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
 	if (!X509_sign(x509.get(), pkey.get(), EVP_sha256()))
 		throw std::runtime_error("Unable to auto-sign certificate");
 
-	auto certificate = std::make_shared<Certificate>(x509, pkey);
-	cache.emplace(std::make_pair(commonName, certificate));
-	return certificate;
+	return std::make_shared<Certificate>(x509, pkey);
 }
 
+} // namespace
+
 } // namespace rtc
 
 #endif
 
+// Common for GnuTLS and OpenSSL
+
+namespace rtc {
+
+future_certificate_ptr make_certificate(string commonName) {
+	static std::unordered_map<string, future_certificate_ptr> cache;
+	static std::mutex cacheMutex;
+
+	std::lock_guard lock(cacheMutex);
+
+	if (auto it = cache.find(commonName); it != cache.end())
+		return it->second;
+
+	auto future = std::async(make_certificate_impl, commonName);
+	auto shared = future.share();
+	cache.emplace(std::move(commonName), shared);
+	return shared;
+}
+
+} // namespace rtc

+ 5 - 1
src/certificate.hpp

@@ -21,6 +21,7 @@
 
 #include "include.hpp"
 
+#include <future>
 #include <tuple>
 
 #if USE_GNUTLS
@@ -62,7 +63,10 @@ string make_fingerprint(gnutls_x509_crt_t crt);
 string make_fingerprint(X509 *x509);
 #endif
 
-std::shared_ptr<Certificate> make_certificate(const string &commonName);
+using certificate_ptr = std::shared_ptr<Certificate>;
+using future_certificate_ptr = std::shared_future<certificate_ptr>;
+
+future_certificate_ptr make_certificate(string commonName);
 
 } // namespace rtc
 

+ 2 - 3
src/dtlstransport.cpp

@@ -63,9 +63,8 @@ void DtlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
-                             verifier_callback verifierCallback,
-                             state_callback stateChangeCallback)
+DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr certificate,
+                             verifier_callback verifierCallback, state_callback stateChangeCallback)
     : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
       mVerifierCallback(std::move(verifierCallback)),
       mStateChangeCallback(std::move(stateChangeCallback)) {

+ 2 - 2
src/dtlstransport.hpp

@@ -51,7 +51,7 @@ public:
 	using verifier_callback = std::function<bool(const std::string &fingerprint)>;
 	using state_callback = std::function<void(State state)>;
 
-	DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
+	DtlsTransport(std::shared_ptr<IceTransport> lower, certificate_ptr certificate,
 	              verifier_callback verifierCallback, state_callback stateChangeCallback);
 	~DtlsTransport();
 
@@ -65,7 +65,7 @@ private:
 	void changeState(State state);
 	void runRecvLoop();
 
-	const std::shared_ptr<Certificate> mCertificate;
+	const certificate_ptr mCertificate;
 
 	Queue<message_ptr> mIncomingQueue;
 	std::atomic<State> mState;

+ 5 - 2
src/peerconnection.cpp

@@ -269,9 +269,10 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 		if (auto transport = std::atomic_load(&mDtlsTransport))
 			return transport;
 
+		auto certificate = mCertificate.get();
 		auto lower = std::atomic_load(&mIceTransport);
 		auto transport = std::make_shared<DtlsTransport>(
-		    lower, mCertificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1),
+		    lower, certificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1),
 		    [this, weak_this = weak_from_this()](DtlsTransport::State state) {
 			    auto shared_this = weak_this.lock();
 			    if (!shared_this)
@@ -513,9 +514,11 @@ void PeerConnection::processLocalDescription(Description description) {
 	if (auto remote = remoteDescription())
 		remoteSctpPort = remote->sctpPort();
 
+	auto certificate = mCertificate.get(); // wait for certificate if not ready
+
 	std::lock_guard lock(mLocalDescriptionMutex);
 	mLocalDescription.emplace(std::move(description));
-	mLocalDescription->setFingerprint(mCertificate->fingerprint());
+	mLocalDescription->setFingerprint(certificate->fingerprint());
 	mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
 	mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);