Parcourir la source

Merge branch 'pd/pua_fix'

* pd/pua_fix:
  modules/tm, modules_k/pua: Fix for concurrency issue in PUA module
pd il y a 14 ans
Parent
commit
f2086ceaf8
6 fichiers modifiés avec 191 ajouts et 37 suppressions
  1. 10 1
      modules/tm/uac.c
  2. 2 2
      modules/tm/uac.h
  3. 58 17
      modules_k/pua/hash.c
  4. 1 0
      modules_k/pua/hash.h
  5. 47 15
      modules_k/pua/pua.c
  6. 73 2
      modules_k/pua/send_subscribe.c

+ 10 - 1
modules/tm/uac.c

@@ -707,7 +707,7 @@ int req_within(uac_req_t *uac_r)
  * Send an initial request that will start a dialog
  * WARNING: writes uac_r->dialog
  */
-int req_outside(uac_req_t *uac_r, str* to, str* from)
+int req_outside(uac_req_t *uac_r, str* ruri, str* to, str* from, str *next_hop)
 {
 	str callid, fromtag;
 
@@ -721,6 +721,15 @@ int req_outside(uac_req_t *uac_r, str* to, str* from)
 		goto err;
 	}
 
+	if (ruri) {
+		uac_r->dialog->rem_target.s = ruri->s;
+		uac_r->dialog->rem_target.len = ruri->len;
+		/* hooks will be set from w_calculate_hooks */
+	}
+
+	if (next_hop) uac_r->dialog->dst_uri = *next_hop;
+	w_calculate_hooks(uac_r->dialog);
+
 	return t_uac(uac_r);
 
  err:

+ 2 - 2
modules/tm/uac.h

@@ -83,7 +83,7 @@ extern int goto_on_local_req;
  * Function prototypes
  */
 typedef int (*reqwith_t)(uac_req_t *uac_r);
-typedef int (*reqout_t)(uac_req_t *uac_r, str* to, str* from);
+typedef int (*reqout_t)(uac_req_t *uac_r, str* ruri, str* to, str* from, str *next_hop);
 typedef int (*req_t)(uac_req_t *uac_r, str* ruri, str* to, str* from, str *next_hop);
 typedef int (*t_uac_t)(uac_req_t *uac_r);
 typedef int (*t_uac_with_ids_t)(uac_req_t *uac_r,
@@ -128,7 +128,7 @@ int req_within(uac_req_t *uac_r);
 /*
  * Send an initial request that will start a dialog
  */
-int req_outside(uac_req_t *uac_r, str* to, str* from);
+int req_outside(uac_req_t *uac_r, str* ruri, str* to, str* from, str* next_hop);
 
 
 #ifdef WITH_AS_SUPPORT

+ 58 - 17
modules_k/pua/hash.c

@@ -244,28 +244,31 @@ void insert_htable(ua_pres_t* presentity)
 
 }
 
+/* This function used to perform a search to find the hash table
+   entry that matches the presentity it is passed.  However,
+   everywhere it is used it is passed a pointer to the correct
+   hash table entry already...  so let's just delete that */
 void delete_htable(ua_pres_t* presentity, unsigned int hash_code)
 { 
-	ua_pres_t* p= NULL, *q= NULL;
+	ua_pres_t *q = NULL;
 
-	p= search_htable(presentity, hash_code);
-	if(p== NULL)
+	if (presentity == NULL)
 		return;
 
-	q=HashT->p_records[hash_code].entity;
+	q = HashT->p_records[hash_code].entity;
 
-	while(q->next!=p)
-		q= q->next;
-	q->next=p->next;
+	while (q->next != presentity)
+		q = q->next;
+	q->next = presentity->next;
 	
-	if(p->etag.s)
-		shm_free(p->etag.s);
+	if(presentity->etag.s)
+		shm_free(presentity->etag.s);
 	else
-		if(p->remote_contact.s)
-			shm_free(p->remote_contact.s);
+		if(presentity->remote_contact.s)
+			shm_free(presentity->remote_contact.s);
 
-	shm_free(p);
-	p= NULL;
+	shm_free(presentity);
+	presentity = NULL;
 
 }
 	
@@ -323,7 +326,7 @@ ua_pres_t* get_dialog(ua_pres_t* dialog, unsigned int hash_code)
 			if((p->pres_uri->len== dialog->pres_uri->len) &&
 				(strncmp(p->pres_uri->s, dialog->pres_uri->s,p->pres_uri->len)==0)&&
 				(p->watcher_uri->len== dialog->watcher_uri->len) &&
- 	    		(strncmp(p->watcher_uri->s,dialog->watcher_uri->s,p->watcher_uri->len )==0)&&
+				(strncmp(p->watcher_uri->s,dialog->watcher_uri->s,p->watcher_uri->len )==0)&&
 				(strncmp(p->call_id.s, dialog->call_id.s, p->call_id.len)== 0) &&
 				(strncmp(p->to_tag.s, dialog->to_tag.s, p->to_tag.len)== 0) &&
 				(strncmp(p->from_tag.s, dialog->from_tag.s, p->from_tag.len)== 0) )
@@ -338,6 +341,39 @@ ua_pres_t* get_dialog(ua_pres_t* dialog, unsigned int hash_code)
 	return p;
 }
 
+/* must lock the record line before calling this function*/
+ua_pres_t* get_temporary_dialog(ua_pres_t* dialog, unsigned int hash_code)
+{
+	ua_pres_t* p= NULL, *L;
+	LM_DBG("core_hash= %u\n", hash_code);
+
+	L= HashT->p_records[hash_code].entity;
+	for(p= L->next; p; p=p->next)
+	{
+		LM_DBG("pres_uri= %.*s\twatcher_uri=%.*s\n\t"
+				"callid= %.*s\tfrom_tag= %.*s\n",
+			p->pres_uri->len, p->pres_uri->s, p->watcher_uri->len,
+			p->watcher_uri->s,p->call_id.len, p->call_id.s,
+			p->from_tag.len, p->from_tag.s);
+
+		if((p->pres_uri->len== dialog->pres_uri->len) &&
+			(strncmp(p->pres_uri->s, dialog->pres_uri->s,p->pres_uri->len)==0)&&
+			(p->watcher_uri->len== dialog->watcher_uri->len) &&
+			(strncmp(p->watcher_uri->s,dialog->watcher_uri->s,p->watcher_uri->len )==0)&&
+			(p->call_id.len == dialog->call_id.len) &&
+			(strncmp(p->call_id.s, dialog->call_id.s, p->call_id.len)== 0) &&
+			(p->from_tag.len == dialog->from_tag.len) &&
+			(strncmp(p->from_tag.s, dialog->from_tag.s, p->from_tag.len)== 0) &&
+			p->to_tag.len == 0)
+			{
+				LM_DBG("FOUND temporary dialog\n");
+				break;
+			}
+	}
+
+	return p;
+}
+
 int get_record_id(ua_pres_t* dialog, str** rec_id)
 {
 	unsigned int hash_code;
@@ -352,9 +388,14 @@ int get_record_id(ua_pres_t* dialog, str** rec_id)
 	rec= get_dialog(dialog, hash_code);
 	if(rec== NULL)
 	{
-		LM_DBG("Record not found\n");
-		lock_release(&HashT->p_records[hash_code].lock);
-		return 0;
+		LM_DBG("Record not found - looking for temporary\n");
+		rec = get_temporary_dialog(dialog, hash_code);
+		if (rec == NULL)
+		{
+			LM_DBG("Temporary record not found\n");
+			lock_release(&HashT->p_records[hash_code].lock);
+			return 0;
+		}
 	}
 	id= (str*)pkg_malloc(sizeof(str));
 	if(id== NULL)

+ 1 - 0
modules_k/pua/hash.h

@@ -125,6 +125,7 @@ void destroy_htable(void);
 int is_dialog(ua_pres_t* dialog);
 
 ua_pres_t* get_dialog(ua_pres_t* dialog, unsigned int hash_code);
+ua_pres_t* get_temporary_dialog(ua_pres_t* dialog, unsigned int hash_code);
 
 int get_record_id(ua_pres_t* dialog, str** rec_id);
 typedef int (*get_record_id_t)(ua_pres_t* dialog, str** rec_id);

+ 47 - 15
modules_k/pua/pua.c

@@ -749,14 +749,14 @@ static void db_update(unsigned int ticks,void *param)
 	db_key_t db_cols[5];
 	db_val_t q_vals[20], db_vals[5];
 	db_op_t  db_ops[1] ;
-	int n_query_cols= 0, n_query_update= 0;
+	int n_query_cols= 0, n_query_update= 0, n_actual_query_cols= 0;
 	int n_update_cols= 0;
 	int i;
 	int puri_col,pid_col,expires_col,flag_col,etag_col,tuple_col,event_col;
 	int watcher_col,callid_col,totag_col,fromtag_col,record_route_col,cseq_col;
 	int no_lock= 0, contact_col, desired_expires_col, extra_headers_col;
 	int remote_contact_col, version_col;
-	
+
 	if(ticks== 0 && param == NULL)
 		no_lock= 1;
 
@@ -765,7 +765,7 @@ static void db_update(unsigned int ticks,void *param)
 	q_vals[puri_col].type = DB1_STR;
 	q_vals[puri_col].nul = 0;
 	n_query_cols++;
-	
+
 	q_cols[pid_col= n_query_cols] = &str_pres_id_col;	
 	q_vals[pid_col].type = DB1_STR;
 	q_vals[pid_col].nul = 0;
@@ -1003,21 +1003,43 @@ static void db_update(unsigned int ticks,void *param)
 					q_vals[puri_col].val.str_val = *(p->pres_uri);
 					q_vals[pid_col].val.str_val = p->id;
 					q_vals[flag_col].val.int_val = p->flag;
-					if((p->watcher_uri))
-						q_vals[watcher_col].val.str_val = *(p->watcher_uri);
-					else
-						memset(& q_vals[watcher_col].val.str_val ,0, sizeof(str));
-					q_vals[tuple_col].val.str_val = p->tuple_id;
-					q_vals[etag_col].val.str_val = p->etag;
 					q_vals[callid_col].val.str_val = p->call_id;
-					q_vals[totag_col].val.str_val = p->to_tag;
 					q_vals[fromtag_col].val.str_val = p->from_tag;
 					q_vals[cseq_col].val.int_val= p->cseq;
 					q_vals[expires_col].val.int_val = p->expires;
 					q_vals[desired_expires_col].val.int_val = p->desired_expires;
 					q_vals[event_col].val.int_val = p->event;
 					q_vals[version_col].val.int_val = p->version;
-					
+
+					if((p->watcher_uri))
+						q_vals[watcher_col].val.str_val = *(p->watcher_uri);
+					else
+						memset(& q_vals[watcher_col].val.str_val ,0, sizeof(str));
+
+					if(p->tuple_id.s == NULL)
+					{
+						q_vals[tuple_col].val.str_val.s="";
+						q_vals[tuple_col].val.str_val.len=0;
+					}
+					else
+						q_vals[tuple_col].val.str_val = p->tuple_id;
+
+					if(p->etag.s == NULL)
+					{
+						q_vals[etag_col].val.str_val.s="";
+						q_vals[etag_col].val.str_val.len=0;
+					}
+					else
+						q_vals[etag_col].val.str_val = p->etag;
+
+					if (p->to_tag.s == NULL)
+					{
+						q_vals[totag_col].val.str_val.s="";
+						q_vals[totag_col].val.str_val.len=0;
+					}
+					else
+						q_vals[totag_col].val.str_val = p->to_tag;
+
 					if(p->record_route.s== NULL)
 					{
 						q_vals[record_route_col].val.str_val.s= "";
@@ -1025,8 +1047,15 @@ static void db_update(unsigned int ticks,void *param)
 					}
 					else
 						q_vals[record_route_col].val.str_val = p->record_route;
-					
-					q_vals[contact_col].val.str_val = p->contact;
+
+					if(p->contact.s == NULL)
+					{
+						q_vals[contact_col].val.str_val.s = "";
+						q_vals[contact_col].val.str_val.len = 0;
+					}
+					else
+						q_vals[contact_col].val.str_val = p->contact;
+
 					if(p->remote_contact.s)
 					{
 						q_vals[remote_contact_col].val.str_val = p->remote_contact;
@@ -1039,11 +1068,14 @@ static void db_update(unsigned int ticks,void *param)
 					}
 
 					if(p->extra_headers)
+					{
+						n_actual_query_cols = n_query_cols;
 						q_vals[extra_headers_col].val.str_val = *(p->extra_headers);
+					}
 					else
-						n_query_cols--;
+						n_actual_query_cols = n_query_cols - 1;
 						
-					if(pua_dbf.insert(pua_db, q_cols, q_vals,n_query_cols )<0)
+					if(pua_dbf.insert(pua_db, q_cols, q_vals,n_actual_query_cols )<0)
 					{
 						LM_ERR("while inserting in db table pua\n");
 						if(!no_lock)

+ 73 - 2
modules_k/pua/send_subscribe.c

@@ -344,7 +344,6 @@ void subs_cback_func(struct cell *t, int cb_type, struct tmcb_params *ps)
 		hentity->call_id=  msg->callid->body;
 		hentity->to_tag= pto->tag_value;
 		hentity->from_tag= pfrom->tag_value;
-	
 	}
 
 	/* extract the other necesary information for inserting a new record */		
@@ -608,6 +607,12 @@ done:
 		run_pua_callbacks( hentity, msg);
 	}
 error:	
+	lock_get(&HashT->p_records[hash_code].lock);
+	presentity = get_temporary_dialog(hentity, hash_code);
+	if (presentity!=NULL)
+		delete_htable(presentity, hash_code);
+	lock_release(&HashT->p_records[hash_code].lock);
+
 	if(hentity)
 	{	
 		shm_free(hentity);
@@ -858,6 +863,7 @@ int send_subscribe(subs_info_t* subs)
 	
 	if(presentity== NULL )
 	{
+		int size;
 insert:
 		lock_release(&HashT->p_records[hash_code].lock); 
 		if(subs->flag & UPDATE_TYPE)
@@ -887,7 +893,7 @@ insert:
 
 		set_uac_req(&uac_r, &met, str_hdr, 0, 0, TMCB_LOCAL_COMPLETED,
 				subs_cback_func, (void*)hentity);
-		result= tmb.t_request
+		result= tmb.t_request_outside
 			(&uac_r,						  /* Type of the message */
 		subs->remote_target?subs->remote_target:subs->pres_uri,/* Request-URI*/
 			subs->pres_uri,				  /* To */
@@ -897,9 +903,74 @@ insert:
 		if(result< 0)
 		{
 			LM_ERR("while sending request with t_request\n");
+			if (uac_r.dialog != NULL)
+			{
+				uac_r.dialog->rem_target.s = 0;
+				uac_r.dialog->dst_uri.s = 0;
+				tmb.free_dlg(uac_r.dialog);
+				uac_r.dialog = 0;
+			}
 			shm_free(hentity);
 			goto  done;
 		}
+
+		/* Now create a temporary hash table entry.
+		   This is needed to deal with the race-hazard when NOTIFYs
+		   arrive before the 2xx response to the SUBSCRIBE. */
+		size = sizeof(ua_pres_t)+ 2 * sizeof(str) + (
+			subs->pres_uri->len +
+			subs->watcher_uri->len +
+			uac_r.dialog->id.loc_tag.len +
+			uac_r.dialog->id.call_id.len +
+			subs->id.len) * sizeof(char);
+
+		presentity= (ua_pres_t*)shm_malloc(size);
+		if(presentity== NULL)
+		{
+			LM_ERR("no more share memory\n");
+			goto done;
+		}
+		memset(presentity, 0, size);
+		size= sizeof(ua_pres_t);
+
+		presentity->pres_uri = (str *) ((char *) presentity + size);
+		size += sizeof(str);
+		presentity->pres_uri->s= (char *) presentity + size;
+		memcpy(presentity->pres_uri->s, subs->pres_uri->s, subs->pres_uri->len);
+		presentity->pres_uri->len= subs->pres_uri->len;
+		size+= subs->pres_uri->len;
+
+		presentity->watcher_uri= (str *) ((char *) presentity + size);
+		size += sizeof(str);
+		presentity->watcher_uri->s= (char *) presentity + size;
+		memcpy(presentity->watcher_uri->s, subs->watcher_uri->s, subs->watcher_uri->len);
+		presentity->watcher_uri->len = subs->watcher_uri->len;
+		size += subs->watcher_uri->len;
+
+		presentity->call_id.s = (char *) presentity + size;
+		memcpy(presentity->call_id.s, uac_r.dialog->id.call_id.s, uac_r.dialog->id.call_id.len);
+		presentity->call_id.len = uac_r.dialog->id.call_id.len;
+		size += uac_r.dialog->id.call_id.len;
+
+		presentity->from_tag.s = (char *) presentity + size;
+		memcpy(presentity->from_tag.s, uac_r.dialog->id.loc_tag.s, uac_r.dialog->id.loc_tag.len);
+		presentity->from_tag.len= uac_r.dialog->id.loc_tag.len;
+		size += uac_r.dialog->id.loc_tag.len;
+
+		presentity->id.s = (char *) presentity+ size;
+		memcpy(presentity->id.s, subs->id.s, subs->id.len);
+		presentity->id.len = subs->id.len;
+		size += subs->id.len;
+
+		/* Set the temporary record expiry for 2 * 64T1 seconds from now */
+		presentity->expires= (int)time(NULL) + 64;
+
+		insert_htable(presentity);
+
+		uac_r.dialog->rem_target.s = 0;
+		uac_r.dialog->dst_uri.s = 0;
+		tmb.free_dlg(uac_r.dialog);
+		uac_r.dialog = 0;
 	}
 	else
 	{