Browse Source

Fix and refactor MbedTLS sync

Paul-Louis Ageneau 2 years ago
parent
commit
4129d5f0f6
4 changed files with 43 additions and 21 deletions
  1. 18 11
      src/impl/dtlstransport.cpp
  2. 2 2
      src/impl/dtlstransport.hpp
  3. 19 5
      src/impl/tlstransport.cpp
  4. 4 3
      src/impl/tlstransport.hpp

+ 18 - 11
src/impl/dtlstransport.cpp

@@ -367,7 +367,7 @@ int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* m
 
 #elif USE_MBEDTLS
 
-mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = {
+const mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = {
     MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80,
     MBEDTLS_TLS_SRTP_UNSET,
 };
@@ -473,15 +473,15 @@ bool DtlsTransport::send(message_ptr message) {
 
 	int ret;
 	do {
-		std::lock_guard lock(mMutex);
-		mCurrentDscp = message->dscp;
-
+		std::lock_guard lock(mSslMutex);
 		if (message->size() > size_t(mbedtls_ssl_get_max_out_record_payload(&mSsl)))
 			return false;
 
+		mCurrentDscp = message->dscp;
 		ret = mbedtls_ssl_write(&mSsl, reinterpret_cast<const unsigned char *>(message->data()),
 		                        message->size());
 	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
+
 	mbedtls::check(ret);
 
 	return mOutgoingResult;
@@ -529,7 +529,11 @@ void DtlsTransport::doRecv() {
 		// Handle handshake if connecting
 		if (state() == State::Connecting) {
 			while (true) {
-				auto ret = mbedtls_ssl_handshake(&mSsl);
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = mbedtls_ssl_handshake(&mSsl);
+				}
 
 				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
 					ThreadPool::Instance().schedule(mTimerSetAt + milliseconds(mFinMs),
@@ -541,8 +545,9 @@ void DtlsTransport::doRecv() {
 				}
 
 				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS)
+				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
 					continue;
+				}
 
 				mbedtls::check(ret, "Handshake failed");
 
@@ -555,17 +560,19 @@ void DtlsTransport::doRecv() {
 
 		if (state() == State::Connected) {
 			while (true) {
-				mMutex.lock();
-				auto ret =
-				    mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer), bufferSize);
-				mMutex.unlock();
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer),
+					                       bufferSize);
+				}
 
 				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
 					return;
 				}
 
 				if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
-				           ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
+				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
 					continue;
 				}
 

+ 2 - 2
src/impl/dtlstransport.hpp

@@ -71,13 +71,13 @@ protected:
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 
 #elif USE_MBEDTLS
-	std::mutex mMutex;
-
 	mbedtls_entropy_context mEntropy;
 	mbedtls_ctr_drbg_context mDrbg;
 	mbedtls_ssl_config mConf;
 	mbedtls_ssl_context mSsl;
 
+	std::mutex mSslMutex;
+
 	uint32_t mFinMs = 0, mIntMs = 0;
 	std::chrono::time_point<std::chrono::steady_clock> mTimerSetAt;
 

+ 19 - 5
src/impl/tlstransport.cpp

@@ -379,8 +379,14 @@ bool TlsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
-	mbedtls::check(mbedtls_ssl_write(
-	    &mSsl, reinterpret_cast<const unsigned char *>(message->data()), int(message->size())));
+	int ret;
+	do {
+		std::lock_guard lock(mSslMutex);
+		ret = mbedtls_ssl_write(&mSsl, reinterpret_cast<const unsigned char *>(message->data()),
+		                        int(message->size()));
+	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
+
+	mbedtls::check(ret);
 
 	return mOutgoingResult;
 }
@@ -421,7 +427,11 @@ void TlsTransport::doRecv() {
 		// Handle handshake if connecting
 		if (state() == State::Connecting) {
 			while (true) {
-				auto ret = mbedtls_ssl_handshake(&mSsl);
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = mbedtls_ssl_handshake(&mSsl);
+				}
 
 				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
 					return;
@@ -443,8 +453,12 @@ void TlsTransport::doRecv() {
 
 		if (state() == State::Connected) {
 			while (true) {
-				auto ret =
-				    mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer), bufferSize);
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer),
+					                       bufferSize);
+				}
 
 				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
 					return;

+ 4 - 3
src/impl/tlstransport.hpp

@@ -65,15 +65,16 @@ protected:
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
-#elif USE_MBEDTLS
-	std::mutex mSendMutex;
-	std::atomic<bool> mOutgoingResult = true;
 
+#elif USE_MBEDTLS
 	mbedtls_entropy_context mEntropy;
 	mbedtls_ctr_drbg_context mDrbg;
 	mbedtls_ssl_config mConf;
 	mbedtls_ssl_context mSsl;
 
+	std::mutex mSslMutex;
+	std::atomic<bool> mOutgoingResult = true;
+
 	message_ptr mIncomingMessage;
 	size_t mIncomingMessagePosition = 0;