Browse Source

Added TLS transport for GnuTLS

Paul-Louis Ageneau 5 years ago
parent
commit
bc9df9ba72
3 changed files with 289 additions and 1 deletions
  1. 1 1
      src/tcptransport.hpp
  2. 209 0
      src/tlstransport.cpp
  3. 79 0
      src/tlstransport.hpp

+ 1 - 1
src/tcptransport.hpp

@@ -19,7 +19,7 @@
 #ifndef RTC_TCP_TRANSPORT_H
 #define RTC_TCP_TRANSPORT_H
 
-#ifdef ENABLE_WEBSOCKET
+#if ENABLE_WEBSOCKET
 
 #include "include.hpp"
 #include "queue.hpp"

+ 209 - 0
src/tlstransport.cpp

@@ -0,0 +1,209 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "tlstransport.hpp"
+#include "tcptransport.hpp"
+
+#include <chrono>
+#include <cstring>
+#include <exception>
+#include <iostream>
+
+using namespace std::chrono;
+
+using std::shared_ptr;
+using std::string;
+using std::unique_ptr;
+using std::weak_ptr;
+
+#if USE_GNUTLS
+
+namespace {
+
+static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
+	if (ret < 0) {
+		if (!gnutls_error_is_fatal(ret)) {
+			PLOG_INFO << gnutls_strerror(ret);
+			return false;
+		}
+		PLOG_ERROR << message << ": " << gnutls_strerror(ret);
+		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
+	}
+	return true;
+}
+
+} // namespace
+
+namespace rtc {
+
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
+    : Transport(lower), mHost(host) {
+
+	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
+
+	check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT));
+
+	try {
+		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
+		const char *err_pos = NULL;
+		check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
+		             "Unable to set TLS priorities");
+
+		gnutls_session_set_ptr(mSession, this);
+		gnutls_transport_set_ptr(mSession, this);
+		gnutls_transport_set_push_function(mSession, WriteCallback);
+		gnutls_transport_set_pull_function(mSession, ReadCallback);
+		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
+
+		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
+
+		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+
+	} catch (...) {
+
+		gnutls_deinit(mSession);
+		throw;
+	}
+}
+
+TlsTransport::~DtlsTransport() {
+	stop();
+	gnutls_deinit(mSession);
+}
+
+bool DtlsTransport::stop() {
+	if (!Transport::stop())
+		return false;
+
+	PLOG_DEBUG << "Stopping TLS recv thread";
+	mIncomingQueue.stop();
+	mRecvThread.join();
+	return true;
+}
+
+bool DtlsTransport::send(message_ptr message) {
+	if (!message)
+		return false;
+
+	ssize_t ret;
+	do {
+		ret = gnutls_record_send(mSession, message->data(), message->size());
+	} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
+
+	return check_gnutls(ret);
+}
+
+void DtlsTransport::incoming(message_ptr message) {
+	if (message)
+		mIncomingQueue.push(message);
+	else
+		mIncomingQueue.stop();
+}
+
+void TlsTransport::runRecvLoop() {
+	const size_t bufferSize = 4096;
+
+	// Handshake loop
+	try {
+		int ret;
+		do {
+			ret = gnutls_handshake(mSession);
+		} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
+		         !check_gnutls(ret, "TLS handshake failed"));
+
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TLS handshake: " << e.what();
+		changeState(State::Failed);
+		return;
+	}
+
+	// Receive loop
+	try {
+		while (true) {
+			char buffer[bufferSize];
+			ssize_t ret;
+			do {
+				ret = gnutls_record_recv(mSession, buffer, bufferSize);
+			} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
+
+			// Consider premature termination as remote closing
+			if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
+				PLOG_DEBUG << "TLS connection terminated";
+				break;
+			}
+
+			if (check_gnutls(ret)) {
+				if (ret == 0) {
+					// Closed
+					PLOG_DEBUG << "TLS connection cleanly closed";
+					break;
+				}
+				auto *b = reinterpret_cast<byte *>(buffer);
+				recv(make_message(b, b + ret));
+			}
+		}
+
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TLS recv: " << e.what();
+	}
+
+	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
+
+	PLOG_INFO << "TLS disconnected";
+	recv(nullptr);
+}
+
+ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
+	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
+	if (len > 0) {
+		auto b = reinterpret_cast<const byte *>(data);
+		t->outgoing(make_message(b, b + len));
+	}
+	gnutls_transport_set_errno(t->mSession, 0);
+	return ssize_t(len);
+}
+
+ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
+	TlsTransport *t = static_cast<DtlsTransport *>(ptr);
+	if (auto next = t->mIncomingQueue.pop()) {
+		auto message = *next;
+		ssize_t len = std::min(maxlen, message->size());
+		std::memcpy(data, message->data(), len);
+		gnutls_transport_set_errno(t->mSession, 0);
+		return len;
+	}
+	// Closed
+	gnutls_transport_set_errno(t->mSession, 0);
+	return 0;
+}
+
+int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
+	TlsTransport *t = static_cast<DtlsTransport *>(ptr);
+	if (ms != GNUTLS_INDEFINITE_TIMEOUT)
+		t->mIncomingQueue.wait(milliseconds(ms));
+	else
+		t->mIncomingQueue.wait();
+	return !t->mIncomingQueue.empty() ? 1 : 0;
+}
+
+} // namespace rtc
+
+#else // USE_GNUTLS==0
+// TODO
+#endif
+

+ 79 - 0
src/tlstransport.hpp

@@ -0,0 +1,79 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_TLS_TRANSPORT_H
+#define RTC_TLS_TRANSPORT_H
+
+#if ENABLE_WEBSOCKET
+
+#include "include.hpp"
+#include "queue.hpp"
+#include "transport.hpp"
+
+#include <memory>
+#include <mutex>
+#include <thread>
+
+#if USE_GNUTLS
+#include <gnutls/gnutls.h>
+#else
+#include <openssl/ssl.h>
+#endif
+
+namespace rtc {
+
+class TcpTransport;
+
+class TlsTransport : public Transport {
+public:
+	TlsTransport(std::shared_ptr<TcpTransport> lower, const string &host);
+	virtual ~TlsTransport();
+
+	bool stop() override;
+	bool send(message_ptr message) override;
+
+	void incoming(message_ptr message) override;
+	bool outgoing(message_ptr message) override;
+
+protected:
+	void runRecvLoop();
+
+	Queue<message_ptr> mIncomingQueue;
+	std::thread mRecvThread;
+
+#if USE_GNUTLS
+	gnutls_session_t mSession;
+
+	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
+	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
+	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
+#else
+	SSL_CTX *mCtx;
+	SSL *mSsl;
+	BIO *mInBio, *mOutBio;
+
+	static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
+	static void InfoCallback(const SSL *ssl, int where, int ret);
+#endif
+};
+
+} // namespace rtc
+
+#endif
+
+#endif