Sfoglia il codice sorgente

modules/websocket: tidied up some of the WS connection code

Peter Dunkley 13 anni fa
parent
commit
3a64dffb99

+ 39 - 45
modules/websocket/ws_conn.c

@@ -110,7 +110,7 @@ void wsconn_destroy(void)
 			ws_connection_t *wsc = wsconn_hash[h];
 			ws_connection_t *wsc = wsconn_hash[h];
 			while (wsc)
 			while (wsc)
 			{
 			{
-				ws_connection_t *next = wsc->next;
+				ws_connection_t *next = wsc->id_next;
 				_wsconn_rm(wsc);
 				_wsconn_rm(wsc);
 				wsc = next;
 				wsc = next;
 			}
 			}
@@ -136,17 +136,13 @@ void wsconn_destroy(void)
 	}
 	}
 }
 }
 
 
-int wsconn_add(struct tcp_connection *con)
+int wsconn_add(int id)
 {
 {
 	int cur_cons, max_cons;
 	int cur_cons, max_cons;
+	int id_hash = tcp_id_hash(id);
 	ws_connection_t *wsc;
 	ws_connection_t *wsc;
 
 
-	if (!con)
-	{
-		LM_ERR("wsconn_add: null pointer\n");
-		return -1;
-	}
-
+	/* Allocate and fill in new WebSocket connection */
 	wsc = shm_malloc(sizeof(ws_connection_t));
 	wsc = shm_malloc(sizeof(ws_connection_t));
 	if (wsc == NULL)
 	if (wsc == NULL)
 	{
 	{
@@ -155,20 +151,14 @@ int wsconn_add(struct tcp_connection *con)
 	}
 	}
 	memset(wsc, 0, sizeof(ws_connection_t));
 	memset(wsc, 0, sizeof(ws_connection_t));
 
 
-	wsc->con = con;
-	wsc->id_hash = con->id_hash;
+	wsc->id = id;
+	wsc->id_hash = id_hash;
 	wsc->last_used = (int)time(NULL);
 	wsc->last_used = (int)time(NULL);
 	wsc->state = WS_S_OPEN;
 	wsc->state = WS_S_OPEN;
 
 
-	/* Make sure Kamailio core sends future messages on this connection
-	   directly to this module */
-	con->flags |= F_CONN_WS;
-
+	/* Add to WebSocket connection table */
 	lock_get(wsconn_lock);
 	lock_get(wsconn_lock);
-	wsc->next = wsconn_hash[wsc->id_hash];
-	wsc->prev = NULL;
-	if (wsconn_hash[wsc->id_hash]) wsconn_hash[wsc->id_hash]->prev = wsc;
-	wsconn_hash[wsc->id_hash] = wsc;
+	wsconn_listadd(wsconn_hash[wsc->id_hash], wsc, id_next, id_prev);
 	lock_release(wsconn_lock);
 	lock_release(wsconn_lock);
 
 
 	/* Update connection statistics */
 	/* Update connection statistics */
@@ -185,10 +175,7 @@ int wsconn_add(struct tcp_connection *con)
 
 
 static inline void _wsconn_rm(ws_connection_t *wsc)
 static inline void _wsconn_rm(ws_connection_t *wsc)
 {
 {
-	if (wsconn_hash[wsc->id_hash] == wsc)
-		wsconn_hash[wsc->id_hash] = wsc->next;
-	if (wsc->next) wsc->next->prev = wsc->prev;
-	if (wsc->prev) wsc->prev->next = wsc->next;
+	wsconn_listrm(wsconn_hash[wsc->id_hash], wsc, id_next, id_prev);
 	shm_free(wsc);
 	shm_free(wsc);
 	wsc = NULL;
 	wsc = NULL;
 	update_stat(ws_current_connections, -1);
 	update_stat(ws_current_connections, -1);
@@ -223,27 +210,31 @@ int wsconn_update(ws_connection_t *wsc)
 
 
 void wsconn_close_now(ws_connection_t *wsc)
 void wsconn_close_now(ws_connection_t *wsc)
 {
 {
-	wsc->con->send_flags.f |= SND_F_CON_CLOSE;
-	wsc->con->state = S_CONN_BAD;
-	wsc->con->timeout = get_ticks_raw();
+	struct tcp_connection *con = tcpconn_get(wsc->id, 0, 0, 0, 0);
+
+	if (con == NULL)
+	{
+		LM_ERR("getting TCP/TLS connection\n");
+		return;
+	}
+
+	con->send_flags.f |= SND_F_CON_CLOSE;
+	con->state = S_CONN_BAD;
+	con->timeout = get_ticks_raw();
+
 	if (wsconn_rm(wsc) < 0)
 	if (wsconn_rm(wsc) < 0)
 		LM_ERR("removing WebSocket connection\n");
 		LM_ERR("removing WebSocket connection\n");
 }
 }
 
 
-ws_connection_t *wsconn_find(struct tcp_connection *con)
+ws_connection_t *wsconn_get(int id)
 {
 {
+	int id_hash = tcp_id_hash(id);
 	ws_connection_t *wsc;
 	ws_connection_t *wsc;
 
 
-	if (!con)
-	{
-		LM_ERR("wsconn_find: null pointer\n");
-		return NULL;
-	}
-
 	lock_get(wsconn_lock);
 	lock_get(wsconn_lock);
-	for (wsc = wsconn_hash[con->id_hash]; wsc; wsc = wsc->next)
+	for (wsc = wsconn_hash[id_hash]; wsc; wsc = wsc->id_next)
 	{
 	{
-		if (wsc->con->id == con->id)
+		if (wsc->id == id)
 		{
 		{
 			lock_release(wsconn_lock);
 			lock_release(wsconn_lock);
 			return wsc;
 			return wsc;
@@ -271,18 +262,21 @@ struct mi_root *ws_mi_dump(struct mi_root *cmd, void *param)
 		wsc = wsconn_hash[h];
 		wsc = wsconn_hash[h];
 		while(wsc)
 		while(wsc)
 		{
 		{
-			if (wsc->con)
+			struct tcp_connection *con =
+					tcpconn_get(wsc->id, 0, 0, 0, 0);
+
+			if (con)
 			{
 			{
-				src_proto = (wsc->con->rcv.proto== PROTO_TCP)
-						? "tcp" : "tls";
+				src_proto = (con->rcv.proto== PROTO_TCP)
+						? "ws" : "wss";
 				memset(src_ip, 0, IP6_MAX_STR_SIZE + 1);
 				memset(src_ip, 0, IP6_MAX_STR_SIZE + 1);
-				ip_addr2sbuf(&wsc->con->rcv.src_ip, src_ip,
+				ip_addr2sbuf(&con->rcv.src_ip, src_ip,
 						IP6_MAX_STR_SIZE);
 						IP6_MAX_STR_SIZE);
 
 
-				dst_proto = (wsc->con->rcv.proto == PROTO_TCP)
-						? "tcp" : "tls";
+				dst_proto = (con->rcv.proto == PROTO_TCP)
+						? "ws" : "wss";
 				memset(dst_ip, 0, IP6_MAX_STR_SIZE + 1);
 				memset(dst_ip, 0, IP6_MAX_STR_SIZE + 1);
-				ip_addr2sbuf(&wsc->con->rcv.dst_ip, src_ip,
+				ip_addr2sbuf(&con->rcv.dst_ip, src_ip,
 						IP6_MAX_STR_SIZE);
 						IP6_MAX_STR_SIZE);
 
 
 				interval = (int)time(NULL) - wsc->last_used;
 				interval = (int)time(NULL) - wsc->last_used;
@@ -291,13 +285,13 @@ struct mi_root *ws_mi_dump(struct mi_root *cmd, void *param)
 						"%d: %s:%s:%hu -> %s:%s:%hu "
 						"%d: %s:%s:%hu -> %s:%s:%hu "
 						"(state: %s, "
 						"(state: %s, "
 						"last used %ds ago)",
 						"last used %ds ago)",
-						wsc->con->id,
+						wsc->id,
 						src_proto,
 						src_proto,
 						strlen(src_ip) ? src_ip : "*",
 						strlen(src_ip) ? src_ip : "*",
-						wsc->con->rcv.src_port,
+						con->rcv.src_port,
 						dst_proto,
 						dst_proto,
 						strlen(dst_ip) ? dst_ip : "*",
 						strlen(dst_ip) ? dst_ip : "*",
-						wsc->con->rcv.dst_port,
+						con->rcv.dst_port,
 						wsconn_state_str[wsc->state],
 						wsconn_state_str[wsc->state],
 						interval) == 0)
 						interval) == 0)
 					return 0;
 					return 0;
@@ -309,7 +303,7 @@ struct mi_root *ws_mi_dump(struct mi_root *cmd, void *param)
 				}
 				}
 			}
 			}
 
 
-			wsc = wsc->next;
+			wsc = wsc->id_next;
 		}
 		}
 	}
 	}
 	lock_release(wsconn_lock);
 	lock_release(wsconn_lock);

+ 9 - 9
modules/websocket/ws_conn.h

@@ -25,7 +25,6 @@
 #define _WS_CONN_H
 #define _WS_CONN_H
 
 
 #include "../../locking.h"
 #include "../../locking.h"
-#include "../../tcp_conn.h"
 #include "../../lib/kmi/tree.h"
 #include "../../lib/kmi/tree.h"
 
 
 typedef enum
 typedef enum
@@ -38,17 +37,18 @@ typedef enum
 
 
 typedef struct ws_connection
 typedef struct ws_connection
 {
 {
-	struct tcp_connection *con;
-
 	ws_conn_state_t state;
 	ws_conn_state_t state;
-	int id;
-	unsigned id_hash;
 	int last_used;
 	int last_used;
 
 
-	struct ws_connection *prev;
-	struct ws_connection *next;
+	int id;			/* id and id_hash are identical to the values */
+	unsigned id_hash;	/* for the corresponding TCP/TLS connection */
+	struct ws_connection *id_prev;
+	struct ws_connection *id_next;
 } ws_connection_t;
 } ws_connection_t;
 
 
+#define wsconn_listadd	tcpconn_listadd
+#define wsconn_listrm	tcpconn_listrm
+
 extern char *wsconn_state_str[];
 extern char *wsconn_state_str[];
 
 
 extern stat_var *ws_current_connections;
 extern stat_var *ws_current_connections;
@@ -56,11 +56,11 @@ extern stat_var *ws_max_concurrent_connections;
 
 
 int wsconn_init(void);
 int wsconn_init(void);
 void wsconn_destroy(void);
 void wsconn_destroy(void);
-int wsconn_add(struct tcp_connection *con);
+int wsconn_add(int id);
 int wsconn_rm(ws_connection_t *wsc);
 int wsconn_rm(ws_connection_t *wsc);
 int wsconn_update(ws_connection_t *wsc);
 int wsconn_update(ws_connection_t *wsc);
 void wsconn_close_now(ws_connection_t *wsc);
 void wsconn_close_now(ws_connection_t *wsc);
-ws_connection_t *wsconn_find(struct tcp_connection *con);
+ws_connection_t *wsconn_get(int id);
 struct mi_root *ws_mi_dump(struct mi_root *cmd, void *param);
 struct mi_root *ws_mi_dump(struct mi_root *cmd, void *param);
 
 
 #endif /* _WS_CONN_H */
 #endif /* _WS_CONN_H */

+ 195 - 192
modules/websocket/ws_frame.c

@@ -87,9 +87,6 @@ typedef enum
 #define OPCODE_PONG		(0xa)
 #define OPCODE_PONG		(0xa)
 /* 0xb - 0xf are reserved for further control frames */
 /* 0xb - 0xf are reserved for further control frames */
 
 
-static int close_connection(ws_connection_t *wsc, ws_close_type_t type,
-				short int status, str reason);
-
 stat_var *ws_failed_connections;
 stat_var *ws_failed_connections;
 stat_var *ws_local_closed_connections;
 stat_var *ws_local_closed_connections;
 stat_var *ws_received_frames;
 stat_var *ws_received_frames;
@@ -109,177 +106,12 @@ static str str_status_bad_param = str_init("Bad connection ID parameter");
 static str str_status_error_closing = str_init("Error closing connection");
 static str str_status_error_closing = str_init("Error closing connection");
 static str str_status_error_sending = str_init("Error sending frame");
 static str str_status_error_sending = str_init("Error sending frame");
 
 
-static int decode_and_validate_ws_frame(ws_frame_t *frame,
-					tcp_event_info_t *tcpinfo)
-{
-	unsigned int i, len = tcpinfo->len;
-	int mask_start, j;
-	char *buf = tcpinfo->buf;
-
-	LM_INFO("decoding WebSocket frame\n");
-
-	if ((frame->wsc = wsconn_find(tcpinfo->con)) == NULL)
-	{
-		LM_WARN("WebSocket connection not found\n");
-		return -1;
-	}
-
-	wsconn_update(frame->wsc);
-
-	/* Decode and validate first 9 bits */
-	if (len < 2)
-	{
-		LM_WARN("message is too short\n");
-		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
-		return -1;
-	}
-	frame->fin = (buf[0] & 0xff) & BYTE0_MASK_FIN;
-	frame->rsv1 = (buf[0] & 0xff) & BYTE0_MASK_RSV1;
-	frame->rsv2 = (buf[0] & 0xff) & BYTE0_MASK_RSV2;
-	frame->rsv3 = (buf[0] & 0xff) & BYTE0_MASK_RSV3;
-	frame->opcode = (buf[0] & 0xff) & BYTE0_MASK_OPCODE;
-	frame->mask = (buf[1] & 0xff) & BYTE1_MASK_MASK;
-	
-	if (!frame->fin)
-	{
-		LM_WARN("WebSocket fragmentation not supported in the sip "
-			"sub-protocol\n");
-		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
-		return -1;
-	}
-
-	if (frame->rsv1 || frame->rsv2 || frame->rsv3)
-	{
-		LM_WARN("WebSocket reserved fields with non-zero values\n");
-		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
-		return -1;
-	}
-
-	switch(frame->opcode)
-	{
-	case OPCODE_TEXT_FRAME:
-	case OPCODE_BINARY_FRAME:
-		LM_INFO("supported non-control frame: 0x%x\n",
-			(unsigned char) frame->opcode);
-		break;
-
-	case OPCODE_CLOSE:
-	case OPCODE_PING:
-	case OPCODE_PONG:
-		LM_INFO("supported control frame: 0x%x\n",
-			(unsigned char) frame->opcode);
-		break;
-
-	default:
-		LM_WARN("unsupported opcode: 0x%x\n",
-			(unsigned char) frame->opcode);
-		if (close_connection(frame->wsc, LOCAL_CLOSE, 1008,
-					str_status_unsupported_opcode) < 0)
-			LM_ERR("closing connection\n");
-		return -1;
-	}
-
-	if (!frame->mask)
-	{
-		LM_WARN("this is a server - all received messages must be "
-			"masked\n");
-		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
-		return -1;
-	}
-
-	/* Decode and validate length */
-	frame->payload_len = (buf[1] & 0xff) & BYTE1_MASK_PAYLOAD_LEN;
-	if (frame->payload_len == 126)
-	{
-		if (len < 4)
-		{
-			LM_WARN("message is too short\n");
-			if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
-						str_status_protocol_error) < 0)
-				LM_ERR("closing connection\n");
-			return -1;
-		}
-		mask_start = 4;
-
-		frame->payload_len = 	  ((buf[2] & 0xff) <<  8)
-					| ((buf[3] & 0xff) <<  0);
-	}
-	else if (frame->payload_len == 127)
-	{
-		if (len < 10)
-		{
-			LM_WARN("message is too short\n");
-			if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
-						str_status_protocol_error) < 0)
-				LM_ERR("closing connection\n");
-			return -1;
-		}
-		mask_start = 10;
-
-		if ((buf[2] & 0xff) != 0 || (buf[3] & 0xff) != 0
-			|| (buf[4] & 0xff) != 0 || (buf[5] & 0xff) != 0)
-		{
-			LM_WARN("message is too long\n");
-			if (close_connection(frame->wsc, LOCAL_CLOSE, 1009,
-						str_status_message_too_big) < 0)
-				LM_ERR("closing connection\n");
-			return -1;
-		}
-
-		/* Only decoding the last four bytes of the length...
-		   This limits the size of WebSocket messages that can be
-		   handled to 2^32 = which should be plenty for SIP! */
-	 	frame->payload_len =	  ((buf[6] & 0xff) << 24)
-					| ((buf[7] & 0xff) << 16)
-					| ((buf[8] & 0xff) <<  8)
-					| ((buf[9] & 0xff) <<  0);
-	}
-	else
-		mask_start = 2;
-
-	/* Decode mask */
-	frame->masking_key[0] = (buf[mask_start + 0] & 0xff);
-	frame->masking_key[1] = (buf[mask_start + 1] & 0xff);
-	frame->masking_key[2] = (buf[mask_start + 2] & 0xff);
-	frame->masking_key[3] = (buf[mask_start + 3] & 0xff);
-
-	/* Decode and unmask payload */
-	if (len != frame->payload_len + mask_start + 4)
-	{
-		LM_WARN("message not complete frame size %u but received %u\n",
-			frame->payload_len + mask_start + 4, len);
-		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
-		return -1;
-	}
-	frame->payload_data = &buf[mask_start + 4];
-	for (i = 0; i < frame->payload_len; i++)
-	{
-		j = i % 4;
-		frame->payload_data[i]
-			= frame->payload_data[i] ^ frame->masking_key[j];
-	}
-
-	LM_INFO("Rx (decoded): %.*s\n",
-		(int) frame->payload_len, frame->payload_data);
-
-	return frame->opcode;
-}
-
 static int encode_and_send_ws_frame(ws_frame_t *frame, conn_close_t conn_close)
 static int encode_and_send_ws_frame(ws_frame_t *frame, conn_close_t conn_close)
 {
 {
 	int pos = 0, extended_length;
 	int pos = 0, extended_length;
 	unsigned int frame_length;
 	unsigned int frame_length;
 	char *send_buf;
 	char *send_buf;
+	struct tcp_connection *con;
 	struct dest_info dst;
 	struct dest_info dst;
 
 
 	LM_INFO("encoding WebSocket frame\n");
 	LM_INFO("encoding WebSocket frame\n");
@@ -372,7 +204,12 @@ static int encode_and_send_ws_frame(ws_frame_t *frame, conn_close_t conn_close)
 	}
 	}
 	memcpy(&send_buf[pos], frame->payload_data, frame->payload_len);
 	memcpy(&send_buf[pos], frame->payload_data, frame->payload_len);
 
 
-	init_dst_from_rcv(&dst, &frame->wsc->con->rcv);
+	if ((con = tcpconn_get(frame->wsc->id, 0, 0, 0, 0)) == NULL)
+	{
+		LM_ERR("getting TCP/TLS connection\n");
+		return -1;
+	}
+	init_dst_from_rcv(&dst, &con->rcv);
 	if (conn_close == CONN_CLOSE_DO)
 	if (conn_close == CONN_CLOSE_DO)
 	{
 	{
 		dst.send_flags.f |= SND_F_CON_CLOSE;
 		dst.send_flags.f |= SND_F_CON_CLOSE;
@@ -450,6 +287,172 @@ static int close_connection(ws_connection_t *wsc, ws_close_type_t type,
 	return 0;
 	return 0;
 }
 }
 
 
+static int decode_and_validate_ws_frame(ws_frame_t *frame,
+					tcp_event_info_t *tcpinfo)
+{
+	unsigned int i, len = tcpinfo->len;
+	int mask_start, j;
+	char *buf = tcpinfo->buf;
+
+	LM_INFO("decoding WebSocket frame\n");
+
+	if ((frame->wsc = wsconn_get(tcpinfo->con->id)) == NULL)
+	{
+		LM_WARN("WebSocket connection not found\n");
+		return -1;
+	}
+
+	wsconn_update(frame->wsc);
+
+	/* Decode and validate first 9 bits */
+	if (len < 2)
+	{
+		LM_WARN("message is too short\n");
+		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
+					str_status_protocol_error) < 0)
+			LM_ERR("closing connection\n");
+		return -1;
+	}
+	frame->fin = (buf[0] & 0xff) & BYTE0_MASK_FIN;
+	frame->rsv1 = (buf[0] & 0xff) & BYTE0_MASK_RSV1;
+	frame->rsv2 = (buf[0] & 0xff) & BYTE0_MASK_RSV2;
+	frame->rsv3 = (buf[0] & 0xff) & BYTE0_MASK_RSV3;
+	frame->opcode = (buf[0] & 0xff) & BYTE0_MASK_OPCODE;
+	frame->mask = (buf[1] & 0xff) & BYTE1_MASK_MASK;
+	
+	if (!frame->fin)
+	{
+		LM_WARN("WebSocket fragmentation not supported in the sip "
+			"sub-protocol\n");
+		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
+					str_status_protocol_error) < 0)
+			LM_ERR("closing connection\n");
+		return -1;
+	}
+
+	if (frame->rsv1 || frame->rsv2 || frame->rsv3)
+	{
+		LM_WARN("WebSocket reserved fields with non-zero values\n");
+		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
+					str_status_protocol_error) < 0)
+			LM_ERR("closing connection\n");
+		return -1;
+	}
+
+	switch(frame->opcode)
+	{
+	case OPCODE_TEXT_FRAME:
+	case OPCODE_BINARY_FRAME:
+		LM_INFO("supported non-control frame: 0x%x\n",
+			(unsigned char) frame->opcode);
+		break;
+
+	case OPCODE_CLOSE:
+	case OPCODE_PING:
+	case OPCODE_PONG:
+		LM_INFO("supported control frame: 0x%x\n",
+			(unsigned char) frame->opcode);
+		break;
+
+	default:
+		LM_WARN("unsupported opcode: 0x%x\n",
+			(unsigned char) frame->opcode);
+		if (close_connection(frame->wsc, LOCAL_CLOSE, 1008,
+					str_status_unsupported_opcode) < 0)
+			LM_ERR("closing connection\n");
+		return -1;
+	}
+
+	if (!frame->mask)
+	{
+		LM_WARN("this is a server - all received messages must be "
+			"masked\n");
+		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
+					str_status_protocol_error) < 0)
+			LM_ERR("closing connection\n");
+		return -1;
+	}
+
+	/* Decode and validate length */
+	frame->payload_len = (buf[1] & 0xff) & BYTE1_MASK_PAYLOAD_LEN;
+	if (frame->payload_len == 126)
+	{
+		if (len < 4)
+		{
+			LM_WARN("message is too short\n");
+			if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
+						str_status_protocol_error) < 0)
+				LM_ERR("closing connection\n");
+			return -1;
+		}
+		mask_start = 4;
+
+		frame->payload_len = 	  ((buf[2] & 0xff) <<  8)
+					| ((buf[3] & 0xff) <<  0);
+	}
+	else if (frame->payload_len == 127)
+	{
+		if (len < 10)
+		{
+			LM_WARN("message is too short\n");
+			if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
+						str_status_protocol_error) < 0)
+				LM_ERR("closing connection\n");
+			return -1;
+		}
+		mask_start = 10;
+
+		if ((buf[2] & 0xff) != 0 || (buf[3] & 0xff) != 0
+			|| (buf[4] & 0xff) != 0 || (buf[5] & 0xff) != 0)
+		{
+			LM_WARN("message is too long\n");
+			if (close_connection(frame->wsc, LOCAL_CLOSE, 1009,
+						str_status_message_too_big) < 0)
+				LM_ERR("closing connection\n");
+			return -1;
+		}
+
+		/* Only decoding the last four bytes of the length...
+		   This limits the size of WebSocket messages that can be
+		   handled to 2^32 = which should be plenty for SIP! */
+	 	frame->payload_len =	  ((buf[6] & 0xff) << 24)
+					| ((buf[7] & 0xff) << 16)
+					| ((buf[8] & 0xff) <<  8)
+					| ((buf[9] & 0xff) <<  0);
+	}
+	else
+		mask_start = 2;
+
+	/* Decode mask */
+	frame->masking_key[0] = (buf[mask_start + 0] & 0xff);
+	frame->masking_key[1] = (buf[mask_start + 1] & 0xff);
+	frame->masking_key[2] = (buf[mask_start + 2] & 0xff);
+	frame->masking_key[3] = (buf[mask_start + 3] & 0xff);
+
+	/* Decode and unmask payload */
+	if (len != frame->payload_len + mask_start + 4)
+	{
+		LM_WARN("message not complete frame size %u but received %u\n",
+			frame->payload_len + mask_start + 4, len);
+		if (close_connection(frame->wsc, LOCAL_CLOSE, 1002,
+					str_status_protocol_error) < 0)
+			LM_ERR("closing connection\n");
+		return -1;
+	}
+	frame->payload_data = &buf[mask_start + 4];
+	for (i = 0; i < frame->payload_len; i++)
+	{
+		j = i % 4;
+		frame->payload_data[i]
+			= frame->payload_data[i] ^ frame->masking_key[j];
+	}
+
+	LM_INFO("Rx (decoded): %.*s\n",
+		(int) frame->payload_len, frame->payload_data);
+
+	return frame->opcode;
+}
+
 static int handle_sip_message(ws_frame_t *frame)
 static int handle_sip_message(ws_frame_t *frame)
 {
 {
 	LM_INFO("Received SIP message\n");
 	LM_INFO("Received SIP message\n");
@@ -562,6 +565,26 @@ int ws_frame_received(void *data)
 	return 0;
 	return 0;
 }
 }
 
 
+static int ping_pong(ws_connection_t *wsc, int opcode)
+{
+	ws_frame_t frame;
+
+	memset(&frame, 0, sizeof(frame));
+	frame.fin = 1;
+	frame.opcode = opcode;
+	frame.payload_len = server_hdr.len;
+	frame.payload_data = server_hdr.s;
+	frame.wsc = wsc;
+
+	if (encode_and_send_ws_frame(&frame, CONN_CLOSE_DONT) < 0)
+	{	
+		LM_ERR("closing connection\n");
+		return -1;
+	}
+
+	return 0;
+}
+
 struct mi_root *ws_mi_close(struct mi_root *cmd, void *param)
 struct mi_root *ws_mi_close(struct mi_root *cmd, void *param)
 {
 {
 	unsigned int id;
 	unsigned int id;
@@ -589,7 +612,7 @@ struct mi_root *ws_mi_close(struct mi_root *cmd, void *param)
 					str_status_too_many_params.len);
 					str_status_too_many_params.len);
 	}
 	}
 
 
-	if ((wsc = wsconn_find(tcpconn_get(id, 0, 0, 0, 0))) == NULL)
+	if ((wsc = wsconn_get(id)) == NULL)
 	{
 	{
 		LM_ERR("bad connection ID parameter\n");
 		LM_ERR("bad connection ID parameter\n");
 		return init_mi_tree(400, str_status_bad_param.s,
 		return init_mi_tree(400, str_status_bad_param.s,
@@ -607,26 +630,6 @@ struct mi_root *ws_mi_close(struct mi_root *cmd, void *param)
 	return init_mi_tree(200, MI_OK_S, MI_OK_LEN);
 	return init_mi_tree(200, MI_OK_S, MI_OK_LEN);
 }
 }
 
 
-static int ping_pong(ws_connection_t *wsc, int opcode)
-{
-	ws_frame_t frame;
-
-	memset(&frame, 0, sizeof(frame));
-	frame.fin = 1;
-	frame.opcode = opcode;
-	frame.payload_len = server_hdr.len;
-	frame.payload_data = server_hdr.s;
-	frame.wsc = wsc;
-
-	if (encode_and_send_ws_frame(&frame, CONN_CLOSE_DONT) < 0)
-	{	
-		LM_ERR("closing connection\n");
-		return -1;
-	}
-
-	return 0;
-}
-
 static struct mi_root *mi_ping_pong(struct mi_root *cmd, void *param,
 static struct mi_root *mi_ping_pong(struct mi_root *cmd, void *param,
 					int opcode)
 					int opcode)
 {
 {
@@ -655,7 +658,7 @@ static struct mi_root *mi_ping_pong(struct mi_root *cmd, void *param,
 					str_status_too_many_params.len);
 					str_status_too_many_params.len);
 	}
 	}
 
 
-	if ((wsc = wsconn_find(tcpconn_get(id, 0, 0, 0, 0))) == NULL)
+	if ((wsc = wsconn_get(id)) == NULL)
 	{
 	{
 		LM_ERR("bad connection ID parameter\n");
 		LM_ERR("bad connection ID parameter\n");
 		return init_mi_tree(400, str_status_bad_param.s,
 		return init_mi_tree(400, str_status_bad_param.s,

+ 5 - 1
modules/websocket/ws_handshake.c

@@ -299,7 +299,11 @@ int ws_handle_handshake(struct sip_msg *msg)
 		return 0;
 		return 0;
 
 
 	/* Add the connection to the WebSocket connection table */
 	/* Add the connection to the WebSocket connection table */
-	wsconn_add(con);
+	wsconn_add(con->id);
+
+	/* Make sure Kamailio core sends future messages on this connection
+	   directly to this module */
+	con->flags |= F_CONN_WS;
 
 
 	return 0;
 	return 0;
 }
 }