Browse Source

Added TLS transport for OpenSSL

Paul-Louis Ageneau 5 years ago
parent
commit
2ce7138ab5
5 changed files with 240 additions and 27 deletions
  1. 7 9
      src/dtlstransport.cpp
  2. 12 2
      src/init.cpp
  3. 1 1
      src/tcptransport.hpp
  4. 216 12
      src/tlstransport.cpp
  5. 4 3
      src/tlstransport.hpp

+ 7 - 9
src/dtlstransport.cpp

@@ -81,7 +81,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
 		const char *err_pos = NULL;
 		check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
-		             "Unable to set TLS priorities");
+		             "Failed to set TLS priorities");
 
 		gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
 		check_gnutls(
@@ -345,7 +345,7 @@ void DtlsTransport::Init() {
 	if (!BioMethods) {
 		BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer");
 		if (!BioMethods)
-			throw std::runtime_error("Unable to BIO methods for DTLS writer");
+			throw std::runtime_error("Failed to create BIO methods for DTLS writer");
 		BIO_meth_set_create(BioMethods, BioMethodNew);
 		BIO_meth_set_destroy(BioMethods, BioMethodFree);
 		BIO_meth_set_write(BioMethods, BioMethodWrite);
@@ -370,10 +370,10 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 
 	try {
 		if (!(mCtx = SSL_CTX_new(DTLS_method())))
-			throw std::runtime_error("Unable to create SSL context");
+			throw std::runtime_error("Failed to create SSL context");
 
 		check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
-		              "Unable to set SSL priorities");
+		              "Failed to set SSL priorities");
 
 		// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
 		// Therefore, the DTLS layer MUST NOT use any compression algorithm.
@@ -394,7 +394,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 		check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
 
 		if (!(mSsl = SSL_new(mCtx)))
-			throw std::runtime_error("Unable to create SSL instance");
+			throw std::runtime_error("Failed to create SSL instance");
 
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 
@@ -404,7 +404,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 			SSL_set_accept_state(mSsl);
 
 		if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BioMethods)))
-			throw std::runtime_error("Unable to create BIO");
+			throw std::runtime_error("Failed to create BIO");
 
 		BIO_set_mem_eof_return(mInBio, BIO_EOF);
 		BIO_set_data(mOutBio, this);
@@ -454,9 +454,7 @@ bool DtlsTransport::send(message_ptr message) {
 	PLOG_VERBOSE << "Send size=" << message->size();
 
 	int ret = SSL_write(mSsl, message->data(), message->size());
-	if (!check_openssl_ret(mSsl, ret))
-		return false;
-	return true;
+	return check_openssl_ret(mSsl, ret);
 }
 
 void DtlsTransport::incoming(message_ptr message) {

+ 12 - 2
src/init.cpp

@@ -21,6 +21,10 @@
 #include "dtlstransport.hpp"
 #include "sctptransport.hpp"
 
+#if RTC_ENABLE_WEBSOCKET
+#include "tlstransport.hpp"
+#endif
+
 #ifdef _WIN32
 #include <winsock2.h>
 #endif
@@ -69,13 +73,19 @@ Init::Init() {
 	ERR_load_crypto_strings();
 #endif
 
-	DtlsTransport::Init();
 	SctpTransport::Init();
+	DtlsTransport::Init();
+#if RTC_ENABLE_WEBSOCKET
+	TlsTransport::Cleanup();
+#endif
 }
 
 Init::~Init() {
-	DtlsTransport::Cleanup();
 	SctpTransport::Cleanup();
+	DtlsTransport::Cleanup();
+#if RTC_ENABLE_WEBSOCKET
+	TlsTransport::Cleanup();
+#endif
 
 #ifdef _WIN32
 	WSACleanup();

+ 1 - 1
src/tcptransport.hpp

@@ -35,7 +35,7 @@ namespace rtc {
 class TcpTransport : public Transport {
 public:
 	TcpTransport(const string &hostname, const string &service);
-	virtual ~TcpTransport();
+	~TcpTransport();
 
 	bool stop() override;
 	bool send(message_ptr message) override;

+ 216 - 12
src/tlstransport.cpp

@@ -16,6 +16,8 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
+#if ENABLE_WEBSOCKET
+
 #include "tlstransport.hpp"
 #include "tcptransport.hpp"
 
@@ -51,8 +53,15 @@ static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
 
 namespace rtc {
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
-    : Transport(lower), mHost(host) {
+void TlsTransport::Init() {
+	// Nothing to do
+}
+
+void TlsTransport::Cleanup() {
+	// Nothing to do
+}
+
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
 
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 
@@ -62,7 +71,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
 		const char *err_pos = NULL;
 		check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
-		             "Unable to set TLS priorities");
+		             "Failed to set TLS priorities");
 
 		gnutls_session_set_ptr(mSession, this);
 		gnutls_transport_set_ptr(mSession, this);
@@ -72,7 +81,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
 
 		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
 
-		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
 
 	} catch (...) {
 
@@ -81,7 +90,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
 	}
 }
 
-TlsTransport::~DtlsTransport() {
+TlsTransport::~TlsTransport() {
 	stop();
 	gnutls_deinit(mSession);
 }
@@ -96,7 +105,7 @@ bool DtlsTransport::stop() {
 	return true;
 }
 
-bool DtlsTransport::send(message_ptr message) {
+bool TlsTransport::send(message_ptr message) {
 	if (!message)
 		return false;
 
@@ -108,7 +117,7 @@ bool DtlsTransport::send(message_ptr message) {
 	return check_gnutls(ret);
 }
 
-void DtlsTransport::incoming(message_ptr message) {
+void TlsTransport::incoming(message_ptr message) {
 	if (message)
 		mIncomingQueue.push(message);
 	else
@@ -128,7 +137,6 @@ void TlsTransport::runRecvLoop() {
 
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TLS handshake: " << e.what();
-		changeState(State::Failed);
 		return;
 	}
 
@@ -169,7 +177,7 @@ void TlsTransport::runRecvLoop() {
 }
 
 ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
-	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
+	TlsTransport *t = static_cast<TlsTransport *>(ptr);
 	if (len > 0) {
 		auto b = reinterpret_cast<const byte *>(data);
 		t->outgoing(make_message(b, b + len));
@@ -179,7 +187,7 @@ ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data
 }
 
 ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
-	TlsTransport *t = static_cast<DtlsTransport *>(ptr);
+	TlsTransport *t = static_cast<TlsTransport *>(ptr);
 	if (auto next = t->mIncomingQueue.pop()) {
 		auto message = *next;
 		ssize_t len = std::min(maxlen, message->size());
@@ -193,7 +201,7 @@ ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_
 }
 
 int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
-	TlsTransport *t = static_cast<DtlsTransport *>(ptr);
+	TlsTransport *t = static_cast<TlsTransport *>(ptr);
 	if (ms != GNUTLS_INDEFINITE_TIMEOUT)
 		t->mIncomingQueue.wait(milliseconds(ms));
 	else
@@ -204,6 +212,202 @@ int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 } // namespace rtc
 
 #else // USE_GNUTLS==0
-// TODO
+
+#include <openssl/bio.h>
+#include <openssl/ec.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+namespace {
+
+const int BIO_EOF = -1;
+
+string openssl_error_string(unsigned long err) {
+	const size_t bufferSize = 256;
+	char buffer[bufferSize];
+	ERR_error_string_n(err, buffer, bufferSize);
+	return string(buffer);
+}
+
+bool check_openssl(int success, const string &message = "OpenSSL error") {
+	if (success)
+		return true;
+
+	string str = openssl_error_string(ERR_get_error());
+	PLOG_ERROR << message << ": " << str;
+	throw std::runtime_error(message + ": " + str);
+}
+
+bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
+	if (ret == BIO_EOF)
+		return true;
+
+	unsigned long err = SSL_get_error(ssl, ret);
+	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
+		return true;
+	}
+	if (err == SSL_ERROR_ZERO_RETURN) {
+		PLOG_DEBUG << "TLS connection cleanly closed";
+		return false;
+	}
+	string str = openssl_error_string(err);
+	PLOG_ERROR << str;
+	throw std::runtime_error(message + ": " + str);
+}
+
+} // namespace
+
+namespace rtc {
+
+int TlsTransport::TransportExIndex = -1;
+
+void TlsTransport::Init() {
+	if (TransportExIndex < 0) {
+		TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
+	}
+}
+
+void TlsTransport::Cleanup() {
+	// Nothing to do
+}
+
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
+
+	PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
+	GlobalInit();
+
+	if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
+		throw std::runtime_error("Failed to create SSL context");
+
+	check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
+	              "Failed to set SSL priorities");
+
+	SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
+	SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
+	SSL_CTX_set_read_ahead(mCtx, 1);
+	SSL_CTX_set_quiet_shutdown(mCtx, 1);
+	SSL_CTX_set_info_callback(mCtx, InfoCallback);
+
+	SSL_CTX_set_default_verify_paths(mCtx);
+	SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL);
+	SSL_CTX_set_verify_depth(mCtx, 4);
+
+	if (!(mSsl = SSL_new(mCtx)))
+		throw std::runtime_error("Failed to create SSL instance");
+
+	SSL_set_ex_data(mSsl, TransportExIndex, this);
+	SSL_set_tlsext_host_name(mSsl, host.c_str());
+
+	SSL_set_connect_state(mSsl);
+
+	if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(Bio_s_mem())))
+		throw std::runtime_error("Failed to create BIO");
+
+	BIO_set_mem_eof_return(mInBio, BIO_EOF);
+	BIO_set_mem_eof_return(mOutBio, BIO_EOF);
+	SSL_set_bio(mSsl, mInBio, mOutBio);
+
+	auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
+	    EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
+	SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
+	SSL_set_tmp_ecdh(mSsl, ecdh.get());
+
+	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+}
+
+TlsTransport::~TlsTransport() {
+	stop();
+
+	SSL_free(mSsl);
+	SSL_CTX_free(mCtx);
+}
+
+bool TlsTransport::stop() {
+	if (!Transport::stop())
+		return false;
+
+	PLOG_DEBUG << "Stopping TLS recv thread";
+	mIncomingQueue.stop();
+	mRecvThread.join();
+	SSL_shutdown(mSsl);
+	return true;
+}
+
+bool TlsTransport::send(message_ptr message) {
+	if (!message)
+		return false;
+
+	int ret = SSL_write(mSsl, message->data(), message->size());
+	if(!check_openssl_ret(mSsl, ret)
+			return false;
+
+	while (int len = BIO_read(mOutBio, buffer, bufferSize); len > 0)
+		outgoing(make_message(buffer, buffer + len));
+
+	return true;
+}
+
+void TlsTransport::incoming(message_ptr message) {
+	if (message)
+		mIncomingQueue.push(message);
+	else
+		mIncomingQueue.stop();
+}
+
+void TlsTransport::runRecvLoop() {
+	const size_t bufferSize = 4096;
+
+	byte buffer[bufferSize];
+	bool initFinished = false;
+	try {
+		SSL_do_handshake(mSsl);
+		while (int len = BIO_read(mOutBio, buffer, bufferSize); len > 0)
+			outgoing(make_message(buffer, buffer + len));
+
+		while (auto next = mIncomingQueue.pop()) {
+			auto message = *next;
+			BIO_write(mInBio, message->data(), message->size());
+			int ret = SSL_read(mSsl, buffer, bufferSize);
+			if (!check_openssl_ret(mSsl, ret))
+				break;
+
+			auto received = ret > 0 ? make_message(buffer, buffer + ret) : nullptr;
+
+			while (int len = BIO_read(mOutBio, buffer, bufferSize); len > 0)
+				outgoing(make_message(buffer, buffer + len));
+
+			if (!initFinished && SSL_is_init_finished(mSsl))
+				initFinished = true;
+
+			if (received)
+				recv(received);
+		}
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TLS recv: " << e.what();
+	}
+
+	if (initFinished) {
+		PLOG_INFO << "TLS disconnected";
+		recv(nullptr);
+	} else {
+		PLOG_ERROR << "TLS handshake failed";
+	}
+}
+
+void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
+	TlsTransport *t =
+	    static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));
+
+	if (where & SSL_CB_ALERT) {
+		if (ret != 256) // Close Notify
+			PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret);
+
+		t->mIncomingQueue.stop(); // Close the connection
+	}
+}
+
+} // namespace rtc
+
 #endif
 
+#endif

+ 4 - 3
src/tlstransport.hpp

@@ -41,14 +41,13 @@ class TcpTransport;
 
 class TlsTransport : public Transport {
 public:
-	TlsTransport(std::shared_ptr<TcpTransport> lower, const string &host);
-	virtual ~TlsTransport();
+	TlsTransport(std::shared_ptr<TcpTransport> lower, string host);
+	~TlsTransport();
 
 	bool stop() override;
 	bool send(message_ptr message) override;
 
 	void incoming(message_ptr message) override;
-	bool outgoing(message_ptr message) override;
 
 protected:
 	void runRecvLoop();
@@ -67,6 +66,8 @@ protected:
 	SSL *mSsl;
 	BIO *mInBio, *mOutBio;
 
+	static int TransportExIndex;
+
 	static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
 	static void InfoCallback(const SSL *ssl, int where, int ret);
 #endif