Browse Source

Fix multiple threading bugs including #699 and #697

David Wu 5 years ago
parent
commit
02d3cd5909
3 changed files with 175 additions and 46 deletions
  1. 6 0
      README.md
  2. 167 45
      httplib.h
  3. 2 1
      test/test.cc

+ 6 - 0
README.md

@@ -645,6 +645,12 @@ cli.set_ca_cert_path("./ca-bundle.crt");
 cli.enable_server_certificate_verification(true);
 cli.enable_server_certificate_verification(true);
 ```
 ```
 
 
+Note: When using SSL, it seems impossible to avoid SIGPIPE in all cases, since on some operating systems, SIGPIPE
+can only be suppressed on a per-message basis, but there is no way to make the OpenSSL library do so for its
+internal communications. If your program needs to avoid being terminated on SIGPIPE, the only fully general way might
+be to set up a signal handler for SIGPIPE to handle or ignore it yourself.
+
+
 Compression
 Compression
 -----------
 -----------
 
 

+ 167 - 45
httplib.h

@@ -932,7 +932,21 @@ protected:
   };
   };
 
 
   virtual bool create_and_connect_socket(Socket &socket);
   virtual bool create_and_connect_socket(Socket &socket);
-  virtual void close_socket(Socket &socket, bool process_socket_ret);
+
+  // All of:
+  //   shutdown_ssl
+  //   shutdown_socket
+  //   close_socket
+  // should ONLY be called when socket_mutex_ is locked.
+  // Also, shutdown_ssl and close_socket should also NOT be called concurrently
+  // with a DIFFERENT thread sending requests using that socket.
+  virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully);
+  void shutdown_socket(Socket &socket);
+  void close_socket(Socket &socket);
+
+  // Similar to shutdown_ssl and close_socket, this should NOT be called
+  // concurrently with a DIFFERENT thread sending requests from the socket
+  void lock_socket_and_shutdown_and_close();
 
 
   bool process_request(Stream &strm, const Request &req, Response &res,
   bool process_request(Stream &strm, const Request &req, Response &res,
                        bool close_connection);
                        bool close_connection);
@@ -943,7 +957,7 @@ protected:
   void copy_settings(const ClientImpl &rhs);
   void copy_settings(const ClientImpl &rhs);
 
 
   // Error state
   // Error state
-  mutable Error error_ = Error::Success;
+  mutable std::atomic<Error> error_;
 
 
   // Socket endoint information
   // Socket endoint information
   const std::string host_;
   const std::string host_;
@@ -955,6 +969,11 @@ protected:
   mutable std::mutex socket_mutex_;
   mutable std::mutex socket_mutex_;
   std::recursive_mutex request_mutex_;
   std::recursive_mutex request_mutex_;
 
 
+  // These are all protected under socket_mutex
+  int socket_requests_in_flight_ = 0;
+  std::thread::id socket_requests_are_from_thread_ = std::thread::id();
+  bool socket_should_be_closed_when_request_is_done_ = false;
+
   // Default headers
   // Default headers
   Headers default_headers_;
   Headers default_headers_;
 
 
@@ -1012,7 +1031,6 @@ private:
   bool redirect(const Request &req, Response &res);
   bool redirect(const Request &req, Response &res);
   bool handle_request(Stream &strm, const Request &req, Response &res,
   bool handle_request(Stream &strm, const Request &req, Response &res,
                       bool close_connection);
                       bool close_connection);
-  void stop_core();
   std::unique_ptr<Response> send_with_content_provider(
   std::unique_ptr<Response> send_with_content_provider(
       const char *method, const char *path, const Headers &headers,
       const char *method, const char *path, const Headers &headers,
       const std::string &body, size_t content_length,
       const std::string &body, size_t content_length,
@@ -1020,7 +1038,8 @@ private:
       ContentProviderWithoutLength content_provider_without_length,
       ContentProviderWithoutLength content_provider_without_length,
       const char *content_type);
       const char *content_type);
 
 
-  virtual bool process_socket(Socket &socket,
+  // socket is const because this function is called when socket_mutex_ is not locked
+  virtual bool process_socket(const Socket &socket,
                               std::function<bool(Stream &strm)> callback);
                               std::function<bool(Stream &strm)> callback);
   virtual bool is_ssl() const;
   virtual bool is_ssl() const;
 };
 };
@@ -1243,9 +1262,9 @@ public:
 
 
 private:
 private:
   bool create_and_connect_socket(Socket &socket) override;
   bool create_and_connect_socket(Socket &socket) override;
-  void close_socket(Socket &socket, bool process_socket_ret) override;
+  void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override;
 
 
-  bool process_socket(Socket &socket,
+  bool process_socket(const Socket &socket,
                       std::function<bool(Stream &strm)> callback) override;
                       std::function<bool(Stream &strm)> callback) override;
   bool is_ssl() const override;
   bool is_ssl() const override;
 
 
@@ -2046,7 +2065,7 @@ inline socket_t create_client_socket(const char *host, int port,
                                      bool tcp_nodelay,
                                      bool tcp_nodelay,
                                      SocketOptions socket_options,
                                      SocketOptions socket_options,
                                      time_t timeout_sec, time_t timeout_usec,
                                      time_t timeout_sec, time_t timeout_usec,
-                                     const std::string &intf, Error &error) {
+                                     const std::string &intf, std::atomic<Error> &error) {
   auto sock = create_socket(
   auto sock = create_socket(
       host, port, 0, tcp_nodelay, std::move(socket_options),
       host, port, 0, tcp_nodelay, std::move(socket_options),
       [&](socket_t sock, struct addrinfo &ai) -> bool {
       [&](socket_t sock, struct addrinfo &ai) -> bool {
@@ -4793,11 +4812,11 @@ inline ClientImpl::ClientImpl(const std::string &host, int port)
 inline ClientImpl::ClientImpl(const std::string &host, int port,
 inline ClientImpl::ClientImpl(const std::string &host, int port,
                               const std::string &client_cert_path,
                               const std::string &client_cert_path,
                               const std::string &client_key_path)
                               const std::string &client_key_path)
-    : host_(host), port_(port),
+    : error_(Error::Success), host_(host), port_(port),
       host_and_port_(host_ + ":" + std::to_string(port_)),
       host_and_port_(host_ + ":" + std::to_string(port_)),
       client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
       client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
 
 
-inline ClientImpl::~ClientImpl() { stop_core(); }
+inline ClientImpl::~ClientImpl() { lock_socket_and_shutdown_and_close(); }
 
 
 inline bool ClientImpl::is_valid() const { return true; }
 inline bool ClientImpl::is_valid() const { return true; }
 
 
@@ -4858,15 +4877,47 @@ inline bool ClientImpl::create_and_connect_socket(Socket &socket) {
   return true;
   return true;
 }
 }
 
 
-inline void ClientImpl::close_socket(Socket &socket,
-                                     bool /*process_socket_ret*/) {
-  detail::close_socket(socket.sock);
-  socket_.sock = INVALID_SOCKET;
+inline void ClientImpl::shutdown_ssl(Socket &socket, bool shutdown_gracefully) {
+  (void)socket;
+  (void)shutdown_gracefully;
+  //If there are any requests in flight from threads other than us, then it's
+  //a thread-unsafe race because individual ssl* objects are not thread-safe. 
+  assert(socket_requests_in_flight_ == 0 ||
+         socket_requests_are_from_thread_ == std::this_thread::get_id());
+}
+
+inline void ClientImpl::shutdown_socket(Socket &socket) {
+  if (socket.sock == INVALID_SOCKET)
+    return;  
+  detail::shutdown_socket(socket.sock);
+}
+ 
+inline void ClientImpl::close_socket(Socket &socket) {
+  // If there are requests in flight in another thread, usually closing
+  // the socket will be fine and they will simply receive an error when
+  // using the closed socket, but it is still a bug since rarely the OS
+  // may reassign the socket id to be used for a new socket, and then
+  // suddenly they will be operating on a live socket that is different
+  // than the one they intended!
+  assert(socket_requests_in_flight_ == 0 ||
+         socket_requests_are_from_thread_ == std::this_thread::get_id());
+  // It is also a bug if this happens while SSL is still active
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-  socket_.ssl = nullptr;
+  assert(socket.ssl == nullptr);
 #endif
 #endif
+  if (socket.sock == INVALID_SOCKET)
+    return;
+  detail::close_socket(socket.sock);
+  socket.sock = INVALID_SOCKET;
 }
 }
 
 
+inline void ClientImpl::lock_socket_and_shutdown_and_close() {
+  std::lock_guard<std::mutex> guard(socket_mutex_);
+  shutdown_ssl(socket_, true);
+  shutdown_socket(socket_);
+  close_socket(socket_);
+}
+ 
 inline bool ClientImpl::read_response_line(Stream &strm, Response &res) {
 inline bool ClientImpl::read_response_line(Stream &strm, Response &res) {
   std::array<char, 2048> buf;
   std::array<char, 2048> buf;
 
 
@@ -4901,11 +4952,23 @@ inline bool ClientImpl::send(const Request &req, Response &res) {
 
 
   {
   {
     std::lock_guard<std::mutex> guard(socket_mutex_);
     std::lock_guard<std::mutex> guard(socket_mutex_);
+    // Set this to false immediately - if it ever gets set to true by the end of the
+    // request, we know another thread instructed us to close the socket.
+    socket_should_be_closed_when_request_is_done_ = false;
 
 
     auto is_alive = false;
     auto is_alive = false;
     if (socket_.is_open()) {
     if (socket_.is_open()) {
       is_alive = detail::select_write(socket_.sock, 0, 0) > 0;
       is_alive = detail::select_write(socket_.sock, 0, 0) > 0;
-      if (!is_alive) { close_socket(socket_, false); }
+      if (!is_alive) {
+        // Attempt to avoid sigpipe by shutting down nongracefully if it seems like
+        // the other side has already closed the connection
+        // Also, there cannot be any requests in flight from other threads since we locked
+        // request_mutex_, so safe to close everything immediately
+        const bool shutdown_gracefully = false;
+        shutdown_ssl(socket_, shutdown_gracefully);
+        shutdown_socket(socket_);
+        close_socket(socket_);
+      }
     }
     }
 
 
     if (!is_alive) {
     if (!is_alive) {
@@ -4926,15 +4989,38 @@ inline bool ClientImpl::send(const Request &req, Response &res) {
       }
       }
 #endif
 #endif
     }
     }
+
+    // Mark the current socket as being in use so that it cannot be closed by anyone
+    // else while this request is ongoing, even though we will be releasing the mutex.
+    if (socket_requests_in_flight_ > 1) {
+      assert(socket_requests_are_from_thread_ == std::this_thread::get_id());
+    }
+    socket_requests_in_flight_ += 1;
+    socket_requests_are_from_thread_ = std::this_thread::get_id();
   }
   }
 
 
   auto close_connection = !keep_alive_;
   auto close_connection = !keep_alive_;
-
   auto ret = process_socket(socket_, [&](Stream &strm) {
   auto ret = process_socket(socket_, [&](Stream &strm) {
     return handle_request(strm, req, res, close_connection);
     return handle_request(strm, req, res, close_connection);
   });
   });
 
 
-  if (close_connection || !ret) { stop_core(); }
+  //Briefly lock mutex in order to mark that a request is no longer ongoing
+  {
+    std::lock_guard<std::mutex> guard(socket_mutex_);
+    socket_requests_in_flight_ -= 1;
+    if (socket_requests_in_flight_ <= 0) {
+      assert(socket_requests_in_flight_ == 0);
+      socket_requests_are_from_thread_ = std::thread::id();
+    }
+
+    if (socket_should_be_closed_when_request_is_done_ ||
+        close_connection ||
+        !ret ) {
+      shutdown_ssl(socket_, true);
+      shutdown_socket(socket_);
+      close_socket(socket_);
+    }
+  }
 
 
   if (!ret) {
   if (!ret) {
     if (error_ == Error::Success) { error_ = Error::Unknown; }
     if (error_ == Error::Success) { error_ = Error::Unknown; }
@@ -5320,7 +5406,16 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
 
 
   if (res.get_header_value("Connection") == "close" ||
   if (res.get_header_value("Connection") == "close" ||
       (res.version == "HTTP/1.0" && res.reason != "Connection established")) {
       (res.version == "HTTP/1.0" && res.reason != "Connection established")) {
-    stop_core();
+    // TODO this requires a not-entirely-obvious chain of calls to be correct
+    // for this to be safe. Maybe a code refactor (such as moving this out to
+    // the send function and getting rid of the recursiveness of the mutex)
+    // could make this more obvious.
+    
+    // This is safe to call because process_request is only called by handle_request
+    // which is only called by send, which locks the request mutex during the process.
+    // It would be a bug to call it from a different thread since it's a thread-safety
+    // issue to do these things to the socket if another thread is using the socket.
+    lock_socket_and_shutdown_and_close();
   }
   }
 
 
   // Log
   // Log
@@ -5330,7 +5425,7 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
 }
 }
 
 
 inline bool
 inline bool
-ClientImpl::process_socket(Socket &socket,
+ClientImpl::process_socket(const Socket &socket,
                            std::function<bool(Stream &strm)> callback) {
                            std::function<bool(Stream &strm)> callback) {
   return detail::process_client_socket(
   return detail::process_client_socket(
       socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
       socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
@@ -5706,18 +5801,27 @@ inline size_t ClientImpl::is_socket_open() const {
 }
 }
 
 
 inline void ClientImpl::stop() {
 inline void ClientImpl::stop() {
-  stop_core();
-  error_ = Error::Canceled;
-}
-
-inline void ClientImpl::stop_core() {
   std::lock_guard<std::mutex> guard(socket_mutex_);
   std::lock_guard<std::mutex> guard(socket_mutex_);
-  if (socket_.is_open()) {
-    detail::shutdown_socket(socket_.sock);
-    std::this_thread::sleep_for(std::chrono::milliseconds(1));
-    close_socket(socket_, true);
-    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  // There is no guarantee that this doesn't get overwritten later, but set it so that
+  // there is a good chance that any threads stopping as a result pick up this error.
+  error_ = Error::Canceled;
+  
+  // If there is anything ongoing right now, the ONLY thread-safe thing we can do
+  // is to shutdown_socket, so that threads using this socket suddenly discover
+  // they can't read/write any more and error out.
+  // Everything else (closing the socket, shutting ssl down) is unsafe because these
+  // actions are not thread-safe.
+  if (socket_requests_in_flight_ > 0) {
+    shutdown_socket(socket_);
+    // Aside from that, we set a flag for the socket to be closed when we're done.
+    socket_should_be_closed_when_request_is_done_ = true;
+    return;
   }
   }
+
+  //Otherwise, sitll holding the mutex, we can shut everything down ourselves
+  shutdown_ssl(socket_, true);
+  shutdown_socket(socket_);
+  close_socket(socket_);
 }
 }
 
 
 inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) {
 inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) {
@@ -5844,9 +5948,12 @@ inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex,
 }
 }
 
 
 inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl,
 inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl,
-                       bool process_socket_ret) {
-  if (process_socket_ret) {
-    SSL_shutdown(ssl); // shutdown only if not already closed by remote
+                       bool shutdown_gracefully) {
+  // sometimes we may want to skip this to try to avoid SIGPIPE if we know
+  // the remote has closed the network connection
+  // Note that it is not always possible to avoid SIGPIPE, this is merely a best-efforts.
+  if (shutdown_gracefully) {
+    SSL_shutdown(ssl); 
   }
   }
 
 
   std::lock_guard<std::mutex> guard(ctx_mutex);
   std::lock_guard<std::mutex> guard(ctx_mutex);
@@ -6108,9 +6215,10 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) {
                                  [&](Request &req) { req.ssl = ssl; });
                                  [&](Request &req) { req.ssl = ssl; });
         });
         });
 
 
-    detail::ssl_delete(ctx_mutex_, ssl, ret);
-    detail::shutdown_socket(sock);
-    detail::close_socket(sock);
+    // Shutdown gracefully if the result seemed successful, non-gracefully if the
+    // connection appeared to be closed.
+    const bool shutdown_gracefully = ret;
+    detail::ssl_delete(ctx_mutex_, ssl, shutdown_gracefully);
     return ret;
     return ret;
   }
   }
 
 
@@ -6167,6 +6275,10 @@ inline SSLClient::SSLClient(const std::string &host, int port,
 
 
 inline SSLClient::~SSLClient() {
 inline SSLClient::~SSLClient() {
   if (ctx_) { SSL_CTX_free(ctx_); }
   if (ctx_) { SSL_CTX_free(ctx_); }
+  // Make sure to shut down SSL since shutdown_ssl will resolve to the
+  // base function rather than the derived function once we get to the
+  // base class destructor, and won't free the SSL (causing a leak).
+  SSLClient::shutdown_ssl(socket_, true);
 }
 }
 
 
 inline bool SSLClient::is_valid() const { return ctx_; }
 inline bool SSLClient::is_valid() const { return ctx_; }
@@ -6200,11 +6312,11 @@ inline bool SSLClient::create_and_connect_socket(Socket &socket) {
   return is_valid() && ClientImpl::create_and_connect_socket(socket);
   return is_valid() && ClientImpl::create_and_connect_socket(socket);
 }
 }
 
 
+// Assumes that socket_mutex_ is locked and that there are no requests in flight
 inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
 inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
                                           bool &success) {
                                           bool &success) {
   success = true;
   success = true;
   Response res2;
   Response res2;
-
   if (!detail::process_client_socket(
   if (!detail::process_client_socket(
           socket.sock, read_timeout_sec_, read_timeout_usec_,
           socket.sock, read_timeout_sec_, read_timeout_usec_,
           write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) {
           write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) {
@@ -6213,7 +6325,10 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
             req2.path = host_and_port_;
             req2.path = host_and_port_;
             return process_request(strm, req2, res2, false);
             return process_request(strm, req2, res2, false);
           })) {
           })) {
-    close_socket(socket, true);
+    // Thread-safe to close everything because we are assuming there are no requests in flight
+    shutdown_ssl(socket, true);
+    shutdown_socket(socket);
+    close_socket(socket);
     success = false;
     success = false;
     return false;
     return false;
   }
   }
@@ -6236,7 +6351,10 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
                       true));
                       true));
                   return process_request(strm, req3, res3, false);
                   return process_request(strm, req3, res3, false);
                 })) {
                 })) {
-          close_socket(socket, true);
+          // Thread-safe to close everything because we are assuming there are no requests in flight
+          shutdown_ssl(socket, true);
+          shutdown_socket(socket);
+          close_socket(socket);
           success = false;
           success = false;
           return false;
           return false;
         }
         }
@@ -6331,21 +6449,25 @@ inline bool SSLClient::initialize_ssl(Socket &socket) {
     return true;
     return true;
   }
   }
 
 
-  close_socket(socket, false);
+  shutdown_socket(socket);
+  close_socket(socket);
   return false;
   return false;
 }
 }
 
 
-inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) {
-  detail::close_socket(socket.sock);
-  socket_.sock = INVALID_SOCKET;
+inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) {
+  if (socket.sock == INVALID_SOCKET) {
+    assert(socket.ssl == nullptr);
+    return;
+  }
   if (socket.ssl) {
   if (socket.ssl) {
-    detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret);
-    socket_.ssl = nullptr;
+    detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully);
+    socket.ssl = nullptr;
   }
   }
+  assert(socket.ssl == nullptr);
 }
 }
 
 
 inline bool
 inline bool
-SSLClient::process_socket(Socket &socket,
+SSLClient::process_socket(const Socket &socket,
                           std::function<bool(Stream &strm)> callback) {
                           std::function<bool(Stream &strm)> callback) {
   assert(socket.ssl);
   assert(socket.ssl);
   return detail::process_client_socket_ssl(
   return detail::process_client_socket_ssl(

+ 2 - 1
test/test.cc

@@ -5,6 +5,7 @@
 #include <chrono>
 #include <chrono>
 #include <future>
 #include <future>
 #include <thread>
 #include <thread>
+#include <atomic>
 
 
 #define SERVER_CERT_FILE "./cert.pem"
 #define SERVER_CERT_FILE "./cert.pem"
 #define SERVER_CERT2_FILE "./cert2.pem"
 #define SERVER_CERT2_FILE "./cert2.pem"
@@ -2761,7 +2762,7 @@ TEST_F(ServerTest, Brotli) {
 // Sends a raw request to a server listening at HOST:PORT.
 // Sends a raw request to a server listening at HOST:PORT.
 static bool send_request(time_t read_timeout_sec, const std::string &req,
 static bool send_request(time_t read_timeout_sec, const std::string &req,
                          std::string *resp = nullptr) {
                          std::string *resp = nullptr) {
-  Error error = Error::Success;
+  std::atomic<Error> error(Error::Success);
 
 
   auto client_sock =
   auto client_sock =
       detail::create_client_socket(HOST, PORT, false, nullptr,
       detail::create_client_socket(HOST, PORT, false, nullptr,