Browse Source

Synchronize calls to the same OpenSSL object

Paul-Louis Ageneau 2 years ago
parent
commit
d018fdbfe0
4 changed files with 87 additions and 38 deletions
  1. 43 17
      src/impl/dtlstransport.cpp
  2. 1 0
      src/impl/dtlstransport.hpp
  3. 39 19
      src/impl/tlstransport.cpp
  4. 4 2
      src/impl/tlstransport.hpp

+ 43 - 17
src/impl/dtlstransport.cpp

@@ -489,8 +489,13 @@ bool DtlsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
-	mCurrentDscp = message->dscp;
-	int ret = SSL_write(mSsl, message->data(), int(message->size()));
+	int ret;
+	{
+		std::lock_guard lock(mSslMutex);
+		mCurrentDscp = message->dscp;
+		ret = SSL_write(mSsl, message->data(), int(message->size()));
+	}
+
 	if (!openssl::check(mSsl, ret))
 		return false;
 
@@ -529,13 +534,18 @@ void DtlsTransport::runRecvLoop() {
 	try {
 		changeState(State::Connecting);
 
-		size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
-		SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
-		PLOG_VERBOSE << "SSL MTU set to " << mtu;
-
 		// Initiate the handshake
-		int ret = SSL_do_handshake(mSsl);
-		openssl::check(mSsl, ret, "Handshake failed");
+		{
+			std::lock_guard lock(mSslMutex);
+
+			size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
+			SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
+			PLOG_VERBOSE << "SSL MTU set to " << mtu;
+
+			int ret = SSL_do_handshake(mSsl);
+
+			openssl::check(mSsl, ret, "Handshake failed");
+		}
 
 		byte buffer[bufferSize];
 		while (mIncomingQueue.running()) {
@@ -549,23 +559,38 @@ void DtlsTransport::runRecvLoop() {
 
 				if (state() == State::Connecting) {
 					// Continue the handshake
-					ret = SSL_do_handshake(mSsl);
-					if (!openssl::check(mSsl, ret, "Handshake failed"))
-						break;
+					bool finished;
+					{
+						std::lock_guard lock(mSslMutex);
+						int ret = SSL_do_handshake(mSsl);
 
-					if (SSL_is_init_finished(mSsl)) {
+						if (!openssl::check(mSsl, ret, "Handshake failed"))
+							break;
+
+						finished = (SSL_is_init_finished(mSsl) != 0);
+					}
+
+					if (finished) {
 						// RFC 8261: DTLS MUST support sending messages larger than the current path
 						// MTU See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
-						SSL_set_mtu(mSsl, bufferSize + 1);
+						{
+							std::lock_guard lock(mSslMutex);
+							SSL_set_mtu(mSsl, bufferSize + 1);
+						}
 
 						PLOG_INFO << "DTLS handshake finished";
 						postHandshake();
 						changeState(State::Connected);
 					}
 				} else {
-					ret = SSL_read(mSsl, buffer, bufferSize);
-					if (!openssl::check(mSsl, ret))
-						break;
+					int ret;
+					{
+						std::lock_guard lock(mSslMutex);
+						ret = SSL_read(mSsl, buffer, bufferSize);
+
+						if (!openssl::check(mSsl, ret))
+							break;
+					}
 
 					if (ret > 0)
 						recv(make_message(buffer, buffer + ret));
@@ -575,8 +600,9 @@ void DtlsTransport::runRecvLoop() {
 			// No more messages pending, retransmit and rearm timeout if connecting
 			optional<milliseconds> duration;
 			if (state() == State::Connecting) {
+				std::lock_guard lock(mSslMutex);
 				// Warning: This function breaks the usual return value convention
-				ret = DTLSv1_handle_timeout(mSsl);
+				int ret = DTLSv1_handle_timeout(mSsl);
 				if (ret < 0) {
 					throw std::runtime_error("Handshake timeout"); // write BIO can't fail
 				} else if (ret > 0) {

+ 1 - 0
src/impl/dtlstransport.hpp

@@ -73,6 +73,7 @@ protected:
 	SSL_CTX *mCtx = NULL;
 	SSL *mSsl = NULL;
 	BIO *mInBio, *mOutBio;
+	std::mutex mSslMutex;
 
 	static BIO_METHOD *BioMethods;
 	static int TransportExIndex;

+ 39 - 19
src/impl/tlstransport.cpp

@@ -403,16 +403,13 @@ bool TlsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
+	std::lock_guard lock(mSslMutex);
 	int ret = SSL_write(mSsl, message->data(), int(message->size()));
+	bool result = flushOutput();
+
 	if (!openssl::check(mSsl, ret))
 		throw std::runtime_error("TLS send failed");
 
-	const size_t bufferSize = 4096;
-	byte buffer[bufferSize];
-	bool result = true;
-	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
-		result = outgoing(make_message(buffer, buffer + ret));
-
 	return result;
 }
 
@@ -439,19 +436,22 @@ void TlsTransport::runRecvLoop() {
 	try {
 		changeState(State::Connecting);
 
-		int ret;
 		while (true) {
 			if (state() == State::Connecting) {
 				// Initiate or continue the handshake
-				ret = SSL_do_handshake(mSsl);
-				if (!openssl::check(mSsl, ret, "Handshake failed"))
-					break;
+				bool finished;
+				{
+					std::lock_guard lock(mSslMutex);
+					int ret = SSL_do_handshake(mSsl);
+					flushOutput();
+
+					if (!openssl::check(mSsl, ret, "Handshake failed"))
+						break;
 
-				// Output
-				while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
-					outgoing(make_message(buffer, buffer + ret));
+					finished = (SSL_is_init_finished(mSsl) != 0);
+				}
 
-				if (SSL_is_init_finished(mSsl)) {
+				if (finished) {
 					PLOG_INFO << "TLS handshake finished";
 					changeState(State::Connected);
 					postHandshake();
@@ -459,12 +459,20 @@ void TlsTransport::runRecvLoop() {
 			}
 
 			if (state() == State::Connected) {
-				// Input
-				while ((ret = SSL_read(mSsl, buffer, bufferSize)) > 0)
-					recv(make_message(buffer, buffer + ret));
+				int ret;
+				{
+					std::lock_guard lock(mSslMutex);
+					ret = SSL_read(mSsl, buffer, bufferSize);
+					flushOutput(); // SSL_read() can also cause write operations
+
+					if (!openssl::check(mSsl, ret))
+						break;
+				}
 
-				if (!openssl::check(mSsl, ret))
-					break;
+				if (ret > 0) {
+					recv(make_message(buffer, buffer + ret));
+					continue;
+				}
 			}
 
 			auto next = mIncomingQueue.pop();
@@ -492,6 +500,18 @@ void TlsTransport::runRecvLoop() {
 	}
 }
 
+bool TlsTransport::flushOutput() {
+	// Requires mSslMutex to be locked
+	bool result = true;
+	const size_t bufferSize = 4096;
+	byte buffer[bufferSize];
+	int len;
+	while ((len = BIO_read(mOutBio, buffer, bufferSize)) > 0)
+		result = outgoing(make_message(buffer, buffer + len));
+
+	return result;
+}
+
 void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
 	TlsTransport *t =
 	    static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));

+ 4 - 2
src/impl/tlstransport.hpp

@@ -67,11 +67,13 @@ protected:
 	SSL_CTX *mCtx;
 	SSL *mSsl;
 	BIO *mInBio, *mOutBio;
+	std::mutex mSslMutex;
 
-	static int TransportExIndex;
+	bool flushOutput();
 
-	static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
+	static int TransportExIndex;
 	static void InfoCallback(const SSL *ssl, int where, int ret);
+
 #endif
 };