Răsfoiți Sursa

websocket: use the list with ids for ws connections to do ping-pong

- avoid using pointers to ws connections, they can get closed
Daniel-Constantin Mierla 6 ani în urmă
părinte
comite
b5253b6209

+ 121 - 4
src/modules/websocket/ws_conn.c

@@ -382,8 +382,8 @@ void wsconn_close_now(ws_connection_t *wsc)
 	con->timeout = get_ticks_raw();
 }
 
-/* must be called with unlocked WSCONN_LOCK */
-int wsconn_put(ws_connection_t *wsc)
+/* mode controls if lock needs to be aquired */
+int wsconn_put_mode(ws_connection_t *wsc, int mode)
 {
 	int destroy = 0;
 
@@ -393,7 +393,9 @@ int wsconn_put(ws_connection_t *wsc)
 	if(!wsc)
 		return -1;
 
-	WSCONN_LOCK;
+	if(mode) {
+		WSCONN_LOCK;
+	}
 	/* refcnt == 0*/
 	if(wsconn_unref(wsc)) {
 		/* Remove from the WebSocket used list */
@@ -418,7 +420,9 @@ int wsconn_put(ws_connection_t *wsc)
 
 		destroy = 1;
 	}
-	WSCONN_UNLOCK;
+	if(mode) {
+		WSCONN_UNLOCK;
+	}
 
 	LM_DBG("wsconn_put end for [%p] refcnt [%d]\n", wsc,
 			atomic_get(&wsc->refcnt));
@@ -430,6 +434,12 @@ int wsconn_put(ws_connection_t *wsc)
 	return 0;
 }
 
+/* must be called with unlocked WSCONN_LOCK */
+int wsconn_put(ws_connection_t *wsc)
+{
+	return wsconn_put_mode(wsc, 1);
+}
+
 ws_connection_t *wsconn_get(int id)
 {
 	int id_hash = tcp_id_hash(id);
@@ -454,6 +464,30 @@ ws_connection_t *wsconn_get(int id)
 	return NULL;
 }
 
+int wsconn_put_id(int id)
+{
+	int id_hash = tcp_id_hash(id);
+	ws_connection_t *wsc;
+
+	LM_DBG("wsconn put id [%d]\n", id);
+
+	WSCONN_LOCK;
+	for(wsc = wsconn_id_hash[id_hash]; wsc; wsc = wsc->id_next) {
+		if(wsc->id == id) {
+			LM_DBG("wsc [%p] refcnt [%d]\n", wsc,
+					atomic_get(&wsc->refcnt));
+			wsconn_put_mode(wsc, 0);
+
+			WSCONN_UNLOCK;
+
+			return 1;
+		}
+	}
+	WSCONN_UNLOCK;
+
+	return 0;
+}
+
 ws_connection_t **wsconn_get_list(void)
 {
 	ws_connection_t **list = NULL;
@@ -539,6 +573,89 @@ int wsconn_put_list(ws_connection_t **list_head)
 }
 
 
+ws_connection_id_t *wsconn_get_list_ids(void)
+{
+	ws_connection_id_t *list = NULL;
+	ws_connection_t *wsc = NULL;
+	size_t list_size = 0;
+	size_t list_len = 0;
+	size_t i = 0;
+
+	if(ws_verbose_list)
+		LM_DBG("wsconn get list ids - starting\n");
+
+	WSCONN_LOCK;
+
+	/* get the number of used connections */
+	wsc = wsconn_used_list->head;
+	while(wsc) {
+		if(ws_verbose_list)
+			LM_DBG("counter wsc [%p] prev => [%p] next => [%p]\n", wsc,
+					wsc->used_prev, wsc->used_next);
+		list_len++;
+		wsc = wsc->used_next;
+	}
+
+	if(!list_len)
+		goto end;
+
+	/* allocate a NULL terminated list of wsconn pointers */
+	list_size = (list_len + 1) * sizeof(ws_connection_id_t);
+	list = pkg_malloc(list_size);
+	if(!list)
+		goto end;
+
+	memset(list, 0, list_size);
+
+	/* copy */
+	wsc = wsconn_used_list->head;
+	for(i = 0; i < list_len; i++) {
+		if(!wsc) {
+			LM_ERR("Wrong list length\n");
+			break;
+		}
+
+		list[i].id = wsc->id;
+		wsconn_ref(wsc);
+		if(ws_verbose_list)
+			LM_DBG("wsc [%p] id [%d] ref++\n", wsc, wsc->id);
+
+		wsc = wsc->used_next;
+	}
+	list[i].id = -1; /* explicit -1 termination */
+
+end:
+	WSCONN_UNLOCK;
+
+	if(ws_verbose_list)
+		LM_DBG("wsconn get list id returns list [%p]"
+			   " with [%d] members\n",
+				list, (int)list_len);
+
+	return list;
+}
+
+int wsconn_put_list_ids(ws_connection_id_t *list_head)
+{
+	ws_connection_id_t *list = NULL;
+	int i;
+
+	LM_DBG("wsconn put list id [%p]\n", list_head);
+
+	if(!list_head)
+		return -1;
+
+	list = list_head;
+	for(i=0; list[i].id!=-1; i++) {
+		wsconn_put_id(list[i].id);
+	}
+
+	pkg_free(list_head);
+
+	return 0;
+}
+
+
 static int ws_rpc_add_node(
 		rpc_t *rpc, void *ctx, void *ih, ws_connection_t *wsc)
 {

+ 8 - 0
src/modules/websocket/ws_conn.h

@@ -63,6 +63,11 @@ typedef struct ws_connection
 	str frag_buf;
 } ws_connection_t;
 
+typedef struct ws_connection_id
+{
+	int id;
+} ws_connection_id_t;
+
 typedef struct
 {
 	ws_connection_t *head;
@@ -95,5 +100,8 @@ ws_connection_t *wsconn_get(int id);
 int wsconn_put(ws_connection_t *wsc);
 ws_connection_t **wsconn_get_list(void);
 int wsconn_put_list(ws_connection_t **list);
+ws_connection_id_t *wsconn_get_list_ids(void);
+int wsconn_put_list_ids(ws_connection_id_t *list);
+int wsconn_put_id(int id);
 void ws_rpc_dump(rpc_t *rpc, void *ctx);
 #endif /* _WS_CONN_H */

+ 18 - 13
src/modules/websocket/ws_frame.c

@@ -796,31 +796,36 @@ void ws_keepalive(unsigned int ticks, void *param)
 	int check_time =
 			(int)time(NULL) - cfg_get(websocket, ws_cfg, keepalive_timeout);
 
-	ws_connection_t **list = NULL, **list_head = NULL;
+	ws_connection_id_t *list_head = NULL;
 	ws_connection_t *wsc = NULL;
+	int i = 0;
 
 	/* get an array of pointer to all ws connection */
-	list_head = wsconn_get_list();
+	list_head = wsconn_get_list_ids();
 	if(!list_head)
 		return;
 
-	list = list_head;
-	wsc = *list_head;
-	while(wsc && wsc->last_used < check_time) {
-		if(wsc->state == WS_S_CLOSING || wsc->awaiting_pong) {
-			LM_WARN("forcibly closing connection\n");
-			wsconn_close_now(wsc);
-		} else {
-			int opcode = (ws_keepalive_mechanism == KEEPALIVE_MECHANISM_PING)
+	while(list_head[i].id!=-1) {
+		wsc = wsconn_get(list_head[i].id);
+		if(wsc && wsc->last_used < check_time) {
+			if(wsc->state == WS_S_CLOSING || wsc->awaiting_pong) {
+				LM_WARN("forcibly closing connection\n");
+				wsconn_close_now(wsc);
+			} else {
+				int opcode = (ws_keepalive_mechanism == KEEPALIVE_MECHANISM_PING)
 								 ? OPCODE_PING
 								 : OPCODE_PONG;
-			ping_pong(wsc, opcode);
+				ping_pong(wsc, opcode);
+			}
+		}
+		if(wsc) {
+			wsconn_get(list_head[i].id);
 		}
+		i++;
 
-		wsc = *(++list);
 	}
 
-	wsconn_put_list(list_head);
+	wsconn_put_list_ids(list_head);
 }
 
 int ws_close(sip_msg_t *msg)