Browse Source

Clean up and make TLS and DTLS Transports more consistent

Paul-Louis Ageneau 2 years ago
parent
commit
4fc3f92a83
2 changed files with 68 additions and 32 deletions
  1. 34 16
      src/impl/dtlstransport.cpp
  2. 34 16
      src/impl/tlstransport.cpp

+ 34 - 16
src/impl/dtlstransport.cpp

@@ -209,7 +209,7 @@ void DtlsTransport::doRecv() {
 					throw std::runtime_error("MTU is too low");
 				}
 
-			} while (!gnutls::check(ret, "DTLS handshake failed")); // Re-call on non-fatal error
+			} while (!gnutls::check(ret, "Handshake failed")); // Re-call on non-fatal error
 
 			// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
 			// See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
@@ -263,9 +263,14 @@ void DtlsTransport::doRecv() {
 
 	gnutls_bye(mSession, GNUTLS_SHUT_WR);
 
-	PLOG_INFO << "DTLS closed";
-	changeState(State::Disconnected);
-	recv(nullptr);
+	if (state() == State::Connected) {
+		PLOG_INFO << "DTLS closed";
+		changeState(State::Disconnected);
+		recv(nullptr);
+	} else {
+		PLOG_ERROR << "DTLS handshake failed";
+		changeState(State::Failed);
+	}
 }
 
 int DtlsTransport::CertificateCallback(gnutls_session_t session) {
@@ -413,6 +418,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		mbedtls_ssl_set_export_keys_cb(&mSsl, DtlsTransport::ExportKeysCallback, this);
 		mbedtls_ssl_set_bio(&mSsl, this, WriteCallback, ReadCallback, NULL);
 		mbedtls_ssl_set_timer_cb(&mSsl, this, SetTimerCallback, GetTimerCallback);
+
 	} catch (...) {
 		mbedtls_entropy_free(&mEntropy);
 		mbedtls_ctr_drbg_free(&mDrbg);
@@ -524,6 +530,7 @@ void DtlsTransport::doRecv() {
 		if (state() == State::Connecting) {
 			while (true) {
 				auto ret = mbedtls_ssl_handshake(&mSsl);
+
 				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
 					ThreadPool::Instance().schedule(mTimerSetAt + milliseconds(mFinMs),
 					                                [weak_this = weak_from_this()]() {
@@ -531,12 +538,14 @@ void DtlsTransport::doRecv() {
 							                                locked->doRecv();
 					                                });
 					return;
-				} else if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
 				}
 
-				mbedtls::check(ret);
+				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
+				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS)
+					continue;
+
+				mbedtls::check(ret, "Handshake failed");
+
 				PLOG_INFO << "DTLS handshake finished";
 				changeState(State::Connected);
 				postHandshake();
@@ -551,17 +560,21 @@ void DtlsTransport::doRecv() {
 				    mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer), bufferSize);
 				mMutex.unlock();
 
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+					return;
+				}
+
+				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
+				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
+					continue;
+				}
+
 				if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
 					// Closed
 					PLOG_DEBUG << "DTLS connection cleanly closed";
 					break;
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
-				    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					return;
-				}
 				mbedtls::check(ret);
 
 				auto *b = reinterpret_cast<byte *>(buffer);
@@ -572,9 +585,14 @@ void DtlsTransport::doRecv() {
 		PLOG_ERROR << "DTLS recv: " << e.what();
 	}
 
-	PLOG_INFO << "DTLS closed";
-	changeState(State::Disconnected);
-	recv(nullptr);
+	if (state() == State::Connected) {
+		PLOG_INFO << "DTLS closed";
+		changeState(State::Disconnected);
+		recv(nullptr);
+	} else {
+		PLOG_ERROR << "DTLS handshake failed";
+		changeState(State::Failed);
+	}
 }
 
 void DtlsTransport::ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type /*type*/,

+ 34 - 16
src/impl/tlstransport.cpp

@@ -177,7 +177,7 @@ void TlsTransport::doRecv() {
 				if (ret == GNUTLS_E_AGAIN)
 					return;
 
-			} while (!gnutls::check(ret, "TLS handshake failed")); // Re-call on non-fatal error
+			} while (!gnutls::check(ret, "Handshake failed")); // Re-call on non-fatal error
 
 			PLOG_INFO << "TLS handshake finished";
 			changeState(State::Connected);
@@ -214,9 +214,14 @@ void TlsTransport::doRecv() {
 
 	gnutls_bye(mSession, GNUTLS_SHUT_WR);
 
-	PLOG_INFO << "TLS closed";
-	changeState(State::Disconnected);
-	recv(nullptr);
+	if (state() == State::Connected) {
+		PLOG_INFO << "TLS closed";
+		changeState(State::Disconnected);
+		recv(nullptr);
+	} else {
+		PLOG_ERROR << "TLS handshake failed";
+		changeState(State::Failed);
+	}
 }
 
 ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
@@ -339,6 +344,7 @@ TlsTransport::TlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<HttpProx
 
 		mbedtls::check(mbedtls_ssl_setup(&mSsl, &mConf));
 		mbedtls_ssl_set_bio(&mSsl, static_cast<void *>(this), WriteCallback, ReadCallback, NULL);
+
 	} catch (...) {
 		mbedtls_entropy_free(&mEntropy);
 		mbedtls_ctr_drbg_free(&mDrbg);
@@ -416,14 +422,18 @@ void TlsTransport::doRecv() {
 		if (state() == State::Connecting) {
 			while (true) {
 				auto ret = mbedtls_ssl_handshake(&mSsl);
+
 				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
 					return;
-				} else if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
+				}
+
+				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
+				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
 					continue;
 				}
 
-				mbedtls::check(ret);
+				mbedtls::check(ret, "Handshake failed");
+
 				PLOG_INFO << "TLS handshake finished";
 				changeState(State::Connected);
 				postHandshake();
@@ -436,18 +446,21 @@ void TlsTransport::doRecv() {
 				auto ret =
 				    mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer), bufferSize);
 
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+					return;
+				}
+
+				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
+				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
+					continue;
+				}
+
 				if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
 					// Closed
 					PLOG_DEBUG << "TLS connection cleanly closed";
 					break;
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
-					return;
-				} else if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
-				}
 				mbedtls::check(ret);
 
 				auto *b = reinterpret_cast<byte *>(buffer);
@@ -458,9 +471,14 @@ void TlsTransport::doRecv() {
 		PLOG_ERROR << "TLS recv: " << e.what();
 	}
 
-	PLOG_INFO << "TLS closed";
-	changeState(State::Disconnected);
-	recv(nullptr);
+	if (state() == State::Connected) {
+		PLOG_INFO << "TLS closed";
+		changeState(State::Disconnected);
+		recv(nullptr);
+	} else {
+		PLOG_ERROR << "TLS handshake failed";
+		changeState(State::Failed);
+	}
 }
 
 int TlsTransport::WriteCallback(void *ctx, const unsigned char *buf, size_t len) {