Browse Source

Merge pull request #79 from paullouisageneau/async-certificate

Asynchronous certificate generation
Paul-Louis Ageneau 5 years ago
parent
commit
c18b1738b0
8 changed files with 93 additions and 45 deletions
  1. 18 15
      CMakeLists.txt
  2. 5 1
      include/rtc/peerconnection.hpp
  3. 52 21
      src/certificate.cpp
  4. 7 1
      src/certificate.hpp
  5. 2 3
      src/dtlstransport.cpp
  6. 2 2
      src/dtlstransport.hpp
  7. 2 0
      src/init.cpp
  8. 5 2
      src/peerconnection.cpp

+ 18 - 15
CMakeLists.txt

@@ -71,7 +71,8 @@ set(TESTS_ANSWERER_SOURCES
     ${CMAKE_CURRENT_SOURCE_DIR}/test/p2p/answerer.cpp
 )
 
-set(THREADS_PREFER_PTHREAD_FLAG ON)
+set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
+set(THREADS_PREFER_PTHREAD_FLAG TRUE)
 find_package(Threads REQUIRED)
 
 add_subdirectory(deps/usrsctp EXCLUDE_FROM_ALL)
@@ -92,10 +93,11 @@ set_target_properties(datachannel PROPERTIES
 	CXX_STANDARD 17)
 
 target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
+target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
 target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
 target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
-target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
-target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic)
+target_link_libraries(datachannel PUBLIC Threads::Threads)
+target_link_libraries(datachannel PRIVATE Usrsctp::UsrsctpStatic)
 
 add_library(datachannel-static STATIC EXCLUDE_FROM_ALL ${LIBDATACHANNEL_SOURCES})
 set_target_properties(datachannel-static PROPERTIES
@@ -103,14 +105,15 @@ set_target_properties(datachannel-static PROPERTIES
 	CXX_STANDARD 17)
 
 target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
+target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
 target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
 target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
-target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
-target_link_libraries(datachannel-static Threads::Threads Usrsctp::UsrsctpStatic)
+target_link_libraries(datachannel-static PUBLIC Threads::Threads)
+target_link_libraries(datachannel-static PRIVATE Usrsctp::UsrsctpStatic)
 
 if(WIN32)
-	target_link_libraries(datachannel "wsock32" "ws2_32") # winsock2
-	target_link_libraries(datachannel-static "wsock32" "ws2_32") # winsock2
+	target_link_libraries(datachannel PRIVATE wsock32 ws2_32) # winsock2
+	target_link_libraries(datachannel-static PRIVATE wsock32 ws2_32) # winsock2
 endif()
 
 if (USE_GNUTLS)
@@ -124,29 +127,29 @@ if (USE_GNUTLS)
 			IMPORTED_LOCATION "${GNUTLS_LIBRARIES}")
 	endif()
 	target_compile_definitions(datachannel PRIVATE USE_GNUTLS=1)
-	target_link_libraries(datachannel GnuTLS::GnuTLS)
+	target_link_libraries(datachannel PRIVATE GnuTLS::GnuTLS)
 	target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=1)
-	target_link_libraries(datachannel-static GnuTLS::GnuTLS)
+	target_link_libraries(datachannel-static PRIVATE GnuTLS::GnuTLS)
 else()
 	find_package(OpenSSL REQUIRED)
 	target_compile_definitions(datachannel PRIVATE USE_GNUTLS=0)
-	target_link_libraries(datachannel OpenSSL::SSL)
+	target_link_libraries(datachannel PRIVATE OpenSSL::SSL)
 	target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=0)
-	target_link_libraries(datachannel-static OpenSSL::SSL)
+	target_link_libraries(datachannel-static PRIVATE OpenSSL::SSL)
 endif()
 
 if (USE_JUICE)
 	add_subdirectory(deps/libjuice EXCLUDE_FROM_ALL)
 	target_compile_definitions(datachannel PRIVATE USE_JUICE=1)
-	target_link_libraries(datachannel LibJuice::LibJuiceStatic)
+	target_link_libraries(datachannel PRIVATE LibJuice::LibJuiceStatic)
 	target_compile_definitions(datachannel-static PRIVATE USE_JUICE=1)
-	target_link_libraries(datachannel-static LibJuice::LibJuiceStatic)
+	target_link_libraries(datachannel-static PRIVATE LibJuice::LibJuiceStatic)
 else()
 	find_package(LibNice REQUIRED)
 	target_compile_definitions(datachannel PRIVATE USE_JUICE=0)
-	target_link_libraries(datachannel LibNice::LibNice)
+	target_link_libraries(datachannel PRIVATE LibNice::LibNice)
 	target_compile_definitions(datachannel-static PRIVATE USE_JUICE=0)
-	target_link_libraries(datachannel-static LibNice::LibNice)
+	target_link_libraries(datachannel-static PRIVATE LibNice::LibNice)
 endif()
 
 add_library(LibDataChannel::LibDataChannel ALIAS datachannel)

+ 5 - 1
include/rtc/peerconnection.hpp

@@ -31,6 +31,7 @@
 
 #include <atomic>
 #include <functional>
+#include <future>
 #include <list>
 #include <mutex>
 #include <shared_mutex>
@@ -44,6 +45,9 @@ class IceTransport;
 class DtlsTransport;
 class SctpTransport;
 
+using certificate_ptr = std::shared_ptr<Certificate>;
+using future_certificate_ptr = std::shared_future<certificate_ptr>;
+
 class PeerConnection : public std::enable_shared_from_this<PeerConnection> {
 public:
 	enum class State : int {
@@ -126,7 +130,7 @@ private:
 	void resetCallbacks();
 
 	const Configuration mConfig;
-	const std::shared_ptr<Certificate> mCertificate;
+	const future_certificate_ptr mCertificate;
 
 	std::optional<Description> mLocalDescription, mRemoteDescription;
 	mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;

+ 52 - 21
src/certificate.cpp

@@ -141,14 +141,9 @@ string make_fingerprint(gnutls_x509_crt_t crt) {
 	return oss.str();
 }
 
-shared_ptr<Certificate> make_certificate(const string &commonName) {
-	static std::unordered_map<string, shared_ptr<Certificate>> cache;
-	static std::mutex cacheMutex;
-
-	std::lock_guard lock(cacheMutex);
-	if (auto it = cache.find(commonName); it != cache.end())
-		return it->second;
+namespace {
 
+certificate_ptr make_certificate_impl(string commonName) {
 	std::unique_ptr<gnutls_x509_crt_t, decltype(&delete_crt)> crt(create_crt(), delete_crt);
 	std::unique_ptr<gnutls_x509_privkey_t, decltype(&delete_privkey)> privkey(create_privkey(),
 	                                                                          delete_privkey);
@@ -174,11 +169,11 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
 	check_gnutls(gnutls_x509_crt_sign2(*crt, *crt, *privkey, GNUTLS_DIG_SHA256, 0),
 	             "Unable to auto-sign certificate");
 
-	auto certificate = std::make_shared<Certificate>(*crt, *privkey);
-	cache.emplace(std::make_pair(commonName, certificate));
-	return certificate;
+	return std::make_shared<Certificate>(*crt, *privkey);
 }
 
+} // namespace
+
 } // namespace rtc
 
 #else
@@ -236,15 +231,9 @@ string make_fingerprint(X509 *x509) {
 	return oss.str();
 }
 
+namespace {
 
-shared_ptr<Certificate> make_certificate(const string &commonName) {
-	static std::unordered_map<string, shared_ptr<Certificate>> cache;
-	static std::mutex cacheMutex;
-
-	std::lock_guard lock(cacheMutex);
-	if (auto it = cache.find(commonName); it != cache.end())
-		return it->second;
-
+certificate_ptr make_certificate_impl(string commonName) {
 	shared_ptr<X509> x509(X509_new(), X509_free);
 	shared_ptr<EVP_PKEY> pkey(EVP_PKEY_new(), EVP_PKEY_free);
 
@@ -281,12 +270,54 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
 	if (!X509_sign(x509.get(), pkey.get(), EVP_sha256()))
 		throw std::runtime_error("Unable to auto-sign certificate");
 
-	auto certificate = std::make_shared<Certificate>(x509, pkey);
-	cache.emplace(std::make_pair(commonName, certificate));
-	return certificate;
+	return std::make_shared<Certificate>(x509, pkey);
 }
 
+} // namespace
+
 } // namespace rtc
 
 #endif
 
+// Common for GnuTLS and OpenSSL
+
+namespace rtc {
+
+namespace {
+
+// Helper function roughly equivalent to std::async with policy std::launch::async
+// since std::async might be unreliable on some platforms (e.g. Mingw32 on Windows)
+template <class F, class... Args>
+std::future<std::result_of_t<std::decay_t<F>(std::decay_t<Args>...)>> thread_call(F &&f,
+                                                                                  Args &&... args) {
+	using R = std::result_of_t<std::decay_t<F>(std::decay_t<Args>...)>;
+	std::packaged_task<R()> task(std::bind(f, std::forward<Args>(args)...));
+	std::future<R> future = task.get_future();
+	std::thread t(std::move(task));
+	t.detach();
+	return future;
+}
+
+static std::unordered_map<string, future_certificate_ptr> CertificateCache;
+static std::mutex CertificateCacheMutex;
+
+} // namespace
+
+future_certificate_ptr make_certificate(string commonName) {
+	std::lock_guard lock(CertificateCacheMutex);
+
+	if (auto it = CertificateCache.find(commonName); it != CertificateCache.end())
+		return it->second;
+
+	auto future = thread_call(make_certificate_impl, commonName);
+	auto shared = future.share();
+	CertificateCache.emplace(std::move(commonName), shared);
+	return shared;
+}
+
+void CleanupCertificateCache() {
+	std::lock_guard lock(CertificateCacheMutex);
+	CertificateCache.clear();
+}
+
+} // namespace rtc

+ 7 - 1
src/certificate.hpp

@@ -21,6 +21,7 @@
 
 #include "include.hpp"
 
+#include <future>
 #include <tuple>
 
 #if USE_GNUTLS
@@ -62,7 +63,12 @@ string make_fingerprint(gnutls_x509_crt_t crt);
 string make_fingerprint(X509 *x509);
 #endif
 
-std::shared_ptr<Certificate> make_certificate(const string &commonName);
+using certificate_ptr = std::shared_ptr<Certificate>;
+using future_certificate_ptr = std::shared_future<certificate_ptr>;
+
+future_certificate_ptr make_certificate(string commonName); // cached
+
+void CleanupCertificateCache();
 
 } // namespace rtc
 

+ 2 - 3
src/dtlstransport.cpp

@@ -63,9 +63,8 @@ void DtlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
-                             verifier_callback verifierCallback,
-                             state_callback stateChangeCallback)
+DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr certificate,
+                             verifier_callback verifierCallback, state_callback stateChangeCallback)
     : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
       mVerifierCallback(std::move(verifierCallback)),
       mStateChangeCallback(std::move(stateChangeCallback)) {

+ 2 - 2
src/dtlstransport.hpp

@@ -51,7 +51,7 @@ public:
 	using verifier_callback = std::function<bool(const std::string &fingerprint)>;
 	using state_callback = std::function<void(State state)>;
 
-	DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
+	DtlsTransport(std::shared_ptr<IceTransport> lower, certificate_ptr certificate,
 	              verifier_callback verifierCallback, state_callback stateChangeCallback);
 	~DtlsTransport();
 
@@ -65,7 +65,7 @@ private:
 	void changeState(State state);
 	void runRecvLoop();
 
-	const std::shared_ptr<Certificate> mCertificate;
+	const certificate_ptr mCertificate;
 
 	Queue<message_ptr> mIncomingQueue;
 	std::atomic<State> mState;

+ 2 - 0
src/init.cpp

@@ -18,6 +18,7 @@
 
 #include "init.hpp"
 
+#include "certificate.hpp"
 #include "dtlstransport.hpp"
 #include "sctptransport.hpp"
 
@@ -74,6 +75,7 @@ Init::Init() {
 }
 
 Init::~Init() {
+	CleanupCertificateCache();
 	DtlsTransport::Cleanup();
 	SctpTransport::Cleanup();
 

+ 5 - 2
src/peerconnection.cpp

@@ -269,9 +269,10 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 		if (auto transport = std::atomic_load(&mDtlsTransport))
 			return transport;
 
+		auto certificate = mCertificate.get();
 		auto lower = std::atomic_load(&mIceTransport);
 		auto transport = std::make_shared<DtlsTransport>(
-		    lower, mCertificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1),
+		    lower, certificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1),
 		    [this, weak_this = weak_from_this()](DtlsTransport::State state) {
 			    auto shared_this = weak_this.lock();
 			    if (!shared_this)
@@ -513,9 +514,11 @@ void PeerConnection::processLocalDescription(Description description) {
 	if (auto remote = remoteDescription())
 		remoteSctpPort = remote->sctpPort();
 
+	auto certificate = mCertificate.get(); // wait for certificate if not ready
+
 	std::lock_guard lock(mLocalDescriptionMutex);
 	mLocalDescription.emplace(std::move(description));
-	mLocalDescription->setFingerprint(mCertificate->fingerprint());
+	mLocalDescription->setFingerprint(certificate->fingerprint());
 	mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
 	mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);