Browse Source

CONNECT method support on client

yhirose 6 years ago
parent
commit
eb4fcb5003
2 changed files with 291 additions and 108 deletions
  1. 261 106
      httplib.h
  2. 30 2
      test/test.cc

+ 261 - 106
httplib.h

@@ -505,12 +505,13 @@ public:
 };
 };
 #endif
 #endif
 
 
+using Logger = std::function<void(const Request &, const Response &)>;
+
 class Server {
 class Server {
 public:
 public:
   using Handler = std::function<void(const Request &, Response &)>;
   using Handler = std::function<void(const Request &, Response &)>;
   using HandlerWithContentReader = std::function<void(
   using HandlerWithContentReader = std::function<void(
       const Request &, Response &, const ContentReader &content_reader)>;
       const Request &, Response &, const ContentReader &content_reader)>;
-  using Logger = std::function<void(const Request &, const Response &)>;
 
 
   Server();
   Server();
 
 
@@ -614,7 +615,9 @@ private:
 
 
 class Client {
 class Client {
 public:
 public:
-  explicit Client(const char *host, int port = 80);
+  explicit Client(const std::string &host, int port = 80,
+                  const std::string &client_cert_path = std::string(),
+                  const std::string &client_key_path = std::string());
 
 
   virtual ~Client();
   virtual ~Client();
 
 
@@ -736,11 +739,13 @@ public:
 
 
   void set_timeout_sec(time_t timeout_sec);
   void set_timeout_sec(time_t timeout_sec);
 
 
+  void set_read_timeout(time_t sec, time_t usec);
+
   void set_keep_alive_max_count(size_t count);
   void set_keep_alive_max_count(size_t count);
 
 
-  void set_read_timeout(time_t sec, time_t usec);
+  void set_basic_auth(const char *username, const char *password);
 
 
-  void set_auth(const char *username, const char *password);
+  void set_digest_auth(const char *username, const char *password);
 
 
   void set_follow_location(bool on);
   void set_follow_location(bool on);
 
 
@@ -748,6 +753,14 @@ public:
 
 
   void set_interface(const char *intf);
   void set_interface(const char *intf);
 
 
+  void set_proxy(const char *host, int port);
+
+  void set_proxy_basic_auth(const char *username, const char *password);
+
+  void set_proxy_digest_auth(const char *username, const char *password);
+
+  void set_logger(Logger logger);
+
 protected:
 protected:
   bool process_request(Stream &strm, const Request &req, Response &res,
   bool process_request(Stream &strm, const Request &req, Response &res,
                        bool last_connection, bool &connection_close);
                        bool last_connection, bool &connection_close);
@@ -756,17 +769,60 @@ protected:
   const int port_;
   const int port_;
   const std::string host_and_port_;
   const std::string host_and_port_;
 
 
-  // Options
+  // Settings
+  std::string client_cert_path_;
+  std::string client_key_path_;
+
   time_t timeout_sec_ = 300;
   time_t timeout_sec_ = 300;
-  size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
   time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND;
   time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND;
   time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND;
   time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND;
-  std::string username_;
-  std::string password_;
+
+  size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
+
+  std::string basic_auth_username_;
+  std::string basic_auth_password_;
+  std::string digest_auth_username_;
+  std::string digest_auth_password_;
+
   bool follow_location_ = false;
   bool follow_location_ = false;
+
   bool compress_ = false;
   bool compress_ = false;
+
   std::string interface_;
   std::string interface_;
 
 
+  std::string proxy_host_;
+  int proxy_port_;
+
+  std::string proxy_basic_auth_username_;
+  std::string proxy_basic_auth_password_;
+  std::string proxy_digest_auth_username_;
+  std::string proxy_digest_auth_password_;
+
+  Logger logger_;
+
+  void copy_settings(const Client &rhs) {
+    client_cert_path_ = rhs.client_cert_path_;
+    client_key_path_ = rhs.client_key_path_;
+    timeout_sec_ = rhs.timeout_sec_;
+    read_timeout_sec_ = rhs.read_timeout_sec_;
+    read_timeout_usec_ = rhs.read_timeout_usec_;
+    keep_alive_max_count_ = rhs.keep_alive_max_count_;
+    basic_auth_username_ = rhs.basic_auth_username_;
+    basic_auth_password_ = rhs.basic_auth_password_;
+    digest_auth_username_ = rhs.digest_auth_username_;
+    digest_auth_password_ = rhs.digest_auth_password_;
+    follow_location_ = rhs.follow_location_;
+    compress_ = rhs.compress_;
+    interface_ = rhs.interface_;
+    proxy_host_ = rhs.proxy_host_;
+    proxy_port_ = rhs.proxy_port_;
+    proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_;
+    proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_;
+    proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_;
+    proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_;
+    logger_ = rhs.logger_;
+  }
+
 private:
 private:
   socket_t create_client_socket() const;
   socket_t create_client_socket() const;
   bool read_response_line(Stream &strm, Response &res);
   bool read_response_line(Stream &strm, Response &res);
@@ -856,9 +912,9 @@ private:
 
 
 class SSLClient : public Client {
 class SSLClient : public Client {
 public:
 public:
-  SSLClient(const char *host, int port = 443,
-            const char *client_cert_path = nullptr,
-            const char *client_key_path = nullptr);
+  SSLClient(const std::string &host, int port = 443,
+            const std::string &client_cert_path = std::string(),
+            const std::string &client_key_path = std::string());
 
 
   virtual ~SSLClient();
   virtual ~SSLClient();
 
 
@@ -866,6 +922,7 @@ public:
 
 
   void set_ca_cert_path(const char *ca_ceert_file_path,
   void set_ca_cert_path(const char *ca_ceert_file_path,
                         const char *ca_cert_dir_path = nullptr);
                         const char *ca_cert_dir_path = nullptr);
+
   void enable_server_certificate_verification(bool enabled);
   void enable_server_certificate_verification(bool enabled);
 
 
   long get_openssl_verify_result() const;
   long get_openssl_verify_result() const;
@@ -889,7 +946,6 @@ private:
   std::mutex ctx_mutex_;
   std::mutex ctx_mutex_;
   std::vector<std::string> host_components_;
   std::vector<std::string> host_components_;
 
 
-  // Options
   std::string ca_cert_file_path_;
   std::string ca_cert_file_path_;
   std::string ca_cert_dir_path_;
   std::string ca_cert_dir_path_;
   bool server_certificate_verification_ = false;
   bool server_certificate_verification_ = false;
@@ -1234,10 +1290,9 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) {
 }
 }
 
 
 template <typename T>
 template <typename T>
-inline bool process_and_close_socket(bool is_client_request, socket_t sock,
-                                     size_t keep_alive_max_count,
-                                     time_t read_timeout_sec,
-                                     time_t read_timeout_usec, T callback) {
+inline bool process_socket(bool is_client_request, socket_t sock,
+                           size_t keep_alive_max_count, time_t read_timeout_sec,
+                           time_t read_timeout_usec, T callback) {
   assert(keep_alive_max_count > 0);
   assert(keep_alive_max_count > 0);
 
 
   bool ret = false;
   bool ret = false;
@@ -1263,6 +1318,16 @@ inline bool process_and_close_socket(bool is_client_request, socket_t sock,
     ret = callback(strm, true, dummy_connection_close);
     ret = callback(strm, true, dummy_connection_close);
   }
   }
 
 
+  return ret;
+}
+
+template <typename T>
+inline bool process_and_close_socket(bool is_client_request, socket_t sock,
+                                     size_t keep_alive_max_count,
+                                     time_t read_timeout_sec,
+                                     time_t read_timeout_usec, T callback) {
+  auto ret = process_socket(is_client_request, sock, keep_alive_max_count,
+                            read_timeout_sec, read_timeout_usec, callback);
   close_socket(sock);
   close_socket(sock);
   return ret;
   return ret;
 }
 }
@@ -1309,8 +1374,8 @@ socket_t create_socket(const char *host, int port, Fn fn,
     auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol,
     auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol,
                            nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT);
                            nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT);
     /**
     /**
-     * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 and above
-     * the socket creation fails on older Windows Systems.
+     * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1
+     * and above the socket creation fails on older Windows Systems.
      *
      *
      * Let's try to create a socket the old way in this case.
      * Let's try to create a socket the old way in this case.
      *
      *
@@ -1318,11 +1383,12 @@ socket_t create_socket(const char *host, int port, Fn fn,
      * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa
      * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa
      *
      *
      * WSA_FLAG_NO_HANDLE_INHERIT:
      * WSA_FLAG_NO_HANDLE_INHERIT:
-     * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with SP1, and later
+     * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with
+     * SP1, and later
      *
      *
      */
      */
     if (sock == INVALID_SOCKET) {
     if (sock == INVALID_SOCKET) {
-        sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
+      sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
     }
     }
 #else
 #else
     auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
     auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
@@ -1880,17 +1946,12 @@ write_content_chunked(Stream &strm,
 template <typename T>
 template <typename T>
 inline bool redirect(T &cli, const Request &req, Response &res,
 inline bool redirect(T &cli, const Request &req, Response &res,
                      const std::string &path) {
                      const std::string &path) {
-  Request new_req;
-  new_req.method = req.method;
+  Request new_req = req;
   new_req.path = path;
   new_req.path = path;
-  new_req.headers = req.headers;
-  new_req.body = req.body;
-  new_req.redirect_count = req.redirect_count - 1;
-  new_req.response_handler = req.response_handler;
-  new_req.content_receiver = req.content_receiver;
-  new_req.progress = req.progress;
+  new_req.redirect_count -= 1;
 
 
   Response new_res;
   Response new_res;
+
   auto ret = cli.send(new_req, new_res);
   auto ret = cli.send(new_req, new_res);
   if (ret) { res = new_res; }
   if (ret) { res = new_res; }
   return ret;
   return ret;
@@ -2416,16 +2477,17 @@ inline std::pair<std::string, std::string> make_range_header(Ranges ranges) {
 
 
 inline std::pair<std::string, std::string>
 inline std::pair<std::string, std::string>
 make_basic_authentication_header(const std::string &username,
 make_basic_authentication_header(const std::string &username,
-                                 const std::string &password) {
+                                 const std::string &password, bool proxy = false) {
   auto field = "Basic " + detail::base64_encode(username + ":" + password);
   auto field = "Basic " + detail::base64_encode(username + ":" + password);
-  return std::make_pair("Authorization", field);
+  auto key = proxy ? "Proxy-Authorization" : "Authorization";
+  return std::make_pair(key, field);
 }
 }
 
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 inline std::pair<std::string, std::string> make_digest_authentication_header(
 inline std::pair<std::string, std::string> make_digest_authentication_header(
     const Request &req, const std::map<std::string, std::string> &auth,
     const Request &req, const std::map<std::string, std::string> &auth,
     size_t cnonce_count, const std::string &cnonce, const std::string &username,
     size_t cnonce_count, const std::string &cnonce, const std::string &username,
-    const std::string &password) {
+    const std::string &password, bool proxy = false) {
   using namespace std;
   using namespace std;
 
 
   string nc;
   string nc;
@@ -2442,10 +2504,11 @@ inline std::pair<std::string, std::string> make_digest_authentication_header(
     qop = "auth";
     qop = "auth";
   }
   }
 
 
+  std::string algo = "MD5";
+  if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); }
+
   string response;
   string response;
   {
   {
-    auto algo = auth.at("algorithm");
-
     auto H = algo == "SHA-256"
     auto H = algo == "SHA-256"
                  ? detail::SHA_256
                  ? detail::SHA_256
                  : algo == "SHA-512" ? detail::SHA_512 : detail::MD5;
                  : algo == "SHA-512" ? detail::SHA_512 : detail::MD5;
@@ -2461,25 +2524,26 @@ inline std::pair<std::string, std::string> make_digest_authentication_header(
 
 
   auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") +
   auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") +
                "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path +
                "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path +
-               "\", algorithm=" + auth.at("algorithm") + ", qop=" + qop +
-               ", nc=\"" + nc + "\", cnonce=\"" + cnonce + "\", response=\"" +
-               response + "\"";
+               "\", algorithm=" + algo + ", qop=" + qop + ", nc=\"" + nc +
+               "\", cnonce=\"" + cnonce + "\", response=\"" + response + "\"";
 
 
-  return make_pair("Authorization", field);
+  auto key = proxy ? "Proxy-Authorization" : "Authorization";
+  return std::make_pair(key, field);
 }
 }
 #endif
 #endif
 
 
-inline int
-parse_www_authenticate(const httplib::Response &res,
-                       std::map<std::string, std::string> &digest_auth) {
-  if (res.has_header("WWW-Authenticate")) {
+inline bool parse_www_authenticate(const httplib::Response &res,
+                                   std::map<std::string, std::string> &auth,
+                                   bool proxy) {
+  auto key = proxy ? "Proxy-Authenticate" : "WWW-Authenticate";
+  if (res.has_header(key)) {
     static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~");
     static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~");
-    auto s = res.get_header_value("WWW-Authenticate");
+    auto s = res.get_header_value(key);
     auto pos = s.find(' ');
     auto pos = s.find(' ');
     if (pos != std::string::npos) {
     if (pos != std::string::npos) {
       auto type = s.substr(0, pos);
       auto type = s.substr(0, pos);
       if (type == "Basic") {
       if (type == "Basic") {
-        return 1;
+        return false;
       } else if (type == "Digest") {
       } else if (type == "Digest") {
         s = s.substr(pos + 1);
         s = s.substr(pos + 1);
         auto beg = std::sregex_iterator(s.begin(), s.end(), re);
         auto beg = std::sregex_iterator(s.begin(), s.end(), re);
@@ -2488,13 +2552,13 @@ parse_www_authenticate(const httplib::Response &res,
           auto key = s.substr(m.position(1), m.length(1));
           auto key = s.substr(m.position(1), m.length(1));
           auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2))
           auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2))
                                      : s.substr(m.position(3), m.length(3));
                                      : s.substr(m.position(3), m.length(3));
-          digest_auth[key] = val;
+          auth[key] = val;
         }
         }
-        return 2;
+        return true;
       }
       }
     }
     }
   }
   }
-  return 0;
+  return false;
 }
 }
 
 
 // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240
 // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240
@@ -3377,15 +3441,22 @@ inline bool Server::process_and_close_socket(socket_t sock) {
 }
 }
 
 
 // HTTP client implementation
 // HTTP client implementation
-inline Client::Client(const char *host, int port)
+inline Client::Client(const std::string &host, int port,
+                      const std::string &client_cert_path,
+                      const std::string &client_key_path)
     : host_(host), port_(port),
     : host_(host), port_(port),
-      host_and_port_(host_ + ":" + std::to_string(port_)) {}
+      host_and_port_(host_ + ":" + std::to_string(port_)),
+      client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
 
 
 inline Client::~Client() {}
 inline Client::~Client() {}
 
 
 inline bool Client::is_valid() const { return true; }
 inline bool Client::is_valid() const { return true; }
 
 
 inline socket_t Client::create_client_socket() const {
 inline socket_t Client::create_client_socket() const {
+  if (!proxy_host_.empty()) {
+    return detail::create_client_socket(proxy_host_.c_str(), proxy_port_,
+                                        timeout_sec_, interface_);
+  }
   return detail::create_client_socket(host_.c_str(), port_, timeout_sec_,
   return detail::create_client_socket(host_.c_str(), port_, timeout_sec_,
                                       interface_);
                                       interface_);
 }
 }
@@ -3414,54 +3485,97 @@ inline bool Client::send(const Request &req, Response &res) {
   auto sock = create_client_socket();
   auto sock = create_client_socket();
   if (sock == INVALID_SOCKET) { return false; }
   if (sock == INVALID_SOCKET) { return false; }
 
 
-  auto ret = process_and_close_socket(
-      sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) {
-        return process_request(strm, req, res, last_connection,
-                               connection_close);
-      });
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+  // CONNECT
+  if (is_ssl() && !proxy_host_.empty()) {
+    Response res2;
+    if (!detail::process_socket(
+            true, sock, 1, read_timeout_sec_, read_timeout_usec_,
+            [&](Stream &strm, bool /*last_connection*/,
+                bool &connection_close) {
+              Request req2;
+              req2.method = "CONNECT";
+              req2.path = host_and_port_;
+              return process_request(strm, req2, res2, false, connection_close);
+            })) {
+      return false;
+    }
 
 
-  if (ret && follow_location_ && (300 < res.status && res.status < 400)) {
-    ret = redirect(req, res);
+    if (res2.status == 407 && !proxy_digest_auth_username_.empty() &&
+        !proxy_digest_auth_password_.empty()) {
+      std::map<std::string, std::string> auth;
+      if (parse_www_authenticate(res2, auth, true)) {
+        detail::close_socket(sock);
+        sock = create_client_socket();
+        if (sock == INVALID_SOCKET) { return false; }
+
+        Response res2;
+        if (!detail::process_socket(
+                true, sock, 1, read_timeout_sec_, read_timeout_usec_,
+                [&](Stream &strm, bool /*last_connection*/,
+                    bool &connection_close) {
+                  Request req2;
+                  req2.method = "CONNECT";
+                  req2.path = host_and_port_;
+                  req2.headers.insert(make_digest_authentication_header(
+                      req2, auth, 1, random_string(10),
+                      proxy_digest_auth_username_, proxy_digest_auth_password_,
+                      true));
+                  return process_request(strm, req2, res2, false,
+                                         connection_close);
+                })) {
+          return false;
+        }
+      }
+    }
   }
   }
+#endif
 
 
-  if (ret && !username_.empty() && !password_.empty() && res.status == 401) {
-    int type;
-    std::map<std::string, std::string> digest_auth;
+  if (!process_and_close_socket(
+          sock, 1,
+          [&](Stream &strm, bool last_connection, bool &connection_close) {
+            if (!is_ssl() && !proxy_host_.empty()) {
+              auto req2 = req;
+              req2.path = "http://" + host_and_port_ + req.path;
+              return process_request(strm, req2, res, last_connection,
+                                     connection_close);
+            }
+            return process_request(strm, req, res, last_connection,
+                                   connection_close);
+          })) {
+    return false;
+  }
 
 
-    if ((type = parse_www_authenticate(res, digest_auth)) > 0) {
-      std::pair<std::string, std::string> header;
+  if (300 < res.status && res.status < 400 && follow_location_) {
+    return redirect(req, res);
+  }
 
 
-      if (type == 1) {
-        header = make_basic_authentication_header(username_, password_);
-      } else if (type == 2) {
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
-        size_t cnonce_count = 1;
-        auto cnonce = random_string(10);
-
-        header = make_digest_authentication_header(
-            req, digest_auth, cnonce_count, cnonce, username_, password_);
-#endif
+  if (res.status == 401 || res.status == 407) {
+    auto is_proxy = res.status == 407;
+    const auto &username =
+        is_proxy ? proxy_digest_auth_username_ : digest_auth_username_;
+    const auto &password =
+        is_proxy ? proxy_digest_auth_password_ : digest_auth_password_;
+
+    if (!username.empty() && !password.empty()) {
+      std::map<std::string, std::string> auth;
+      if (parse_www_authenticate(res, auth, is_proxy)) {
+        Request new_req = req;
+        new_req.headers.insert(make_digest_authentication_header(
+            req, auth, 1, random_string(10), username, password, is_proxy));
+
+        Response new_res;
+
+        auto ret = send(new_req, new_res);
+        if (ret) { res = new_res; }
+        return ret;
       }
       }
-
-      Request new_req;
-      new_req.method = req.method;
-      new_req.path = req.path;
-      new_req.headers = req.headers;
-      new_req.body = req.body;
-      new_req.response_handler = req.response_handler;
-      new_req.content_receiver = req.content_receiver;
-      new_req.progress = req.progress;
-
-      new_req.headers.insert(header);
-
-      Response new_res;
-      auto ret = send(new_req, new_res);
-      if (ret) { res = new_res; }
-      return ret;
     }
     }
   }
   }
+#endif
 
 
-  return ret;
+  return true;
 }
 }
 
 
 inline bool Client::send(const std::vector<Request> &requests,
 inline bool Client::send(const std::vector<Request> &requests,
@@ -3511,28 +3625,30 @@ inline bool Client::redirect(const Request &req, Response &res) {
   std::smatch m;
   std::smatch m;
   if (!regex_match(location, m, re)) { return false; }
   if (!regex_match(location, m, re)) { return false; }
 
 
+  auto scheme = is_ssl() ? "https" : "http";
+
   auto next_scheme = m[1].str();
   auto next_scheme = m[1].str();
   auto next_host = m[2].str();
   auto next_host = m[2].str();
   auto next_path = m[3].str();
   auto next_path = m[3].str();
+  if (next_scheme.empty()) { next_scheme = scheme; }
+  if (next_scheme.empty()) { next_scheme = scheme; }
   if (next_host.empty()) { next_host = host_; }
   if (next_host.empty()) { next_host = host_; }
   if (next_path.empty()) { next_path = "/"; }
   if (next_path.empty()) { next_path = "/"; }
 
 
-  auto scheme = is_ssl() ? "https" : "http";
-
   if (next_scheme == scheme && next_host == host_) {
   if (next_scheme == scheme && next_host == host_) {
     return detail::redirect(*this, req, res, next_path);
     return detail::redirect(*this, req, res, next_path);
   } else {
   } else {
     if (next_scheme == "https") {
     if (next_scheme == "https") {
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
       SSLClient cli(next_host.c_str());
       SSLClient cli(next_host.c_str());
-      cli.set_follow_location(true);
+      cli.copy_settings(*this);
       return detail::redirect(cli, req, res, next_path);
       return detail::redirect(cli, req, res, next_path);
 #else
 #else
       return false;
       return false;
 #endif
 #endif
     } else {
     } else {
       Client cli(next_host.c_str());
       Client cli(next_host.c_str());
-      cli.set_follow_location(true);
+      cli.copy_settings(*this);
       return detail::redirect(cli, req, res, next_path);
       return detail::redirect(cli, req, res, next_path);
     }
     }
   }
   }
@@ -3544,7 +3660,7 @@ inline bool Client::write_request(Stream &strm, const Request &req,
 
 
   // Request line
   // Request line
   const static std::regex re(
   const static std::regex re(
-      R"(^([^:/?#]+://[^/?#]*)?([^?#]*(?:\?[^#]*)?(?:#.*)?))");
+      R"(^((?:[^:/?#]+://)?(?:[^/?#]*)?)?([^?#]*(?:\?[^#]*)?(?:#.*)?))");
 
 
   std::smatch m;
   std::smatch m;
   if (!regex_match(req.path, m, re)) { return false; }
   if (!regex_match(req.path, m, re)) { return false; }
@@ -3597,6 +3713,17 @@ inline bool Client::write_request(Stream &strm, const Request &req,
     }
     }
   }
   }
 
 
+  if (!basic_auth_username_.empty() && !basic_auth_password_.empty()) {
+    headers.insert(make_basic_authentication_header(
+        basic_auth_username_, basic_auth_password_, false));
+  }
+
+  if (!proxy_basic_auth_username_.empty() &&
+      !proxy_basic_auth_password_.empty()) {
+    headers.insert(make_basic_authentication_header(
+        proxy_basic_auth_username_, proxy_basic_auth_password_, true));
+  }
+
   detail::write_headers(bstrm, req, headers);
   detail::write_headers(bstrm, req, headers);
 
 
   // Flush buffer
   // Flush buffer
@@ -3689,7 +3816,7 @@ inline bool Client::process_request(Stream &strm, const Request &req,
   }
   }
 
 
   // Body
   // Body
-  if (req.method != "HEAD") {
+  if (req.method != "HEAD" && req.method != "CONNECT") {
     ContentReceiver out = [&](const char *buf, size_t n) {
     ContentReceiver out = [&](const char *buf, size_t n) {
       if (res.body.size() + n > res.body.max_size()) { return false; }
       if (res.body.size() + n > res.body.max_size()) { return false; }
       res.body.append(buf, n);
       res.body.append(buf, n);
@@ -3709,6 +3836,9 @@ inline bool Client::process_request(Stream &strm, const Request &req,
     }
     }
   }
   }
 
 
+  // Log
+  if (logger_) { logger_(req, res); }
+
   return true;
   return true;
 }
 }
 
 
@@ -4010,18 +4140,24 @@ inline void Client::set_timeout_sec(time_t timeout_sec) {
   timeout_sec_ = timeout_sec;
   timeout_sec_ = timeout_sec;
 }
 }
 
 
+inline void Client::set_read_timeout(time_t sec, time_t usec) {
+  read_timeout_sec_ = sec;
+  read_timeout_usec_ = usec;
+}
+
 inline void Client::set_keep_alive_max_count(size_t count) {
 inline void Client::set_keep_alive_max_count(size_t count) {
   keep_alive_max_count_ = count;
   keep_alive_max_count_ = count;
 }
 }
 
 
-inline void Client::set_read_timeout(time_t sec, time_t usec) {
-  read_timeout_sec_ = sec;
-  read_timeout_usec_ = usec;
+inline void Client::set_basic_auth(const char *username, const char *password) {
+  basic_auth_username_ = username;
+  basic_auth_password_ = password;
 }
 }
 
 
-inline void Client::set_auth(const char *username, const char *password) {
-  username_ = username;
-  password_ = password;
+inline void Client::set_digest_auth(const char *username,
+                                    const char *password) {
+  digest_auth_username_ = username;
+  digest_auth_password_ = password;
 }
 }
 
 
 inline void Client::set_follow_location(bool on) { follow_location_ = on; }
 inline void Client::set_follow_location(bool on) { follow_location_ = on; }
@@ -4030,6 +4166,25 @@ inline void Client::set_compress(bool on) { compress_ = on; }
 
 
 inline void Client::set_interface(const char *intf) { interface_ = intf; }
 inline void Client::set_interface(const char *intf) { interface_ = intf; }
 
 
+inline void Client::set_proxy(const char *host, int port) {
+  proxy_host_ = host;
+  proxy_port_ = port;
+}
+
+inline void Client::set_proxy_basic_auth(const char *username,
+                                         const char *password) {
+  proxy_basic_auth_username_ = username;
+  proxy_basic_auth_password_ = password;
+}
+
+inline void Client::set_proxy_digest_auth(const char *username,
+                                          const char *password) {
+  proxy_digest_auth_username_ = username;
+  proxy_digest_auth_password_ = password;
+}
+
+inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); }
+
 /*
 /*
  * SSL Implementation
  * SSL Implementation
  */
  */
@@ -4249,21 +4404,21 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) {
 }
 }
 
 
 // SSL HTTP client implementation
 // SSL HTTP client implementation
-inline SSLClient::SSLClient(const char *host, int port,
-                            const char *client_cert_path,
-                            const char *client_key_path)
-    : Client(host, port) {
+inline SSLClient::SSLClient(const std::string &host, int port,
+                            const std::string &client_cert_path,
+                            const std::string &client_key_path)
+    : Client(host, port, client_cert_path, client_key_path) {
   ctx_ = SSL_CTX_new(SSLv23_client_method());
   ctx_ = SSL_CTX_new(SSLv23_client_method());
 
 
   detail::split(&host_[0], &host_[host_.size()], '.',
   detail::split(&host_[0], &host_[host_.size()], '.',
                 [&](const char *b, const char *e) {
                 [&](const char *b, const char *e) {
                   host_components_.emplace_back(std::string(b, e));
                   host_components_.emplace_back(std::string(b, e));
                 });
                 });
-  if (client_cert_path && client_key_path) {
-    if (SSL_CTX_use_certificate_file(ctx_, client_cert_path,
+  if (!client_cert_path.empty() && !client_key_path.empty()) {
+    if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(),
                                      SSL_FILETYPE_PEM) != 1 ||
                                      SSL_FILETYPE_PEM) != 1 ||
-        SSL_CTX_use_PrivateKey_file(ctx_, client_key_path, SSL_FILETYPE_PEM) !=
-            1) {
+        SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(),
+                                    SSL_FILETYPE_PEM) != 1) {
       SSL_CTX_free(ctx_);
       SSL_CTX_free(ctx_);
       ctx_ = nullptr;
       ctx_ = nullptr;
     }
     }

+ 30 - 2
test/test.cc

@@ -474,13 +474,27 @@ TEST(BaseAuthTest, FromHTTPWatch) {
   }
   }
 
 
   {
   {
-    cli.set_auth("hello", "world");
+    cli.set_basic_auth("hello", "world");
     auto res = cli.Get("/basic-auth/hello/world");
     auto res = cli.Get("/basic-auth/hello/world");
     ASSERT_TRUE(res != nullptr);
     ASSERT_TRUE(res != nullptr);
     EXPECT_EQ(res->body,
     EXPECT_EQ(res->body,
               "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
               "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
     EXPECT_EQ(200, res->status);
     EXPECT_EQ(200, res->status);
   }
   }
+
+  {
+    cli.set_basic_auth("hello", "bad");
+    auto res = cli.Get("/basic-auth/hello/world");
+    ASSERT_TRUE(res != nullptr);
+    EXPECT_EQ(401, res->status);
+  }
+
+  {
+    cli.set_basic_auth("bad", "world");
+    auto res = cli.Get("/basic-auth/hello/world");
+    ASSERT_TRUE(res != nullptr);
+    EXPECT_EQ(401, res->status);
+  }
 }
 }
 
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@@ -504,7 +518,7 @@ TEST(DigestAuthTest, FromHTTPWatch) {
         "/digest-auth/auth-int/hello/world/MD5",
         "/digest-auth/auth-int/hello/world/MD5",
     };
     };
 
 
-    cli.set_auth("hello", "world");
+    cli.set_digest_auth("hello", "world");
     for (auto path : paths) {
     for (auto path : paths) {
       auto res = cli.Get(path.c_str());
       auto res = cli.Get(path.c_str());
       ASSERT_TRUE(res != nullptr);
       ASSERT_TRUE(res != nullptr);
@@ -512,6 +526,20 @@ TEST(DigestAuthTest, FromHTTPWatch) {
                 "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
                 "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
       EXPECT_EQ(200, res->status);
       EXPECT_EQ(200, res->status);
     }
     }
+
+    cli.set_digest_auth("hello", "bad");
+    for (auto path : paths) {
+      auto res = cli.Get(path.c_str());
+      ASSERT_TRUE(res != nullptr);
+      EXPECT_EQ(400, res->status);
+    }
+
+    cli.set_digest_auth("bad", "world");
+    for (auto path : paths) {
+      auto res = cli.Get(path.c_str());
+      ASSERT_TRUE(res != nullptr);
+      EXPECT_EQ(400, res->status);
+    }
   }
   }
 }
 }
 #endif
 #endif