Explorar o código

Implement non blocking-handshake for StreamPeerSSL

Fabio Alessandrelli %!s(int64=7) %!d(string=hai) anos
pai
achega
99d0b7ba14

+ 15 - 2
core/io/stream_peer_ssl.cpp

@@ -52,6 +52,14 @@ bool StreamPeerSSL::is_available() {
 	return available;
 }
 
+void StreamPeerSSL::set_blocking_handshake_enabled(bool p_enabled) {
+	blocking_handshake = p_enabled;
+}
+
+bool StreamPeerSSL::is_blocking_handshake_enabled() const {
+	return blocking_handshake;
+}
+
 PoolByteArray StreamPeerSSL::get_project_cert_array() {
 
 	PoolByteArray out;
@@ -84,16 +92,21 @@ PoolByteArray StreamPeerSSL::get_project_cert_array() {
 void StreamPeerSSL::_bind_methods() {
 
 	ClassDB::bind_method(D_METHOD("poll"), &StreamPeerSSL::poll);
-	ClassDB::bind_method(D_METHOD("accept_stream", "stream"), &StreamPeerSSL::accept_stream);
+	ClassDB::bind_method(D_METHOD("accept_stream"), &StreamPeerSSL::accept_stream);
 	ClassDB::bind_method(D_METHOD("connect_to_stream", "stream", "validate_certs", "for_hostname"), &StreamPeerSSL::connect_to_stream, DEFVAL(false), DEFVAL(String()));
 	ClassDB::bind_method(D_METHOD("get_status"), &StreamPeerSSL::get_status);
 	ClassDB::bind_method(D_METHOD("disconnect_from_stream"), &StreamPeerSSL::disconnect_from_stream);
+	ClassDB::bind_method(D_METHOD("set_blocking_handshake_enabled", "enabled"), &StreamPeerSSL::set_blocking_handshake_enabled);
+	ClassDB::bind_method(D_METHOD("is_blocking_handshake_enabled"), &StreamPeerSSL::is_blocking_handshake_enabled);
+
+	ADD_PROPERTY(PropertyInfo(Variant::BOOL, "blocking_handshake"), "set_blocking_handshake_enabled", "is_blocking_handshake_enabled");
 
 	BIND_ENUM_CONSTANT(STATUS_DISCONNECTED);
 	BIND_ENUM_CONSTANT(STATUS_CONNECTED);
-	BIND_ENUM_CONSTANT(STATUS_ERROR_NO_CERTIFICATE);
+	BIND_ENUM_CONSTANT(STATUS_ERROR);
 	BIND_ENUM_CONSTANT(STATUS_ERROR_HOSTNAME_MISMATCH);
 }
 
 StreamPeerSSL::StreamPeerSSL() {
+	blocking_handshake = true;
 }

+ 7 - 1
core/io/stream_peer_ssl.h

@@ -49,14 +49,20 @@ protected:
 	friend class Main;
 	static bool initialize_certs;
 
+	bool blocking_handshake;
+
 public:
 	enum Status {
 		STATUS_DISCONNECTED,
+		STATUS_HANDSHAKING,
 		STATUS_CONNECTED,
-		STATUS_ERROR_NO_CERTIFICATE,
+		STATUS_ERROR,
 		STATUS_ERROR_HOSTNAME_MISMATCH
 	};
 
+	void set_blocking_handshake_enabled(bool p_enabled);
+	bool is_blocking_handshake_enabled() const;
+
 	virtual void poll() = 0;
 	virtual Error accept_stream(Ref<StreamPeer> p_base) = 0;
 	virtual Error connect_to_stream(Ref<StreamPeer> p_base, bool p_validate_certs = false, const String &p_for_hostname = String()) = 0;

+ 52 - 28
modules/mbedtls/stream_peer_mbed_tls.cpp

@@ -29,6 +29,8 @@
 /*************************************************************************/
 
 #include "stream_peer_mbed_tls.h"
+#include "mbedtls/platform_util.h"
+#include "os/file_access.h"
 
 static void my_debug(void *ctx, int level,
 		const char *file, int line,
@@ -81,6 +83,36 @@ int StreamPeerMbedTLS::bio_recv(void *ctx, unsigned char *buf, size_t len) {
 	return got;
 }
 
+void StreamPeerMbedTLS::_cleanup() {
+
+	mbedtls_ssl_free(&ssl);
+	mbedtls_ssl_config_free(&conf);
+	mbedtls_ctr_drbg_free(&ctr_drbg);
+	mbedtls_entropy_free(&entropy);
+
+	base = Ref<StreamPeer>();
+	status = STATUS_DISCONNECTED;
+}
+
+Error StreamPeerMbedTLS::_do_handshake() {
+	int ret = 0;
+	while ((ret = mbedtls_ssl_handshake(&ssl)) != 0) {
+		if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
+			ERR_PRINTS("TLS handshake error: " + itos(ret));
+			_print_error(ret);
+			disconnect_from_stream();
+			status = STATUS_ERROR;
+			return FAILED;
+		} else if (!blocking_handshake) {
+			// Will retry via poll later
+			return OK;
+		}
+	}
+
+	status = STATUS_CONNECTED;
+	return OK;
+}
+
 Error StreamPeerMbedTLS::connect_to_stream(Ref<StreamPeer> p_base, bool p_validate_certs, const String &p_for_hostname) {
 
 	base = p_base;
@@ -95,6 +127,7 @@ Error StreamPeerMbedTLS::connect_to_stream(Ref<StreamPeer> p_base, bool p_valida
 	ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0);
 	if (ret != 0) {
 		ERR_PRINTS(" failed\n  ! mbedtls_ctr_drbg_seed returned an error" + itos(ret));
+		_cleanup();
 		return FAILED;
 	}
 
@@ -112,29 +145,24 @@ Error StreamPeerMbedTLS::connect_to_stream(Ref<StreamPeer> p_base, bool p_valida
 
 	mbedtls_ssl_set_bio(&ssl, this, bio_send, bio_recv, NULL);
 
-	while ((ret = mbedtls_ssl_handshake(&ssl)) != 0) {
-		if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
-			ERR_PRINTS("TLS handshake error: " + itos(ret));
-			_print_error(ret);
-			status = STATUS_ERROR_HOSTNAME_MISMATCH;
-			return FAILED;
-		}
-	}
+	status = STATUS_HANDSHAKING;
 
-	connected = true;
-	status = STATUS_CONNECTED;
+	if ((ret = _do_handshake()) != OK) {
+		status = STATUS_ERROR_HOSTNAME_MISMATCH;
+		return FAILED;
+	}
 
 	return OK;
 }
 
 Error StreamPeerMbedTLS::accept_stream(Ref<StreamPeer> p_base) {
 
-	return ERR_UNAVAILABLE;
+	return OK;
 }
 
 Error StreamPeerMbedTLS::put_data(const uint8_t *p_data, int p_bytes) {
 
-	ERR_FAIL_COND_V(!connected, ERR_UNCONFIGURED);
+	ERR_FAIL_COND_V(status != STATUS_CONNECTED, ERR_UNCONFIGURED);
 
 	Error err;
 	int sent = 0;
@@ -155,7 +183,7 @@ Error StreamPeerMbedTLS::put_data(const uint8_t *p_data, int p_bytes) {
 
 Error StreamPeerMbedTLS::put_partial_data(const uint8_t *p_data, int p_bytes, int &r_sent) {
 
-	ERR_FAIL_COND_V(!connected, ERR_UNCONFIGURED);
+	ERR_FAIL_COND_V(status != STATUS_CONNECTED, ERR_UNCONFIGURED);
 
 	r_sent = 0;
 
@@ -177,7 +205,7 @@ Error StreamPeerMbedTLS::put_partial_data(const uint8_t *p_data, int p_bytes, in
 
 Error StreamPeerMbedTLS::get_data(uint8_t *p_buffer, int p_bytes) {
 
-	ERR_FAIL_COND_V(!connected, ERR_UNCONFIGURED);
+	ERR_FAIL_COND_V(status != STATUS_CONNECTED, ERR_UNCONFIGURED);
 
 	Error err;
 
@@ -199,7 +227,7 @@ Error StreamPeerMbedTLS::get_data(uint8_t *p_buffer, int p_bytes) {
 
 Error StreamPeerMbedTLS::get_partial_data(uint8_t *p_buffer, int p_bytes, int &r_received) {
 
-	ERR_FAIL_COND_V(!connected, ERR_UNCONFIGURED);
+	ERR_FAIL_COND_V(status != STATUS_CONNECTED, ERR_UNCONFIGURED);
 
 	r_received = 0;
 
@@ -218,27 +246,30 @@ Error StreamPeerMbedTLS::get_partial_data(uint8_t *p_buffer, int p_bytes, int &r
 
 void StreamPeerMbedTLS::poll() {
 
-	ERR_FAIL_COND(!connected);
+	ERR_FAIL_COND(status != STATUS_CONNECTED && status != STATUS_HANDSHAKING);
 	ERR_FAIL_COND(!base.is_valid());
 
+	if (status == STATUS_HANDSHAKING) {
+		_do_handshake();
+		return;
+	}
+
 	int ret = mbedtls_ssl_read(&ssl, NULL, 0);
 
 	if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
 		_print_error(ret);
 		disconnect_from_stream();
-		return;
 	}
 }
 
 int StreamPeerMbedTLS::get_available_bytes() const {
 
-	ERR_FAIL_COND_V(!connected, 0);
+	ERR_FAIL_COND_V(status != STATUS_CONNECTED, 0);
 
 	return mbedtls_ssl_get_bytes_avail(&ssl);
 }
 StreamPeerMbedTLS::StreamPeerMbedTLS() {
 
-	connected = false;
 	status = STATUS_DISCONNECTED;
 }
 
@@ -248,17 +279,10 @@ StreamPeerMbedTLS::~StreamPeerMbedTLS() {
 
 void StreamPeerMbedTLS::disconnect_from_stream() {
 
-	if (!connected)
+	if (status != STATUS_CONNECTED && status != STATUS_HANDSHAKING)
 		return;
 
-	mbedtls_ssl_free(&ssl);
-	mbedtls_ssl_config_free(&conf);
-	mbedtls_ctr_drbg_free(&ctr_drbg);
-	mbedtls_entropy_free(&entropy);
-
-	base = Ref<StreamPeer>();
-	connected = false;
-	status = STATUS_DISCONNECTED;
+	_cleanup();
 }
 
 StreamPeerMbedTLS::Status StreamPeerMbedTLS::get_status() const {

+ 4 - 2
modules/mbedtls/stream_peer_mbed_tls.h

@@ -48,8 +48,6 @@ private:
 	Status status;
 	String hostname;
 
-	bool connected;
-
 	Ref<StreamPeer> base;
 
 	static StreamPeerSSL *_create_func();
@@ -57,9 +55,11 @@ private:
 
 	static int bio_recv(void *ctx, unsigned char *buf, size_t len);
 	static int bio_send(void *ctx, const unsigned char *buf, size_t len);
+	void _cleanup();
 
 protected:
 	static mbedtls_x509_crt cacert;
+
 	mbedtls_entropy_context entropy;
 	mbedtls_ctr_drbg_context ctr_drbg;
 	mbedtls_ssl_context ssl;
@@ -67,6 +67,8 @@ protected:
 
 	static void _bind_methods();
 
+	Error _do_handshake();
+
 public:
 	virtual void poll();
 	virtual Error accept_stream(Ref<StreamPeer> p_base);