Browse Source

Merge pull request #910 from paullouisageneau/refactor-openssl-logic

Refactor OpenSSL error handling logic
Paul-Louis Ageneau 2 years ago
parent
commit
175434011f
4 changed files with 79 additions and 66 deletions
  1. 25 19
      src/impl/dtlstransport.cpp
  2. 16 21
      src/impl/tls.cpp
  3. 1 1
      src/impl/tls.hpp
  4. 37 25
      src/impl/tlstransport.cpp

+ 25 - 19
src/impl/dtlstransport.cpp

@@ -819,6 +819,7 @@ void DtlsTransport::start() {
 	registerIncoming();
 	registerIncoming();
 	changeState(State::Connecting);
 	changeState(State::Connecting);
 
 
+	int ret, err;
 	{
 	{
 		std::lock_guard lock(mSslMutex);
 		std::lock_guard lock(mSslMutex);
 
 
@@ -827,11 +828,12 @@ void DtlsTransport::start() {
 		PLOG_VERBOSE << "DTLS MTU set to " << mtu;
 		PLOG_VERBOSE << "DTLS MTU set to " << mtu;
 
 
 		// Initiate the handshake
 		// Initiate the handshake
-		int ret = SSL_do_handshake(mSsl);
-
-		openssl::check(mSsl, ret, "Handshake initiation failed");
+		ret = SSL_do_handshake(mSsl);
+		err = SSL_get_error(mSsl, ret);
 	}
 	}
 
 
+	openssl::check_error(err, "Handshake failed");
+
 	handleTimeout();
 	handleTimeout();
 }
 }
 
 
@@ -848,11 +850,15 @@ bool DtlsTransport::send(message_ptr message) {
 
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 	PLOG_VERBOSE << "Send size=" << message->size();
 
 
-	std::lock_guard lock(mSslMutex);
-	mCurrentDscp = message->dscp;
-	int ret = SSL_write(mSsl, message->data(), int(message->size()));
+	int ret, err;
+	{
+		std::lock_guard lock(mSslMutex);
+		mCurrentDscp = message->dscp;
+		ret = SSL_write(mSsl, message->data(), int(message->size()));
+		err = SSL_get_error(mSsl, ret);
+	}
 
 
-	if (!openssl::check(mSsl, ret))
+	if (!openssl::check_error(err))
 		return false;
 		return false;
 
 
 	return mOutgoingResult;
 	return mOutgoingResult;
@@ -917,17 +923,14 @@ void DtlsTransport::doRecv() {
 
 
 			if (state() == State::Connecting) {
 			if (state() == State::Connecting) {
 				// Continue the handshake
 				// Continue the handshake
-				bool finished;
+				int ret, err;
 				{
 				{
 					std::lock_guard lock(mSslMutex);
 					std::lock_guard lock(mSslMutex);
-					int ret = SSL_do_handshake(mSsl);
-
-					if (!openssl::check(mSsl, ret, "Handshake failed"))
-						break;
-
-					finished = (SSL_is_init_finished(mSsl) != 0);
+					ret = SSL_do_handshake(mSsl);
+					err = SSL_get_error(mSsl, ret);
 				}
 				}
-				if (finished) {
+
+				if (openssl::check_error(err, "Handshake failed")) {
 					// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
 					// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
 					// See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
 					// See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
 					{
 					{
@@ -942,16 +945,19 @@ void DtlsTransport::doRecv() {
 			}
 			}
 
 
 			if (state() == State::Connected) {
 			if (state() == State::Connected) {
-				int ret;
+				int ret, err;
 				{
 				{
 					std::lock_guard lock(mSslMutex);
 					std::lock_guard lock(mSslMutex);
 					ret = SSL_read(mSsl, buffer, bufferSize);
 					ret = SSL_read(mSsl, buffer, bufferSize);
+					err = SSL_get_error(mSsl, ret);
+				}
 
 
-					if (!openssl::check(mSsl, ret))
-						break;
+				if (err == SSL_ERROR_ZERO_RETURN) {
+					PLOG_DEBUG << "TLS connection cleanly closed";
+					break;
 				}
 				}
 
 
-				if (ret > 0)
+				if (openssl::check_error(err))
 					recv(make_message(buffer, buffer + ret));
 					recv(make_message(buffer, buffer + ret));
 			}
 			}
 		}
 		}

+ 16 - 21
src/impl/tls.cpp

@@ -177,34 +177,29 @@ bool check(int success, const string &message) {
 	if (success > 0)
 	if (success > 0)
 		return true;
 		return true;
 
 
-	string str = message;
-	if (last_error != 0)
-		str += ": " + error_string(last_error);
-
-	throw std::runtime_error(str);
+	throw std::runtime_error(message + (last_error != 0 ? ": " + error_string(last_error) : ""));
 }
 }
 
 
-// Return false on EOF
-bool check(SSL *ssl, int ret, const string &message) {
+// Return false on recoverable error
+bool check_error(int err, const string &message) {
 	unsigned long last_error = ERR_peek_last_error();
 	unsigned long last_error = ERR_peek_last_error();
 	ERR_clear_error();
 	ERR_clear_error();
 
 
-	int err = SSL_get_error(ssl, ret);
-	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
+	if (err == SSL_ERROR_NONE)
 		return true;
 		return true;
-	}
-	if (err == SSL_ERROR_ZERO_RETURN) {
-		return false;
-	}
 
 
-	string str = message;
-	if (err == SSL_ERROR_SYSCALL) {
-		str += ": fatal I/O error";
-	} else if (err == SSL_ERROR_SSL) {
-		if (last_error != 0)
-			str += ": " + error_string(last_error);
-	}
-	throw std::runtime_error(str);
+	if (err == SSL_ERROR_ZERO_RETURN)
+		throw std::runtime_error(message + ": peer closed connection");
+
+	if (err == SSL_ERROR_SYSCALL)
+		throw std::runtime_error(message + ": fatal I/O error");
+
+	if (err == SSL_ERROR_SSL)
+		throw std::runtime_error(message +
+		                         (last_error != 0 ? ": " + error_string(last_error) : ""));
+
+	// SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE end up here
+	return false;
 }
 }
 
 
 BIO *BIO_new_from_file(const string &filename) {
 BIO *BIO_new_from_file(const string &filename) {

+ 1 - 1
src/impl/tls.hpp

@@ -85,7 +85,7 @@ void init();
 string error_string(unsigned long error);
 string error_string(unsigned long error);
 
 
 bool check(int success, const string &message = "OpenSSL error");
 bool check(int success, const string &message = "OpenSSL error");
-bool check(SSL *ssl, int ret, const string &message = "OpenSSL error");
+bool check_error(int err, const string &message = "OpenSSL error");
 
 
 BIO *BIO_new_from_file(const string &filename);
 BIO *BIO_new_from_file(const string &filename);
 
 

+ 37 - 25
src/impl/tlstransport.cpp

@@ -386,7 +386,8 @@ bool TlsTransport::send(message_ptr message) {
 		                        int(message->size()));
 		                        int(message->size()));
 	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
 	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
 
 
-	mbedtls::check(ret);
+	if (!mbedtls::check(ret))
+		throw std::runtime_error("TLS send failed");
 
 
 	return mOutgoingResult;
 	return mOutgoingResult;
 }
 }
@@ -639,10 +640,15 @@ void TlsTransport::start() {
 	changeState(State::Connecting);
 	changeState(State::Connecting);
 
 
 	// Initiate the handshake
 	// Initiate the handshake
-	std::lock_guard lock(mSslMutex);
-	int ret = SSL_do_handshake(mSsl);
-	openssl::check(mSsl, ret, "Handshake initiation failed");
-	flushOutput();
+	int ret, err;
+	{
+		std::lock_guard lock(mSslMutex);
+		ret = SSL_do_handshake(mSsl);
+		err = SSL_get_error(mSsl, ret);
+		flushOutput();
+	}
+
+	openssl::check_error(err, "Handshake failed");
 }
 }
 
 
 void TlsTransport::stop() {
 void TlsTransport::stop() {
@@ -661,12 +667,19 @@ bool TlsTransport::send(message_ptr message) {
 
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 	PLOG_VERBOSE << "Send size=" << message->size();
 
 
-	std::lock_guard lock(mSslMutex);
-	int ret = SSL_write(mSsl, message->data(), int(message->size()));
-	if (!openssl::check(mSsl, ret))
+	int err;
+	bool result;
+	{
+		std::lock_guard lock(mSslMutex);
+		int ret = SSL_write(mSsl, message->data(), int(message->size()));
+		err = SSL_get_error(mSsl, ret);
+		result = flushOutput();
+	}
+
+	if (!openssl::check_error(err))
 		throw std::runtime_error("TLS send failed");
 		throw std::runtime_error("TLS send failed");
 
 
-	return flushOutput();
+	return result;
 }
 }
 
 
 void TlsTransport::incoming(message_ptr message) {
 void TlsTransport::incoming(message_ptr message) {
@@ -698,7 +711,7 @@ void TlsTransport::doRecv() {
 		const size_t bufferSize = 4096;
 		const size_t bufferSize = 4096;
 		byte buffer[bufferSize];
 		byte buffer[bufferSize];
 
 
-		// Process incoming messages
+		// Read incoming messages
 		while (mIncomingQueue.running()) {
 		while (mIncomingQueue.running()) {
 			auto next = mIncomingQueue.pop();
 			auto next = mIncomingQueue.pop();
 			if (!next)
 			if (!next)
@@ -712,18 +725,15 @@ void TlsTransport::doRecv() {
 
 
 			if (state() == State::Connecting) {
 			if (state() == State::Connecting) {
 				// Continue the handshake
 				// Continue the handshake
-				bool finished;
+				int ret, err;
 				{
 				{
 					std::lock_guard lock(mSslMutex);
 					std::lock_guard lock(mSslMutex);
-					int ret = SSL_do_handshake(mSsl);
-					if (!openssl::check(mSsl, ret, "Handshake failed"))
-						break;
-
+					ret = SSL_do_handshake(mSsl);
+					err = SSL_get_error(mSsl, ret);
 					flushOutput();
 					flushOutput();
-					finished = (SSL_is_init_finished(mSsl) != 0);
 				}
 				}
 
 
-				if (finished) {
+				if (openssl::check_error(err, "Handshake failed")) {
 					PLOG_INFO << "TLS handshake finished";
 					PLOG_INFO << "TLS handshake finished";
 					changeState(State::Connected);
 					changeState(State::Connected);
 					postHandshake();
 					postHandshake();
@@ -731,25 +741,27 @@ void TlsTransport::doRecv() {
 			}
 			}
 
 
 			if (state() == State::Connected) {
 			if (state() == State::Connected) {
-				int ret;
+				int ret, err;
 				while (true) {
 				while (true) {
 					{
 					{
 						std::lock_guard lock(mSslMutex);
 						std::lock_guard lock(mSslMutex);
 						ret = SSL_read(mSsl, buffer, bufferSize);
 						ret = SSL_read(mSsl, buffer, bufferSize);
+						err = SSL_get_error(mSsl, ret);
+						flushOutput(); // SSL_read() can also cause write operations
 					}
 					}
 
 
-					if (ret > 0)
+					if (err == SSL_ERROR_ZERO_RETURN)
+						break;
+
+					if (openssl::check_error(err))
 						recv(make_message(buffer, buffer + ret));
 						recv(make_message(buffer, buffer + ret));
 					else
 					else
 						break;
 						break;
 				}
 				}
 
 
-				{
-					std::lock_guard lock(mSslMutex);
-					if (!openssl::check(mSsl, ret))
-						break;
-
-					flushOutput(); // SSL_read() can also cause write operations
+				if (err == SSL_ERROR_ZERO_RETURN) {
+					PLOG_DEBUG << "TLS connection cleanly closed";
+					break; // No more data can be read
 				}
 				}
 			}
 			}
 		}
 		}