2
0
Эх сурвалжийг харах

Implemented TLS certificate support for WebSocketServer

Paul-Louis Ageneau 4 жил өмнө
parent
commit
bf8b168783

+ 4 - 2
include/rtc/websocketserver.hpp

@@ -21,9 +21,7 @@
 
 #if RTC_ENABLE_WEBSOCKET
 
-#include "channel.hpp"
 #include "common.hpp"
-#include "message.hpp"
 #include "websocket.hpp"
 
 namespace rtc {
@@ -38,6 +36,10 @@ class RTC_CPP_EXPORT WebSocketServer final : private CheshireCat<impl::WebSocket
 public:
 	struct Configuration {
 		uint16_t port = 8080;
+		bool secure = false;
+		optional<string> certificatePemFile;
+		optional<string> keyPemFile;
+		optional<string> keyPemPass;
 	};
 
 	WebSocketServer();

+ 112 - 126
src/impl/certificate.cpp

@@ -28,89 +28,37 @@
 
 namespace rtc::impl {
 
-const string COMMON_NAME = "libdatachannel";
-
 #if USE_GNUTLS
 
 Certificate Certificate::FromString(string crt_pem, string key_pem) {
-	Certificate certificate;
+	PLOG_DEBUG << "Importing certificate from PEM string (GnuTLS)";
 
+	shared_ptr<gnutls_certificate_credentials_t> creds(gnutls::new_credentials(),
+	                                                   gnutls::free_credentials);
 	gnutls_datum_t crt_datum = gnutls::make_datum(crt_pem.data(), crt_pem.size());
 	gnutls_datum_t key_datum = gnutls::make_datum(key_pem.data(), key_pem.size());
-	gnutls::check(gnutls_certificate_set_x509_key_mem(*certificate.mCredentials, &crt_datum,
-	                                                  &key_datum, GNUTLS_X509_FMT_PEM),
-	              "Unable to import PEM certificate and key");
+	gnutls::check(
+	    gnutls_certificate_set_x509_key_mem(*creds, &crt_datum, &key_datum, GNUTLS_X509_FMT_PEM),
+	    "Unable to import PEM certificate and key");
 
-	certificate.computeFingerprint();
-	return certificate;
+	return Certificate(std::move(creds));
 }
 
 Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file,
                                   const string &pass) {
-	Certificate certificate;
+	PLOG_DEBUG << "Importing certificate from PEM file (GnuTLS): " << crt_pem_file;
 
-	gnutls::check(gnutls_certificate_set_x509_key_file2(*certificate.mCredentials,
-	                                                    crt_pem_file.c_str(), key_pem_file.c_str(),
-	                                                    GNUTLS_X509_FMT_PEM, pass.c_str(), 0),
+	shared_ptr<gnutls_certificate_credentials_t> creds(gnutls::new_credentials(),
+	                                                   gnutls::free_credentials);
+	gnutls::check(gnutls_certificate_set_x509_key_file2(*creds, crt_pem_file.c_str(),
+	                                                    key_pem_file.c_str(), GNUTLS_X509_FMT_PEM,
+	                                                    pass.c_str(), 0),
 	              "Unable to import PEM certificate and key from file");
 
-	certificate.computeFingerprint();
-	return certificate;
-}
-
-Certificate::Certificate() : mCredentials(gnutls::new_credentials(), gnutls::free_credentials) {}
-
-Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
-    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials),
-      mFingerprint(make_fingerprint(crt)) {
-
-	gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
-	              "Unable to set certificate and key pair in credentials");
-}
-
-gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
-
-string Certificate::fingerprint() const { return mFingerprint; }
-
-void Certificate::computeFingerprint() {
-	auto new_crt_list = [this]() -> gnutls_x509_crt_t * {
-		gnutls_x509_crt_t *crt_list = nullptr;
-		unsigned int crt_list_size = 0;
-		gnutls::check(gnutls_certificate_get_x509_crt(*mCredentials, 0, &crt_list, &crt_list_size));
-		assert(crt_list_size == 1);
-		return crt_list;
-	};
-
-	auto free_crt_list = [](gnutls_x509_crt_t *crt_list) {
-		gnutls_x509_crt_deinit(crt_list[0]);
-		gnutls_free(crt_list);
-	};
-
-	unique_ptr<gnutls_x509_crt_t, decltype(free_crt_list)> crt_list(new_crt_list(), free_crt_list);
-
-	mFingerprint = make_fingerprint(*crt_list);
-}
-
-string make_fingerprint(gnutls_x509_crt_t crt) {
-	const size_t size = 32;
-	unsigned char buffer[size];
-	size_t len = size;
-	gnutls::check(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len),
-	              "X509 fingerprint error");
-
-	std::ostringstream oss;
-	oss << std::hex << std::uppercase << std::setfill('0');
-	for (size_t i = 0; i < len; ++i) {
-		if (i)
-			oss << std::setw(1) << ':';
-		oss << std::setw(2) << unsigned(buffer[i]);
-	}
-	return oss.str();
+	return Certificate(std::move(creds));
 }
 
-namespace {
-
-certificate_ptr make_certificate_impl(CertificateType type) {
+Certificate Certificate::Generate(CertificateType type, const string &commonName) {
 	PLOG_DEBUG << "Generating certificate (GnuTLS)";
 
 	using namespace gnutls;
@@ -146,8 +94,8 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 	gnutls_x509_crt_set_expiration_time(*crt, (now + hours(24 * 365)).time_since_epoch().count());
 	gnutls_x509_crt_set_version(*crt, 1);
 	gnutls_x509_crt_set_key(*crt, *privkey);
-	gnutls_x509_crt_set_dn_by_oid(*crt, GNUTLS_OID_X520_COMMON_NAME, 0, COMMON_NAME.data(),
-	                              COMMON_NAME.size());
+	gnutls_x509_crt_set_dn_by_oid(*crt, GNUTLS_OID_X520_COMMON_NAME, 0, commonName.data(),
+	                              commonName.size());
 
 	const size_t serialSize = 16;
 	char serial[serialSize];
@@ -157,10 +105,59 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 	gnutls::check(gnutls_x509_crt_sign2(*crt, *crt, *privkey, GNUTLS_DIG_SHA256, 0),
 	              "Unable to auto-sign certificate");
 
-	return std::make_shared<Certificate>(*crt, *privkey);
+	return Certificate(*crt, *privkey);
 }
 
-} // namespace
+Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
+    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials),
+      mFingerprint(make_fingerprint(crt)) {
+
+	gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
+	              "Unable to set certificate and key pair in credentials");
+}
+
+Certificate::Certificate(shared_ptr<gnutls_certificate_credentials_t> creds)
+    : mCredentials(std::move(creds)), mFingerprint(make_fingerprint(*mCredentials)) {}
+
+gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
+
+string Certificate::fingerprint() const { return mFingerprint; }
+
+string make_fingerprint(gnutls_certificate_credentials_t credentials) {
+	auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * {
+		gnutls_x509_crt_t *crt_list = nullptr;
+		unsigned int crt_list_size = 0;
+		gnutls::check(gnutls_certificate_get_x509_crt(credentials, 0, &crt_list, &crt_list_size));
+		assert(crt_list_size == 1);
+		return crt_list;
+	};
+
+	auto free_crt_list = [](gnutls_x509_crt_t *crt_list) {
+		gnutls_x509_crt_deinit(crt_list[0]);
+		gnutls_free(crt_list);
+	};
+
+	unique_ptr<gnutls_x509_crt_t, decltype(free_crt_list)> crt_list(new_crt_list(), free_crt_list);
+
+	return make_fingerprint(*crt_list);
+}
+
+string make_fingerprint(gnutls_x509_crt_t crt) {
+	const size_t size = 32;
+	unsigned char buffer[size];
+	size_t len = size;
+	gnutls::check(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len),
+	              "X509 fingerprint error");
+
+	std::ostringstream oss;
+	oss << std::hex << std::uppercase << std::setfill('0');
+	for (size_t i = 0; i < len; ++i) {
+		if (i)
+			oss << std::setw(1) << ':';
+		oss << std::setw(2) << unsigned(buffer[i]);
+	}
+	return oss.str();
+}
 
 #else // USE_GNUTLS==0
 
@@ -177,91 +174,54 @@ int dummy_pass_cb(char *buf, int size, int /*rwflag*/, void *u) {
 } // namespace
 
 Certificate Certificate::FromString(string crt_pem, string key_pem) {
-	Certificate certificate;
+	PLOG_DEBUG << "Importing certificate from PEM string (OpenSSL)";
 
 	BIO *bio = BIO_new(BIO_s_mem());
 	BIO_write(bio, crt_pem.data(), int(crt_pem.size()));
-	certificate.mX509 =
-	    shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, nullptr, nullptr), X509_free);
+	auto x509 = shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, nullptr, nullptr), X509_free);
 	BIO_free(bio);
-	if (!certificate.mX509)
+	if (!x509)
 		throw std::invalid_argument("Unable to import PEM certificate");
 
 	bio = BIO_new(BIO_s_mem());
 	BIO_write(bio, key_pem.data(), int(key_pem.size()));
-	certificate.mPKey = shared_ptr<EVP_PKEY>(
-	    PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr), EVP_PKEY_free);
+	auto pkey = shared_ptr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr),
+	                                 EVP_PKEY_free);
 	BIO_free(bio);
-	if (!certificate.mPKey)
+	if (!pkey)
 		throw std::invalid_argument("Unable to import PEM key");
 
-	certificate.computeFingerprint();
-	return certificate;
+	return Certificate(x509, pkey);
 }
 
 Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file,
                                   const string &pass) {
-	Certificate certificate;
+	PLOG_DEBUG << "Importing certificate from PEM file (OpenSSL): " << crt_pem_file;
 
 	FILE *file = fopen(crt_pem_file.c_str(), "r");
 	if (!file)
 		throw std::invalid_argument("Unable to open PEM certificate file");
 
-	certificate.mX509 = shared_ptr<X509>(PEM_read_X509(file, nullptr, nullptr, nullptr), X509_free);
+	auto x509 = shared_ptr<X509>(PEM_read_X509(file, nullptr, nullptr, nullptr), X509_free);
 	fclose(file);
-	if (!certificate.mX509)
+	if (!x509)
 		throw std::invalid_argument("Unable to import PEM certificate from file");
 
 	file = fopen(key_pem_file.c_str(), "r");
 	if (!file)
 		throw std::invalid_argument("Unable to open PEM key file");
 
-	certificate.mPKey = shared_ptr<EVP_PKEY>(
+	auto pkey = shared_ptr<EVP_PKEY>(
 	    PEM_read_PrivateKey(file, nullptr, dummy_pass_cb, const_cast<char *>(pass.c_str())),
 	    EVP_PKEY_free);
 	fclose(file);
-	if (!certificate.mPKey)
+	if (!pkey)
 		throw std::invalid_argument("Unable to import PEM key from file");
 
-	certificate.computeFingerprint();
-	return certificate;
-}
-
-Certificate::Certificate() {}
-
-Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
-    : mX509(std::move(x509)), mPKey(std::move(pkey)) {
-	mFingerprint = make_fingerprint(mX509.get());
-}
-
-void Certificate::computeFingerprint() { mFingerprint = make_fingerprint(mX509.get()); }
-
-string Certificate::fingerprint() const { return mFingerprint; }
-
-std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
-	return {mX509.get(), mPKey.get()};
-}
-
-string make_fingerprint(X509 *x509) {
-	const size_t size = 32;
-	unsigned char buffer[size];
-	unsigned int len = size;
-	if (!X509_digest(x509, EVP_sha256(), buffer, &len))
-		throw std::runtime_error("X509 fingerprint error");
-
-	std::ostringstream oss;
-	oss << std::hex << std::uppercase << std::setfill('0');
-	for (size_t i = 0; i < len; ++i) {
-		if (i)
-			oss << std::setw(1) << ':';
-		oss << std::setw(2) << unsigned(buffer[i]);
-	}
-	return oss.str();
+	return Certificate(x509, pkey);
 }
 
-namespace {
-
-certificate_ptr make_certificate_impl(CertificateType type) {
+Certificate Certificate::Generate(CertificateType type, const string &commonName) {
 	PLOG_DEBUG << "Generating certificate (OpenSSL)";
 
 	shared_ptr<X509> x509(X509_new(), X509_free);
@@ -318,7 +278,7 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 
 	const size_t serialSize = 16;
 	auto *commonNameBytes =
-	    reinterpret_cast<unsigned char *>(const_cast<char *>(COMMON_NAME.c_str()));
+	    reinterpret_cast<unsigned char *>(const_cast<char *>(commonName.c_str()));
 
 	if (!X509_set_pubkey(x509.get(), pkey.get()))
 		throw std::runtime_error("Unable to set certificate public key");
@@ -337,17 +297,43 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 	if (!X509_sign(x509.get(), pkey.get(), EVP_sha256()))
 		throw std::runtime_error("Unable to auto-sign certificate");
 
-	return std::make_shared<Certificate>(x509, pkey);
+	return Certificate(x509, pkey);
 }
 
-} // namespace
+Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
+    : mX509(std::move(x509)), mPKey(std::move(pkey)), mFingerprint(make_fingerprint(mX509.get())) {}
+
+string Certificate::fingerprint() const { return mFingerprint; }
+
+std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
+	return {mX509.get(), mPKey.get()};
+}
+
+string make_fingerprint(X509 *x509) {
+	const size_t size = 32;
+	unsigned char buffer[size];
+	unsigned int len = size;
+	if (!X509_digest(x509, EVP_sha256(), buffer, &len))
+		throw std::runtime_error("X509 fingerprint error");
+
+	std::ostringstream oss;
+	oss << std::hex << std::uppercase << std::setfill('0');
+	for (size_t i = 0; i < len; ++i) {
+		if (i)
+			oss << std::setw(1) << ':';
+		oss << std::setw(2) << unsigned(buffer[i]);
+	}
+	return oss.str();
+}
 
 #endif
 
 // Common for GnuTLS and OpenSSL
 
 future_certificate_ptr make_certificate(CertificateType type) {
-	return ThreadPool::Instance().enqueue(make_certificate_impl, type);
+	return ThreadPool::Instance().enqueue([type]() {
+		return std::make_shared<Certificate>(Certificate::Generate(type, "libdatachannel"));
+	});
 }
 
 } // namespace rtc::impl

+ 8 - 8
src/impl/certificate.hpp

@@ -32,7 +32,8 @@ class Certificate {
 public:
 	static Certificate FromString(string crt_pem, string key_pem);
 	static Certificate FromFile(const string &crt_pem_file, const string &key_pem_file,
-	                            const string &pass);
+	                            const string &pass = "");
+	static Certificate Generate(CertificateType type, const string &commonName);
 
 #if USE_GNUTLS
 	Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
@@ -45,20 +46,19 @@ public:
 	string fingerprint() const;
 
 private:
-	Certificate();
-	void computeFingerprint();
-
 #if USE_GNUTLS
-	shared_ptr<gnutls_certificate_credentials_t> mCredentials;
+	Certificate(shared_ptr<gnutls_certificate_credentials_t> creds);
+	const shared_ptr<gnutls_certificate_credentials_t> mCredentials;
 #else
-	shared_ptr<X509> mX509;
-	shared_ptr<EVP_PKEY> mPKey;
+	const shared_ptr<X509> mX509;
+	const shared_ptr<EVP_PKEY> mPKey;
 #endif
 
-	string mFingerprint;
+	const string mFingerprint;
 };
 
 #if USE_GNUTLS
+string make_fingerprint(gnutls_certificate_credentials_t credentials);
 string make_fingerprint(gnutls_x509_crt_t crt);
 #else
 string make_fingerprint(X509 *x509);

+ 2 - 2
src/impl/tlstransport.cpp

@@ -75,7 +75,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host
 		                                     certificate ? certificate->credentials()
 		                                                 : default_certificate_credentials()));
 
-		if (mHost) {
+		if (mIsClient && mHost) {
 			PLOG_VERBOSE << "Server Name Indication: " << *mHost;
 			gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost->data(), mHost->size());
 		}
@@ -306,7 +306,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host
 
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 
-		if (mHost) {
+		if (mIsClient && mHost) {
 			SSL_set_hostflags(mSsl, 0);
 			openssl::check(SSL_set1_host(mSsl, mHost->c_str()), "Failed to set SSL host");
 

+ 7 - 9
src/impl/websocket.cpp

@@ -97,6 +97,7 @@ void WebSocket::open(const string &url) {
 	if (string query = m[15]; !query.empty())
 		path += "?" + query;
 
+	mHostname = hostname; // for TLS SNI
 	std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
 
 	changeState(State::Connecting);
@@ -256,9 +257,7 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 			}
 		};
 
-		auto handshake = getWsHandshake();
-		auto host = handshake ? make_optional(handshake->host()) : nullopt;
-		bool verify = host.has_value() && !config.disableTlsVerification;
+		bool verify = mHostname.has_value() && !config.disableTlsVerification;
 
 #ifdef _WIN32
 		if (std::exchange(verify, false)) {
@@ -267,11 +266,11 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 #else
 		shared_ptr<TlsTransport> transport;
 		if (verify)
-			transport = std::make_shared<VerifiedTlsTransport>(lower, host.value(), mCertificate,
+			transport = std::make_shared<VerifiedTlsTransport>(lower, mHostname.value(), mCertificate,
 			                                                   stateChangeCallback);
 		else
 			transport =
-			    std::make_shared<TlsTransport>(lower, host, mCertificate, stateChangeCallback);
+			    std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);
 #endif
 
 		std::atomic_store(&mTlsTransport, transport);
@@ -311,7 +310,7 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 			lower = transport;
 		}
 
-		if(!atomic_load(&mWsHandshake))
+		if (!atomic_load(&mWsHandshake))
 			atomic_store(&mWsHandshake, std::make_shared<WsHandshake>());
 
 		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
@@ -339,9 +338,8 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 			}
 		};
 
-		auto transport = std::make_shared<WsTransport>(lower, mWsHandshake,
-		                                          weak_bind(&WebSocket::incoming, this, _1),
-		                                          stateChangeCallback);
+		auto transport = std::make_shared<WsTransport>(
+		    lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback);
 
 		std::atomic_store(&mWsTransport, transport);
 		if (state == WebSocket::State::Closed) {

+ 2 - 0
src/impl/websocket.hpp

@@ -80,6 +80,8 @@ private:
 	const certificate_ptr mCertificate;
 	bool mIsSecure;
 
+	optional<string> mHostname; // for TLS SNI
+
 	shared_ptr<TcpTransport> mTcpTransport;
 	shared_ptr<TlsTransport> mTlsTransport;
 	shared_ptr<WsTransport> mWsTransport;

+ 15 - 0
src/impl/websocketserver.cpp

@@ -31,6 +31,21 @@ WebSocketServer::WebSocketServer(Configuration config_)
     : config(std::move(config_)), tcpServer(std::make_unique<TcpServer>(config.port)),
       mStopped(false) {
 	PLOG_VERBOSE << "Creating WebSocketServer";
+
+	if (config.secure) {
+		if (config.certificatePemFile && config.keyPemFile) {
+			mCertificate = std::make_shared<Certificate>(Certificate::FromFile(
+			    *config.certificatePemFile, *config.keyPemFile, config.keyPemPass.value_or("")));
+
+		} else if (!config.certificatePemFile && !config.keyPemFile) {
+			mCertificate = std::make_shared<Certificate>(
+			    Certificate::Generate(CertificateType::Default, "localhost"));
+		} else {
+			throw std::invalid_argument(
+			    "Either none or both certificate and key PEM files must be specified");
+		}
+	}
+
 	mThread = std::thread(&WebSocketServer::runLoop, this);
 }
 

+ 5 - 2
test/websocketserver.cpp

@@ -32,12 +32,15 @@ using namespace std;
 template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 
 void test_websocketserver() {
-	InitLogger(LogLevel::Debug);
+	InitLogger(LogLevel::Verbose);
 
 	const string myMessage = "Hello world from client";
 
 	WebSocketServer::Configuration serverConfig;
 	serverConfig.port = 48080;
+	serverConfig.secure = true;
+	// serverConfig.certificatePemFile = ...
+	// serverConfig.keyPemFile = ...
 	WebSocketServer server(std::move(serverConfig));
 
 	shared_ptr<WebSocket> client;
@@ -87,7 +90,7 @@ void test_websocketserver() {
 		}
 	});
 
-	ws.open("ws://localhost:48080/");
+	ws.open("wss://localhost:48080/");
 
 	int attempts = 10;
 	while ((!ws.isOpen() || !received) && attempts--)