Browse Source

Support remote_addr and remote_port REMOTE_PORT header in client Request (#433)

yhirose 5 years ago
parent
commit
d0b123be26
2 changed files with 58 additions and 36 deletions
  1. 55 36
      httplib.h
  2. 3 0
      test/test.cc

+ 55 - 36
httplib.h

@@ -262,6 +262,9 @@ struct Request {
   Headers headers;
   std::string body;
 
+  std::string remote_addr;
+  int remote_port = -1;
+
   // for server
   std::string version;
   std::string target;
@@ -352,7 +355,7 @@ public:
 
   virtual ssize_t read(char *ptr, size_t size) = 0;
   virtual ssize_t write(const char *ptr, size_t size) = 0;
-  virtual std::string get_remote_addr() const = 0;
+  virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0;
 
   template <typename... Args>
   ssize_t write_format(const char *fmt, const Args &... args);
@@ -1283,7 +1286,7 @@ public:
   bool is_writable() const override;
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
-  std::string get_remote_addr() const override;
+  void get_remote_ip_and_port(std::string &ip, int &port) const override;
 
 private:
   socket_t sock_;
@@ -1302,7 +1305,7 @@ public:
   bool is_writable() const override;
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
-  std::string get_remote_addr() const override;
+  void get_remote_ip_and_port(std::string &ip, int &port) const override;
 
 private:
   socket_t sock_;
@@ -1321,7 +1324,7 @@ public:
   bool is_writable() const override;
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
-  std::string get_remote_addr() const override;
+  void get_remote_ip_and_port(std::string &ip, int &port) const override;
 
   const std::string &get_buffer() const;
 
@@ -1554,21 +1557,32 @@ inline socket_t create_client_socket(const char *host, int port,
       });
 }
 
-inline std::string get_remote_addr(socket_t sock) {
-  struct sockaddr_storage addr;
-  socklen_t len = sizeof(addr);
-
-  if (!getpeername(sock, reinterpret_cast<struct sockaddr *>(&addr), &len)) {
-    std::array<char, NI_MAXHOST> ipstr{};
+inline void get_remote_ip_and_port(const struct sockaddr_storage &addr,
+                                   socklen_t addr_len, std::string &ip,
+                                   int &port) {
+  if (addr.ss_family == AF_INET) {
+    port = ntohs(reinterpret_cast<const struct sockaddr_in *>(&addr)->sin_port);
+  } else if (addr.ss_family == AF_INET6) {
+    port =
+        ntohs(reinterpret_cast<const struct sockaddr_in6 *>(&addr)->sin6_port);
+  }
 
-    if (!getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), len,
-                     ipstr.data(), static_cast<socklen_t>(ipstr.size()),
-                     nullptr, 0, NI_NUMERICHOST)) {
-      return ipstr.data();
-    }
+  std::array<char, NI_MAXHOST> ipstr{};
+  if (!getnameinfo(reinterpret_cast<const struct sockaddr *>(&addr), addr_len,
+                   ipstr.data(), static_cast<socklen_t>(ipstr.size()), nullptr,
+                   0, NI_NUMERICHOST)) {
+    ip = ipstr.data();
   }
+}
 
-  return std::string();
+inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) {
+  struct sockaddr_storage addr;
+  socklen_t addr_len = sizeof(addr);
+
+  if (!getpeername(sock, reinterpret_cast<struct sockaddr *>(&addr),
+                   &addr_len)) {
+    get_remote_ip_and_port(addr, addr_len, ip, port);
+  }
 }
 
 inline const char *
@@ -2910,11 +2924,11 @@ inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec,
 inline SocketStream::~SocketStream() {}
 
 inline bool SocketStream::is_readable() const {
-  return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
+  return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
 }
 
 inline bool SocketStream::is_writable() const {
-  return detail::select_write(sock_, 0, 0) > 0;
+  return select_write(sock_, 0, 0) > 0;
 }
 
 inline ssize_t SocketStream::read(char *ptr, size_t size) {
@@ -2927,8 +2941,9 @@ inline ssize_t SocketStream::write(const char *ptr, size_t size) {
   return -1;
 }
 
-inline std::string SocketStream::get_remote_addr() const {
-  return detail::get_remote_addr(sock_);
+inline void SocketStream::get_remote_ip_and_port(std::string &ip,
+                                                 int &port) const {
+  return detail::get_remote_ip_and_port(sock_, ip, port);
 }
 
 // Buffer stream implementation
@@ -2951,7 +2966,8 @@ inline ssize_t BufferStream::write(const char *ptr, size_t size) {
   return static_cast<ssize_t>(size);
 }
 
-inline std::string BufferStream::get_remote_addr() const { return ""; }
+inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/,
+                                                 int & /*port*/) const {}
 
 inline const std::string &BufferStream::get_buffer() const { return buffer; }
 
@@ -3431,17 +3447,16 @@ inline int Server::bind_internal(const char *host, int port, int socket_flags) {
   if (svr_sock_ == INVALID_SOCKET) { return -1; }
 
   if (port == 0) {
-    struct sockaddr_storage address;
-    socklen_t len = sizeof(address);
-    if (getsockname(svr_sock_, reinterpret_cast<struct sockaddr *>(&address),
-                    &len) == -1) {
+    struct sockaddr_storage addr;
+    socklen_t addr_len = sizeof(addr);
+    if (getsockname(svr_sock_, reinterpret_cast<struct sockaddr *>(&addr),
+                    &addr_len) == -1) {
       return -1;
     }
-    if (address.ss_family == AF_INET) {
-      return ntohs(reinterpret_cast<struct sockaddr_in *>(&address)->sin_port);
-    } else if (address.ss_family == AF_INET6) {
-      return ntohs(
-          reinterpret_cast<struct sockaddr_in6 *>(&address)->sin6_port);
+    if (addr.ss_family == AF_INET) {
+      return ntohs(reinterpret_cast<struct sockaddr_in *>(&addr)->sin_port);
+    } else if (addr.ss_family == AF_INET6) {
+      return ntohs(reinterpret_cast<struct sockaddr_in6 *>(&addr)->sin6_port);
     } else {
       return -1;
     }
@@ -3646,7 +3661,9 @@ Server::process_request(Stream &strm, bool last_connection,
     connection_close = true;
   }
 
-  req.set_header("REMOTE_ADDR", strm.get_remote_addr());
+  strm.get_remote_ip_and_port(req.remote_addr, req.remote_port);
+  req.set_header("REMOTE_ADDR", req.remote_addr);
+  req.set_header("REMOTE_PORT", std::to_string(req.remote_port));
 
   if (req.has_header("Range")) {
     const auto &range_header_value = req.get_header_value("Range");
@@ -4527,8 +4544,8 @@ inline bool process_and_close_socket_ssl(
       auto count = keep_alive_max_count;
       while (count > 0 &&
              (is_client_request ||
-              detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
-                                  CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
+              select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
+                          CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
         SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec);
         auto last_connection = count == 1;
         auto connection_close = false;
@@ -4639,8 +4656,9 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
   return -1;
 }
 
-inline std::string SSLSocketStream::get_remote_addr() const {
-  return detail::get_remote_addr(sock_);
+inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip,
+                                                    int &port) const {
+  detail::get_remote_ip_and_port(sock_, ip, port);
 }
 
 static SSLInit sslinit_;
@@ -5020,7 +5038,8 @@ inline std::shared_ptr<Response> Get(const char *url, Options &options) {
     SSLClient cli(next_host.c_str(), next_port, options.client_cert_path,
                   options.client_key_path);
     cli.set_follow_location(options.follow_location);
-    cli.set_ca_cert_path(options.ca_cert_file_path.c_str(), options.ca_cert_dir_path.c_str());
+    cli.set_ca_cert_path(options.ca_cert_file_path.c_str(),
+                         options.ca_cert_dir_path.c_str());
     cli.enable_server_certificate_verification(
         options.server_certificate_verification);
     return cli.Get(next_path.c_str());

+ 3 - 0
test/test.cc

@@ -801,6 +801,9 @@ protected:
         .Get("/remote_addr",
              [&](const Request &req, Response &res) {
                auto remote_addr = req.headers.find("REMOTE_ADDR")->second;
+               EXPECT_TRUE(req.has_header("REMOTE_PORT"));
+               EXPECT_EQ(req.remote_addr, req.get_header_value("REMOTE_ADDR"));
+               EXPECT_EQ(req.remote_port, std::stoi(req.get_header_value("REMOTE_PORT")));
                res.set_content(remote_addr.c_str(), "text/plain");
              })
         .Get("/endwith%",