Browse Source

Allow to specify server IP address (#1067)

* Allow to specify server IP address

* Reimplement in set_hostname_addr_map

* Add tests for set_hostname_addr_map

* Fix tests after implement set_hostname_addr_map

* SpecifyServerIPAddressTest.RealHostname typo
zhenyolka 4 years ago
parent
commit
4f8fcdbaf7
2 changed files with 71 additions and 9 deletions
  1. 37 8
      httplib.h
  2. 34 1
      test/test.cc

+ 37 - 8
httplib.h

@@ -955,6 +955,8 @@ public:
 
   void stop();
 
+  void set_hostname_addr_map(const std::map<std::string, std::string> addr_map);
+
   void set_default_headers(Headers headers);
 
   void set_address_family(int family);
@@ -1058,6 +1060,9 @@ protected:
   std::thread::id socket_requests_are_from_thread_ = std::thread::id();
   bool socket_should_be_closed_when_request_is_done_ = false;
 
+  // Hostname-IP map
+  std::map<std::string, std::string> addr_map_;
+
   // Default headers
   Headers default_headers_;
 
@@ -1285,6 +1290,8 @@ public:
 
   void stop();
 
+  void set_hostname_addr_map(const std::map<std::string, std::string> addr_map);
+
   void set_default_headers(Headers headers);
 
   void set_address_family(int family);
@@ -1656,7 +1663,7 @@ bool process_client_socket(socket_t sock, time_t read_timeout_sec,
                            time_t write_timeout_usec,
                            std::function<bool(Stream &)> callback);
 
-socket_t create_client_socket(const char *host, int port, int address_family,
+socket_t create_client_socket(const char *host, const char *ip, int port, int address_family,
                               bool tcp_nodelay, SocketOptions socket_options,
                               time_t connection_timeout_sec,
                               time_t connection_timeout_usec,
@@ -2453,7 +2460,7 @@ inline int shutdown_socket(socket_t sock) {
 }
 
 template <typename BindOrConnect>
-socket_t create_socket(const char *host, int port, int address_family,
+socket_t create_socket(const char *host, const char *ip, int port, int address_family,
                        int socket_flags, bool tcp_nodelay,
                        SocketOptions socket_options,
                        BindOrConnect bind_or_connect) {
@@ -2467,9 +2474,17 @@ socket_t create_socket(const char *host, int port, int address_family,
   hints.ai_flags = socket_flags;
   hints.ai_protocol = 0;
 
+  // Ask getaddrinfo to convert IP in c-string to address
+  if(ip[0] != '\0') {
+    hints.ai_family = AF_UNSPEC;
+    hints.ai_flags = AI_NUMERICHOST;
+  }
+
   auto service = std::to_string(port);
 
-  if (getaddrinfo(host, service.c_str(), &hints, &result)) {
+  if (ip[0] != '\0' ?
+      getaddrinfo(ip, service.c_str(), &hints, &result) :
+      getaddrinfo(host, service.c_str(), &hints, &result)) {
 #if defined __linux__ && !defined __ANDROID__
     res_init();
 #endif
@@ -2604,13 +2619,13 @@ inline std::string if2ip(const std::string &ifn) {
 #endif
 
 inline socket_t create_client_socket(
-    const char *host, int port, int address_family, bool tcp_nodelay,
+    const char *host, const char *ip, int port, int address_family, bool tcp_nodelay,
     SocketOptions socket_options, time_t connection_timeout_sec,
     time_t connection_timeout_usec, time_t read_timeout_sec,
     time_t read_timeout_usec, time_t write_timeout_sec,
     time_t write_timeout_usec, const std::string &intf, Error &error) {
   auto sock = create_socket(
-      host, port, address_family, 0, tcp_nodelay, std::move(socket_options),
+      host, ip, port, address_family, 0, tcp_nodelay, std::move(socket_options),
       [&](socket_t sock2, struct addrinfo &ai) -> bool {
         if (!intf.empty()) {
 #ifdef USE_IF2IP
@@ -5079,7 +5094,7 @@ inline socket_t
 Server::create_server_socket(const char *host, int port, int socket_flags,
                              SocketOptions socket_options) const {
   return detail::create_socket(
-      host, port, address_family_, socket_flags, tcp_nodelay_,
+      host, "", port, address_family_, socket_flags, tcp_nodelay_,
       std::move(socket_options),
       [](socket_t sock, struct addrinfo &ai) -> bool {
         if (::bind(sock, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen))) {
@@ -5598,13 +5613,19 @@ inline void ClientImpl::copy_settings(const ClientImpl &rhs) {
 inline socket_t ClientImpl::create_client_socket(Error &error) const {
   if (!proxy_host_.empty() && proxy_port_ != -1) {
     return detail::create_client_socket(
-        proxy_host_.c_str(), proxy_port_, address_family_, tcp_nodelay_,
+        proxy_host_.c_str(), "", proxy_port_, address_family_, tcp_nodelay_,
         socket_options_, connection_timeout_sec_, connection_timeout_usec_,
         read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
         write_timeout_usec_, interface_, error);
   }
+  // Check is custom IP specified for host_
+  std::string ip;
+  auto it = addr_map_.find(host_);
+  if(it != addr_map_.end())
+    ip = it->second;
+
   return detail::create_client_socket(
-      host_.c_str(), port_, address_family_, tcp_nodelay_, socket_options_,
+      host_.c_str(), ip.c_str(), port_, address_family_, tcp_nodelay_, socket_options_,
       connection_timeout_sec_, connection_timeout_usec_, read_timeout_sec_,
       read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, interface_,
       error);
@@ -6732,6 +6753,10 @@ inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; }
 
 inline void ClientImpl::set_url_encode(bool on) { url_encode_ = on; }
 
+inline void ClientImpl::set_hostname_addr_map(const std::map<std::string, std::string> addr_map) {
+  addr_map_ = std::move(addr_map);
+}
+
 inline void ClientImpl::set_default_headers(Headers headers) {
   default_headers_ = std::move(headers);
 }
@@ -7855,6 +7880,10 @@ inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); }
 
 inline void Client::stop() { cli_->stop(); }
 
+inline void Client::set_hostname_addr_map(const std::map<std::string, std::string> addr_map) {
+  cli_->set_hostname_addr_map(std::move(addr_map));
+}
+
 inline void Client::set_default_headers(Headers headers) {
   cli_->set_default_headers(std::move(headers));
 }

+ 34 - 1
test/test.cc

@@ -736,6 +736,39 @@ TEST(DigestAuthTest, FromHTTPWatch_Online) {
 }
 #endif
 
+TEST(SpecifyServerIPAddressTest, AnotherHostname) {
+  auto host = "google.com";
+  auto another_host = "example.com";
+  auto wrong_ip = "0.0.0.0";
+  
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+  SSLClient cli(host);
+#else
+  Client cli(host);
+#endif
+
+  cli.set_hostname_addr_map({{another_host, wrong_ip}});
+  auto res = cli.Get("/");
+  ASSERT_TRUE(res);
+  ASSERT_EQ(301, res->status);
+}
+
+TEST(SpecifyServerIPAddressTest, RealHostname) {
+  auto host = "google.com";
+  auto wrong_ip = "0.0.0.0";
+  
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+  SSLClient cli(host);
+#else
+  Client cli(host);
+#endif
+
+  cli.set_hostname_addr_map({{host, wrong_ip}});
+  auto res = cli.Get("/");
+  ASSERT_TRUE(!res);
+  EXPECT_EQ(Error::Connection, res.error());
+}
+
 TEST(AbsoluteRedirectTest, Redirect_Online) {
   auto host = "nghttp2.org";
 
@@ -3321,7 +3354,7 @@ static bool send_request(time_t read_timeout_sec, const std::string &req,
   auto error = Error::Success;
 
   auto client_sock = detail::create_client_socket(
-      HOST, PORT, AF_UNSPEC, false, nullptr,
+      HOST, "", PORT, AF_UNSPEC, false, nullptr,
       /*connection_timeout_sec=*/5, 0,
       /*read_timeout_sec=*/5, 0,
       /*write_timeout_sec=*/5, 0, std::string(), error);