Browse Source

Move MbedTLS checks to TLS helper function

Paul-Louis Ageneau 2 years ago
parent
commit
c31088b6f4
4 changed files with 44 additions and 51 deletions
  1. 17 26
      src/impl/dtlstransport.cpp
  2. 10 1
      src/impl/tls.cpp
  3. 1 1
      src/impl/tls.hpp
  4. 16 23
      src/impl/tlstransport.cpp

+ 17 - 26
src/impl/dtlstransport.cpp

@@ -480,9 +480,7 @@ bool DtlsTransport::send(message_ptr message) {
 		mCurrentDscp = message->dscp;
 		ret = mbedtls_ssl_write(&mSsl, reinterpret_cast<const unsigned char *>(message->data()),
 		                        message->size());
-	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
-
-	mbedtls::check(ret);
+	} while (!mbedtls::check(ret));
 
 	return mOutgoingResult;
 }
@@ -535,7 +533,7 @@ void DtlsTransport::doRecv() {
 					ret = mbedtls_ssl_handshake(&mSsl);
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
 					ThreadPool::Instance().schedule(mTimerSetAt + milliseconds(mFinMs),
 					                                [weak_this = weak_from_this()]() {
 						                                if (auto locked = weak_this.lock())
@@ -544,17 +542,12 @@ void DtlsTransport::doRecv() {
 					return;
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
+				if(mbedtls::check(ret, "Handshake failed")) {
+					PLOG_INFO << "DTLS handshake finished";
+					changeState(State::Connected);
+					postHandshake();
+					break;
 				}
-
-				mbedtls::check(ret, "Handshake failed");
-
-				PLOG_INFO << "DTLS handshake finished";
-				changeState(State::Connected);
-				postHandshake();
-				break;
 			}
 		}
 
@@ -567,25 +560,23 @@ void DtlsTransport::doRecv() {
 					                       bufferSize);
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
 					return;
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
-				}
-
-				if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
-					// Closed
+				if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
 					PLOG_DEBUG << "DTLS connection cleanly closed";
 					break;
 				}
 
-				mbedtls::check(ret);
-
-				auto *b = reinterpret_cast<byte *>(buffer);
-				recv(make_message(b, b + ret));
+				if(mbedtls::check(ret)) {
+					if(ret == 0) {
+						PLOG_DEBUG << "DTLS connection terminated";
+						break;
+					}
+					auto *b = reinterpret_cast<byte *>(buffer);
+					recv(make_message(b, b + ret));
+				}
 			}
 		}
 	} catch (const std::exception &e) {

+ 10 - 1
src/impl/tls.cpp

@@ -15,6 +15,7 @@
 
 namespace rtc::gnutls {
 
+// Return false on non-fatal error
 bool check(int ret, const string &message) {
 	if (ret < 0) {
 		if (!gnutls_error_is_fatal(ret)) {
@@ -95,13 +96,20 @@ size_t my_strftme(char *buf, size_t size, const char *format, const time_t *t) {
 
 namespace rtc::mbedtls {
 
-void check(int ret, const string &message) {
+// Return false on non-fatal error
+bool check(int ret, const string &message) {
 	if (ret < 0) {
+		if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS ||
+		    ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
+			return false;
+
 		const size_t bufferSize = 1024;
 		char buffer[bufferSize];
 		mbedtls_strerror(ret, reinterpret_cast<char *>(buffer), bufferSize);
 		throw std::runtime_error(message + ": " + std::string(buffer));
 	}
+	return true;
 }
 
 string format_time(const std::chrono::system_clock::time_point &tp) {
@@ -170,6 +178,7 @@ bool check(int success, const string &message) {
 	throw std::runtime_error(message + ": " + str);
 }
 
+// Return false on EOF
 bool check(SSL *ssl, int ret, const string &message) {
 	unsigned long err = SSL_get_error(ssl, ret);
 	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {

+ 1 - 1
src/impl/tls.hpp

@@ -52,7 +52,7 @@ gnutls_datum_t make_datum(char *data, size_t size);
 
 namespace rtc::mbedtls {
 
-void check(int ret, const string &message = "MbedTLS error");
+bool check(int ret, const string &message = "MbedTLS error");
 
 string format_time(const std::chrono::system_clock::time_point &tp);
 

+ 16 - 23
src/impl/tlstransport.cpp

@@ -433,21 +433,16 @@ void TlsTransport::doRecv() {
 					ret = mbedtls_ssl_handshake(&mSsl);
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
 					return;
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
+				if (mbedtls::check(ret, "Handshake failed")) {
+					PLOG_INFO << "TLS handshake finished";
+					changeState(State::Connected);
+					postHandshake();
+					break;
 				}
-
-				mbedtls::check(ret, "Handshake failed");
-
-				PLOG_INFO << "TLS handshake finished";
-				changeState(State::Connected);
-				postHandshake();
-				break;
 			}
 		}
 
@@ -460,25 +455,23 @@ void TlsTransport::doRecv() {
 					                       bufferSize);
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
 					return;
 				}
 
-				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
-				}
-
-				if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
-					// Closed
+				if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
 					PLOG_DEBUG << "TLS connection cleanly closed";
 					break;
 				}
 
-				mbedtls::check(ret);
-
-				auto *b = reinterpret_cast<byte *>(buffer);
-				recv(make_message(b, b + ret));
+				if (mbedtls::check(ret)) {
+					if (ret == 0) {
+						PLOG_DEBUG << "TLS connection terminated";
+						break;
+					}
+					auto *b = reinterpret_cast<byte *>(buffer);
+					recv(make_message(b, b + ret));
+				}
 			}
 		}
 	} catch (const std::exception &e) {