Browse Source

Catch exceptions in all GnuTLS callbacks for safety

Paul-Louis Ageneau 4 years ago
parent
commit
c112ae77c2
2 changed files with 114 additions and 69 deletions
  1. 63 37
      src/impl/dtlstransport.cpp
  2. 51 32
      src/impl/tlstransport.cpp

+ 63 - 37
src/impl/dtlstransport.cpp

@@ -17,8 +17,8 @@
  */
 
 #include "dtlstransport.hpp"
-#include "internals.hpp"
 #include "icetransport.hpp"
+#include "internals.hpp"
 
 #include <chrono>
 #include <cstring>
@@ -54,7 +54,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
-	if(!mCertificate)
+	if (!mCertificate)
 		throw std::invalid_argument("DTLS certificate is null");
 
 	gnutls_certificate_credentials_t creds = mCertificate->credentials();
@@ -247,61 +247,87 @@ void DtlsTransport::runRecvLoop() {
 
 int DtlsTransport::CertificateCallback(gnutls_session_t session) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(gnutls_session_get_ptr(session));
+	try {
+		if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
+			return GNUTLS_E_CERTIFICATE_ERROR;
+		}
 
-	if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
-		return GNUTLS_E_CERTIFICATE_ERROR;
-	}
+		unsigned int count = 0;
+		const gnutls_datum_t *array = gnutls_certificate_get_peers(session, &count);
+		if (!array || count == 0) {
+			return GNUTLS_E_CERTIFICATE_ERROR;
+		}
 
-	unsigned int count = 0;
-	const gnutls_datum_t *array = gnutls_certificate_get_peers(session, &count);
-	if (!array || count == 0) {
-		return GNUTLS_E_CERTIFICATE_ERROR;
-	}
+		gnutls_x509_crt_t crt;
+		gnutls::check(gnutls_x509_crt_init(&crt));
+		int ret = gnutls_x509_crt_import(crt, &array[0], GNUTLS_X509_FMT_DER);
+		if (ret != GNUTLS_E_SUCCESS) {
+			gnutls_x509_crt_deinit(crt);
+			return GNUTLS_E_CERTIFICATE_ERROR;
+		}
 
-	gnutls_x509_crt_t crt;
-	gnutls::check(gnutls_x509_crt_init(&crt));
-	int ret = gnutls_x509_crt_import(crt, &array[0], GNUTLS_X509_FMT_DER);
-	if (ret != GNUTLS_E_SUCCESS) {
+		string fingerprint = make_fingerprint(crt);
 		gnutls_x509_crt_deinit(crt);
-		return GNUTLS_E_CERTIFICATE_ERROR;
-	}
 
-	string fingerprint = make_fingerprint(crt);
-	gnutls_x509_crt_deinit(crt);
+		bool success = t->mVerifierCallback(fingerprint);
+		return success ? GNUTLS_E_SUCCESS : GNUTLS_E_CERTIFICATE_ERROR;
 
-	bool success = t->mVerifierCallback(fingerprint);
-	return success ? GNUTLS_E_SUCCESS : GNUTLS_E_CERTIFICATE_ERROR;
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return GNUTLS_E_CERTIFICATE_ERROR;
+	}
 }
 
 ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
-	if (len > 0) {
-		auto b = reinterpret_cast<const byte *>(data);
-		t->outgoing(make_message(b, b + len));
+	try {
+		if (len > 0) {
+			auto b = reinterpret_cast<const byte *>(data);
+			t->outgoing(make_message(b, b + len));
+		}
+		gnutls_transport_set_errno(t->mSession, 0);
+		return ssize_t(len);
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		gnutls_transport_set_errno(t->mSession, ECONNRESET);
+		return -1;
 	}
-	gnutls_transport_set_errno(t->mSession, 0);
-	return ssize_t(len);
 }
 
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
-	if (auto next = t->mIncomingQueue.pop()) {
-		message_ptr message = std::move(*next);
-		ssize_t len = std::min(maxlen, message->size());
-		std::memcpy(data, message->data(), len);
+	try {
+		if (auto next = t->mIncomingQueue.pop()) {
+			message_ptr message = std::move(*next);
+			ssize_t len = std::min(maxlen, message->size());
+			std::memcpy(data, message->data(), len);
+			gnutls_transport_set_errno(t->mSession, 0);
+			return len;
+		}
+
+		// Closed
 		gnutls_transport_set_errno(t->mSession, 0);
-		return len;
+		return 0;
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		gnutls_transport_set_errno(t->mSession, ECONNRESET);
+		return -1;
 	}
-	// Closed
-	gnutls_transport_set_errno(t->mSession, 0);
-	return 0;
 }
 
 int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
-	bool notEmpty = t->mIncomingQueue.wait(
-	    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
-	return notEmpty ? 1 : 0;
+	try {
+		bool notEmpty = t->mIncomingQueue.wait(
+		    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
+		return notEmpty ? 1 : 0;
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return 1;
+	}
 }
 
 #else // USE_GNUTLS==0
@@ -341,7 +367,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
       mIsClient(lower->role() == Description::Role::Active), mCurrentDscp(0) {
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
-	if(!mCertificate)
+	if (!mCertificate)
 		throw std::invalid_argument("DTLS certificate is null");
 
 	try {

+ 51 - 32
src/impl/tlstransport.cpp

@@ -209,53 +209,72 @@ void TlsTransport::runRecvLoop() {
 
 ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
-	if (len > 0) {
-		auto b = reinterpret_cast<const byte *>(data);
-		t->outgoing(make_message(b, b + len));
+	try {
+		if (len > 0) {
+			auto b = reinterpret_cast<const byte *>(data);
+			t->outgoing(make_message(b, b + len));
+		}
+		gnutls_transport_set_errno(t->mSession, 0);
+		return ssize_t(len);
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		gnutls_transport_set_errno(t->mSession, ECONNRESET);
+		return -1;
 	}
-	gnutls_transport_set_errno(t->mSession, 0);
-	return ssize_t(len);
 }
 
 ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
+	try {
+		message_ptr &message = t->mIncomingMessage;
+		size_t &position = t->mIncomingMessagePosition;
 
-	message_ptr &message = t->mIncomingMessage;
-	size_t &position = t->mIncomingMessagePosition;
+		if (message && position >= message->size())
+			message.reset();
 
-	if (message && position >= message->size())
-		message.reset();
+		if (!message) {
+			position = 0;
+			while (auto next = t->mIncomingQueue.pop()) {
+				message = *next;
+				if (message->size() > 0)
+					break;
+				else
+					t->recv(message); // Pass zero-sized messages through
+			}
+		}
 
-	if (!message) {
-		position = 0;
-		while (auto next = t->mIncomingQueue.pop()) {
-			message = *next;
-			if (message->size() > 0)
-				break;
-			else
-				t->recv(message); // Pass zero-sized messages through
+		if (message) {
+			size_t available = message->size() - position;
+			ssize_t len = std::min(maxlen, available);
+			std::memcpy(data, message->data() + position, len);
+			position += len;
+			gnutls_transport_set_errno(t->mSession, 0);
+			return len;
+		} else {
+			// Closed
+			gnutls_transport_set_errno(t->mSession, 0);
+			return 0;
 		}
-	}
 
-	if (message) {
-		size_t available = message->size() - position;
-		ssize_t len = std::min(maxlen, available);
-		std::memcpy(data, message->data() + position, len);
-		position += len;
-		gnutls_transport_set_errno(t->mSession, 0);
-		return len;
-	} else {
-		// Closed
-		gnutls_transport_set_errno(t->mSession, 0);
-		return 0;
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		gnutls_transport_set_errno(t->mSession, ECONNRESET);
+		return -1;
 	}
 }
 
 int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
-	bool notEmpty = t->mIncomingQueue.wait(
-	    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
-	return notEmpty ? 1 : 0;
+	try {
+		bool notEmpty = t->mIncomingQueue.wait(
+		    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
+		return notEmpty ? 1 : 0;
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return 1;
+	}
 }
 
 #else // USE_GNUTLS==0