Ver Fonte

websocket: Fix crash in websocket module

- Avoid race condition by maintaining a connection reference count
- Fixes FS#406
Vitaliy Aleksandrov há 11 anos atrás
pai
commit
4460dce0e2

+ 161 - 19
modules/websocket/ws_conn.c

@@ -47,6 +47,9 @@ gen_lock_t *wsconn_lock = NULL;
 #define WSCONN_LOCK	lock_get(wsconn_lock)
 #define WSCONN_UNLOCK	lock_release(wsconn_lock)
 
+#define wsconn_ref(c)   atomic_inc(&((c)->refcnt))
+#define wsconn_unref(c) atomic_dec_and_test(&((c)->refcnt))
+
 gen_lock_t *wsstat_lock = NULL;
 
 ws_connection_used_list_t *wsconn_used_list = NULL;
@@ -197,6 +200,8 @@ int wsconn_add(struct receive_info rcv, unsigned int sub_protocol)
 	int id_hash = tcp_id_hash(id);
 	ws_connection_t *wsc;
 
+	LM_DBG("wsconn_add id [%d]\n", id);
+
 	/* Allocate and fill in new WebSocket connection */
 	wsc = shm_malloc(sizeof(ws_connection_t));
 	if (wsc == NULL)
@@ -210,6 +215,10 @@ int wsconn_add(struct receive_info rcv, unsigned int sub_protocol)
 	wsc->state = WS_S_OPEN;
 	wsc->rcv = rcv;
 	wsc->sub_protocol = sub_protocol;
+	wsc->run_event = 0;
+	atomic_set(&wsc->refcnt, 0);
+
+	LM_DBG("wsconn_add new wsc => [%p], ref => [%d]\n", wsc, atomic_get(&wsc->refcnt));
 
 	WSCONN_LOCK;
 	/* Add to WebSocket connection table */
@@ -225,8 +234,12 @@ int wsconn_add(struct receive_info rcv, unsigned int sub_protocol)
 		wsconn_used_list->tail->used_next = wsc;
 		wsconn_used_list->tail = wsc;
 	}
+	wsconn_ref(wsc);
+
 	WSCONN_UNLOCK;
 
+	LM_DBG("wsconn_add added to conn_table wsc => [%p], ref => [%d]\n", wsc, atomic_get(&wsc->refcnt));
+
 	/* Update connection statistics */
 	lock_get(wsstat_lock);
 
@@ -290,32 +303,29 @@ static void wsconn_run_route(ws_connection_t *wsc)
 	set_route_type(backup_rt);
 }
 
-int wsconn_rm(ws_connection_t *wsc, ws_conn_eventroute_t run_event_route)
+static void wsconn_dtor(ws_connection_t *wsc)
 {
 	if (!wsc)
-	{
-		LM_ERR("wsconn_rm: null pointer\n");
-		return -1;
-	}
+		return;
 
-	if (run_event_route == WSCONN_EVENTROUTE_YES)
+	LM_DBG("wsconn_dtor for [%p] refcnt [%d]\n", wsc, atomic_get(&wsc->refcnt));
+
+	if (wsc->run_event)
 		wsconn_run_route(wsc);
 
-	WSCONN_LOCK;
-	/* Remove from the WebSocket used list */
-	if (wsconn_used_list->head == wsc)
-		wsconn_used_list->head = wsc->used_next;
-	if (wsconn_used_list->tail == wsc)
-		wsconn_used_list->tail = wsc->used_prev;
-	if (wsc->used_prev)
-		wsc->used_prev->used_next = wsc->used_next;
-	if (wsc->used_next)
-		wsc->used_next->used_prev = wsc->used_prev;
+	shm_free(wsc);
 
-	_wsconn_rm(wsc);
-	WSCONN_UNLOCK;
+	LM_DBG("wsconn_dtor for [%p] destroyed\n", wsc);
+}
 
-	return 0;
+int wsconn_rm(ws_connection_t *wsc, ws_conn_eventroute_t run_event_route)
+{
+	LM_DBG("wsconn_rm for [%p] refcnt [%d]\n", wsc, atomic_get(&wsc->refcnt));
+
+	if (run_event_route == WSCONN_EVENTROUTE_YES)
+		wsc->run_event = 1;
+
+	return wsconn_put(wsc);
 }
 
 int wsconn_update(ws_connection_t *wsc)
@@ -366,17 +376,70 @@ void wsconn_close_now(ws_connection_t *wsc)
 	con->timeout = get_ticks_raw();
 }
 
+/* must be called with unlocked WSCONN_LOCK */
+int wsconn_put(ws_connection_t *wsc)
+{
+	int destroy = 0;
+
+	LM_DBG("wsconn_put start for [%p] refcnt [%d]\n", wsc, atomic_get(&wsc->refcnt));
+
+	if (!wsc)
+		return -1;
+
+	WSCONN_LOCK;
+	/* refcnt == 0*/
+	if (wsconn_unref(wsc))
+	{
+		/* Remove from the WebSocket used list */
+		if (wsconn_used_list->head == wsc)
+			wsconn_used_list->head = wsc->used_next;
+		if (wsconn_used_list->tail == wsc)
+			wsconn_used_list->tail = wsc->used_prev;
+		if (wsc->used_prev)
+			wsc->used_prev->used_next = wsc->used_next;
+		if (wsc->used_next)
+			wsc->used_next->used_prev = wsc->used_prev;
+
+		/* remove from wsconn_id_hash */
+		wsconn_listrm(wsconn_id_hash[wsc->id_hash], wsc, id_next, id_prev);
+
+		/* stat */
+		update_stat(ws_current_connections, -1);
+		if (wsc->sub_protocol == SUB_PROTOCOL_SIP)
+			update_stat(ws_sip_current_connections, -1);
+		else if (wsc->sub_protocol == SUB_PROTOCOL_MSRP)
+			update_stat(ws_msrp_current_connections, -1);
+
+		destroy = 1;
+	}
+	WSCONN_UNLOCK;
+
+	LM_DBG("wsconn_put end for [%p] refcnt [%d]\n", wsc, atomic_get(&wsc->refcnt));
+
+	/* wsc is removed from all lists and can be destroyed safely */
+	if (destroy)
+		wsconn_dtor(wsc);
+
+	return 0;
+}
+
 ws_connection_t *wsconn_get(int id)
 {
 	int id_hash = tcp_id_hash(id);
 	ws_connection_t *wsc;
 
+	LM_DBG("wsconn_get for id [%d]\n", id);
+
 	WSCONN_LOCK;
 	for (wsc = wsconn_id_hash[id_hash]; wsc; wsc = wsc->id_next)
 	{
 		if (wsc->id == id)
 		{
+			wsconn_ref(wsc);
+			LM_DBG("wsconn_get returns wsc [%p] refcnt [%d]\n", wsc, atomic_get(&wsc->refcnt));
+
 			WSCONN_UNLOCK;
+
 			return wsc;
 		}
 	}
@@ -385,6 +448,85 @@ ws_connection_t *wsconn_get(int id)
 	return NULL;
 }
 
+ws_connection_t **wsconn_get_list(void)
+{
+	ws_connection_t **list = NULL;
+	ws_connection_t *wsc   = NULL;
+	size_t list_size = 0;
+	size_t list_len  = 0;
+	size_t i = 0;
+
+	LM_DBG("wsconn_get_list\n");
+
+	WSCONN_LOCK;
+
+	/* get the number of used connections */
+	wsc = wsconn_used_list->head;
+	while (wsc)
+	{
+		LM_DBG("counter wsc [%p] prev => [%p] next => [%p]\n", wsc, wsc->used_prev, wsc->used_next);
+		list_len++;
+		wsc = wsc->used_next;
+	}
+
+	if (!list_len)
+		goto end;
+
+	/* allocate a NULL terminated list of wsconn pointers */
+	list_size = (list_len + 1) * sizeof(ws_connection_t *);
+	list = pkg_malloc(list_size);
+	if (!list)
+		goto end;
+
+	memset(list, 0, list_size);
+
+	/* copy */
+	wsc = wsconn_used_list->head;
+	for(i = 0; i < list_len; i++)
+	{
+		if (!wsc) {
+			LM_ERR("Wrong list length\n");
+		}
+
+		list[i] = wsc;
+		wsconn_ref(wsc);
+		LM_DBG("wsc [%p] id [%d] ref++\n", wsc, wsc->id);
+
+		wsc = wsc->used_next;
+	}
+	list[list_len] = NULL; /* explicit NULL termination */
+
+end:
+	WSCONN_UNLOCK;
+
+	LM_DBG("wsconn_get_list returns list [%p] with [%d] members\n", list, (int)list_len);
+
+	return list;
+}
+
+int wsconn_put_list(ws_connection_t **list_head)
+{
+	ws_connection_t **list = NULL;
+	ws_connection_t *wsc   = NULL;
+
+	LM_DBG("wsconn_put_list [%p]\n", list_head);
+
+	if (!list_head)
+		return -1;
+
+	list =  list_head;
+	wsc  = *list_head;
+	while (wsc)
+	{
+		wsconn_put(wsc);
+		wsc = *(++list);
+	}
+
+	pkg_free(list_head);
+
+	return 0;
+}
+
 static int add_node(struct mi_root *tree, ws_connection_t *wsc)
 {
 	int interval;

+ 8 - 0
modules/websocket/ws_conn.h

@@ -29,6 +29,8 @@
 #ifndef _WS_CONN_H
 #define _WS_CONN_H
 
+#include "../../atomic_ops.h"
+
 #include "../../lib/kcore/kstats_wrapper.h"
 #include "../../lib/kmi/tree.h"
 
@@ -57,6 +59,9 @@ typedef struct ws_connection
 	struct receive_info rcv;
 
 	unsigned int sub_protocol;
+
+	atomic_t refcnt;
+	int      run_event;
 } ws_connection_t;
 
 typedef struct
@@ -89,6 +94,9 @@ int wsconn_rm(ws_connection_t *wsc, ws_conn_eventroute_t run_event_route);
 int wsconn_update(ws_connection_t *wsc);
 void wsconn_close_now(ws_connection_t *wsc);
 ws_connection_t *wsconn_get(int id);
+int wsconn_put(ws_connection_t *wsc);
+ws_connection_t **wsconn_get_list(void);
+int wsconn_put_list(ws_connection_t **list);
 struct mi_root *ws_mi_dump(struct mi_root *cmd, void *param);
 
 #endif /* _WS_CONN_H */

+ 122 - 58
modules/websocket/ws_frame.c

@@ -240,7 +240,6 @@ static int encode_and_send_ws_frame(ws_frame_t *frame, conn_close_t conn_close)
 		pkg_free(send_buf);
 		if (wsconn_rm(frame->wsc, WSCONN_EVENTROUTE_YES) < 0)
 			LM_ERR("removing WebSocket connection\n");
-		frame->wsc = NULL;
 		return -1;
 	}
 	init_dst_from_rcv(&dst, &con->rcv);
@@ -252,10 +251,8 @@ static int encode_and_send_ws_frame(ws_frame_t *frame, conn_close_t conn_close)
 			LM_ERR("removing WebSocket connection\n");
 			tcpconn_put(con);
 			pkg_free(send_buf);
-			frame->wsc = NULL;
 			return -1;
 		}
-		frame->wsc = NULL;
 	}
 
 	if (dst.proto == PROTO_WS)
@@ -308,7 +305,6 @@ static int encode_and_send_ws_frame(ws_frame_t *frame, conn_close_t conn_close)
 			update_stat(ws_msrp_failed_connections, 1);
 		if (wsconn_rm(frame->wsc, WSCONN_EVENTROUTE_YES) < 0)
 			LM_ERR("removing WebSocket connection\n");
-		frame->wsc = NULL;
 		tcpconn_put(con);
 		return -1;
 	}
@@ -394,20 +390,19 @@ static int close_connection(ws_connection_t **p_wsc, ws_close_type_t type,
 			else if (sub_proto == SUB_PROTOCOL_MSRP)
 				update_stat(ws_msrp_remote_closed_connections,
 						1);
-			*p_wsc = NULL;
 		}
 	}
 	else /* if (frame->wsc->state == WS_S_CLOSING) */
 	{
 		wsconn_close_now(wsc);
-		*p_wsc = NULL;
 	}
 
 	return 0;
 }
 
 static int decode_and_validate_ws_frame(ws_frame_t *frame,
-					tcp_event_info_t *tcpinfo)
+                                        tcp_event_info_t *tcpinfo,
+                                        short *err_code, str *err_text)
 {
 	unsigned int i, len = tcpinfo->len;
 	int mask_start, j;
@@ -415,21 +410,14 @@ static int decode_and_validate_ws_frame(ws_frame_t *frame,
 
 	LM_DBG("decoding WebSocket frame\n");
 
-	if ((frame->wsc = wsconn_get(tcpinfo->con->id)) == NULL)
-	{
-		LM_ERR("WebSocket connection not found\n");
-		return -1;
-	}
-
 	wsconn_update(frame->wsc);
 
 	/* Decode and validate first 9 bits */
 	if (len < 2)
 	{
 		LM_WARN("message is too short\n");
-		if (close_connection(&frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
+		*err_code = 1002;
+		*err_text = str_status_protocol_error;
 		return -1;
 	}
 	frame->fin = (buf[0] & 0xff) & BYTE0_MASK_FIN;
@@ -443,18 +431,16 @@ static int decode_and_validate_ws_frame(ws_frame_t *frame,
 	{
 		LM_WARN("WebSocket fragmentation not supported in the sip "
 			"sub-protocol\n");
-		if (close_connection(&frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
+		*err_code = 1002;
+		*err_text = str_status_protocol_error;
 		return -1;
 	}
 
 	if (frame->rsv1 || frame->rsv2 || frame->rsv3)
 	{
 		LM_WARN("WebSocket reserved fields with non-zero values\n");
-		if (close_connection(&frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
+		*err_code = 1002;
+		*err_text = str_status_protocol_error;
 		return -1;
 	}
 
@@ -476,9 +462,8 @@ static int decode_and_validate_ws_frame(ws_frame_t *frame,
 	default:
 		LM_WARN("unsupported opcode: 0x%x\n",
 			(unsigned char) frame->opcode);
-		if (close_connection(&frame->wsc, LOCAL_CLOSE, 1008,
-					str_status_unsupported_opcode) < 0)
-			LM_ERR("closing connection\n");
+		*err_code = 1008;
+		*err_text = str_status_unsupported_opcode;
 		return -1;
 	}
 
@@ -486,9 +471,8 @@ static int decode_and_validate_ws_frame(ws_frame_t *frame,
 	{
 		LM_WARN("this is a server - all received messages must be "
 			"masked\n");
-		if (close_connection(&frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
+		*err_code = 1002;
+		*err_text = str_status_protocol_error;
 		return -1;
 	}
 
@@ -499,9 +483,8 @@ static int decode_and_validate_ws_frame(ws_frame_t *frame,
 		if (len < 4)
 		{
 			LM_WARN("message is too short\n");
-			if (close_connection(&frame->wsc, LOCAL_CLOSE, 1002,
-						str_status_protocol_error) < 0)
-				LM_ERR("closing connection\n");
+			*err_code = 1002;
+			*err_text = str_status_protocol_error;
 			return -1;
 		}
 		mask_start = 4;
@@ -514,9 +497,8 @@ static int decode_and_validate_ws_frame(ws_frame_t *frame,
 		if (len < 10)
 		{
 			LM_WARN("message is too short\n");
-			if (close_connection(&frame->wsc, LOCAL_CLOSE, 1002,
-						str_status_protocol_error) < 0)
-				LM_ERR("closing connection\n");
+			*err_code = 1002;
+			*err_text = str_status_protocol_error;
 			return -1;
 		}
 		mask_start = 10;
@@ -525,9 +507,8 @@ static int decode_and_validate_ws_frame(ws_frame_t *frame,
 			|| (buf[4] & 0xff) != 0 || (buf[5] & 0xff) != 0)
 		{
 			LM_WARN("message is too long\n");
-			if (close_connection(&frame->wsc, LOCAL_CLOSE, 1009,
-						str_status_message_too_big) < 0)
-				LM_ERR("closing connection\n");
+			*err_code = 1009;
+			*err_text = str_status_message_too_big;
 			return -1;
 		}
 
@@ -553,9 +534,8 @@ static int decode_and_validate_ws_frame(ws_frame_t *frame,
 	{
 		LM_WARN("message not complete frame size %u but received %u\n",
 			frame->payload_len + mask_start + 4, len);
-		if (close_connection(&frame->wsc, LOCAL_CLOSE, 1002,
-					str_status_protocol_error) < 0)
-			LM_ERR("closing connection\n");
+		*err_code = 1002;
+		*err_text = str_status_protocol_error;
 		return -1;
 	}
 	frame->payload_data = &buf[mask_start + 4];
@@ -632,6 +612,11 @@ int ws_frame_receive(void *data)
 	ws_frame_t frame;
 	tcp_event_info_t *tcpinfo = (tcp_event_info_t *) data;
 
+	int opcode      = -1;
+	int ret         = 0;
+	short err_code  = 0;
+	str   err_text  = {NULL, 0};
+
 	update_stat(ws_received_frames, 1);
 
 	if (tcpinfo == NULL || tcpinfo->buf == NULL || tcpinfo->len <= 0)
@@ -640,7 +625,26 @@ int ws_frame_receive(void *data)
 		return -1;
 	}
 
-	switch(decode_and_validate_ws_frame(&frame, tcpinfo))
+	/* wsc refcnt++ */
+	frame.wsc = wsconn_get(tcpinfo->con->id);
+	if (frame.wsc == NULL)
+	{
+		LM_ERR("WebSocket connection not found\n");
+		return -1;
+	}
+
+	opcode = decode_and_validate_ws_frame(&frame, tcpinfo, &err_code, &err_text);
+	if (opcode < 0)
+	{
+		if (close_connection(&frame.wsc, LOCAL_CLOSE, err_code, err_text) < 0)
+			LM_ERR("closing connection\n");
+
+		wsconn_put(frame.wsc);
+
+		return -1;
+	}
+
+	switch(opcode)
 	{
 	case OPCODE_TEXT_FRAME:
 	case OPCODE_BINARY_FRAME:
@@ -649,6 +653,9 @@ int ws_frame_receive(void *data)
 			LM_DBG("Rx SIP message:\n%.*s\n", frame.payload_len,
 				frame.payload_data);
 			update_stat(ws_sip_received_frames, 1);
+
+			wsconn_put(frame.wsc);
+
 			return receive_msg(frame.payload_data,
 						frame.payload_len,
 						tcpinfo->rcv);
@@ -667,30 +674,46 @@ int ws_frame_receive(void *data)
 				tev.len = frame.payload_len;
 				tev.rcv = tcpinfo->rcv;
 				tev.con = tcpinfo->con;
+
+				wsconn_put(frame.wsc);
+
 				return sr_event_exec(SREV_TCP_MSRP_FRAME,
 							(void *) &tev);
 			}
 			else
 			{
 				LM_ERR("no callback registered for MSRP\n");
+
+				wsconn_put(frame.wsc);
+
 				return -1;
 			}
 		}
 
 	case OPCODE_CLOSE:
-		return handle_close(&frame);
+		ret = handle_close(&frame);
+		if (frame.wsc) wsconn_put(frame.wsc);
+		return ret;
 
 	case OPCODE_PING:
-		return handle_ping(&frame);
+		ret = handle_ping(&frame);
+		if (frame.wsc) wsconn_put(frame.wsc);
+		return ret;
 
 	case OPCODE_PONG:
-		return handle_pong(&frame);
+		ret = handle_pong(&frame);
+		if (frame.wsc) wsconn_put(frame.wsc);
+		return ret;
 
 	default:
 		LM_WARN("received bad frame\n");
+		wsconn_put(frame.wsc);
 		return -1;
 	}
 
+	/* how can we get here ? */
+	wsconn_put(frame.wsc);
+
 	return 0;
 }
 
@@ -715,9 +738,14 @@ int ws_frame_transmit(void *data)
 	if (encode_and_send_ws_frame(&frame, CONN_CLOSE_DONT) < 0)
 	{	
 		LM_ERR("sending message\n");
+
+		wsconn_put(frame.wsc);
+
 		return -1;
 	}
 
+	wsconn_put(frame.wsc);
+
 	return 0;
 }
 
@@ -783,8 +811,11 @@ struct mi_root *ws_mi_close(struct mi_root *cmd, void *param)
 					str_status_bad_param.len);
 	}
 
-	if (close_connection(&wsc, LOCAL_CLOSE, 1000,
-				str_status_normal_closure) < 0)
+	int ret = close_connection(&wsc, LOCAL_CLOSE, 1000, str_status_normal_closure);
+
+	wsconn_put(wsc);
+
+	if (ret < 0)
 	{
 		LM_WARN("closing connection\n");
 		return init_mi_tree(500, str_status_error_closing.s,
@@ -834,7 +865,11 @@ static struct mi_root *mi_ping_pong(struct mi_root *cmd, void *param,
 					str_status_bad_param.len);
 	}
 
-	if (ping_pong(wsc, opcode) < 0)
+	int ret = ping_pong(wsc, opcode);
+
+	wsconn_put(wsc);
+
+	if (ret < 0)
 	{
 		LM_WARN("sending %s\n", OPCODE_PING ? "Ping" : "Pong");
 		return init_mi_tree(500, str_status_error_sending.s,
@@ -858,36 +893,55 @@ void ws_keepalive(unsigned int ticks, void *param)
 {
 	int check_time = (int) time(NULL)
 		- cfg_get(websocket, ws_cfg, keepalive_timeout);
-	ws_connection_t *wsc = wsconn_used_list->head;
 
+	ws_connection_t **list      = NULL,
+	                **list_head = NULL;
+	ws_connection_t *wsc   = NULL;
+
+	/* get an array of pointer to all ws connection */
+	list_head = wsconn_get_list();
+	if (!list_head)
+		return;
+
+	list =  list_head;
+	wsc  = *list_head;
 	while (wsc && wsc->last_used < check_time)
 	{
-		if (wsc->state == WS_S_CLOSING
-			|| wsc->awaiting_pong)
+		if (wsc->state == WS_S_CLOSING || wsc->awaiting_pong)
 		{
 			LM_WARN("forcibly closing connection\n");
 			wsconn_close_now(wsc);
 		}
 		else
-			ping_pong(wsconn_used_list->head,
-			  ws_keepalive_mechanism == KEEPALIVE_MECHANISM_PING
-					? OPCODE_PING : OPCODE_PONG);
-		wsc = wsconn_used_list->head;
+		{
+			int opcode = (ws_keepalive_mechanism == KEEPALIVE_MECHANISM_PING)
+			             ? OPCODE_PING
+			             : OPCODE_PONG;
+			ping_pong(wsc, opcode);
+		}
+
+		wsc = *(++list);
 	}
-	
+
+	wsconn_put_list(list_head);
 }
 
 int ws_close(sip_msg_t *msg)
 {
 	ws_connection_t *wsc;
+	int ret;
 
 	if ((wsc = wsconn_get(msg->rcv.proto_reserved1)) == NULL) {
 		LM_ERR("failed to retrieve WebSocket connection\n");
 		return -1;
 	}
 
-	return (close_connection(&wsc, LOCAL_CLOSE, 1000,
+	ret = (close_connection(&wsc, LOCAL_CLOSE, 1000,
 				 str_status_normal_closure) == 0) ? 1: 0;
+
+	wsconn_put(wsc);
+
+	return ret;
 }
 
 int ws_close2(sip_msg_t *msg, char *_status, char *_reason)
@@ -895,6 +949,7 @@ int ws_close2(sip_msg_t *msg, char *_status, char *_reason)
 	int status;
 	str reason;
 	ws_connection_t *wsc;
+	int ret;
 
 	if (get_int_fparam(&status, msg, (fparam_t *) _status) < 0) {
 		LM_ERR("failed to get status code\n");
@@ -911,7 +966,11 @@ int ws_close2(sip_msg_t *msg, char *_status, char *_reason)
 		return -1;
 	}
 
-	return (close_connection(&wsc, LOCAL_CLOSE, status, reason) == 0) ? 1: 0;
+	ret = (close_connection(&wsc, LOCAL_CLOSE, status, reason) == 0) ? 1: 0;
+
+	wsconn_put(wsc);
+
+	return ret;
 }
 
 int ws_close3(sip_msg_t *msg, char *_status, char *_reason, char *_con)
@@ -920,6 +979,7 @@ int ws_close3(sip_msg_t *msg, char *_status, char *_reason, char *_con)
 	str reason;
 	int con;
 	ws_connection_t *wsc;
+	int ret;
 
 	if (get_int_fparam(&status, msg, (fparam_t *) _status) < 0) {
 		LM_ERR("failed to get status code\n");
@@ -941,5 +1001,9 @@ int ws_close3(sip_msg_t *msg, char *_status, char *_reason, char *_con)
 		return -1;
 	}
 
-	return (close_connection(&wsc, LOCAL_CLOSE, status, reason) == 0) ? 1: 0;
+	ret = (close_connection(&wsc, LOCAL_CLOSE, status, reason) == 0) ? 1: 0;
+
+	wsconn_put(wsc);
+
+	return ret;
 }

+ 3 - 1
modules/websocket/ws_handshake.c

@@ -427,8 +427,10 @@ int ws_handle_handshake(struct sip_msg *msg)
 				&headers) < 0)
 	{
 		if ((wsc = wsconn_get(msg->rcv.proto_reserved1)) != NULL)
+		{
 			wsconn_rm(wsc, WSCONN_EVENTROUTE_NO);
-
+			wsconn_put(wsc);
+		}
 		goto end;
 	}
 	else