|
|
@@ -149,9 +149,13 @@ using socket_t = int;
|
|
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
|
#include <openssl/err.h>
|
|
|
+#include <openssl/md5.h>
|
|
|
#include <openssl/ssl.h>
|
|
|
#include <openssl/x509v3.h>
|
|
|
|
|
|
+#include <iomanip>
|
|
|
+#include <sstream>
|
|
|
+
|
|
|
// #if OPENSSL_VERSION_NUMBER < 0x1010100fL
|
|
|
// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported
|
|
|
// #endif
|
|
|
@@ -756,10 +760,13 @@ public:
|
|
|
std::vector<Response> &responses);
|
|
|
|
|
|
void set_keep_alive_max_count(size_t count);
|
|
|
+
|
|
|
void set_read_timeout(time_t sec, time_t usec);
|
|
|
|
|
|
void follow_location(bool on);
|
|
|
|
|
|
+ void set_auth(const char *username, const char *password);
|
|
|
+
|
|
|
protected:
|
|
|
bool process_request(Stream &strm, const Request &req, Response &res,
|
|
|
bool last_connection, bool &connection_close);
|
|
|
@@ -772,6 +779,8 @@ protected:
|
|
|
time_t read_timeout_sec_;
|
|
|
time_t read_timeout_usec_;
|
|
|
size_t follow_location_;
|
|
|
+ std::string username_;
|
|
|
+ std::string password_;
|
|
|
|
|
|
private:
|
|
|
socket_t create_client_socket() const;
|
|
|
@@ -1439,6 +1448,7 @@ inline const char *status_message(int status) {
|
|
|
case 303: return "See Other";
|
|
|
case 304: return "Not Modified";
|
|
|
case 400: return "Bad Request";
|
|
|
+ case 401: return "Unauthorized";
|
|
|
case 403: return "Forbidden";
|
|
|
case 404: return "Not Found";
|
|
|
case 413: return "Payload Too Large";
|
|
|
@@ -2287,6 +2297,43 @@ inline bool expect_content(const Request &req) {
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
|
+template <typename CTX, typename Init, typename Update, typename Final>
|
|
|
+inline std::string message_digest(const std::string &s, Init init,
|
|
|
+ Update update, Final final,
|
|
|
+ size_t digest_length) {
|
|
|
+ using namespace std;
|
|
|
+
|
|
|
+ unsigned char md[digest_length];
|
|
|
+ CTX ctx;
|
|
|
+ init(&ctx);
|
|
|
+ update(&ctx, s.data(), s.size());
|
|
|
+ final(md, &ctx);
|
|
|
+
|
|
|
+ stringstream ss;
|
|
|
+ for (auto c : md) {
|
|
|
+ ss << setfill('0') << setw(2) << hex << (unsigned int)c;
|
|
|
+ }
|
|
|
+ return ss.str();
|
|
|
+}
|
|
|
+
|
|
|
+inline std::string MD5(const std::string &s) {
|
|
|
+ using namespace detail;
|
|
|
+ return message_digest<MD5_CTX>(s, MD5_Init, MD5_Update, MD5_Final,
|
|
|
+ MD5_DIGEST_LENGTH);
|
|
|
+}
|
|
|
+
|
|
|
+inline std::string SHA_256(const std::string &s) {
|
|
|
+ return message_digest<SHA256_CTX>(s, SHA256_Init, SHA256_Update, SHA256_Final,
|
|
|
+ SHA256_DIGEST_LENGTH);
|
|
|
+}
|
|
|
+
|
|
|
+inline std::string SHA_512(const std::string &s) {
|
|
|
+ return message_digest<SHA512_CTX>(s, SHA512_Init, SHA512_Update, SHA512_Final,
|
|
|
+ SHA512_DIGEST_LENGTH);
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
#ifdef _WIN32
|
|
|
class WSInit {
|
|
|
public:
|
|
|
@@ -2324,6 +2371,98 @@ make_basic_authentication_header(const std::string &username,
|
|
|
return std::make_pair("Authorization", field);
|
|
|
}
|
|
|
|
|
|
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
|
+inline std::pair<std::string, std::string> make_digest_authentication_header(
|
|
|
+ const Request &req,
|
|
|
+ const std::map<std::string, std::string> &auth,
|
|
|
+ size_t cnonce_count, const std::string &cnonce,
|
|
|
+ const std::string &username, const std::string &password) {
|
|
|
+ using namespace std;
|
|
|
+
|
|
|
+ string nc;
|
|
|
+ {
|
|
|
+ stringstream ss;
|
|
|
+ ss << setfill('0') << setw(8) << hex << cnonce_count;
|
|
|
+ nc = ss.str();
|
|
|
+ }
|
|
|
+
|
|
|
+ auto qop = auth.at("qop");
|
|
|
+ if (qop.find("auth-int") != std::string::npos) {
|
|
|
+ qop = "auth-int";
|
|
|
+ } else {
|
|
|
+ qop = "auth";
|
|
|
+ }
|
|
|
+
|
|
|
+ string response;
|
|
|
+ {
|
|
|
+ auto algo = auth.at("algorithm");
|
|
|
+
|
|
|
+ auto H = algo == "SHA-256"
|
|
|
+ ? detail::SHA_256
|
|
|
+ : algo == "SHA-512" ? detail::SHA_512 : detail::MD5;
|
|
|
+
|
|
|
+ auto A1 = username + ":" + auth.at("realm") + ":" + password;
|
|
|
+
|
|
|
+ auto A2 = req.method + ":" + req.path;
|
|
|
+ if (qop == "auth-int") {
|
|
|
+ A2 += ":" + H(req.body);
|
|
|
+ }
|
|
|
+
|
|
|
+ response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce +
|
|
|
+ ":" + qop + ":" + H(A2));
|
|
|
+ }
|
|
|
+
|
|
|
+ auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") +
|
|
|
+ "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path +
|
|
|
+ "\", algorithm=" + auth.at("algorithm") + ", qop=" + qop + ", nc=\"" +
|
|
|
+ nc + "\", cnonce=\"" + cnonce + "\", response=\"" + response +
|
|
|
+ "\"";
|
|
|
+
|
|
|
+ return make_pair("Authorization", field);
|
|
|
+}
|
|
|
+#endif
|
|
|
+
|
|
|
+inline int parse_www_authenticate(const httplib::Response &res,
|
|
|
+ std::map<std::string, std::string> &digest_auth) {
|
|
|
+ if (res.has_header("WWW-Authenticate")) {
|
|
|
+ static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*)))))~");
|
|
|
+ auto s = res.get_header_value("WWW-Authenticate");
|
|
|
+ auto pos = s.find(' ');
|
|
|
+ if (pos != std::string::npos) {
|
|
|
+ auto type = s.substr(0, pos);
|
|
|
+ if (type == "Basic") {
|
|
|
+ return 1;
|
|
|
+ } else if (type == "Digest") {
|
|
|
+ s = s.substr(pos + 1);
|
|
|
+ auto beg = std::sregex_iterator(s.begin(), s.end(), re);
|
|
|
+ for (auto i = beg; i != std::sregex_iterator(); ++i) {
|
|
|
+ auto m = *i;
|
|
|
+ auto key = s.substr(m.position(1), m.length(1));
|
|
|
+ auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2))
|
|
|
+ : s.substr(m.position(3), m.length(3));
|
|
|
+ digest_auth[key] = val;
|
|
|
+ }
|
|
|
+ return 2;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240
|
|
|
+inline std::string random_string(size_t length) {
|
|
|
+ auto randchar = []() -> char {
|
|
|
+ const char charset[] = "0123456789"
|
|
|
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
|
+ "abcdefghijklmnopqrstuvwxyz";
|
|
|
+ const size_t max_index = (sizeof(charset) - 1);
|
|
|
+ return charset[rand() % max_index];
|
|
|
+ };
|
|
|
+ std::string str(length, 0);
|
|
|
+ std::generate_n(str.begin(), length, randchar);
|
|
|
+ return str;
|
|
|
+}
|
|
|
+
|
|
|
// Request implementation
|
|
|
inline bool Request::has_header(const char *key) const {
|
|
|
return detail::has_header(headers, key);
|
|
|
@@ -3244,6 +3383,43 @@ inline bool Client::send(const Request &req, Response &res) {
|
|
|
ret = redirect(req, res);
|
|
|
}
|
|
|
|
|
|
+ if (ret && !username_.empty() && !password_.empty() && res.status == 401) {
|
|
|
+ int type;
|
|
|
+ std::map<std::string, std::string> digest_auth;
|
|
|
+
|
|
|
+ if ((type = parse_www_authenticate(res, digest_auth)) > 0) {
|
|
|
+ std::pair<std::string, std::string> header;
|
|
|
+
|
|
|
+ if (type == 1) {
|
|
|
+ header = make_basic_authentication_header(username_, password_);
|
|
|
+ } else if (type == 2) {
|
|
|
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
|
+ size_t cnonce_count = 1;
|
|
|
+ auto cnonce = random_string(10);
|
|
|
+
|
|
|
+ header = make_digest_authentication_header(
|
|
|
+ req, digest_auth, cnonce_count, cnonce, username_, password_);
|
|
|
+#endif
|
|
|
+ }
|
|
|
+
|
|
|
+ Request new_req;
|
|
|
+ new_req.method = req.method;
|
|
|
+ new_req.path = req.path;
|
|
|
+ new_req.headers = req.headers;
|
|
|
+ new_req.body = req.body;
|
|
|
+ new_req.response_handler = req.response_handler;
|
|
|
+ new_req.content_receiver = req.content_receiver;
|
|
|
+ new_req.progress = req.progress;
|
|
|
+
|
|
|
+ new_req.headers.insert(header);
|
|
|
+
|
|
|
+ Response new_res;
|
|
|
+ auto ret = send(new_req, new_res);
|
|
|
+ if (ret) { res = new_res; }
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
return ret;
|
|
|
}
|
|
|
|
|
|
@@ -3810,6 +3986,11 @@ inline void Client::set_read_timeout(time_t sec, time_t usec) {
|
|
|
|
|
|
inline void Client::follow_location(bool on) { follow_location_ = on; }
|
|
|
|
|
|
+inline void Client::set_auth(const char *username, const char *password) {
|
|
|
+ username_ = username;
|
|
|
+ password_ = password;
|
|
|
+}
|
|
|
+
|
|
|
/*
|
|
|
* SSL Implementation
|
|
|
*/
|