Browse Source

UNIX domain socket support (#1346)

* Add support UNIX domain socket

* `set_address_family(AF_UNIX)` is required

* add unittest for UNIX domain socket

* add support UNIX domain socket with abstract address

Abstract address of AF_UNIX begins with null(0x00) which can't be
delivered via .c_str() method.

* add unittest for UNIX domain socket with abstract address

Co-authored-by: Changbin Park <[email protected]>
Changbin Park 3 years ago
parent
commit
362d064afa
2 changed files with 95 additions and 2 deletions
  1. 27 2
      httplib.h
  2. 68 0
      test/test.cc

+ 27 - 2
httplib.h

@@ -183,6 +183,7 @@ using socket_t = SOCKET;
 #include <pthread.h>
 #include <sys/select.h>
 #include <sys/socket.h>
+#include <sys/un.h>
 #include <unistd.h>
 
 using socket_t = int;
@@ -2570,6 +2571,30 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port,
     hints.ai_flags = socket_flags;
   }
 
+#ifndef _WIN32
+  if (hints.ai_family == AF_UNIX) {
+    const auto addrlen = host.length();
+    if (addrlen > sizeof(sockaddr_un::sun_path)) return INVALID_SOCKET;
+
+    auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol);
+    if (sock != INVALID_SOCKET) {
+      sockaddr_un addr;
+      addr.sun_family = AF_UNIX;
+      std::copy(host.begin(), host.end(), addr.sun_path);
+
+      hints.ai_addr = reinterpret_cast<sockaddr*>(&addr);
+      hints.ai_addrlen = static_cast<socklen_t>(
+          sizeof(addr) - sizeof(addr.sun_path) + addrlen);
+
+      if (!bind_or_connect(sock, hints)) {
+        close_socket(sock);
+        sock = INVALID_SOCKET;
+      }
+    }
+    return sock;
+  }
+#endif
+
   auto service = std::to_string(port);
 
   if (getaddrinfo(node, service.c_str(), &hints, &result)) {
@@ -7858,12 +7883,12 @@ inline Client::Client(const std::string &scheme_host_port,
 
     if (is_ssl) {
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-      cli_ = detail::make_unique<SSLClient>(host.c_str(), port,
+      cli_ = detail::make_unique<SSLClient>(host, port,
                                             client_cert_path, client_key_path);
       is_ssl_ = is_ssl;
 #endif
     } else {
-      cli_ = detail::make_unique<ClientImpl>(host.c_str(), port,
+      cli_ = detail::make_unique<ClientImpl>(host, port,
                                              client_cert_path, client_key_path);
     }
   } else {

+ 68 - 0
test/test.cc

@@ -5064,3 +5064,71 @@ TEST(MultipartFormDataTest, WithPreamble) {
 
 #endif
 
+#ifndef _WIN32
+class UnixSocketTest : public ::testing::Test {
+protected:
+  void TearDown() override {
+    std::remove(pathname_.c_str());
+  }
+
+  void client_GET(const std::string &addr) {
+    httplib::Client cli{addr};
+    cli.set_address_family(AF_UNIX);
+    ASSERT_TRUE(cli.is_valid());
+
+    const auto &result = cli.Get(pattern_);
+    ASSERT_TRUE(result) << "error: " << result.error();
+
+    const auto &resp = result.value();
+    EXPECT_EQ(resp.status, 200);
+    EXPECT_EQ(resp.body, content_);
+  }
+
+  const std::string pathname_ {"./httplib-server.sock"};
+  const std::string pattern_ {"/hi"};
+  const std::string content_ {"Hello World!"};
+};
+
+TEST_F(UnixSocketTest, pathname) {
+  httplib::Server svr;
+  svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) {
+    res.set_content(content_, "text/plain");
+  });
+
+  std::thread t {[&] {
+    ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80)); }};
+  while (!svr.is_running()) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  }
+  ASSERT_TRUE(svr.is_running());
+
+  client_GET(pathname_);
+
+  svr.stop();
+  t.join();
+}
+
+#ifdef __linux__
+TEST_F(UnixSocketTest, abstract) {
+  constexpr char svr_path[] {"\x00httplib-server.sock"};
+  const std::string abstract_addr {svr_path, sizeof(svr_path) - 1};
+
+  httplib::Server svr;
+  svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) {
+    res.set_content(content_, "text/plain");
+  });
+
+  std::thread t {[&] {
+    ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(abstract_addr, 80)); }};
+  while (!svr.is_running()) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  }
+  ASSERT_TRUE(svr.is_running());
+
+  client_GET(abstract_addr);
+
+  svr.stop();
+  t.join();
+}
+#endif
+#endif  // #ifndef _WIN32