Browse Source

Merge pull request #40 from yhirose/connection-timeout

Connection timeout support on Client (Fixed #34)
yhirose 8 years ago
parent
commit
25aa0b34c3
3 changed files with 109 additions and 53 deletions
  1. 5 0
      README.md
  2. 96 30
      httplib.h
  3. 8 23
      test/test.cc

+ 5 - 0
README.md

@@ -118,6 +118,11 @@ params["note"] = "coder";
 auto res = cli.post("/post", params);
 ```
 
+### Connection Timeout
+
+```c++
+httplib::Client cli("localhost", 8080, 5); // timeouts in 5 seconds
+```
 ### With Progress Callback
 
 ```cpp

+ 96 - 30
httplib.h

@@ -27,7 +27,6 @@
 #define S_ISDIR(m)  (((m)&S_IFDIR)==S_IFDIR)
 #endif
 
-#include <fcntl.h>
 #include <io.h>
 #include <winsock2.h>
 #include <ws2tcpip.h>
@@ -57,6 +56,7 @@ typedef int socket_t;
 #include <string>
 #include <thread>
 #include <sys/stat.h>
+#include <fcntl.h>
 #include <assert.h>
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@@ -207,6 +207,8 @@ protected:
 private:
     typedef std::vector<std::pair<std::regex, Handler>> Handlers;
 
+    socket_t create_server_socket(const char* host, int port, int socket_flags) const;
+
     bool routing(Request& req, Response& res);
     bool handle_file_request(Request& req, Response& res);
     bool dispatch_request(Request& req, Response& res, Handlers& handlers);
@@ -226,7 +228,12 @@ private:
 
 class Client {
 public:
-    Client(const char* host, int port, HttpVersion http_version = HttpVersion::v1_0);
+    Client(
+        const char* host,
+        int port = 80,
+        size_t timeout_sec = 300,
+        HttpVersion http_version = HttpVersion::v1_0);
+
     virtual ~Client();
 
     virtual bool is_valid() const;
@@ -250,10 +257,12 @@ protected:
 
     const std::string host_;
     const int         port_;
+    size_t            timeout_sec_;
     const HttpVersion http_version_;
     const std::string host_and_port_;
 
 private:
+    socket_t create_client_socket() const;
     bool read_response_line(Stream& strm, Response& res);
     void write_request(Stream& strm, Request& req);
 
@@ -292,7 +301,12 @@ private:
 
 class SSLClient : public Client {
 public:
-    SSLClient(const char* host, int port, HttpVersion http_version = HttpVersion::v1_0);
+    SSLClient(
+        const char* host,
+        int port = 80,
+        size_t timeout_sec = 300,
+        HttpVersion http_version = HttpVersion::v1_0);
+
     virtual ~SSLClient();
 
     virtual bool is_valid() const;
@@ -406,7 +420,7 @@ inline int close_socket(socket_t sock)
 #endif
 }
 
-inline int select(socket_t sock, size_t sec, size_t usec)
+inline int select_read(socket_t sock, size_t sec, size_t usec)
 {
     fd_set fds;
     FD_ZERO(&fds);
@@ -416,7 +430,28 @@ inline int select(socket_t sock, size_t sec, size_t usec)
     tv.tv_sec = sec;
     tv.tv_usec = usec;
 
-    return ::select(sock + 1, &fds, NULL, NULL, &tv);
+    return select(sock + 1, &fds, NULL, NULL, &tv);
+}
+
+inline bool is_socket_writable(socket_t sock, size_t sec, size_t usec)
+{
+    fd_set fdsw;
+    FD_ZERO(&fdsw);
+    FD_SET(sock, &fdsw);
+
+    fd_set fdse;
+    FD_ZERO(&fdse);
+    FD_SET(sock, &fdse);
+
+    timeval tv;
+    tv.tv_sec = sec;
+    tv.tv_usec = usec;
+
+    if (select(sock + 1, NULL, &fdsw, &fdse, &tv) <= 0) {
+        return false;
+    }
+
+    return FD_ISSET(sock, &fdsw) != 0;
 }
 
 template <typename T>
@@ -427,7 +462,7 @@ inline bool read_and_close_socket(socket_t sock, bool keep_alive, T callback)
     if (keep_alive) {
         auto count = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
         while (count > 0 &&
-               detail::select(sock,
+               detail::select_read(sock,
                    CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
                    CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) {
             auto last_connection = count == 1;
@@ -507,27 +542,24 @@ socket_t create_socket(const char* host, int port, Fn fn, int socket_flags = 0)
     return -1;
 }
 
-inline socket_t create_server_socket(const char* host, int port, int socket_flags)
+inline void set_nonblocking(socket_t sock, bool nonblocking)
 {
-    return create_socket(host, port, [](socket_t sock, struct addrinfo& ai) -> socket_t {
-        if (::bind(sock, ai.ai_addr, ai.ai_addrlen)) {
-              return false;
-        }
-        if (listen(sock, 5)) { // Listen through 5 channels
-            return false;
-        }
-        return true;
-    }, socket_flags);
+#ifdef _WIN32
+    auto flags = nonblocking ? 1UL : 0UL;
+    ioctlsocket(sock, FIONBIO, &flags);
+#else
+    auto flags = fcntl(sock, F_GETFL, 0);
+    fcntl(sock, F_SETFL, nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK)));
+#endif
 }
 
-inline socket_t create_client_socket(const char* host, int port)
+inline bool is_connection_error()
 {
-    return create_socket(host, port, [](socket_t sock, struct addrinfo& ai) -> socket_t {
-        if (connect(sock, ai.ai_addr, ai.ai_addrlen)) {
-            return false;
-        }
-        return true;
-    });
+#ifdef _WIN32
+    return WSAGetLastError() != WSAEWOULDBLOCK;
+#else
+    return errno != EINPROGRESS;
+#endif
 }
 
 inline bool is_file(const std::string& path)
@@ -1339,7 +1371,7 @@ inline bool Server::listen(const char* host, int port, int socket_flags)
         return false;
     }
 
-    svr_sock_ = detail::create_server_socket(host, port, socket_flags);
+    svr_sock_ = create_server_socket(host, port, socket_flags);
     if (svr_sock_ == -1) {
         return false;
     }
@@ -1347,7 +1379,7 @@ inline bool Server::listen(const char* host, int port, int socket_flags)
     auto ret = true;
 
     for (;;) {
-        auto val = detail::select(svr_sock_, 0, 100000);
+        auto val = detail::select_read(svr_sock_, 0, 100000);
 
         if (val == 0) { // Timeout
             if (svr_sock_ == -1) {
@@ -1480,6 +1512,20 @@ inline bool Server::handle_file_request(Request& req, Response& res)
     return false;
 }
 
+inline socket_t Server::create_server_socket(const char* host, int port, int socket_flags) const
+{
+    return detail::create_socket(host, port,
+        [](socket_t sock, struct addrinfo& ai) -> bool {
+            if (::bind(sock, ai.ai_addr, ai.ai_addrlen)) {
+                  return false;
+            }
+            if (::listen(sock, 5)) { // Listen through 5 channels
+                return false;
+            }
+            return true;
+        }, socket_flags);
+}
+
 inline bool Server::routing(Request& req, Response& res)
 {
     if (req.method == "GET" && handle_file_request(req, res)) {
@@ -1590,9 +1636,11 @@ inline bool Server::read_and_close_socket(socket_t sock)
 }
 
 // HTTP client implementation
-inline Client::Client(const char* host, int port, HttpVersion http_version)
+inline Client::Client(
+    const char* host, int port, size_t timeout_sec, HttpVersion http_version)
     : host_(host)
     , port_(port)
+    , timeout_sec_(timeout_sec)
     , http_version_(http_version)
     , host_and_port_(host_ + ":" + std::to_string(port_))
 {
@@ -1607,6 +1655,23 @@ inline bool Client::is_valid() const
     return true;
 }
 
+inline socket_t Client::create_client_socket() const
+{
+    return detail::create_socket(host_.c_str(), port_,
+        [=](socket_t sock, struct addrinfo& ai) -> bool {
+            detail::set_nonblocking(sock, true);
+
+            auto ret = connect(sock, ai.ai_addr, ai.ai_addrlen);
+            if (ret == -1 && detail::is_connection_error()) {
+                return false;
+            }
+
+            detail::set_nonblocking(sock, false);
+
+            return detail::is_socket_writable(sock, timeout_sec_, 0);
+        });
+}
+
 inline bool Client::read_response_line(Stream& strm, Response& res)
 {
     const auto bufsiz = 2048;
@@ -1634,7 +1699,7 @@ inline bool Client::send(Request& req, Response& res)
         return false;
     }
 
-    auto sock = detail::create_client_socket(host_.c_str(), port_);
+    auto sock = create_client_socket();
     if (sock == -1) {
         return false;
     }
@@ -1826,7 +1891,7 @@ inline bool read_and_close_socket_ssl(
     if (keep_alive) {
         auto count = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
         while (count > 0 &&
-               detail::select(sock,
+               detail::select_read(sock,
                    CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
                    CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) {
             auto last_connection = count == 1;
@@ -1936,8 +2001,9 @@ inline bool SSLServer::read_and_close_socket(socket_t sock)
 }
 
 // SSL HTTP client implementation
-inline SSLClient::SSLClient(const char* host, int port, HttpVersion http_version)
-    : Client(host, port, http_version)
+inline SSLClient::SSLClient(
+    const char* host, int port, size_t timeout_sec, HttpVersion http_version)
+    : Client(host, port, timeout_sec, http_version)
 {
     ctx_ = SSL_CTX_new(SSLv23_client_method());
 }

+ 8 - 23
test/test.cc

@@ -63,24 +63,6 @@ TEST(ParseQueryTest, ParseQueryString)
     EXPECT_EQ("val3", dic.find("key3")->second);
 }
 
-TEST(SocketTest, OpenClose)
-{
-    socket_t sock = detail::create_server_socket(HOST, PORT, 0);
-    ASSERT_NE(-1, sock);
-
-    auto ret = detail::close_socket(sock);
-    EXPECT_EQ(0, ret);
-}
-
-TEST(SocketTest, OpenCloseWithAI_PASSIVE)
-{
-    socket_t sock = detail::create_server_socket(nullptr, PORT, AI_PASSIVE);
-    ASSERT_NE(-1, sock);
-
-    auto ret = detail::close_socket(sock);
-    EXPECT_EQ(0, ret);
-}
-
 TEST(GetHeaderValueTest, DefaultValue)
 {
     Headers headers = {{"Dummy","Dummy"}};
@@ -139,13 +121,14 @@ TEST(GetHeaderValueTest, Range)
 void testChunkedEncoding(httplib::HttpVersion ver)
 {
     auto host = "www.httpwatch.com";
+    auto sec = 5;
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
     auto port = 443;
-    httplib::SSLClient cli(host, port, ver);
+    httplib::SSLClient cli(host, port, sec, ver);
 #else
     auto port = 80;
-    httplib::Client cli(host, port, ver);
+    httplib::Client cli(host, port, sec, ver);
 #endif
 
     auto res = cli.get("/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137");
@@ -167,13 +150,15 @@ TEST(ChunkedEncodingTest, FromHTTPWatch)
 TEST(RangeTest, FromHTTPBin)
 {
     auto host = "httpbin.org";
+    auto sec = 5;
+    auto ver = httplib::HttpVersion::v1_1;
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
     auto port = 443;
-    httplib::SSLClient cli(host, port, httplib::HttpVersion::v1_1);
+    httplib::SSLClient cli(host, port, sec, ver);
 #else
     auto port = 80;
-    httplib::Client cli(host, port, httplib::HttpVersion::v1_1);
+    httplib::Client cli(host, port, sec, ver);
 #endif
 
     {
@@ -631,7 +616,7 @@ protected:
             res.set_content("Hello World!", "text/plain");
         });
 
-        t_ = thread([&](){
+        t_ = thread([&]() {
             svr_.listen(nullptr, PORT, AI_PASSIVE);
         });