Browse Source

Moved OpenSSL DTLS and TLS transport implementations to thread pool

Paul-Louis Ageneau 2 years ago
parent
commit
3a823cb476
4 changed files with 169 additions and 130 deletions
  1. 106 89
      src/impl/dtlstransport.cpp
  2. 2 0
      src/impl/dtlstransport.hpp
  3. 52 41
      src/impl/tlstransport.cpp
  4. 9 0
      src/impl/tlstransport.hpp

+ 106 - 89
src/impl/dtlstransport.cpp

@@ -92,10 +92,6 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		gnutls_transport_set_pull_function(mSession, ReadCallback);
 		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 
-		size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
-		gnutls_dtls_set_mtu(mSession, static_cast<unsigned int>(mtu));
-		PLOG_VERBOSE << "DTLS MTU set to " << mtu;
-
 	} catch (...) {
 		gnutls_deinit(mSession);
 		throw;
@@ -117,6 +113,11 @@ void DtlsTransport::start() {
 	PLOG_DEBUG << "Starting DTLS transport";
 	registerIncoming();
 	changeState(State::Connecting);
+
+	size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
+	gnutls_dtls_set_mtu(mSession, static_cast<unsigned int>(mtu));
+	PLOG_VERBOSE << "DTLS MTU set to " << mtu;
+
 	enqueueRecv(); // to initiate the handshake
 }
 
@@ -181,10 +182,13 @@ void DtlsTransport::doRecv() {
 	std::lock_guard lock(mRecvMutex);
 	--mPendingRecvCount;
 
-	const size_t bufferSize = 4096;
-	char buffer[bufferSize];
+	if (state() != State::Connecting && state() != State::Connected)
+		return;
 
 	try {
+		const size_t bufferSize = 4096;
+		char buffer[bufferSize];
+
 		// Handle handshake if connecting
 		if (state() == State::Connecting) {
 			int ret;
@@ -193,7 +197,7 @@ void DtlsTransport::doRecv() {
 
 				if (ret == GNUTLS_E_AGAIN) {
 					// Schedule next call on timeout and return
-					duration timeout = milliseconds(gnutls_dtls_get_timeout(mSession));
+					auto timeout = milliseconds(gnutls_dtls_get_timeout(mSession));
 					ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
 						if (auto locked = weak_this.lock())
 							locked->doRecv();
@@ -317,7 +321,13 @@ ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *dat
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	try {
-		while (auto next = t->mIncomingQueue.pop()) {
+		while (t->mIncomingQueue.running()) {
+			auto next = t->mIncomingQueue.pop();
+			if (!next) {
+				gnutls_transport_set_errno(t->mSession, EAGAIN);
+				return -1;
+			}
+
 			message_ptr message = std::move(*next);
 			if (t->demuxMessage(message))
 				continue;
@@ -328,14 +338,9 @@ ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size
 			return len;
 		}
 
-		if (t->mIncomingQueue.running()) {
-			gnutls_transport_set_errno(t->mSession, EAGAIN);
-			return -1;
-		} else {
-			// Closed
-			gnutls_transport_set_errno(t->mSession, 0);
-			return 0;
-		}
+		// Closed
+		gnutls_transport_set_errno(t->mSession, 0);
+		return 0;
 
 	} catch (const std::exception &e) {
 		PLOG_WARNING << e.what();
@@ -461,6 +466,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 				throw std::runtime_error("Failed to set SRTP profile: " +
 				                         openssl::error_string(ERR_get_error()));
 		}
+
 	} catch (...) {
 		if (mSsl)
 			SSL_free(mSsl);
@@ -483,23 +489,26 @@ DtlsTransport::~DtlsTransport() {
 }
 
 void DtlsTransport::start() {
-	if (mStarted.exchange(true))
-		return;
-
-	PLOG_DEBUG << "Starting DTLS recv thread";
+	PLOG_DEBUG << "Starting DTLS transport";
 	registerIncoming();
-	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+	changeState(State::Connecting);
+
+	size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
+	SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
+	PLOG_VERBOSE << "DTLS MTU set to " << mtu;
+
+	// Initiate the handshake
+	int ret = SSL_do_handshake(mSsl);
+	openssl::check(mSsl, ret, "Handshake initiation failed");
+
+	handleTimeout();
 }
 
 void DtlsTransport::stop() {
-	if (!mStarted.exchange(false))
-		return;
-
-	PLOG_DEBUG << "Stopping DTLS recv thread";
+	PLOG_DEBUG << "Stopping DTLS transport";
 	unregisterIncoming();
 	mIncomingQueue.stop();
-	mRecvThread.join();
-	SSL_shutdown(mSsl);
+	enqueueRecv();
 }
 
 bool DtlsTransport::send(message_ptr message) {
@@ -519,11 +528,13 @@ bool DtlsTransport::send(message_ptr message) {
 void DtlsTransport::incoming(message_ptr message) {
 	if (!message) {
 		mIncomingQueue.stop();
+		enqueueRecv();
 		return;
 	}
 
 	PLOG_VERBOSE << "Incoming size=" << message->size();
 	mIncomingQueue.push(message);
+	enqueueRecv();
 }
 
 bool DtlsTransport::outgoing(message_ptr message) {
@@ -543,86 +554,65 @@ void DtlsTransport::postHandshake() {
 	// Dummy
 }
 
-void DtlsTransport::runRecvLoop() {
-	const size_t bufferSize = 4096;
-	try {
-		changeState(State::Connecting);
-
-		size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
-		SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
-		PLOG_VERBOSE << "SSL MTU set to " << mtu;
+void DtlsTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	--mPendingRecvCount;
 
-		// Initiate the handshake
-		int ret = SSL_do_handshake(mSsl);
-		openssl::check(mSsl, ret, "Handshake failed");
+	if (state() != State::Connecting && state() != State::Connected)
+		return;
 
+	try {
+		const size_t bufferSize = 4096;
 		byte buffer[bufferSize];
+
+		// Process pending messages
 		while (mIncomingQueue.running()) {
-			// Process pending messages
-			while (auto next = mIncomingQueue.pop()) {
-				message_ptr message = std::move(*next);
-				if (demuxMessage(message))
-					continue;
+			auto next = mIncomingQueue.pop();
+			if (!next) {
+				// No more messages pending, handle timeout if connecting
+				if (state() == State::Connecting)
+					handleTimeout();
 
-				BIO_write(mInBio, message->data(), int(message->size()));
+				return;
+			}
 
-				if (state() == State::Connecting) {
-					// Continue the handshake
-					ret = SSL_do_handshake(mSsl);
-					if (!openssl::check(mSsl, ret, "Handshake failed"))
-						break;
+			message_ptr message = std::move(*next);
+			if (demuxMessage(message))
+				continue;
 
-					if (SSL_is_init_finished(mSsl)) {
-						// RFC 8261: DTLS MUST support sending messages larger than the current path
-						// MTU See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
-						SSL_set_mtu(mSsl, bufferSize + 1);
+			BIO_write(mInBio, message->data(), int(message->size()));
 
-						PLOG_INFO << "DTLS handshake finished";
-						postHandshake();
-						changeState(State::Connected);
-					}
-				} else {
-					ret = SSL_read(mSsl, buffer, bufferSize);
-					if (!openssl::check(mSsl, ret))
-						break;
+			if (state() == State::Connecting) {
+				// Continue the handshake
+				int ret = SSL_do_handshake(mSsl);
+				if (!openssl::check(mSsl, ret, "Handshake failed"))
+					break;
 
-					if (ret > 0)
-						recv(make_message(buffer, buffer + ret));
-				}
-			}
+				if (SSL_is_init_finished(mSsl)) {
+					// RFC 8261: DTLS MUST support sending messages larger than the current path
+					// MTU See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
+					SSL_set_mtu(mSsl, bufferSize + 1);
 
-			// No more messages pending, retransmit and rearm timeout if connecting
-			optional<milliseconds> duration;
-			if (state() == State::Connecting) {
-				// Warning: This function breaks the usual return value convention
-				ret = DTLSv1_handle_timeout(mSsl);
-				if (ret < 0) {
-					throw std::runtime_error("Handshake timeout"); // write BIO can't fail
-				} else if (ret > 0) {
-					LOG_VERBOSE << "OpenSSL did DTLS retransmit";
+					PLOG_INFO << "DTLS handshake finished";
+					postHandshake();
+					changeState(State::Connected);
 				}
+			} else {
+				int ret = SSL_read(mSsl, buffer, bufferSize);
+				if (!openssl::check(mSsl, ret))
+					break;
 
-				struct timeval timeout = {};
-				if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
-					duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
-					// Also handle handshake timeout manually because OpenSSL actually doesn't...
-					// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
-					// so this allows for 5 retransmissions and fails after roughly 30s.
-					if (duration > 30s) {
-						throw std::runtime_error("Handshake timeout");
-					} else {
-						LOG_VERBOSE << "OpenSSL DTLS retransmit timeout is " << duration->count()
-						            << "ms";
-					}
-				}
+				if (ret > 0)
+					recv(make_message(buffer, buffer + ret));
 			}
-
-			mIncomingQueue.wait(duration); // TODO
 		}
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "DTLS recv: " << e.what();
 	}
 
+	SSL_shutdown(mSsl);
+
 	if (state() == State::Connected) {
 		PLOG_INFO << "DTLS closed";
 		changeState(State::Disconnected);
@@ -633,6 +623,33 @@ void DtlsTransport::runRecvLoop() {
 	}
 }
 
+void DtlsTransport::handleTimeout() {
+	// Warning: This function breaks the usual return value convention
+	int ret = DTLSv1_handle_timeout(mSsl);
+	if (ret < 0) {
+		throw std::runtime_error("Handshake timeout"); // write BIO can't fail
+	} else if (ret > 0) {
+		LOG_VERBOSE << "DTLS retransmit done";
+	}
+
+	struct timeval tv = {};
+	if (DTLSv1_get_timeout(mSsl, &tv)) {
+		auto timeout = milliseconds(tv.tv_sec * 1000 + tv.tv_usec / 1000);
+		// Also handle handshake timeout manually because OpenSSL actually
+		// doesn't... OpenSSL backs off exponentially in base 2 starting from the
+		// recommended 1s so this allows for 5 retransmissions and fails after
+		// roughly 30s.
+		if (timeout > 30s)
+			throw std::runtime_error("Handshake timeout");
+
+		LOG_VERBOSE << "DTLS retransmit timeout is " << timeout.count() << "ms";
+		ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
+			if (auto locked = weak_this.lock())
+				locked->doRecv();
+		});
+	}
+}
+
 int DtlsTransport::CertificateCallback(int /*preverify_ok*/, X509_STORE_CTX *ctx) {
 	SSL *ssl =
 	    static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));

+ 2 - 0
src/impl/dtlstransport.hpp

@@ -75,6 +75,8 @@ protected:
 	SSL *mSsl = NULL;
 	BIO *mInBio, *mOutBio;
 
+	void handleTimeout();
+
 	static BIO_METHOD *BioMethods;
 	static int TransportExIndex;
 	static std::mutex GlobalMutex;

+ 52 - 41
src/impl/tlstransport.cpp

@@ -137,6 +137,7 @@ bool TlsTransport::send(message_ptr message) {
 void TlsTransport::incoming(message_ptr message) {
 	if (!message) {
 		mIncomingQueue.stop();
+		enqueueRecv();
 		return;
 	}
 
@@ -280,7 +281,7 @@ int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* ms
 		message_ptr &message = t->mIncomingMessage;
 		size_t &position = t->mIncomingMessagePosition;
 
-		if(message && position < message->size())
+		if (message && position < message->size())
 			return 1;
 
 		return !t->mIncomingQueue.empty() ? 1 : 0;
@@ -384,23 +385,22 @@ TlsTransport::~TlsTransport() {
 }
 
 void TlsTransport::start() {
-	if (mStarted.exchange(true))
-		return;
-
-	PLOG_DEBUG << "Starting TLS recv thread";
+	PLOG_DEBUG << "Starting TLS transport";
 	registerIncoming();
-	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+	changeState(State::Connecting);
+
+	// Initiate the handshake
+	int ret = SSL_do_handshake(mSsl);
+	openssl::check(mSsl, ret, "Handshake initiation failed");
+
+	flushOutput();
 }
 
 void TlsTransport::stop() {
-	if (!mStarted.exchange(false))
-		return;
-
-	PLOG_DEBUG << "Stopping TLS recv thread";
+	PLOG_DEBUG << "Stopping TLS transport";
 	unregisterIncoming();
 	mIncomingQueue.stop();
-	mRecvThread.join();
-	SSL_shutdown(mSsl);
+	enqueueRecv();
 }
 
 bool TlsTransport::send(message_ptr message) {
@@ -416,23 +416,19 @@ bool TlsTransport::send(message_ptr message) {
 	if (!openssl::check(mSsl, ret))
 		throw std::runtime_error("TLS send failed");
 
-	const size_t bufferSize = 4096;
-	byte buffer[bufferSize];
-	bool result = true;
-	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
-		result = outgoing(make_message(buffer, buffer + ret));
-
-	return result;
+	return flushOutput();
 }
 
 void TlsTransport::incoming(message_ptr message) {
 	if (!message) {
 		mIncomingQueue.stop();
+		enqueueRecv();
 		return;
 	}
 
 	PLOG_VERBOSE << "Incoming size=" << message->size();
 	mIncomingQueue.push(message);
+	enqueueRecv();
 }
 
 bool TlsTransport::outgoing(message_ptr message) { return Transport::outgoing(std::move(message)); }
@@ -441,24 +437,36 @@ void TlsTransport::postHandshake() {
 	// Dummy
 }
 
-void TlsTransport::runRecvLoop() {
-	const size_t bufferSize = 4096;
-	byte buffer[bufferSize];
+void TlsTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	--mPendingRecvCount;
+
+	if (state() != State::Connecting && state() != State::Connected)
+		return;
 
 	try {
-		changeState(State::Connecting);
+		const size_t bufferSize = 4096;
+		byte buffer[bufferSize];
+
+		// Process incoming messages
+		while (mIncomingQueue.running()) {
+			auto next = mIncomingQueue.pop();
+			if (!next)
+				return;
+
+			message_ptr message = std::move(*next);
+			if (message->size() > 0)
+				BIO_write(mInBio, message->data(), int(message->size())); // Input
+			else
+				recv(message); // Pass zero-sized messages through
 
-		int ret;
-		while (true) {
 			if (state() == State::Connecting) {
-				// Initiate or continue the handshake
-				ret = SSL_do_handshake(mSsl);
+				// Continue the handshake
+				int ret = SSL_do_handshake(mSsl);
 				if (!openssl::check(mSsl, ret, "Handshake failed"))
 					break;
 
-				// Output
-				while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
-					outgoing(make_message(buffer, buffer + ret));
+				flushOutput();
 
 				if (SSL_is_init_finished(mSsl)) {
 					PLOG_INFO << "TLS handshake finished";
@@ -468,29 +476,21 @@ void TlsTransport::runRecvLoop() {
 			}
 
 			if (state() == State::Connected) {
-				// Input
+				int ret;
 				while ((ret = SSL_read(mSsl, buffer, bufferSize)) > 0)
 					recv(make_message(buffer, buffer + ret));
 
 				if (!openssl::check(mSsl, ret))
 					break;
 			}
-
-			auto next = mIncomingQueue.pop();
-			if (!next)
-				break;
-
-			message_ptr message = std::move(*next);
-			if (message->size() > 0)
-				BIO_write(mInBio, message->data(), int(message->size())); // Input
-			else
-				recv(message); // Pass zero-sized messages through
 		}
 
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TLS recv: " << e.what();
 	}
 
+	SSL_shutdown(mSsl);
+
 	if (state() == State::Connected) {
 		PLOG_INFO << "TLS closed";
 		recv(nullptr);
@@ -499,6 +499,17 @@ void TlsTransport::runRecvLoop() {
 	}
 }
 
+bool TlsTransport::flushOutput() {
+	const size_t bufferSize = 4096;
+	byte buffer[bufferSize];
+	int ret;
+	bool result = true;
+	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
+		result = outgoing(make_message(buffer, buffer + ret));
+
+	return result;
+}
+
 void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
 	TlsTransport *t =
 	    static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));

+ 9 - 0
src/impl/tlstransport.hpp

@@ -69,10 +69,19 @@ protected:
 	SSL *mSsl;
 	BIO *mInBio, *mOutBio;
 
+	bool flushOutput();
+
+	static BIO_METHOD *BioMethods;
 	static int TransportExIndex;
+	static std::mutex GlobalMutex;
 
 	static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
 	static void InfoCallback(const SSL *ssl, int where, int ret);
+
+	static int BioMethodNew(BIO *bio);
+	static int BioMethodFree(BIO *bio);
+	static int BioMethodWrite(BIO *bio, const char *in, int inl);
+	static long BioMethodCtrl(BIO *bio, int cmd, long num, void *ptr);
 #endif
 };