Browse Source

Merge pull request #895 from paullouisageneau/fix-mbedtls-sync

Fix MbedTLS sync and refactor MbedTLS error checking
Paul-Louis Ageneau 2 years ago
parent
commit
a33f252e97

+ 32 - 34
src/impl/dtlstransport.cpp

@@ -367,7 +367,7 @@ int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* m
 
 
 #elif USE_MBEDTLS
 #elif USE_MBEDTLS
 
 
-mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = {
+const mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = {
     MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80,
     MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80,
     MBEDTLS_TLS_SRTP_UNSET,
     MBEDTLS_TLS_SRTP_UNSET,
 };
 };
@@ -473,16 +473,14 @@ bool DtlsTransport::send(message_ptr message) {
 
 
 	int ret;
 	int ret;
 	do {
 	do {
-		std::lock_guard lock(mMutex);
-		mCurrentDscp = message->dscp;
-
+		std::lock_guard lock(mSslMutex);
 		if (message->size() > size_t(mbedtls_ssl_get_max_out_record_payload(&mSsl)))
 		if (message->size() > size_t(mbedtls_ssl_get_max_out_record_payload(&mSsl)))
 			return false;
 			return false;
 
 
+		mCurrentDscp = message->dscp;
 		ret = mbedtls_ssl_write(&mSsl, reinterpret_cast<const unsigned char *>(message->data()),
 		ret = mbedtls_ssl_write(&mSsl, reinterpret_cast<const unsigned char *>(message->data()),
 		                        message->size());
 		                        message->size());
-	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
-	mbedtls::check(ret);
+	} while (!mbedtls::check(ret));
 
 
 	return mOutgoingResult;
 	return mOutgoingResult;
 }
 }
@@ -529,9 +527,13 @@ void DtlsTransport::doRecv() {
 		// Handle handshake if connecting
 		// Handle handshake if connecting
 		if (state() == State::Connecting) {
 		if (state() == State::Connecting) {
 			while (true) {
 			while (true) {
-				auto ret = mbedtls_ssl_handshake(&mSsl);
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = mbedtls_ssl_handshake(&mSsl);
+				}
 
 
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
 					ThreadPool::Instance().schedule(mTimerSetAt + milliseconds(mFinMs),
 					ThreadPool::Instance().schedule(mTimerSetAt + milliseconds(mFinMs),
 					                                [weak_this = weak_from_this()]() {
 					                                [weak_this = weak_from_this()]() {
 						                                if (auto locked = weak_this.lock())
 						                                if (auto locked = weak_this.lock())
@@ -540,45 +542,41 @@ void DtlsTransport::doRecv() {
 					return;
 					return;
 				}
 				}
 
 
-				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS)
-					continue;
-
-				mbedtls::check(ret, "Handshake failed");
-
-				PLOG_INFO << "DTLS handshake finished";
-				changeState(State::Connected);
-				postHandshake();
-				break;
+				if(mbedtls::check(ret, "Handshake failed")) {
+					PLOG_INFO << "DTLS handshake finished";
+					changeState(State::Connected);
+					postHandshake();
+					break;
+				}
 			}
 			}
 		}
 		}
 
 
 		if (state() == State::Connected) {
 		if (state() == State::Connected) {
 			while (true) {
 			while (true) {
-				mMutex.lock();
-				auto ret =
-				    mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer), bufferSize);
-				mMutex.unlock();
-
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
-					return;
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer),
+					                       bufferSize);
 				}
 				}
 
 
-				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
+					return;
 				}
 				}
 
 
-				if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
-					// Closed
+				if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
 					PLOG_DEBUG << "DTLS connection cleanly closed";
 					PLOG_DEBUG << "DTLS connection cleanly closed";
 					break;
 					break;
 				}
 				}
 
 
-				mbedtls::check(ret);
-
-				auto *b = reinterpret_cast<byte *>(buffer);
-				recv(make_message(b, b + ret));
+				if(mbedtls::check(ret)) {
+					if(ret == 0) {
+						PLOG_DEBUG << "DTLS connection terminated";
+						break;
+					}
+					auto *b = reinterpret_cast<byte *>(buffer);
+					recv(make_message(b, b + ret));
+				}
 			}
 			}
 		}
 		}
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {

+ 2 - 2
src/impl/dtlstransport.hpp

@@ -71,13 +71,13 @@ protected:
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 
 
 #elif USE_MBEDTLS
 #elif USE_MBEDTLS
-	std::mutex mMutex;
-
 	mbedtls_entropy_context mEntropy;
 	mbedtls_entropy_context mEntropy;
 	mbedtls_ctr_drbg_context mDrbg;
 	mbedtls_ctr_drbg_context mDrbg;
 	mbedtls_ssl_config mConf;
 	mbedtls_ssl_config mConf;
 	mbedtls_ssl_context mSsl;
 	mbedtls_ssl_context mSsl;
 
 
+	std::mutex mSslMutex;
+
 	uint32_t mFinMs = 0, mIntMs = 0;
 	uint32_t mFinMs = 0, mIntMs = 0;
 	std::chrono::time_point<std::chrono::steady_clock> mTimerSetAt;
 	std::chrono::time_point<std::chrono::steady_clock> mTimerSetAt;
 
 

+ 11 - 9
src/impl/tls.cpp

@@ -8,21 +8,19 @@
 
 
 #include "tls.hpp"
 #include "tls.hpp"
 
 
-#include "internals.hpp"
-
 #include <fstream>
 #include <fstream>
+#include <stdexcept>
 
 
 #if USE_GNUTLS
 #if USE_GNUTLS
 
 
 namespace rtc::gnutls {
 namespace rtc::gnutls {
 
 
+// Return false on non-fatal error
 bool check(int ret, const string &message) {
 bool check(int ret, const string &message) {
 	if (ret < 0) {
 	if (ret < 0) {
 		if (!gnutls_error_is_fatal(ret)) {
 		if (!gnutls_error_is_fatal(ret)) {
-			PLOG_INFO << gnutls_strerror(ret);
 			return false;
 			return false;
 		}
 		}
-		PLOG_ERROR << message << ": " << gnutls_strerror(ret);
 		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
 		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
 	}
 	}
 	return true;
 	return true;
@@ -98,14 +96,20 @@ size_t my_strftme(char *buf, size_t size, const char *format, const time_t *t) {
 
 
 namespace rtc::mbedtls {
 namespace rtc::mbedtls {
 
 
-void check(int ret, const string &message) {
+// Return false on non-fatal error
+bool check(int ret, const string &message) {
 	if (ret < 0) {
 	if (ret < 0) {
+		if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS ||
+		    ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
+			return false;
+
 		const size_t bufferSize = 1024;
 		const size_t bufferSize = 1024;
 		char buffer[bufferSize];
 		char buffer[bufferSize];
 		mbedtls_strerror(ret, reinterpret_cast<char *>(buffer), bufferSize);
 		mbedtls_strerror(ret, reinterpret_cast<char *>(buffer), bufferSize);
-		PLOG_ERROR << message << ": " << buffer;
 		throw std::runtime_error(message + ": " + std::string(buffer));
 		throw std::runtime_error(message + ": " + std::string(buffer));
 	}
 	}
+	return true;
 }
 }
 
 
 string format_time(const std::chrono::system_clock::time_point &tp) {
 string format_time(const std::chrono::system_clock::time_point &tp) {
@@ -171,21 +175,19 @@ bool check(int success, const string &message) {
 		return true;
 		return true;
 
 
 	string str = error_string(ERR_get_error());
 	string str = error_string(ERR_get_error());
-	PLOG_ERROR << message << ": " << str;
 	throw std::runtime_error(message + ": " + str);
 	throw std::runtime_error(message + ": " + str);
 }
 }
 
 
+// Return false on EOF
 bool check(SSL *ssl, int ret, const string &message) {
 bool check(SSL *ssl, int ret, const string &message) {
 	unsigned long err = SSL_get_error(ssl, ret);
 	unsigned long err = SSL_get_error(ssl, ret);
 	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
 	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
 		return true;
 		return true;
 	}
 	}
 	if (err == SSL_ERROR_ZERO_RETURN) {
 	if (err == SSL_ERROR_ZERO_RETURN) {
-		PLOG_DEBUG << "OpenSSL connection cleanly closed";
 		return false;
 		return false;
 	}
 	}
 	string str = error_string(err);
 	string str = error_string(err);
-	PLOG_ERROR << str;
 	throw std::runtime_error(message + ": " + str);
 	throw std::runtime_error(message + ": " + str);
 }
 }
 
 

+ 1 - 1
src/impl/tls.hpp

@@ -52,7 +52,7 @@ gnutls_datum_t make_datum(char *data, size_t size);
 
 
 namespace rtc::mbedtls {
 namespace rtc::mbedtls {
 
 
-void check(int ret, const string &message = "MbedTLS error");
+bool check(int ret, const string &message = "MbedTLS error");
 
 
 string format_time(const std::chrono::system_clock::time_point &tp);
 string format_time(const std::chrono::system_clock::time_point &tp);
 
 

+ 35 - 28
src/impl/tlstransport.cpp

@@ -379,8 +379,14 @@ bool TlsTransport::send(message_ptr message) {
 
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 	PLOG_VERBOSE << "Send size=" << message->size();
 
 
-	mbedtls::check(mbedtls_ssl_write(
-	    &mSsl, reinterpret_cast<const unsigned char *>(message->data()), int(message->size())));
+	int ret;
+	do {
+		std::lock_guard lock(mSslMutex);
+		ret = mbedtls_ssl_write(&mSsl, reinterpret_cast<const unsigned char *>(message->data()),
+		                        int(message->size()));
+	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
+
+	mbedtls::check(ret);
 
 
 	return mOutgoingResult;
 	return mOutgoingResult;
 }
 }
@@ -421,50 +427,51 @@ void TlsTransport::doRecv() {
 		// Handle handshake if connecting
 		// Handle handshake if connecting
 		if (state() == State::Connecting) {
 		if (state() == State::Connecting) {
 			while (true) {
 			while (true) {
-				auto ret = mbedtls_ssl_handshake(&mSsl);
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = mbedtls_ssl_handshake(&mSsl);
+				}
 
 
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
 					return;
 					return;
 				}
 				}
 
 
-				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
+				if (mbedtls::check(ret, "Handshake failed")) {
+					PLOG_INFO << "TLS handshake finished";
+					changeState(State::Connected);
+					postHandshake();
+					break;
 				}
 				}
-
-				mbedtls::check(ret, "Handshake failed");
-
-				PLOG_INFO << "TLS handshake finished";
-				changeState(State::Connected);
-				postHandshake();
-				break;
 			}
 			}
 		}
 		}
 
 
 		if (state() == State::Connected) {
 		if (state() == State::Connected) {
 			while (true) {
 			while (true) {
-				auto ret =
-				    mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer), bufferSize);
-
-				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
-					return;
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer),
+					                       bufferSize);
 				}
 				}
 
 
-				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
-					continue;
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
+					return;
 				}
 				}
 
 
-				if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
-					// Closed
+				if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
 					PLOG_DEBUG << "TLS connection cleanly closed";
 					PLOG_DEBUG << "TLS connection cleanly closed";
 					break;
 					break;
 				}
 				}
 
 
-				mbedtls::check(ret);
-
-				auto *b = reinterpret_cast<byte *>(buffer);
-				recv(make_message(b, b + ret));
+				if (mbedtls::check(ret)) {
+					if (ret == 0) {
+						PLOG_DEBUG << "TLS connection terminated";
+						break;
+					}
+					auto *b = reinterpret_cast<byte *>(buffer);
+					recv(make_message(b, b + ret));
+				}
 			}
 			}
 		}
 		}
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {

+ 4 - 3
src/impl/tlstransport.hpp

@@ -65,15 +65,16 @@ protected:
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
-#elif USE_MBEDTLS
-	std::mutex mSendMutex;
-	std::atomic<bool> mOutgoingResult = true;
 
 
+#elif USE_MBEDTLS
 	mbedtls_entropy_context mEntropy;
 	mbedtls_entropy_context mEntropy;
 	mbedtls_ctr_drbg_context mDrbg;
 	mbedtls_ctr_drbg_context mDrbg;
 	mbedtls_ssl_config mConf;
 	mbedtls_ssl_config mConf;
 	mbedtls_ssl_context mSsl;
 	mbedtls_ssl_context mSsl;
 
 
+	std::mutex mSslMutex;
+	std::atomic<bool> mOutgoingResult = true;
+
 	message_ptr mIncomingMessage;
 	message_ptr mIncomingMessage;
 	size_t mIncomingMessagePosition = 0;
 	size_t mIncomingMessagePosition = 0;