Browse Source

Introduced VerifiedTlsTransport

Paul-Louis Ageneau 5 years ago
parent
commit
907e8273c8

+ 1 - 0
CMakeLists.txt

@@ -59,6 +59,7 @@ set(LIBDATACHANNEL_WEBSOCKET_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/verifiedtlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp
 )

+ 6 - 1
include/rtc/websocket.hpp

@@ -47,7 +47,11 @@ public:
 		Closed = 3,
 	};
 
-	WebSocket();
+	struct Configuration {
+		bool disableTlsVerification = false;
+	};
+
+	WebSocket(std::optional<Configuration> config = nullopt);
 	~WebSocket();
 
 	State readyState() const;
@@ -82,6 +86,7 @@ private:
 	std::shared_ptr<WsTransport> mWsTransport;
 	std::recursive_mutex mInitMutex;
 
+	const Configuration mConfig;
 	string mScheme, mHost, mHostname, mService, mPath;
 	std::atomic<State> mState = State::Closed;
 

+ 28 - 16
src/tlstransport.cpp

@@ -56,7 +56,6 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 	try {
 		gnutls::check(gnutls_certificate_set_x509_system_trust(mCreds));
 		gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds));
-		gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
 
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
 		const char *err_pos = NULL;
@@ -72,7 +71,9 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		gnutls_transport_set_pull_function(mSession, ReadCallback);
 		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 
-       	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+		postCreation();
+
+		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
 		registerIncoming();
 
 	} catch (...) {
@@ -123,6 +124,14 @@ void TlsTransport::incoming(message_ptr message) {
 		mIncomingQueue.stop();
 }
 
+void TlsTransport::postCreation() {
+	// Dummy
+}
+
+void TlsTransport::postHandshake() {
+	// Dummy
+}
+
 void TlsTransport::runRecvLoop() {
 	const size_t bufferSize = 4096;
 	char buffer[bufferSize];
@@ -147,6 +156,7 @@ void TlsTransport::runRecvLoop() {
 	try {
 		PLOG_INFO << "TLS handshake finished";
 		changeState(State::Connected);
+		postHandshake();
 
 		while (true) {
 			ssize_t ret;
@@ -263,25 +273,16 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
 		               "Failed to set SSL priorities");
 
+		if (!SSL_CTX_set_default_verify_paths(mCtx)) {
+			PLOG_WARNING << "SSL root CA certificates unavailable";
+		}
+
 		SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
 		SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
 		SSL_CTX_set_read_ahead(mCtx, 1);
 		SSL_CTX_set_quiet_shutdown(mCtx, 1);
 		SSL_CTX_set_info_callback(mCtx, InfoCallback);
-
-		// SSL_CTX_set_default_verify_paths() does nothing on Windows
-#ifndef _WIN32
-		if (SSL_CTX_set_default_verify_paths(mCtx)) {
-#else
-		if (false) {
-#endif
-			PLOG_INFO << "SSL root CA certificates available, server verification enabled";
-			SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL);
-			SSL_CTX_set_verify_depth(mCtx, 4);
-		} else {
-			PLOG_WARNING << "SSL root CA certificates unavailable, server verification disabled";
-			SSL_CTX_set_verify(mCtx, SSL_VERIFY_NONE, NULL);
-		}
+		SSL_CTX_set_verify(mCtx, SSL_VERIFY_NONE, NULL);
 
 		if (!(mSsl = SSL_new(mCtx)))
 			throw std::runtime_error("Failed to create SSL instance");
@@ -308,6 +309,8 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
 		SSL_set_tmp_ecdh(mSsl, ecdh.get());
 
+		postCreation();
+
 		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
 		registerIncoming();
 
@@ -366,6 +369,14 @@ void TlsTransport::incoming(message_ptr message) {
 		mIncomingQueue.stop();
 }
 
+void TlsTransport::postCreation() {
+	// Dummy
+}
+
+void TlsTransport::postHandshake() {
+	// Dummy
+}
+
 void TlsTransport::runRecvLoop() {
 	const size_t bufferSize = 4096;
 	byte buffer[bufferSize];
@@ -387,6 +398,7 @@ void TlsTransport::runRecvLoop() {
 				if (SSL_is_init_finished(mSsl)) {
 					PLOG_INFO << "TLS handshake finished";
 					changeState(State::Connected);
+					postHandshake();
 				}
 			} else {
 				int ret = SSL_read(mSsl, buffer, bufferSize);

+ 4 - 3
src/tlstransport.hpp

@@ -38,14 +38,15 @@ public:
 	static void Cleanup();
 
 	TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
-	~TlsTransport();
+	virtual ~TlsTransport();
 
 	bool stop() override;
 	bool send(message_ptr message) override;
 
-	void incoming(message_ptr message) override;
-
 protected:
+	virtual void incoming(message_ptr message) override;
+	virtual void postCreation();
+	virtual void postHandshake();
 	void runRecvLoop();
 
 	string mHost;

+ 68 - 0
src/verifiedtlstransport.cpp

@@ -0,0 +1,68 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "verifiedtlstransport.hpp"
+#include "include.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+using std::shared_ptr;
+using std::string;
+using std::unique_ptr;
+using std::weak_ptr;
+
+namespace rtc {
+
+#if USE_GNUTLS
+
+VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host,
+                                           state_callback callback)
+    : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {}
+
+VerifiedTlsTransport::~VerifiedTlsTransport() {}
+
+void VerifiedTlsTransport::postCreation() {
+	gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
+}
+
+void VerifiedTlsTransport::postHandshake() {
+	// Nothing to do
+}
+
+#else // USE_GNUTLS==0
+
+VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host,
+                                           state_callback callback)
+    : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {}
+
+VerifiedTlsTransport::~VerifiedTlsTransport() {}
+
+void VerifiedTlsTransport::postCreation() {
+	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);
+	SSL_set_verify_depth(mSsl, 4);
+}
+
+void VerifiedTlsTransport::postHandshake() {
+	// Nothing to do
+}
+
+#endif
+
+} // namespace rtc
+
+#endif

+ 42 - 0
src/verifiedtlstransport.hpp

@@ -0,0 +1,42 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_VERIFIED_TLS_TRANSPORT_H
+#define RTC_VERIFIED_TLS_TRANSPORT_H
+
+#include "tlstransport.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+namespace rtc {
+
+class VerifiedTlsTransport final : public TlsTransport {
+public:
+	VerifiedTlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
+	~VerifiedTlsTransport();
+
+protected:
+	void postCreation() override;
+	void postHandshake() override;
+};
+
+} // namespace rtc
+
+#endif
+
+#endif

+ 45 - 26
src/websocket.cpp

@@ -24,6 +24,7 @@
 
 #include "tcptransport.hpp"
 #include "tlstransport.hpp"
+#include "verifiedtlstransport.hpp"
 #include "wstransport.hpp"
 
 #include <regex>
@@ -34,7 +35,12 @@
 
 namespace rtc {
 
-WebSocket::WebSocket() { PLOG_VERBOSE << "Creating WebSocket"; }
+using std::shared_ptr;
+
+WebSocket::WebSocket(std::optional<Configuration> config)
+    : mConfig(config ? std::move(*config) : Configuration()) {
+	PLOG_VERBOSE << "Creating WebSocket";
+}
 
 WebSocket::~WebSocket() {
 	PLOG_VERBOSE << "Destroying WebSocket";
@@ -149,7 +155,7 @@ void WebSocket::incoming(message_ptr message) {
 	}
 }
 
-std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
+shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
 	using State = TcpTransport::State;
 	try {
 		std::lock_guard lock(mInitMutex);
@@ -194,7 +200,7 @@ std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
 	}
 }
 
-std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
+shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 	using State = TlsTransport::State;
 	try {
 		std::lock_guard lock(mInitMutex);
@@ -202,27 +208,40 @@ std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 			return transport;
 
 		auto lower = std::atomic_load(&mTcpTransport);
-		auto transport = std::make_shared<TlsTransport>(
-		    lower, mHost, [this, weak_this = weak_from_this()](State state) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (state) {
-			    case State::Connected:
-				    initWsTransport();
-				    break;
-			    case State::Failed:
-				    triggerError("TCP connection failed");
-				    remoteClose();
-				    break;
-			    case State::Disconnected:
-				    remoteClose();
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
-		    });
+		auto stateChangeCallback = [this, weak_this = weak_from_this()](State state) {
+			auto shared_this = weak_this.lock();
+			if (!shared_this)
+				return;
+			switch (state) {
+			case State::Connected:
+				initWsTransport();
+				break;
+			case State::Failed:
+				triggerError("TCP connection failed");
+				remoteClose();
+				break;
+			case State::Disconnected:
+				remoteClose();
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		};
+
+		shared_ptr<TlsTransport> transport;
+#ifdef _WIN32
+		if (!mConfig.disableTlsVerification) {
+			PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows";
+		}
+		transport = std::make_shared<TlsTransport>(lower, mHost, stateChangeCallback);
+#else
+		if (mConfig.disableTlsVerification)
+			transport = std::make_shared<TlsTransport>(lower, mHost, stateChangeCallback);
+		else
+			transport = std::make_shared<VerifiedTlsTransport>(lower, mHost, stateChangeCallback);
+#endif
+
 		std::atomic_store(&mTlsTransport, transport);
 		if (mState == WebSocket::State::Closed) {
 			mTlsTransport.reset();
@@ -237,14 +256,14 @@ std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 	}
 }
 
-std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
+shared_ptr<WsTransport> WebSocket::initWsTransport() {
 	using State = WsTransport::State;
 	try {
 		std::lock_guard lock(mInitMutex);
 		if (auto transport = std::atomic_load(&mWsTransport))
 			return transport;
 
-		std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
+		shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
 		if (!lower)
 			lower = std::atomic_load(&mTcpTransport);
 		auto transport = std::make_shared<WsTransport>(