Browse Source

Changed to use getaddrinfo.

yhirose 11 years ago
parent
commit
efc579b14e
2 changed files with 71 additions and 55 deletions
  1. 66 50
      httplib.h
  2. 5 5
      test/test.cc

+ 66 - 50
httplib.h

@@ -28,6 +28,7 @@
 #include <fcntl.h>
 #include <fcntl.h>
 #include <io.h>
 #include <io.h>
 #include <winsock2.h>
 #include <winsock2.h>
+#include <ws2tcpip.h>
 
 
 typedef SOCKET socket_t;
 typedef SOCKET socket_t;
 #else
 #else
@@ -38,7 +39,6 @@ typedef SOCKET socket_t;
 #include <netinet/in.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
 #include <arpa/inet.h>
 #include <sys/socket.h>
 #include <sys/socket.h>
-#include <sys/stat.h>
 
 
 typedef int socket_t;
 typedef int socket_t;
 #endif
 #endif
@@ -49,6 +49,7 @@ typedef int socket_t;
 #include <memory>
 #include <memory>
 #include <regex>
 #include <regex>
 #include <string>
 #include <string>
+#include <sys/stat.h>
 #include <assert.h>
 #include <assert.h>
 
 
 namespace httplib
 namespace httplib
@@ -175,6 +176,24 @@ inline void get_flie_pointers(int fd, FILE*& fp_read, FILE*& fp_write)
 #endif
 #endif
 }
 }
 
 
+inline int shutdown_socket(socket_t sock)
+{
+#ifdef _MSC_VER
+    return shutdown(sock, SD_BOTH);
+#else
+    return shutdown(sock, SHUT_RDWR);
+#endif
+}
+
+inline int close_socket(socket_t sock)
+{
+#ifdef _MSC_VER
+    return closesocket(sock);
+#else
+    return close(sock);
+#endif
+}
+
 template <typename Fn>
 template <typename Fn>
 socket_t create_socket(const char* host, int port, Fn fn)
 socket_t create_socket(const char* host, int port, Fn fn)
 {
 {
@@ -183,70 +202,66 @@ socket_t create_socket(const char* host, int port, Fn fn)
     setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char*)&opt, sizeof(opt));
     setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char*)&opt, sizeof(opt));
 #endif
 #endif
 
 
-    // Create a socket
-    auto sock = socket(AF_INET, SOCK_STREAM, 0);
-    if (sock == -1) {
-        return -1;
-    }
+    // Get address info
+    struct addrinfo hints;
+    struct addrinfo *result;
 
 
-    // Make 'reuse address' option available
-    int yes = 1;
-    setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(yes));
+    memset(&hints, 0, sizeof(struct addrinfo));
+    hints.ai_family = AF_UNSPEC;
+    hints.ai_socktype = SOCK_STREAM;
+    hints.ai_flags = 0;
+    hints.ai_protocol = 0;
 
 
-    // Get a host entry info
-    struct hostent* hp;
-    if (!(hp = gethostbyname(host))) {
+    auto service = std::to_string(port);
+
+    if (getaddrinfo(host, service.c_str(), &hints, &result)) {
         return -1;
         return -1;
     }
     }
 
 
-    // Bind the socket to the given address
-    struct sockaddr_in addr;
-    memset(&addr, 0, sizeof(addr));
-    memcpy(&addr.sin_addr, hp->h_addr, hp->h_length);
-    addr.sin_family = AF_INET;
-    addr.sin_port = htons(port);
+    for (auto rp = result; rp; rp = rp->ai_next) {
+       // Create a socket
+       auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
+       if (sock == -1) {
+          continue;
+       }
+
+       // Make 'reuse address' option available
+       int yes = 1;
+       setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(yes));
+
+       // bind or connect
+       if (fn(sock, *rp)) {
+          freeaddrinfo(result);
+          return sock;
+       }
 
 
-    return fn(sock, addr);
+       close_socket(sock);
+    }
+
+    freeaddrinfo(result);
+    return -1;
 }
 }
 
 
 inline socket_t create_server_socket(const char* host, int port)
 inline socket_t create_server_socket(const char* host, int port)
 {
 {
-    return create_socket(host, port, [](socket_t sock, struct sockaddr_in& addr) -> socket_t {
-        if (::bind(sock, (struct sockaddr*)&addr, sizeof(addr))) {
-            return -1;
+    return create_socket(host, port, [](socket_t sock, struct addrinfo& ai) -> socket_t {
+        if (::bind(sock, ai.ai_addr, ai.ai_addrlen)) {
+              return false;
         }
         }
         if (listen(sock, 5)) { // Listen through 5 channels
         if (listen(sock, 5)) { // Listen through 5 channels
-            return -1;
+            return false;
         }
         }
-        return sock;
+        return true;
     });
     });
 }
 }
 
 
-inline int shutdown_socket(socket_t sock)
-{
-#ifdef _MSC_VER
-    return shutdown(sock, SD_BOTH);
-#else
-    return shutdown(sock, SHUT_RDWR);
-#endif
-}
-
-inline int close_socket(socket_t sock)
-{
-#ifdef _MSC_VER
-    return closesocket(sock);
-#else
-    return close(sock);
-#endif
-}
-
 inline socket_t create_client_socket(const char* host, int port)
 inline socket_t create_client_socket(const char* host, int port)
 {
 {
-    return create_socket(host, port, [](socket_t sock, struct sockaddr_in& addr) -> socket_t {
-        if (connect(sock, (struct sockaddr*)&addr, sizeof(struct sockaddr_in))) {
-            return -1;
+    return create_socket(host, port, [](socket_t sock, struct addrinfo& ai) -> socket_t {
+        if (connect(sock, ai.ai_addr, ai.ai_addrlen)) {
+            return false;
         }
         }
-        return sock;
+        return true;
     });
     });
 }
 }
 
 
@@ -268,7 +283,7 @@ inline void read_file(const std::string& path, std::string& out)
 	fs.seekg(0, std::ios_base::end);
 	fs.seekg(0, std::ios_base::end);
 	auto size = fs.tellg();
 	auto size = fs.tellg();
 	fs.seekg(0);
 	fs.seekg(0);
-	out.assign(size, 0);
+   out.resize(size);
 	fs.read(&out[0], size);
 	fs.read(&out[0], size);
 }
 }
 
 
@@ -370,9 +385,10 @@ inline void write_headers(FILE* fp, const T& res)
         }
         }
     }
     }
 
 
+    auto t = get_header_value(res.headers, "Content-Type", "text/plain");
+    fprintf(fp, "Content-Type: %s\r\n", t);
+
     if (!res.body.empty()) {
     if (!res.body.empty()) {
-        auto t = get_header_value(res.headers, "Content-Type", "text/plain");
-        fprintf(fp, "Content-Type: %s\r\n", t);
         fprintf(fp, "Content-Length: %ld\r\n", res.body.size());
         fprintf(fp, "Content-Length: %ld\r\n", res.body.size());
     }
     }
 
 
@@ -447,7 +463,7 @@ inline int from_hex_to_i(const std::string& s, int i, int cnt, int& val)
     return --i;
     return --i;
 }
 }
 
 
-size_t to_utf8(int code, char* buff)
+inline size_t to_utf8(int code, char* buff)
 {
 {
     if (code < 0x0080) {
     if (code < 0x0080) {
         buff[0] = (code & 0x7F);
         buff[0] = (code & 0x7F);

+ 5 - 5
test/test.cc

@@ -105,15 +105,15 @@ protected:
     virtual void SetUp() {
     virtual void SetUp() {
 		svr_.set_base_dir("./www");
 		svr_.set_base_dir("./www");
 
 
-        svr_.get("/hi", [&](const auto& req, auto& res) {
+        svr_.get("/hi", [&](const Request& req, Response& res) {
             res.set_content("Hello World!", "text/plain");
             res.set_content("Hello World!", "text/plain");
         });
         });
 
 
-        svr_.get("/", [&](const auto& req, auto& res) {
+        svr_.get("/", [&](const Request& req, Response& res) {
             res.set_redirect("/hi");
             res.set_redirect("/hi");
         });
         });
 
 
-        svr_.post("/person", [&](const auto& req, auto& res) {
+        svr_.post("/person", [&](const Request& req, Response& res) {
             if (req.has_param("name") && req.has_param("note")) {
             if (req.has_param("name") && req.has_param("note")) {
                 persons_[req.params.at("name")] = req.params.at("note");
                 persons_[req.params.at("name")] = req.params.at("note");
             } else {
             } else {
@@ -121,7 +121,7 @@ protected:
             }
             }
         });
         });
 
 
-        svr_.get("/person/(.*)", [&](const auto& req, auto& res) {
+        svr_.get("/person/(.*)", [&](const Request& req, Response& res) {
             string name = req.matches[1];
             string name = req.matches[1];
             if (persons_.find(name) != persons_.end()) {
             if (persons_.find(name) != persons_.end()) {
                 auto note = persons_[name];
                 auto note = persons_[name];
@@ -131,7 +131,7 @@ protected:
             }
             }
         });
         });
 
 
-        svr_.get("/stop", [&](const auto& req, auto& res) {
+        svr_.get("/stop", [&](const Request& req, Response& res) {
             svr_.stop();
             svr_.stop();
         });
         });