Browse Source

Merge pull request #30419 from Faless/ws/wslay_server_proto

(Re-)Implement subprotocols in websocket server.
Rémi Verschelde 6 years ago
parent
commit
9da5fdc955
2 changed files with 31 additions and 8 deletions
  1. 27 6
      modules/websocket/wsl_server.cpp
  2. 4 2
      modules/websocket/wsl_server.h

+ 27 - 6
modules/websocket/wsl_server.cpp

@@ -42,7 +42,7 @@ WSLServer::PendingPeer::PendingPeer() {
 	memset(req_buf, 0, sizeof(req_buf));
 }
 
-bool WSLServer::PendingPeer::_parse_request(String &r_key) {
+bool WSLServer::PendingPeer::_parse_request(const PoolStringArray p_protocols) {
 	Vector<String> psa = String((char *)req_buf).split("\r\n");
 	int len = psa.size();
 	if (len < 4) {
@@ -87,11 +87,29 @@ bool WSLServer::PendingPeer::_parse_request(String &r_key) {
 	_WLS_CHECK_EX("connection");
 #undef _WLS_CHECK_EX
 #undef _WLS_CHECK
-	r_key = headers["sec-websocket-key"];
+	key = headers["sec-websocket-key"];
+	if (headers.has("sec-websocket-protocol")) {
+		Vector<String> protos = headers["sec-websocket-protocol"].split(",");
+		for (int i = 0; i < protos.size(); i++) {
+			// Check if we have the given protocol
+			for (int j = 0; j < p_protocols.size(); j++) {
+				if (protos[i] != p_protocols[j])
+					continue;
+				protocol = protos[i];
+				break;
+			}
+			// Found a protocol
+			if (protocol != "")
+				break;
+		}
+		if (protocol == "") // Invalid protocol(s) requested
+			return false;
+	} else if (p_protocols.size() > 0) // No protocol requested, but we need one
+		return false;
 	return true;
 }
 
-Error WSLServer::PendingPeer::do_handshake() {
+Error WSLServer::PendingPeer::do_handshake(PoolStringArray p_protocols) {
 	if (OS::get_singleton()->get_ticks_msec() - time > WSL_SERVER_TIMEOUT)
 		return ERR_TIMEOUT;
 	if (!has_request) {
@@ -111,13 +129,15 @@ Error WSLServer::PendingPeer::do_handshake() {
 			int l = req_pos;
 			if (l > 3 && r[l] == '\n' && r[l - 1] == '\r' && r[l - 2] == '\n' && r[l - 3] == '\r') {
 				r[l - 3] = '\0';
-				if (!_parse_request(key)) {
+				if (!_parse_request(p_protocols)) {
 					return FAILED;
 				}
 				String s = "HTTP/1.1 101 Switching Protocols\r\n";
 				s += "Upgrade: websocket\r\n";
 				s += "Connection: Upgrade\r\n";
 				s += "Sec-WebSocket-Accept: " + WSLPeer::compute_key_response(key) + "\r\n";
+				if (protocol != "")
+					s += "Sec-WebSocket-Protocol: " + protocol + "\r\n";
 				s += "\r\n";
 				response = s.utf8();
 				has_request = true;
@@ -143,6 +163,7 @@ Error WSLServer::listen(int p_port, PoolVector<String> p_protocols, bool gd_mp_a
 	ERR_FAIL_COND_V(is_listening(), ERR_ALREADY_IN_USE);
 
 	_is_multiplayer = gd_mp_api;
+	_protocols = p_protocols;
 	_server->listen(p_port);
 
 	return OK;
@@ -167,7 +188,7 @@ void WSLServer::poll() {
 	List<Ref<PendingPeer> > remove_peers;
 	for (List<Ref<PendingPeer> >::Element *E = _pending.front(); E; E = E->next()) {
 		Ref<PendingPeer> ppeer = E->get();
-		Error err = ppeer->do_handshake();
+		Error err = ppeer->do_handshake(_protocols);
 		if (err == ERR_BUSY) {
 			continue;
 		} else if (err != OK) {
@@ -188,7 +209,7 @@ void WSLServer::poll() {
 
 		_peer_map[id] = ws_peer;
 		remove_peers.push_back(ppeer);
-		_on_connect(id, "");
+		_on_connect(id, ppeer->protocol);
 	}
 	for (List<Ref<PendingPeer> >::Element *E = remove_peers.front(); E; E = E->next()) {
 		_pending.erase(E->get());

+ 4 - 2
modules/websocket/wsl_server.h

@@ -49,7 +49,7 @@ private:
 	class PendingPeer : public Reference {
 
 	private:
-		bool _parse_request(String &r_key);
+		bool _parse_request(const PoolStringArray p_protocols);
 
 	public:
 		Ref<StreamPeer> connection;
@@ -58,13 +58,14 @@ private:
 		uint8_t req_buf[WSL_MAX_HEADER_SIZE];
 		int req_pos;
 		String key;
+		String protocol;
 		bool has_request;
 		CharString response;
 		int response_sent;
 
 		PendingPeer();
 
-		Error do_handshake();
+		Error do_handshake(const PoolStringArray p_protocols);
 	};
 
 	int _in_buf_size;
@@ -74,6 +75,7 @@ private:
 
 	List<Ref<PendingPeer> > _pending;
 	Ref<TCP_Server> _server;
+	PoolStringArray _protocols;
 
 public:
 	Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);