2
0
Эх сурвалжийг харах

Fixed TLS layer with GnuTLS

Paul-Louis Ageneau 5 жил өмнө
parent
commit
755b3e9dac

+ 46 - 16
src/tlstransport.cpp

@@ -62,31 +62,38 @@ void TlsTransport::Cleanup() {
 }
 
 TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
-    : Transport(lower, std::move(callback)) {
+    : Transport(lower, std::move(callback)), mHost(std::move(host)) {
 
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 
+	check_gnutls(gnutls_certificate_allocate_credentials(&mCreds));
 	check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT));
 
 	try {
+        check_gnutls(gnutls_certificate_set_x509_system_trust(mCreds));
+        check_gnutls(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds));
+        gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
+
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
 		const char *err_pos = NULL;
 		check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
 		             "Failed to set TLS priorities");
 
-		gnutls_session_set_ptr(mSession, this);
+       	PLOG_VERBOSE << "Server Name Indication: " << mHost;
+		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost.data(), mHost.size());
+
+ 		gnutls_session_set_ptr(mSession, this);
 		gnutls_transport_set_ptr(mSession, this);
 		gnutls_transport_set_push_function(mSession, WriteCallback);
 		gnutls_transport_set_pull_function(mSession, ReadCallback);
 		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 
-		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
-
-		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+       	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
 		registerIncoming();
 
 	} catch (...) {
 		gnutls_deinit(mSession);
+		gnutls_certificate_free_credentials(mCreds);
 		throw;
 	}
 }
@@ -94,6 +101,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 TlsTransport::~TlsTransport() {
 	stop();
 	gnutls_deinit(mSession);
+	gnutls_certificate_free_credentials(mCreds);
 }
 
 bool TlsTransport::stop() {
@@ -111,6 +119,9 @@ bool TlsTransport::send(message_ptr message) {
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
+	if(message->size() == 0)
+		return true;
+
 	ssize_t ret;
 	do {
 		ret = gnutls_record_send(mSession, message->data(), message->size());
@@ -196,20 +207,37 @@ ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data
 
 ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
-	while (auto next = t->mIncomingQueue.pop()) {
-		auto message = *next;
-		if (message->size() > 0) {
-			ssize_t len = std::min(maxlen, message->size());
-			std::memcpy(data, message->data(), len);
-			gnutls_transport_set_errno(t->mSession, 0);
-			return len;
+
+	message_ptr &message = t->mIncomingMessage;
+	size_t &position = t->mIncomingMessagePosition;
+
+	if(message && position >= message->size())
+		message.reset();
+
+	if(!message) {
+		position = 0;
+		while (auto next = t->mIncomingQueue.pop()) {
+			message = *next;
+			if (message->size() > 0)
+				break;
+
+			t->recv(message); // Pass zero-sized messages through
 		}
+	}
 
-		t->recv(message); // Pass zero-sized messages through
+	if(message) {
+		size_t available = message->size() - position;
+		ssize_t len = std::min(maxlen, available);
+		std::memcpy(data, message->data() + position, len);
+		position+= len;
+		gnutls_transport_set_errno(t->mSession, 0);
+		return len;
+	}
+	else {
+		// Closed
+		gnutls_transport_set_errno(t->mSession, 0);
+		return 0;
 	}
-	// Closed
-	gnutls_transport_set_errno(t->mSession, 0);
-	return 0;
 }
 
 int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
@@ -308,6 +336,8 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		throw std::runtime_error("Failed to create SSL instance");
 
 	SSL_set_ex_data(mSsl, TransportExIndex, this);
+
+	PLOG_VERBOSE << "Server Name Indication: " << host;
 	SSL_set_tlsext_host_name(mSsl, host.c_str());
 
 	SSL_set_connect_state(mSsl);

+ 4 - 0
src/tlstransport.hpp

@@ -56,10 +56,14 @@ protected:
 	void runRecvLoop();
 
 	Queue<message_ptr> mIncomingQueue;
+	message_ptr mIncomingMessage;
+	size_t mIncomingMessagePosition = 0;
 	std::thread mRecvThread;
 
 #if USE_GNUTLS
 	gnutls_session_t mSession;
+	gnutls_certificate_credentials_t mCreds;
+	string mHost;
 
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);