Yuji Hirose 6 years ago
parent
commit
fd4e1b4112
4 changed files with 234 additions and 6 deletions
  1. 10 5
      README.md
  2. 1 1
      example/Makefile
  3. 181 0
      httplib.h
  4. 42 0
      test/test.cc

+ 10 - 5
README.md

@@ -324,16 +324,21 @@ std::shared_ptr<httplib::Response> res =
 
 This feature was contributed by [underscorediscovery](https://github.com/yhirose/cpp-httplib/pull/23).
 
-### Basic Authentication
+### Authentication
 
 ```cpp
 httplib::Client cli("httplib.org");
+cli.set_auth("user", "pass");
 
-auto res = cli.Get("/basic-auth/hello/world", {
-  httplib::make_basic_authentication_header("hello", "world")
-});
+// Basic
+auto res = cli.Get("/basic-auth/user/pass");
+// res->status should be 200
+// res->body should be "{\n  \"authenticated\": true, \n  \"user\": \"user\"\n}\n".
+
+// Digest
+res = cli.Get("/digest-auth/auth/user/pass/SHA-256");
 // res->status should be 200
-// res->body should be "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n".
+// res->body should be "{\n  \"authenticated\": true, \n  \"user\": \"user\"\n}\n".
 ```
 
 ### Range

+ 1 - 1
example/Makefile

@@ -33,4 +33,4 @@ pem:
 	openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem
 
 clean:
-	rm server client hello simplesvr upload redirect *.pem
+	rm server client hello simplesvr upload redirect benchmark *.pem

+ 181 - 0
httplib.h

@@ -149,9 +149,13 @@ using socket_t = int;
 
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 #include <openssl/err.h>
+#include <openssl/md5.h>
 #include <openssl/ssl.h>
 #include <openssl/x509v3.h>
 
+#include <iomanip>
+#include <sstream>
+
 // #if OPENSSL_VERSION_NUMBER < 0x1010100fL
 // #error Sorry, OpenSSL versions prior to 1.1.1 are not supported
 // #endif
@@ -756,10 +760,13 @@ public:
             std::vector<Response> &responses);
 
   void set_keep_alive_max_count(size_t count);
+
   void set_read_timeout(time_t sec, time_t usec);
 
   void follow_location(bool on);
 
+  void set_auth(const char *username, const char *password);
+
 protected:
   bool process_request(Stream &strm, const Request &req, Response &res,
                        bool last_connection, bool &connection_close);
@@ -772,6 +779,8 @@ protected:
   time_t read_timeout_sec_;
   time_t read_timeout_usec_;
   size_t follow_location_;
+  std::string username_;
+  std::string password_;
 
 private:
   socket_t create_client_socket() const;
@@ -1439,6 +1448,7 @@ inline const char *status_message(int status) {
   case 303: return "See Other";
   case 304: return "Not Modified";
   case 400: return "Bad Request";
+  case 401: return "Unauthorized";
   case 403: return "Forbidden";
   case 404: return "Not Found";
   case 413: return "Payload Too Large";
@@ -2287,6 +2297,43 @@ inline bool expect_content(const Request &req) {
   return false;
 }
 
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+template <typename CTX, typename Init, typename Update, typename Final>
+inline std::string message_digest(const std::string &s, Init init,
+                                  Update update, Final final,
+                                  size_t digest_length) {
+  using namespace std;
+
+  unsigned char md[digest_length];
+  CTX ctx;
+  init(&ctx);
+  update(&ctx, s.data(), s.size());
+  final(md, &ctx);
+
+  stringstream ss;
+  for (auto c : md) {
+    ss << setfill('0') << setw(2) << hex << (unsigned int)c;
+  }
+  return ss.str();
+}
+
+inline std::string MD5(const std::string &s) {
+  using namespace detail;
+  return message_digest<MD5_CTX>(s, MD5_Init, MD5_Update, MD5_Final,
+                                 MD5_DIGEST_LENGTH);
+}
+
+inline std::string SHA_256(const std::string &s) {
+  return message_digest<SHA256_CTX>(s, SHA256_Init, SHA256_Update, SHA256_Final,
+                                    SHA256_DIGEST_LENGTH);
+}
+
+inline std::string SHA_512(const std::string &s) {
+  return message_digest<SHA512_CTX>(s, SHA512_Init, SHA512_Update, SHA512_Final,
+                                    SHA512_DIGEST_LENGTH);
+}
+#endif
+
 #ifdef _WIN32
 class WSInit {
 public:
@@ -2324,6 +2371,98 @@ make_basic_authentication_header(const std::string &username,
   return std::make_pair("Authorization", field);
 }
 
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+inline std::pair<std::string, std::string> make_digest_authentication_header(
+    const Request &req,
+    const std::map<std::string, std::string> &auth,
+    size_t cnonce_count, const std::string &cnonce,
+    const std::string &username, const std::string &password) {
+  using namespace std;
+
+  string nc;
+  {
+    stringstream ss;
+    ss << setfill('0') << setw(8) << hex << cnonce_count;
+    nc = ss.str();
+  }
+
+  auto qop = auth.at("qop");
+  if (qop.find("auth-int") != std::string::npos) {
+    qop = "auth-int";
+  } else {
+    qop = "auth";
+  }
+
+  string response;
+  {
+    auto algo = auth.at("algorithm");
+
+    auto H = algo == "SHA-256"
+                 ? detail::SHA_256
+                 : algo == "SHA-512" ? detail::SHA_512 : detail::MD5;
+
+    auto A1 = username + ":" + auth.at("realm") + ":" + password;
+
+    auto A2 = req.method + ":" + req.path;
+    if (qop == "auth-int") {
+      A2 += ":" + H(req.body);
+    }
+
+    response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce +
+                 ":" + qop + ":" + H(A2));
+  }
+
+  auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") +
+               "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path +
+               "\", algorithm=" + auth.at("algorithm") + ", qop=" + qop + ", nc=\"" +
+               nc + "\", cnonce=\"" + cnonce + "\", response=\"" + response +
+               "\"";
+
+  return make_pair("Authorization", field);
+}
+#endif
+
+inline int parse_www_authenticate(const httplib::Response &res,
+                            std::map<std::string, std::string> &digest_auth) {
+  if (res.has_header("WWW-Authenticate")) {
+    static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*)))))~");
+    auto s = res.get_header_value("WWW-Authenticate");
+    auto pos = s.find(' ');
+    if (pos != std::string::npos) {
+      auto type = s.substr(0, pos);
+      if (type == "Basic") {
+        return 1;
+      } else if (type == "Digest") {
+        s = s.substr(pos + 1);
+        auto beg = std::sregex_iterator(s.begin(), s.end(), re);
+        for (auto i = beg; i != std::sregex_iterator(); ++i) {
+          auto m = *i;
+          auto key = s.substr(m.position(1), m.length(1));
+          auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2))
+                                     : s.substr(m.position(3), m.length(3));
+          digest_auth[key] = val;
+        }
+        return 2;
+      }
+    }
+  }
+  return 0;
+}
+
+// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240
+inline std::string random_string(size_t length) {
+  auto randchar = []() -> char {
+    const char charset[] = "0123456789"
+                           "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+                           "abcdefghijklmnopqrstuvwxyz";
+    const size_t max_index = (sizeof(charset) - 1);
+    return charset[rand() % max_index];
+  };
+  std::string str(length, 0);
+  std::generate_n(str.begin(), length, randchar);
+  return str;
+}
+
 // Request implementation
 inline bool Request::has_header(const char *key) const {
   return detail::has_header(headers, key);
@@ -3244,6 +3383,43 @@ inline bool Client::send(const Request &req, Response &res) {
     ret = redirect(req, res);
   }
 
+  if (ret && !username_.empty() && !password_.empty() && res.status == 401) {
+    int type;
+    std::map<std::string, std::string> digest_auth;
+
+    if ((type = parse_www_authenticate(res, digest_auth)) > 0) {
+      std::pair<std::string, std::string> header;
+
+      if (type == 1) {
+        header = make_basic_authentication_header(username_, password_);
+      } else if (type == 2) {
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+        size_t cnonce_count = 1;
+        auto cnonce = random_string(10);
+
+        header = make_digest_authentication_header(
+            req, digest_auth, cnonce_count, cnonce, username_, password_);
+#endif
+      }
+
+      Request new_req;
+      new_req.method = req.method;
+      new_req.path = req.path;
+      new_req.headers = req.headers;
+      new_req.body = req.body;
+      new_req.response_handler = req.response_handler;
+      new_req.content_receiver = req.content_receiver;
+      new_req.progress = req.progress;
+
+      new_req.headers.insert(header);
+
+      Response new_res;
+      auto ret = send(new_req, new_res);
+      if (ret) { res = new_res; }
+      return ret;
+    }
+  }
+
   return ret;
 }
 
@@ -3810,6 +3986,11 @@ inline void Client::set_read_timeout(time_t sec, time_t usec) {
 
 inline void Client::follow_location(bool on) { follow_location_ = on; }
 
+inline void Client::set_auth(const char *username, const char *password) {
+  username_ = username;
+  password_ = password;
+}
+
 /*
  * SSL Implementation
  */

+ 42 - 0
test/test.cc

@@ -469,7 +469,49 @@ TEST(BaseAuthTest, FromHTTPWatch) {
               "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
     EXPECT_EQ(200, res->status);
   }
+
+  {
+    cli.set_auth("hello", "world");
+    auto res = cli.Get("/basic-auth/hello/world");
+    ASSERT_TRUE(res != nullptr);
+    EXPECT_EQ(res->body,
+              "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
+    EXPECT_EQ(200, res->status);
+  }
+}
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+TEST(DigestAuthTest, FromHTTPWatch) {
+  auto host = "httpbin.org";
+  auto port = 443;
+  httplib::SSLClient cli(host, port);
+
+  {
+    auto res = cli.Get("/digest-auth/auth/hello/world");
+    ASSERT_TRUE(res != nullptr);
+    EXPECT_EQ(401, res->status);
+  }
+
+  {
+    std::vector<std::string> paths = {
+      "/digest-auth/auth/hello/world/MD5",
+      "/digest-auth/auth/hello/world/SHA-256",
+      "/digest-auth/auth/hello/world/SHA-512",
+      "/digest-auth/auth-init/hello/world/MD5",
+      "/digest-auth/auth-int/hello/world/MD5",
+    };
+
+    cli.set_auth("hello", "world");
+    for (auto path: paths) {
+      auto res = cli.Get(path.c_str());
+      ASSERT_TRUE(res != nullptr);
+      EXPECT_EQ(res->body,
+                "{\n  \"authenticated\": true, \n  \"user\": \"hello\"\n}\n");
+      EXPECT_EQ(200, res->status);
+    }
+  }
 }
+#endif
 
 TEST(AbsoluteRedirectTest, Redirect) {
   auto host = "httpbin.org";