Browse Source

Implement WebSocketServer SSL support.

Fabio Alessandrelli 5 years ago
parent
commit
c723a8b6aa

+ 39 - 0
modules/websocket/websocket_server.cpp

@@ -49,12 +49,51 @@ void WebSocketServer::_bind_methods() {
 	ClassDB::bind_method(D_METHOD("get_peer_port", "id"), &WebSocketServer::get_peer_port);
 	ClassDB::bind_method(D_METHOD("disconnect_peer", "id", "code", "reason"), &WebSocketServer::disconnect_peer, DEFVAL(1000), DEFVAL(""));
 
+	ClassDB::bind_method(D_METHOD("get_private_key"), &WebSocketServer::get_private_key);
+	ClassDB::bind_method(D_METHOD("set_private_key"), &WebSocketServer::set_private_key);
+	ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "private_key", PROPERTY_HINT_RESOURCE_TYPE, "CryptoKey", 0), "set_private_key", "get_private_key");
+
+	ClassDB::bind_method(D_METHOD("get_ssl_certificate"), &WebSocketServer::get_ssl_certificate);
+	ClassDB::bind_method(D_METHOD("set_ssl_certificate"), &WebSocketServer::set_ssl_certificate);
+	ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "ssl_certificate", PROPERTY_HINT_RESOURCE_TYPE, "X509Certificate", 0), "set_ssl_certificate", "get_ssl_certificate");
+
+	ClassDB::bind_method(D_METHOD("get_ca_chain"), &WebSocketServer::get_ca_chain);
+	ClassDB::bind_method(D_METHOD("set_ca_chain"), &WebSocketServer::set_ca_chain);
+	ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "ca_chain", PROPERTY_HINT_RESOURCE_TYPE, "X509Certificate", 0), "set_ca_chain", "get_ca_chain");
+
 	ADD_SIGNAL(MethodInfo("client_close_request", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::INT, "code"), PropertyInfo(Variant::STRING, "reason")));
 	ADD_SIGNAL(MethodInfo("client_disconnected", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::BOOL, "was_clean_close")));
 	ADD_SIGNAL(MethodInfo("client_connected", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::STRING, "protocol")));
 	ADD_SIGNAL(MethodInfo("data_received", PropertyInfo(Variant::INT, "id")));
 }
 
+Ref<CryptoKey> WebSocketServer::get_private_key() const {
+	return private_key;
+}
+
+void WebSocketServer::set_private_key(Ref<CryptoKey> p_key) {
+	ERR_FAIL_COND(is_listening());
+	private_key = p_key;
+}
+
+Ref<X509Certificate> WebSocketServer::get_ssl_certificate() const {
+	return ssl_cert;
+}
+
+void WebSocketServer::set_ssl_certificate(Ref<X509Certificate> p_cert) {
+	ERR_FAIL_COND(is_listening());
+	ssl_cert = p_cert;
+}
+
+Ref<X509Certificate> WebSocketServer::get_ca_chain() const {
+	return ca_chain;
+}
+
+void WebSocketServer::set_ca_chain(Ref<X509Certificate> p_ca_chain) {
+	ERR_FAIL_COND(is_listening());
+	ca_chain = p_ca_chain;
+}
+
 NetworkedMultiplayerPeer::ConnectionStatus WebSocketServer::get_connection_status() const {
 	if (is_listening())
 		return CONNECTION_CONNECTED;

+ 14 - 0
modules/websocket/websocket_server.h

@@ -31,6 +31,7 @@
 #ifndef WEBSOCKET_H
 #define WEBSOCKET_H
 
+#include "core/crypto/crypto.h"
 #include "core/reference.h"
 #include "websocket_multiplayer_peer.h"
 #include "websocket_peer.h"
@@ -43,6 +44,10 @@ class WebSocketServer : public WebSocketMultiplayerPeer {
 protected:
 	static void _bind_methods();
 
+	Ref<CryptoKey> private_key;
+	Ref<X509Certificate> ssl_cert;
+	Ref<X509Certificate> ca_chain;
+
 public:
 	virtual void poll() = 0;
 	virtual Error listen(int p_port, PoolVector<String> p_protocols = PoolVector<String>(), bool gd_mp_api = false) = 0;
@@ -62,6 +67,15 @@ public:
 	void _on_disconnect(int32_t p_peer_id, bool p_was_clean);
 	void _on_close_request(int32_t p_peer_id, int p_code, String p_reason);
 
+	Ref<CryptoKey> get_private_key() const;
+	void set_private_key(Ref<CryptoKey> p_key);
+
+	Ref<X509Certificate> get_ssl_certificate() const;
+	void set_ssl_certificate(Ref<X509Certificate> p_cert);
+
+	Ref<X509Certificate> get_ca_chain() const;
+	void set_ca_chain(Ref<X509Certificate> p_ca_chain);
+
 	virtual Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) = 0;
 
 	WebSocketServer();

+ 20 - 1
modules/websocket/wsl_server.cpp

@@ -35,6 +35,7 @@
 #include "core/project_settings.h"
 
 WSLServer::PendingPeer::PendingPeer() {
+	use_ssl = false;
 	time = 0;
 	has_request = false;
 	response_sent = 0;
@@ -100,6 +101,16 @@ bool WSLServer::PendingPeer::_parse_request(const PoolStringArray p_protocols) {
 Error WSLServer::PendingPeer::do_handshake(PoolStringArray p_protocols) {
 	if (OS::get_singleton()->get_ticks_msec() - time > WSL_SERVER_TIMEOUT)
 		return ERR_TIMEOUT;
+	if (use_ssl) {
+		Ref<StreamPeerSSL> ssl = static_cast<Ref<StreamPeerSSL> >(connection);
+		if (ssl.is_null())
+			return FAILED;
+		ssl->poll();
+		if (ssl->get_status() == StreamPeerSSL::STATUS_HANDSHAKING)
+			return ERR_BUSY;
+		else if (ssl->get_status() != StreamPeerSSL::STATUS_CONNECTED)
+			return FAILED;
+	}
 	if (!has_request) {
 		int read = 0;
 		while (true) {
@@ -210,7 +221,15 @@ void WSLServer::poll() {
 			continue; // Conn will go out-of-scope and be closed.
 
 		Ref<PendingPeer> peer = memnew(PendingPeer);
-		peer->connection = conn;
+		if (private_key.is_valid() && ssl_cert.is_valid()) {
+			Ref<StreamPeerSSL> ssl = Ref<StreamPeerSSL>(StreamPeerSSL::create());
+			ssl->set_blocking_handshake_enabled(false);
+			ssl->accept_stream(conn, private_key, ssl_cert, ca_chain);
+			peer->connection = ssl;
+			peer->use_ssl = true;
+		} else {
+			peer->connection = conn;
+		}
 		peer->tcp = conn;
 		peer->time = OS::get_singleton()->get_ticks_msec();
 		_pending.push_back(peer);

+ 2 - 0
modules/websocket/wsl_server.h

@@ -36,6 +36,7 @@
 #include "websocket_server.h"
 #include "wsl_peer.h"
 
+#include "core/io/stream_peer_ssl.h"
 #include "core/io/stream_peer_tcp.h"
 #include "core/io/tcp_server.h"
 
@@ -54,6 +55,7 @@ private:
 	public:
 		Ref<StreamPeerTCP> tcp;
 		Ref<StreamPeer> connection;
+		bool use_ssl;
 
 		int time;
 		uint8_t req_buf[WSL_MAX_HEADER_SIZE];