yhirose 5 years ago
parent
commit
24bdb736f0
2 changed files with 46 additions and 18 deletions
  1. 43 16
      httplib.h
  2. 3 2
      test/test.cc

+ 43 - 16
httplib.h

@@ -515,6 +515,26 @@ private:
 
 using Logger = std::function<void(const Request &, const Response &)>;
 
+using SocketOptions = std::function<void(socket_t sock)>;
+
+inline void default_socket_options(socket_t sock) {
+  int yes = 1;
+#ifdef _WIN32
+  setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char *>(&yes),
+             sizeof(yes));
+  setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE,
+             reinterpret_cast<char *>(&yes), sizeof(yes));
+#else
+#ifdef SO_REUSEPORT
+  setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast<void *>(&yes),
+             sizeof(yes));
+#else
+  setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&yes),
+             sizeof(yes));
+#endif
+#endif
+}
+
 class Server {
 public:
   using Handler = std::function<void(const Request &, Response &)>;
@@ -549,9 +569,10 @@ public:
   void set_file_request_handler(Handler handler);
 
   void set_error_handler(Handler handler);
+  void set_expect_100_continue_handler(Expect100ContinueHandler handler);
   void set_logger(Logger logger);
 
-  void set_expect_100_continue_handler(Expect100ContinueHandler handler);
+  void set_socket_options(SocketOptions socket_options);
 
   void set_keep_alive_max_count(size_t count);
   void set_read_timeout(time_t sec, time_t usec = 0);
@@ -590,8 +611,8 @@ private:
   using HandlersForContentReader =
       std::vector<std::pair<std::regex, HandlerWithContentReader>>;
 
-  socket_t create_server_socket(const char *host, int port,
-                                int socket_flags) const;
+  socket_t create_server_socket(const char *host, int port, int socket_flags,
+                                SocketOptions socket_options) const;
   int bind_internal(const char *host, int port, int socket_flags);
   bool listen_internal();
 
@@ -639,6 +660,7 @@ private:
   Handler error_handler_;
   Logger logger_;
   Expect100ContinueHandler expect_100_continue_handler_;
+  SocketOptions socket_options_ = default_socket_options;
 };
 
 class Client {
@@ -1873,9 +1895,10 @@ inline int shutdown_socket(socket_t sock) {
 #endif
 }
 
-template <typename Fn>
-socket_t create_socket(const char *host, int port, Fn fn,
-                       int socket_flags = 0) {
+template <typename BindOrConnect>
+socket_t create_socket(const char *host, int port, int socket_flags,
+                       SocketOptions socket_options,
+                       BindOrConnect bind_or_connect) {
   // Get address info
   struct addrinfo hints;
   struct addrinfo *result;
@@ -1923,6 +1946,8 @@ socket_t create_socket(const char *host, int port, Fn fn,
     if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; }
 #endif
 
+    if (socket_options) { socket_options(sock); }
+
     // Make 'reuse address' option available
     int yes = 1;
     setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char *>(&yes),
@@ -1940,7 +1965,7 @@ socket_t create_socket(const char *host, int port, Fn fn,
     }
 
     // bind or connect
-    if (fn(sock, *rp)) {
+    if (bind_or_connect(sock, *rp)) {
       freeaddrinfo(result);
       return sock;
     }
@@ -2017,10 +2042,12 @@ inline std::string if2ip(const std::string &ifn) {
 #endif
 
 inline socket_t create_client_socket(const char *host, int port,
+                                     SocketOptions socket_options,
                                      time_t timeout_sec, time_t timeout_usec,
                                      const std::string &intf) {
   return create_socket(
-      host, port, [&](socket_t sock, struct addrinfo &ai) -> bool {
+      host, port, 0, socket_options,
+      [&](socket_t sock, struct addrinfo &ai) -> bool {
         if (!intf.empty()) {
 #ifndef _WIN32
           auto ip = if2ip(intf);
@@ -3984,10 +4011,11 @@ inline bool Server::handle_file_request(Request &req, Response &res,
   return false;
 }
 
-inline socket_t Server::create_server_socket(const char *host, int port,
-                                             int socket_flags) const {
+inline socket_t
+Server::create_server_socket(const char *host, int port, int socket_flags,
+                             SocketOptions socket_options) const {
   return detail::create_socket(
-      host, port,
+      host, port, socket_flags, socket_options,
       [](socket_t sock, struct addrinfo &ai) -> bool {
         if (::bind(sock, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen))) {
           return false;
@@ -3996,14 +4024,13 @@ inline socket_t Server::create_server_socket(const char *host, int port,
           return false;
         }
         return true;
-      },
-      socket_flags);
+      });
 }
 
 inline int Server::bind_internal(const char *host, int port, int socket_flags) {
   if (!is_valid()) { return -1; }
 
-  svr_sock_ = create_server_socket(host, port, socket_flags);
+  svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_);
   if (svr_sock_ == INVALID_SOCKET) { return -1; }
 
   if (port == 0) {
@@ -4293,10 +4320,10 @@ inline bool Client::is_valid() const { return true; }
 inline socket_t Client::create_client_socket() const {
   if (!proxy_host_.empty()) {
     return detail::create_client_socket(proxy_host_.c_str(), proxy_port_,
-                                        connection_timeout_sec_,
+                                        nullptr, connection_timeout_sec_,
                                         connection_timeout_usec_, interface_);
   }
-  return detail::create_client_socket(host_.c_str(), port_,
+  return detail::create_client_socket(host_.c_str(), port_, nullptr,
                                       connection_timeout_sec_,
                                       connection_timeout_usec_, interface_);
 }

+ 3 - 2
test/test.cc

@@ -2294,8 +2294,9 @@ TEST_F(ServerTest, MultipartFormDataGzip) {
 // Sends a raw request to a server listening at HOST:PORT.
 static bool send_request(time_t read_timeout_sec, const std::string &req,
                          std::string *resp = nullptr) {
-  auto client_sock = detail::create_client_socket(HOST, PORT, /*timeout_sec=*/5, 0,
-                                                  std::string());
+  auto client_sock = detail::create_client_socket(
+      HOST, PORT, nullptr,
+      /*timeout_sec=*/5, 0, std::string());
 
   if (client_sock == INVALID_SOCKET) { return false; }