Browse Source

Merge branch 'v0.18'

Paul-Louis Ageneau 2 years ago
parent
commit
12e00da55b
4 changed files with 87 additions and 41 deletions
  1. 45 17
      src/impl/dtlstransport.cpp
  2. 1 0
      src/impl/dtlstransport.hpp
  3. 40 16
      src/impl/tlstransport.cpp
  4. 1 8
      src/impl/tlstransport.hpp

+ 45 - 17
src/impl/dtlstransport.cpp

@@ -819,13 +819,18 @@ void DtlsTransport::start() {
 	registerIncoming();
 	changeState(State::Connecting);
 
-	size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
-	SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
-	PLOG_VERBOSE << "DTLS MTU set to " << mtu;
+	{
+		std::lock_guard lock(mSslMutex);
 
-	// Initiate the handshake
-	int ret = SSL_do_handshake(mSsl);
-	openssl::check(mSsl, ret, "Handshake initiation failed");
+		size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
+		SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
+		PLOG_VERBOSE << "DTLS MTU set to " << mtu;
+
+		// Initiate the handshake
+		int ret = SSL_do_handshake(mSsl);
+
+		openssl::check(mSsl, ret, "Handshake initiation failed");
+	}
 
 	handleTimeout();
 }
@@ -843,8 +848,10 @@ bool DtlsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
+	std::lock_guard lock(mSslMutex);
 	mCurrentDscp = message->dscp;
 	int ret = SSL_write(mSsl, message->data(), int(message->size()));
+
 	if (!openssl::check(mSsl, ret))
 		return false;
 
@@ -910,23 +917,39 @@ void DtlsTransport::doRecv() {
 
 			if (state() == State::Connecting) {
 				// Continue the handshake
-				int ret = SSL_do_handshake(mSsl);
-				if (!openssl::check(mSsl, ret, "Handshake failed"))
-					break;
+				bool finished;
+				{
+					std::lock_guard lock(mSslMutex);
+					int ret = SSL_do_handshake(mSsl);
 
-				if (SSL_is_init_finished(mSsl)) {
+					if (!openssl::check(mSsl, ret, "Handshake failed"))
+						break;
+
+					finished = (SSL_is_init_finished(mSsl) != 0);
+				}
+				if (finished) {
 					// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
 					// See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
-					SSL_set_mtu(mSsl, bufferSize + 1);
+					{
+						std::lock_guard lock(mSslMutex);
+						SSL_set_mtu(mSsl, bufferSize + 1);
+					}
 
 					PLOG_INFO << "DTLS handshake finished";
 					postHandshake();
 					changeState(State::Connected);
 				}
-			} else {
-				int ret = SSL_read(mSsl, buffer, bufferSize);
-				if (!openssl::check(mSsl, ret))
-					break;
+			}
+
+			if (state() == State::Connected) {
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = SSL_read(mSsl, buffer, bufferSize);
+
+					if (!openssl::check(mSsl, ret))
+						break;
+				}
 
 				if (ret > 0)
 					recv(make_message(buffer, buffer + ret));
@@ -937,8 +960,6 @@ void DtlsTransport::doRecv() {
 		PLOG_ERROR << "DTLS recv: " << e.what();
 	}
 
-	SSL_shutdown(mSsl);
-
 	if (state() == State::Connected) {
 		PLOG_INFO << "DTLS closed";
 		changeState(State::Disconnected);
@@ -947,9 +968,16 @@ void DtlsTransport::doRecv() {
 		PLOG_ERROR << "DTLS handshake failed";
 		changeState(State::Failed);
 	}
+
+	{
+		std::lock_guard lock(mSslMutex);
+		SSL_shutdown(mSsl);
+	}
 }
 
 void DtlsTransport::handleTimeout() {
+	std::lock_guard lock(mSslMutex);
+
 	// Warning: This function breaks the usual return value convention
 	int ret = DTLSv1_handle_timeout(mSsl);
 	if (ret < 0) {

+ 1 - 0
src/impl/dtlstransport.hpp

@@ -99,6 +99,7 @@ protected:
 	SSL_CTX *mCtx = NULL;
 	SSL *mSsl = NULL;
 	BIO *mInBio, *mOutBio;
+	std::mutex mSslMutex;
 
 	void handleTimeout();
 

+ 40 - 16
src/impl/tlstransport.cpp

@@ -639,9 +639,9 @@ void TlsTransport::start() {
 	changeState(State::Connecting);
 
 	// Initiate the handshake
+	std::lock_guard lock(mSslMutex);
 	int ret = SSL_do_handshake(mSsl);
 	openssl::check(mSsl, ret, "Handshake initiation failed");
-
 	flushOutput();
 }
 
@@ -661,6 +661,7 @@ bool TlsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
+	std::lock_guard lock(mSslMutex);
 	int ret = SSL_write(mSsl, message->data(), int(message->size()));
 	if (!openssl::check(mSsl, ret))
 		throw std::runtime_error("TLS send failed");
@@ -711,13 +712,18 @@ void TlsTransport::doRecv() {
 
 			if (state() == State::Connecting) {
 				// Continue the handshake
-				int ret = SSL_do_handshake(mSsl);
-				if (!openssl::check(mSsl, ret, "Handshake failed"))
-					break;
+				bool finished;
+				{
+					std::lock_guard lock(mSslMutex);
+					int ret = SSL_do_handshake(mSsl);
+					if (!openssl::check(mSsl, ret, "Handshake failed"))
+						break;
 
-				flushOutput();
+					flushOutput();
+					finished = (SSL_is_init_finished(mSsl) != 0);
+				}
 
-				if (SSL_is_init_finished(mSsl)) {
+				if (finished) {
 					PLOG_INFO << "TLS handshake finished";
 					changeState(State::Connected);
 					postHandshake();
@@ -726,11 +732,25 @@ void TlsTransport::doRecv() {
 
 			if (state() == State::Connected) {
 				int ret;
-				while ((ret = SSL_read(mSsl, buffer, bufferSize)) > 0)
-					recv(make_message(buffer, buffer + ret));
+				while (true) {
+					{
+						std::lock_guard lock(mSslMutex);
+						ret = SSL_read(mSsl, buffer, bufferSize);
+					}
 
-				if (!openssl::check(mSsl, ret))
-					break;
+					if (ret > 0)
+						recv(make_message(buffer, buffer + ret));
+					else
+						break;
+				}
+
+				{
+					std::lock_guard lock(mSslMutex);
+					if (!openssl::check(mSsl, ret))
+						break;
+
+					flushOutput(); // SSL_read() can also cause write operations
+				}
 			}
 		}
 
@@ -738,8 +758,6 @@ void TlsTransport::doRecv() {
 		PLOG_ERROR << "TLS recv: " << e.what();
 	}
 
-	SSL_shutdown(mSsl);
-
 	if (state() == State::Connected) {
 		PLOG_INFO << "TLS closed";
 		changeState(State::Disconnected);
@@ -748,15 +766,21 @@ void TlsTransport::doRecv() {
 		PLOG_ERROR << "TLS handshake failed";
 		changeState(State::Failed);
 	}
+
+	{
+		std::lock_guard lock(mSslMutex);
+		SSL_shutdown(mSsl);
+	}
 }
 
 bool TlsTransport::flushOutput() {
+	// Requires mSslMutex to be locked
+	bool result = true;
 	const size_t bufferSize = 4096;
 	byte buffer[bufferSize];
-	int ret;
-	bool result = true;
-	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
-		result = outgoing(make_message(buffer, buffer + ret));
+	int len;
+	while ((len = BIO_read(mOutBio, buffer, bufferSize)) > 0)
+		result = outgoing(make_message(buffer, buffer + len));
 
 	return result;
 }

+ 1 - 8
src/impl/tlstransport.hpp

@@ -85,20 +85,13 @@ protected:
 	SSL_CTX *mCtx;
 	SSL *mSsl;
 	BIO *mInBio, *mOutBio;
+	std::mutex mSslMutex;
 
 	bool flushOutput();
 
-	static BIO_METHOD *BioMethods;
 	static int TransportExIndex;
-	static std::mutex GlobalMutex;
 
-	static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
 	static void InfoCallback(const SSL *ssl, int where, int ret);
-
-	static int BioMethodNew(BIO *bio);
-	static int BioMethodFree(BIO *bio);
-	static int BioMethodWrite(BIO *bio, const char *in, int inl);
-	static long BioMethodCtrl(BIO *bio, int cmd, long num, void *ptr);
 #endif
 };