Browse Source

Support LOCAL_ADDR and LOCAL_PORT header in client Request (#1450)

Having the local address/port is useful if the server is bound to
all interfaces, e.g. to serve different content for developers
on localhost only.
Ingo Bauersachs 3 years ago
parent
commit
8f32271e8c
3 changed files with 63 additions and 8 deletions
  1. 38 6
      httplib.h
  2. 5 2
      test/fuzzing/server_fuzzer.cc
  3. 20 0
      test/test.cc

+ 38 - 6
httplib.h

@@ -413,6 +413,8 @@ struct Request {
 
   std::string remote_addr;
   int remote_port = -1;
+  std::string local_addr;
+  int local_port = -1;
 
   // for server
   std::string version;
@@ -514,6 +516,7 @@ public:
   virtual ssize_t read(char *ptr, size_t size) = 0;
   virtual ssize_t write(const char *ptr, size_t size) = 0;
   virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0;
+  virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0;
   virtual socket_t socket() const = 0;
 
   template <typename... Args>
@@ -1778,6 +1781,7 @@ public:
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
   void get_remote_ip_and_port(std::string &ip, int &port) const override;
+  void get_local_ip_and_port(std::string &ip, int &port) const override;
   socket_t socket() const override;
 
   const std::string &get_buffer() const;
@@ -2446,6 +2450,7 @@ public:
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
   void get_remote_ip_and_port(std::string &ip, int &port) const override;
+  void get_local_ip_and_port(std::string &ip, int &port) const override;
   socket_t socket() const override;
 
 private:
@@ -2475,6 +2480,7 @@ public:
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
   void get_remote_ip_and_port(std::string &ip, int &port) const override;
+  void get_local_ip_and_port(std::string &ip, int &port) const override;
   socket_t socket() const override;
 
 private:
@@ -2843,9 +2849,9 @@ inline socket_t create_client_socket(
   return sock;
 }
 
-inline bool get_remote_ip_and_port(const struct sockaddr_storage &addr,
-                                   socklen_t addr_len, std::string &ip,
-                                   int &port) {
+inline bool get_ip_and_port(const struct sockaddr_storage &addr,
+                            socklen_t addr_len, std::string &ip,
+                            int &port) {
   if (addr.ss_family == AF_INET) {
     port = ntohs(reinterpret_cast<const struct sockaddr_in *>(&addr)->sin_port);
   } else if (addr.ss_family == AF_INET6) {
@@ -2866,6 +2872,15 @@ inline bool get_remote_ip_and_port(const struct sockaddr_storage &addr,
   return true;
 }
 
+inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) {
+  struct sockaddr_storage addr;
+  socklen_t addr_len = sizeof(addr);
+  if (!getsockname(sock, reinterpret_cast<struct sockaddr *>(&addr),
+                   &addr_len)) {
+    get_ip_and_port(addr, addr_len, ip, port);
+  }
+}
+
 inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) {
   struct sockaddr_storage addr;
   socklen_t addr_len = sizeof(addr);
@@ -2890,7 +2905,7 @@ inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) {
       return;
     }
 #endif
-    get_remote_ip_and_port(addr, addr_len, ip, port);
+    get_ip_and_port(addr, addr_len, ip, port);
   }
 }
 
@@ -4517,8 +4532,8 @@ inline void hosted_at(const std::string &hostname,
         *reinterpret_cast<struct sockaddr_storage *>(rp->ai_addr);
     std::string ip;
     int dummy = -1;
-    if (detail::get_remote_ip_and_port(addr, sizeof(struct sockaddr_storage),
-                                       ip, dummy)) {
+    if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage),
+                                ip, dummy)) {
       addrs.push_back(ip);
     }
   }
@@ -4808,6 +4823,11 @@ inline void SocketStream::get_remote_ip_and_port(std::string &ip,
   return detail::get_remote_ip_and_port(sock_, ip, port);
 }
 
+inline void SocketStream::get_local_ip_and_port(std::string &ip,
+                                                int &port) const {
+  return detail::get_local_ip_and_port(sock_, ip, port);
+}
+
 inline socket_t SocketStream::socket() const { return sock_; }
 
 // Buffer stream implementation
@@ -4833,6 +4853,9 @@ inline ssize_t BufferStream::write(const char *ptr, size_t size) {
 inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/,
                                                  int & /*port*/) const {}
 
+inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/,
+                                                int & /*port*/) const {}
+
 inline socket_t BufferStream::socket() const { return 0; }
 
 inline const std::string &BufferStream::get_buffer() const { return buffer; }
@@ -5812,6 +5835,10 @@ Server::process_request(Stream &strm, bool close_connection,
   req.set_header("REMOTE_ADDR", req.remote_addr);
   req.set_header("REMOTE_PORT", std::to_string(req.remote_port));
 
+  strm.get_local_ip_and_port(req.local_addr, req.local_port);
+  req.set_header("LOCAL_ADDR", req.local_addr);
+  req.set_header("LOCAL_PORT", std::to_string(req.local_port));
+
   if (req.has_header("Range")) {
     const auto &range_header_value = req.get_header_value("Range");
     if (!detail::parse_range_header(range_header_value, req.ranges)) {
@@ -7409,6 +7436,11 @@ inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip,
   detail::get_remote_ip_and_port(sock_, ip, port);
 }
 
+inline void SSLSocketStream::get_local_ip_and_port(std::string &ip,
+                                                    int &port) const {
+  detail::get_local_ip_and_port(sock_, ip, port);
+}
+
 inline socket_t SSLSocketStream::socket() const { return sock_; }
 
 static SSLInit sslinit_;

+ 5 - 2
test/fuzzing/server_fuzzer.cc

@@ -22,8 +22,6 @@ public:
 
   ssize_t write(const std::string &s) { return write(s.data(), s.size()); }
 
-  std::string get_remote_addr() const { return ""; }
-
   bool is_readable() const override { return true; }
 
   bool is_writable() const override { return true; }
@@ -33,6 +31,11 @@ public:
     port = 8080;
   }
 
+  void get_local_ip_and_port(std::string &ip, int &port) const override {
+    ip = "127.0.0.1";
+    port = 8080;
+  }
+
   socket_t socket() const override { return 0; }
 
 private:

+ 20 - 0
test/test.cc

@@ -1521,6 +1521,17 @@ protected:
                          std::stoi(req.get_header_value("REMOTE_PORT")));
                res.set_content(remote_addr.c_str(), "text/plain");
              })
+        .Get("/local_addr",
+             [&](const Request &req, Response &res) {
+               EXPECT_TRUE(req.has_header("LOCAL_PORT"));
+               EXPECT_TRUE(req.has_header("LOCAL_ADDR"));
+               auto local_addr = req.get_header_value("LOCAL_ADDR");
+               auto local_port = req.get_header_value("LOCAL_PORT");
+               EXPECT_EQ(req.local_addr, local_addr);
+               EXPECT_EQ(req.local_port, std::stoi(local_port));
+               res.set_content(local_addr.append(":").append(local_port),
+                               "text/plain");
+             })
         .Get("/endwith%",
              [&](const Request & /*req*/, Response &res) {
                res.set_content("Hello World!", "text/plain");
@@ -2810,6 +2821,15 @@ TEST_F(ServerTest, GetMethodRemoteAddr) {
   EXPECT_TRUE(res->body == "::1" || res->body == "127.0.0.1");
 }
 
+TEST_F(ServerTest, GetMethodLocalAddr) {
+  auto res = cli_.Get("/local_addr");
+  ASSERT_TRUE(res);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
+  EXPECT_TRUE(res->body == std::string("::1:").append(to_string(PORT)) ||
+              res->body == std::string("127.0.0.1:").append(to_string(PORT)));
+}
+
 TEST_F(ServerTest, HTTPResponseSplitting) {
   auto res = cli_.Get("/http_response_splitting");
   ASSERT_TRUE(res);