yhirose 5 years ago
parent
commit
9ca1fa8b18
2 changed files with 111 additions and 76 deletions
  1. 97 74
      httplib.h
  2. 14 2
      test/test.cc

+ 97 - 74
httplib.h

@@ -349,6 +349,8 @@ struct Request {
 
   bool has_header(const char *key) const;
   std::string get_header_value(const char *key, size_t id = 0) const;
+  template <typename T>
+  T get_header_value(const char *key, size_t id = 0) const;
   size_t get_header_value_count(const char *key) const;
   void set_header(const char *key, const char *val);
   void set_header(const char *key, const std::string &val);
@@ -374,6 +376,8 @@ struct Response {
 
   bool has_header(const char *key) const;
   std::string get_header_value(const char *key, size_t id = 0) const;
+  template <typename T>
+  T get_header_value(const char *key, size_t id = 0) const;
   size_t get_header_value_count(const char *key) const;
   void set_header(const char *key, const char *val);
   void set_header(const char *key, const std::string &val);
@@ -1580,6 +1584,74 @@ inline bool is_valid_path(const std::string &path) {
   return true;
 }
 
+inline std::string encode_url(const std::string &s) {
+  std::string result;
+
+  for (size_t i = 0; s[i]; i++) {
+    switch (s[i]) {
+    case ' ': result += "%20"; break;
+    case '+': result += "%2B"; break;
+    case '\r': result += "%0D"; break;
+    case '\n': result += "%0A"; break;
+    case '\'': result += "%27"; break;
+    case ',': result += "%2C"; break;
+    // case ':': result += "%3A"; break; // ok? probably...
+    case ';': result += "%3B"; break;
+    default:
+      auto c = static_cast<uint8_t>(s[i]);
+      if (c >= 0x80) {
+        result += '%';
+        char hex[4];
+        auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c);
+        assert(len == 2);
+        result.append(hex, static_cast<size_t>(len));
+      } else {
+        result += s[i];
+      }
+      break;
+    }
+  }
+
+  return result;
+}
+
+inline std::string decode_url(const std::string &s,
+                              bool convert_plus_to_space) {
+  std::string result;
+
+  for (size_t i = 0; i < s.size(); i++) {
+    if (s[i] == '%' && i + 1 < s.size()) {
+      if (s[i + 1] == 'u') {
+        int val = 0;
+        if (from_hex_to_i(s, i + 2, 4, val)) {
+          // 4 digits Unicode codes
+          char buff[4];
+          size_t len = to_utf8(val, buff);
+          if (len > 0) { result.append(buff, len); }
+          i += 5; // 'u0000'
+        } else {
+          result += s[i];
+        }
+      } else {
+        int val = 0;
+        if (from_hex_to_i(s, i + 1, 2, val)) {
+          // 2 digits hex codes
+          result += static_cast<char>(val);
+          i += 2; // '00'
+        } else {
+          result += s[i];
+        }
+      }
+    } else if (convert_plus_to_space && s[i] == '+') {
+      result += ' ';
+    } else {
+      result += s[i];
+    }
+  }
+
+  return result;
+}
+
 inline void read_file(const std::string &path, std::string &out) {
   std::ifstream fs(path, std::ios_base::binary);
   fs.seekg(0, std::ios_base::end);
@@ -2379,10 +2451,18 @@ inline const char *get_header_value(const Headers &headers, const char *key,
   return def;
 }
 
-inline uint64_t get_header_value_uint64(const Headers &headers, const char *key,
-                                        uint64_t def = 0) {
-  auto it = headers.find(key);
-  if (it != headers.end()) {
+template <typename T>
+inline T get_header_value(const Headers & /*headers*/, const char * /*key*/,
+                          size_t /*id*/ = 0, uint64_t /*def*/ = 0) {}
+
+template <>
+inline uint64_t get_header_value<uint64_t>(const Headers &headers,
+                                           const char *key, size_t id,
+                                           uint64_t def) {
+  auto rng = headers.equal_range(key);
+  auto it = rng.first;
+  std::advance(it, static_cast<ssize_t>(id));
+  if (it != rng.second) {
     return std::strtoull(it->second.data(), nullptr, 10);
   }
   return def;
@@ -2404,7 +2484,8 @@ inline void parse_header(const char *beg, const char *end, Headers &headers) {
       while (p < end) {
         p++;
       }
-      headers.emplace(std::string(beg, key_end), std::string(val_begin, end));
+      headers.emplace(std::string(beg, key_end),
+                      decode_url(std::string(val_begin, end), true));
     }
   }
 }
@@ -2574,7 +2655,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
   } else if (!has_header(x.headers, "Content-Length")) {
     ret = read_content_without_length(strm, out);
   } else {
-    auto len = get_header_value_uint64(x.headers, "Content-Length", 0);
+    auto len = get_header_value<uint64_t>(x.headers, "Content-Length");
     if (len > payload_max_length) {
       exceed_payload_max_length = true;
       skip_content_with_length(strm, len);
@@ -2765,74 +2846,6 @@ inline bool redirect(T &cli, const Request &req, Response &res,
   return ret;
 }
 
-inline std::string encode_url(const std::string &s) {
-  std::string result;
-
-  for (size_t i = 0; s[i]; i++) {
-    switch (s[i]) {
-    case ' ': result += "%20"; break;
-    case '+': result += "%2B"; break;
-    case '\r': result += "%0D"; break;
-    case '\n': result += "%0A"; break;
-    case '\'': result += "%27"; break;
-    case ',': result += "%2C"; break;
-    // case ':': result += "%3A"; break; // ok? probably...
-    case ';': result += "%3B"; break;
-    default:
-      auto c = static_cast<uint8_t>(s[i]);
-      if (c >= 0x80) {
-        result += '%';
-        char hex[4];
-        auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c);
-        assert(len == 2);
-        result.append(hex, static_cast<size_t>(len));
-      } else {
-        result += s[i];
-      }
-      break;
-    }
-  }
-
-  return result;
-}
-
-inline std::string decode_url(const std::string &s,
-                              bool convert_plus_to_space) {
-  std::string result;
-
-  for (size_t i = 0; i < s.size(); i++) {
-    if (s[i] == '%' && i + 1 < s.size()) {
-      if (s[i + 1] == 'u') {
-        int val = 0;
-        if (from_hex_to_i(s, i + 2, 4, val)) {
-          // 4 digits Unicode codes
-          char buff[4];
-          size_t len = to_utf8(val, buff);
-          if (len > 0) { result.append(buff, len); }
-          i += 5; // 'u0000'
-        } else {
-          result += s[i];
-        }
-      } else {
-        int val = 0;
-        if (from_hex_to_i(s, i + 1, 2, val)) {
-          // 2 digits hex codes
-          result += static_cast<char>(val);
-          i += 2; // '00'
-        } else {
-          result += s[i];
-        }
-      }
-    } else if (convert_plus_to_space && s[i] == '+') {
-      result += ' ';
-    } else {
-      result += s[i];
-    }
-  }
-
-  return result;
-}
-
 inline std::string params_to_query_str(const Params &params) {
   std::string query;
 
@@ -3458,6 +3471,11 @@ inline std::string Request::get_header_value(const char *key, size_t id) const {
   return detail::get_header_value(headers, key, id, "");
 }
 
+template <typename T>
+inline T Request::get_header_value(const char *key, size_t id) const {
+  return detail::get_header_value<T>(headers, key, id, 0);
+}
+
 inline size_t Request::get_header_value_count(const char *key) const {
   auto r = headers.equal_range(key);
   return static_cast<size_t>(std::distance(r.first, r.second));
@@ -3517,6 +3535,11 @@ inline std::string Response::get_header_value(const char *key,
   return detail::get_header_value(headers, key, id, "");
 }
 
+template <typename T>
+inline T Response::get_header_value(const char *key, size_t id) const {
+  return detail::get_header_value<T>(headers, key, id, 0);
+}
+
 inline size_t Response::get_header_value_count(const char *key) const {
   auto r = headers.equal_range(key);
   return static_cast<size_t>(std::distance(r.first, r.second));

+ 14 - 2
test/test.cc

@@ -100,7 +100,8 @@ TEST(GetHeaderValueTest, DefaultValue) {
 
 TEST(GetHeaderValueTest, DefaultValueInt) {
   Headers headers = {{"Dummy", "Dummy"}};
-  auto val = detail::get_header_value_uint64(headers, "Content-Length", 100);
+  auto val =
+      detail::get_header_value<uint64_t>(headers, "Content-Length", 0, 100);
   EXPECT_EQ(100ull, val);
 }
 
@@ -112,7 +113,8 @@ TEST(GetHeaderValueTest, RegularValue) {
 
 TEST(GetHeaderValueTest, RegularValueInt) {
   Headers headers = {{"Content-Length", "100"}, {"Dummy", "Dummy"}};
-  auto val = detail::get_header_value_uint64(headers, "Content-Length", 0);
+  auto val =
+      detail::get_header_value<uint64_t>(headers, "Content-Length", 0, 0);
   EXPECT_EQ(100ull, val);
 }
 
@@ -716,6 +718,16 @@ TEST(RedirectToDifferentPort, Redirect) {
   ASSERT_FALSE(svr8080.is_running());
   ASSERT_FALSE(svr8081.is_running());
 }
+
+TEST(UrlWithSpace, Redirect) {
+  httplib::SSLClient cli("edge.forgecdn.net");
+  cli.set_follow_location(true);
+
+  auto res = cli.Get("/files/2595/310/Neat 1.4-17.jar");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ(18527, res->get_header_value<uint64_t>("Content-Length"));
+}
 #endif
 
 TEST(Server, BindDualStack) {