Browse Source

fix a few memory leaks in test_upgrade logic

Christian Grothoff 8 năm trước cách đây
mục cha
commit
46bbae5663
1 tập tin đã thay đổi với 61 bổ sung55 xóa
  1. 61 55
      src/microhttpd/test_upgrade.c

+ 61 - 55
src/microhttpd/test_upgrade.c

@@ -143,7 +143,7 @@ gnutlscli_connect (int *sock,
 /**
 /**
  * Wrapper structure for plain&TLS sockets
  * Wrapper structure for plain&TLS sockets
  */
  */
-struct wr_socket_strc
+struct wr_socket
 {
 {
   /**
   /**
    * Real network socket
    * Real network socket
@@ -178,18 +178,6 @@ struct wr_socket_strc
 };
 };
 
 
 
 
-/**
- * Pseudo type for plain&TLS sockets
- */
-typedef struct wr_socket_strc* wr_socket;
-
-
-/**
- * Invalid value of wr_socket
- */
-#define WR_BAD (NULL)
-
-
 /**
 /**
  * Get underlying real socket.
  * Get underlying real socket.
  * @return FD of real socket
  * @return FD of real socket
@@ -199,34 +187,34 @@ typedef struct wr_socket_strc* wr_socket;
 
 
 /**
 /**
  * Create wr_socket with plain TCP underlying socket
  * Create wr_socket with plain TCP underlying socket
- * @return created socket on success, WR_BAD otherwise
+ * @return created socket on success, NULL otherwise
  */
  */
-static wr_socket
+static struct wr_socket *
 wr_create_plain_sckt(void)
 wr_create_plain_sckt(void)
 {
 {
-  wr_socket s = (wr_socket)malloc(sizeof(struct wr_socket_strc));
-  if (WR_BAD == s)
-    return WR_BAD;
+  struct wr_socket *s = malloc(sizeof(struct wr_socket));
+  if (NULL == s)
+    return NULL;
   s->t = wr_plain;
   s->t = wr_plain;
   s->fd = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
   s->fd = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
   if (MHD_INVALID_SOCKET != s->fd)
   if (MHD_INVALID_SOCKET != s->fd)
     return s;
     return s;
   free(s);
   free(s);
-  return WR_BAD;
+  return NULL;
 }
 }
 
 
 
 
 /**
 /**
  * Create wr_socket with TLS TCP underlying socket
  * Create wr_socket with TLS TCP underlying socket
- * @return created socket on success, WR_BAD otherwise
+ * @return created socket on success, NULL otherwise
  */
  */
-static wr_socket
+static struct wr_socket *
 wr_create_tls_sckt(void)
 wr_create_tls_sckt(void)
 {
 {
 #ifdef HTTPS_SUPPORT
 #ifdef HTTPS_SUPPORT
-  wr_socket s = (wr_socket)malloc(sizeof(struct wr_socket_strc));
-  if (WR_BAD == s)
-    return WR_BAD;
+  struct wr_socket *s = malloc(sizeof(struct wr_socket));
+  if (NULL == s)
+    return NULL;
   s->t = wr_tls;
   s->t = wr_tls;
   s->tls_connected = 0;
   s->tls_connected = 0;
   s->fd = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
   s->fd = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
@@ -256,7 +244,7 @@ wr_create_tls_sckt(void)
     }
     }
   free(s);
   free(s);
 #endif /* HTTPS_SUPPORT */
 #endif /* HTTPS_SUPPORT */
-  return WR_BAD;
+  return NULL;
 }
 }
 
 
 
 
@@ -264,15 +252,15 @@ wr_create_tls_sckt(void)
  * Create wr_socket with plain TCP underlying socket
  * Create wr_socket with plain TCP underlying socket
  * from already created TCP socket.
  * from already created TCP socket.
  * @param plain_sk real TCP socket
  * @param plain_sk real TCP socket
- * @return created socket on success, WR_BAD otherwise
+ * @return created socket on success, NULL otherwise
  */
  */
-static wr_socket
+static struct wr_socket *
 wr_create_from_plain_sckt(MHD_socket plain_sk)
 wr_create_from_plain_sckt(MHD_socket plain_sk)
 {
 {
-  wr_socket s = (wr_socket)malloc(sizeof(struct wr_socket_strc));
+  struct wr_socket *s = malloc(sizeof(struct wr_socket));
 
 
-  if (WR_BAD == s)
-    return WR_BAD;
+  if (NULL == s)
+    return NULL;
   s->t = wr_plain;
   s->t = wr_plain;
   s->fd = plain_sk;
   s->fd = plain_sk;
   return s;
   return s;
@@ -287,7 +275,7 @@ wr_create_from_plain_sckt(MHD_socket plain_sk)
  * @return zero on success, -1 otherwise.
  * @return zero on success, -1 otherwise.
  */
  */
 static int
 static int
-wr_connect(wr_socket s,
+wr_connect(struct wr_socket *s,
            const struct sockaddr *addr,
            const struct sockaddr *addr,
            int length)
            int length)
 {
 {
@@ -312,7 +300,8 @@ wr_connect(wr_socket s,
 
 
 #ifdef HTTPS_SUPPORT
 #ifdef HTTPS_SUPPORT
 /* Only to be called from wr_send() and wr_recv() ! */
 /* Only to be called from wr_send() and wr_recv() ! */
-static bool wr_handshake(wr_socket s)
+static bool
+wr_handshake(struct wr_socket *s)
 {
 {
   int res = gnutls_handshake (s->tls_s);
   int res = gnutls_handshake (s->tls_s);
   if (GNUTLS_E_SUCCESS == res)
   if (GNUTLS_E_SUCCESS == res)
@@ -336,7 +325,7 @@ static bool wr_handshake(wr_socket s)
  *         to get socket error.
  *         to get socket error.
  */
  */
 static ssize_t
 static ssize_t
-wr_send (wr_socket s,
+wr_send (struct wr_socket *s,
          const void *buf,
          const void *buf,
          size_t len)
          size_t len)
 {
 {
@@ -372,7 +361,7 @@ wr_send (wr_socket s,
  *         to get socket error.
  *         to get socket error.
  */
  */
 static ssize_t
 static ssize_t
-wr_recv (wr_socket s,
+wr_recv (struct wr_socket *s,
          void *buf,
          void *buf,
          size_t len)
          size_t len)
 {
 {
@@ -404,7 +393,7 @@ wr_recv (wr_socket s,
  * @return zero on succeed, -1 otherwise
  * @return zero on succeed, -1 otherwise
  */
  */
 static int
 static int
-wr_close (wr_socket s)
+wr_close (struct wr_socket *s)
 {
 {
   int ret = (MHD_socket_close_(s->fd)) ? 0 : -1;
   int ret = (MHD_socket_close_(s->fd)) ? 0 : -1;
 #ifdef HTTPS_SUPPORT
 #ifdef HTTPS_SUPPORT
@@ -414,7 +403,7 @@ wr_close (wr_socket s)
       gnutls_certificate_free_credentials (s->tls_crd);
       gnutls_certificate_free_credentials (s->tls_crd);
     }
     }
 #endif /* HTTPS_SUPPORT */
 #endif /* HTTPS_SUPPORT */
-  free(s);
+  free (s);
   return ret;
   return ret;
 }
 }
 
 
@@ -427,7 +416,7 @@ static pthread_t pt;
 /**
 /**
  * Will be set to the upgraded socket.
  * Will be set to the upgraded socket.
  */
  */
-static wr_socket usock;
+static struct wr_socket *usock;
 
 
 /**
 /**
  * Thread we use to run the interaction with the upgraded socket.
  * Thread we use to run the interaction with the upgraded socket.
@@ -440,20 +429,34 @@ static pthread_t pt_client;
 static volatile bool done;
 static volatile bool done;
 
 
 
 
+/**
+ * Callback used by MHD to notify the application about completed
+ * requests.  Frees memory.
+ *
+ * @param cls client-defined closure
+ * @param connection connection handle
+ * @param con_cls value as set by the last call to
+ *        the #MHD_AccessHandlerCallback
+ * @param toe reason for request termination
+ */
 static void
 static void
 notify_completed_cb (void *cls,
 notify_completed_cb (void *cls,
                      struct MHD_Connection *connection,
                      struct MHD_Connection *connection,
                      void **con_cls,
                      void **con_cls,
                      enum MHD_RequestTerminationCode toe)
                      enum MHD_RequestTerminationCode toe)
 {
 {
+  pthread_t* ppth = *con_cls;
+
   (void)cls; (void)connection;  /* Unused. Silent compiler warning. */
   (void)cls; (void)connection;  /* Unused. Silent compiler warning. */
   if ( (toe != MHD_REQUEST_TERMINATED_COMPLETED_OK) &&
   if ( (toe != MHD_REQUEST_TERMINATED_COMPLETED_OK) &&
        (toe != MHD_REQUEST_TERMINATED_CLIENT_ABORT) &&
        (toe != MHD_REQUEST_TERMINATED_CLIENT_ABORT) &&
        (toe != MHD_REQUEST_TERMINATED_DAEMON_SHUTDOWN) )
        (toe != MHD_REQUEST_TERMINATED_DAEMON_SHUTDOWN) )
     abort ();
     abort ();
-  if (! pthread_equal (**((pthread_t**)con_cls), pthread_self ()))
+  if (! pthread_equal (**((pthread_t**)con_cls),
+                       pthread_self ()))
     abort ();
     abort ();
-  free (*con_cls);
+  if (NULL != ppth)
+    free (*con_cls);
   *con_cls = NULL;
   *con_cls = NULL;
 }
 }
 
 
@@ -561,7 +564,7 @@ make_blocking (MHD_socket fd)
 
 
 
 
 static void
 static void
-send_all (wr_socket sock,
+send_all (struct wr_socket *sock,
           const char *text)
           const char *text)
 {
 {
   size_t len = strlen (text);
   size_t len = strlen (text);
@@ -572,8 +575,8 @@ send_all (wr_socket sock,
   for (off = 0; off < len; off += ret)
   for (off = 0; off < len; off += ret)
     {
     {
       ret = wr_send (sock,
       ret = wr_send (sock,
-                       &text[off],
-                       len - off);
+                     &text[off],
+                     len - off);
       if (0 > ret)
       if (0 > ret)
         {
         {
           if (MHD_SCKT_ERR_IS_EAGAIN_ (MHD_socket_get_error_ ()))
           if (MHD_SCKT_ERR_IS_EAGAIN_ (MHD_socket_get_error_ ()))
@@ -592,7 +595,7 @@ send_all (wr_socket sock,
  * get '\r\n\r\n'.
  * get '\r\n\r\n'.
  */
  */
 static void
 static void
-recv_hdr (wr_socket sock)
+recv_hdr (struct wr_socket *sock)
 {
 {
   unsigned int i;
   unsigned int i;
   char next;
   char next;
@@ -637,7 +640,7 @@ recv_hdr (wr_socket sock)
 
 
 
 
 static void
 static void
-recv_all (wr_socket sock,
+recv_all (struct wr_socket *sock,
           const char *text)
           const char *text)
 {
 {
   size_t len = strlen (text);
   size_t len = strlen (text);
@@ -685,6 +688,8 @@ run_usock (void *cls)
             "Finished");
             "Finished");
   MHD_upgrade_action (urh,
   MHD_upgrade_action (urh,
                       MHD_UPGRADE_ACTION_CLOSE);
                       MHD_UPGRADE_ACTION_CLOSE);
+  free (usock);
+  usock = NULL;
   return NULL;
   return NULL;
 }
 }
 
 
@@ -698,18 +703,18 @@ run_usock (void *cls)
 static void *
 static void *
 run_usock_client (void *cls)
 run_usock_client (void *cls)
 {
 {
-  wr_socket *sock = cls;
+  struct wr_socket *sock = cls;
 
 
-  send_all (*sock,
+  send_all (sock,
             "GET / HTTP/1.1\r\nConnection: Upgrade\r\n\r\n");
             "GET / HTTP/1.1\r\nConnection: Upgrade\r\n\r\n");
-  recv_hdr (*sock);
-  recv_all (*sock,
+  recv_hdr (sock);
+  recv_all (sock,
             "Hello");
             "Hello");
-  send_all (*sock,
+  send_all (sock,
             "World");
             "World");
-  recv_all (*sock,
+  recv_all (sock,
             "Finished");
             "Finished");
-  wr_close (*sock);
+  wr_close (sock);
   done = true;
   done = true;
   return NULL;
   return NULL;
 }
 }
@@ -979,6 +984,7 @@ run_mhd_loop (struct MHD_Daemon *daemon,
     abort ();
     abort ();
 }
 }
 
 
+
 static bool test_tls;
 static bool test_tls;
 
 
 /**
 /**
@@ -992,7 +998,7 @@ test_upgrade (int flags,
               unsigned int pool)
               unsigned int pool)
 {
 {
   struct MHD_Daemon *d = NULL;
   struct MHD_Daemon *d = NULL;
-  wr_socket sock;
+  struct wr_socket *sock;
   struct sockaddr_in sa;
   struct sockaddr_in sa;
   const union MHD_DaemonInfo *real_flags;
   const union MHD_DaemonInfo *real_flags;
   const union MHD_DaemonInfo *dinfo;
   const union MHD_DaemonInfo *dinfo;
@@ -1039,7 +1045,7 @@ test_upgrade (int flags,
   if (!test_tls || TLS_LIB_GNUTLS == use_tls_tool)
   if (!test_tls || TLS_LIB_GNUTLS == use_tls_tool)
     {
     {
       sock = test_tls ? wr_create_tls_sckt () : wr_create_plain_sckt ();
       sock = test_tls ? wr_create_tls_sckt () : wr_create_plain_sckt ();
-      if (WR_BAD == sock)
+      if (NULL == sock)
         abort ();
         abort ();
       sa.sin_family = AF_INET;
       sa.sin_family = AF_INET;
       sa.sin_port = htons (dinfo->port);
       sa.sin_port = htons (dinfo->port);
@@ -1059,7 +1065,7 @@ test_upgrade (int flags,
           return 4;
           return 4;
         }
         }
       sock =  wr_create_from_plain_sckt (tls_fork_sock);
       sock =  wr_create_from_plain_sckt (tls_fork_sock);
-      if (WR_BAD == sock)
+      if (NULL == sock)
         abort ();
         abort ();
 #else  /* !HTTPS_SUPPORT || !HAVE_FORK || !HAVE_WAITPID */
 #else  /* !HTTPS_SUPPORT || !HAVE_FORK || !HAVE_WAITPID */
       abort ();
       abort ();
@@ -1069,7 +1075,7 @@ test_upgrade (int flags,
   if (0 != pthread_create (&pt_client,
   if (0 != pthread_create (&pt_client,
                            NULL,
                            NULL,
                            &run_usock_client,
                            &run_usock_client,
-                           &sock))
+                           sock))
     abort ();
     abort ();
   if (0 == (flags & MHD_USE_INTERNAL_POLLING_THREAD) )
   if (0 == (flags & MHD_USE_INTERNAL_POLLING_THREAD) )
     run_mhd_loop (d, real_flags->flags);
     run_mhd_loop (d, real_flags->flags);