فهرست منبع

Fixed TLS layer with OpenSSL

Paul-Louis Ageneau 5 سال پیش
والد
کامیت
009e2e6767
2فایلهای تغییر یافته به همراه75 افزوده شده و 61 حذف شده
  1. 70 58
      src/tlstransport.cpp
  2. 5 3
      src/tlstransport.hpp

+ 70 - 58
src/tlstransport.cpp

@@ -115,10 +115,11 @@ bool TlsTransport::stop() {
 }
 
 bool TlsTransport::send(message_ptr message) {
-	if (!message)
+	if (!message || state() != State::Connected)
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
+
 	if(message->size() == 0)
 		return true;
 
@@ -220,8 +221,8 @@ ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_
 			message = *next;
 			if (message->size() > 0)
 				break;
-
-			t->recv(message); // Pass zero-sized messages through
+			else
+				t->recv(message); // Pass zero-sized messages through
 		}
 	}
 
@@ -312,49 +313,62 @@ void TlsTransport::Cleanup() {
 }
 
 TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
-    : Transport(lower, std::move(callback)) {
+    : Transport(lower, std::move(callback)), mHost(std::move(host)) {
 
 	PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
 
-	if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
-		throw std::runtime_error("Failed to create SSL context");
+	try {
+		if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
+			throw std::runtime_error("Failed to create SSL context");
+
+		check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
+		              "Failed to set SSL priorities");
 
-	check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
-	              "Failed to set SSL priorities");
+		SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
+		SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
+		SSL_CTX_set_read_ahead(mCtx, 1);
+		SSL_CTX_set_quiet_shutdown(mCtx, 1);
+		SSL_CTX_set_info_callback(mCtx, InfoCallback);
 
-	SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
-	SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
-	SSL_CTX_set_read_ahead(mCtx, 1);
-	SSL_CTX_set_quiet_shutdown(mCtx, 1);
-	SSL_CTX_set_info_callback(mCtx, InfoCallback);
+		SSL_CTX_set_default_verify_paths(mCtx);
+		SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL);
+		SSL_CTX_set_verify_depth(mCtx, 4);
 
-	SSL_CTX_set_default_verify_paths(mCtx);
-	SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL);
-	SSL_CTX_set_verify_depth(mCtx, 4);
+		if (!(mSsl = SSL_new(mCtx)))
+			throw std::runtime_error("Failed to create SSL instance");
 
-	if (!(mSsl = SSL_new(mCtx)))
-		throw std::runtime_error("Failed to create SSL instance");
+		SSL_set_ex_data(mSsl, TransportExIndex, this);
 
-	SSL_set_ex_data(mSsl, TransportExIndex, this);
+		SSL_set_hostflags(mSsl, 0);
+		check_openssl(SSL_set1_host(mSsl, mHost.c_str()), "Failed to set SSL host");
 
-	PLOG_VERBOSE << "Server Name Indication: " << host;
-	SSL_set_tlsext_host_name(mSsl, host.c_str());
+		PLOG_VERBOSE << "Server Name Indication: " << mHost;
+		SSL_set_tlsext_host_name(mSsl, mHost.c_str());
 
-	SSL_set_connect_state(mSsl);
+		SSL_set_connect_state(mSsl);
 
-	if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
-		throw std::runtime_error("Failed to create BIO");
+		if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
+			throw std::runtime_error("Failed to create BIO");
 
-	BIO_set_mem_eof_return(mInBio, BIO_EOF);
-	BIO_set_mem_eof_return(mOutBio, BIO_EOF);
-	SSL_set_bio(mSsl, mInBio, mOutBio);
+		BIO_set_mem_eof_return(mInBio, BIO_EOF);
+		BIO_set_mem_eof_return(mOutBio, BIO_EOF);
+		SSL_set_bio(mSsl, mInBio, mOutBio);
 
-	auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
-	    EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
-	SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
-	SSL_set_tmp_ecdh(mSsl, ecdh.get());
+		auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
+		    EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
+		SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
+		SSL_set_tmp_ecdh(mSsl, ecdh.get());
 
-	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+		registerIncoming();
+
+	} catch (...) {
+		if (mSsl)
+			SSL_free(mSsl);
+		if (mCtx)
+			SSL_CTX_free(mCtx);
+		throw;
+	}
 }
 
 TlsTransport::~TlsTransport() {
@@ -376,17 +390,22 @@ bool TlsTransport::stop() {
 }
 
 bool TlsTransport::send(message_ptr message) {
-	if (!message)
+	if (!message || state() != State::Connected)
 		return false;
 
+	PLOG_VERBOSE << "Send size=" << message->size();
+
+	if (message->size() == 0)
+		return true;
+
 	int ret = SSL_write(mSsl, message->data(), message->size());
 	if (!check_openssl_ret(mSsl, ret))
 		return false;
 
 	const size_t bufferSize = 4096;
 	byte buffer[bufferSize];
-	while (int len = BIO_read(mOutBio, buffer, bufferSize))
-		outgoing(make_message(buffer, buffer + len));
+	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
+		outgoing(make_message(buffer, buffer + ret));
 
 	return true;
 }
@@ -405,35 +424,17 @@ void TlsTransport::runRecvLoop() {
 	try {
 		changeState(State::Connecting);
 
-		// Initiate the handshake
-		int ret = SSL_do_handshake(mSsl);
-		check_openssl_ret(mSsl, ret, "Handshake failed");
-
 		while (true) {
-			// Output
-			while (int len = BIO_read(mOutBio, buffer, bufferSize))
-				outgoing(make_message(buffer, buffer + len));
-
-			auto next = mIncomingQueue.pop();
-			if (!next)
-				break;
-			message_ptr message = *next;
-
-			if (message->size() == 0) {
-				// Pass zero-sized messages through
-				recv(message);
-				continue;
-			}
-
-			// Input
-			BIO_write(mInBio, message->data(), message->size());
-
 			if (state() == State::Connecting) {
-				// Continue the handshake
+				// Initiate or continue the handshake
 				int ret = SSL_do_handshake(mSsl);
 				if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
 					break;
 
+				// Output
+				while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
+					outgoing(make_message(buffer, buffer + ret));
+
 				if (SSL_is_init_finished(mSsl)) {
 					PLOG_INFO << "TLS handshake finished";
 					changeState(State::Connected);
@@ -446,7 +447,18 @@ void TlsTransport::runRecvLoop() {
 				if (ret > 0)
 					recv(make_message(buffer, buffer + ret));
 			}
+
+			auto next = mIncomingQueue.pop();
+			if (!next)
+				break;
+
+			message_ptr message = *next;
+			if (message->size() > 0)
+				BIO_write(mInBio, message->data(), message->size()); // Input
+			else
+				recv(message); // Pass zero-sized messages through
 		}
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TLS recv: " << e.what();
 	}

+ 5 - 3
src/tlstransport.hpp

@@ -55,15 +55,17 @@ public:
 protected:
 	void runRecvLoop();
 
+	string mHost;
+
 	Queue<message_ptr> mIncomingQueue;
-	message_ptr mIncomingMessage;
-	size_t mIncomingMessagePosition = 0;
 	std::thread mRecvThread;
 
 #if USE_GNUTLS
 	gnutls_session_t mSession;
 	gnutls_certificate_credentials_t mCreds;
-	string mHost;
+
+	message_ptr mIncomingMessage;
+	size_t mIncomingMessagePosition = 0;
 
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);