yhirose 2 years ago
parent
commit
9bb3ca8169
2 changed files with 54 additions and 7 deletions
  1. 23 7
      httplib.h
  2. 31 0
      test/test.cc

+ 23 - 7
httplib.h

@@ -1823,7 +1823,8 @@ std::string params_to_query_str(const Params &params);
 
 void parse_query_text(const std::string &s, Params &params);
 
-bool parse_multipart_boundary(const std::string &content_type, std::string &boundary);
+bool parse_multipart_boundary(const std::string &content_type,
+                              std::string &boundary);
 
 bool parse_range_header(const std::string &s, Ranges &ranges);
 
@@ -3391,6 +3392,14 @@ inline const char *get_header_value(const Headers &headers,
   return def;
 }
 
+inline bool compare_case_ignore(const std::string &a, const std::string &b) {
+  if (a.size() != b.size()) { return false; }
+  for (size_t i = 0; i < b.size(); i++) {
+    if (::tolower(a[i]) != ::tolower(b[i])) { return false; }
+  }
+  return true;
+}
+
 template <typename T>
 inline bool parse_header(const char *beg, const char *end, T fn) {
   // Skip trailing spaces and tabs.
@@ -3414,7 +3423,11 @@ inline bool parse_header(const char *beg, const char *end, T fn) {
   }
 
   if (p < end) {
-    fn(std::string(beg, key_end), decode_url(std::string(p, end), false));
+    auto key = std::string(beg, key_end);
+    auto val = compare_case_ignore(key, "Location")
+                   ? std::string(p, end)
+                   : decode_url(std::string(p, end), false);
+    fn(std::move(key), std::move(val));
     return true;
   }
 
@@ -6463,11 +6476,11 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) {
     return false;
   }
 
-  auto location = detail::decode_url(res.get_header_value("location"), true);
+  auto location = res.get_header_value("location");
   if (location.empty()) { return false; }
 
   const static std::regex re(
-      R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)");
+      R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)");
 
   std::smatch m;
   if (!std::regex_match(location, m, re)) { return false; }
@@ -6479,6 +6492,7 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) {
   if (next_host.empty()) { next_host = m[3].str(); }
   auto port_str = m[4].str();
   auto next_path = m[5].str();
+  auto next_query = m[6].str();
 
   auto next_port = port_;
   if (!port_str.empty()) {
@@ -6491,22 +6505,24 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) {
   if (next_host.empty()) { next_host = host_; }
   if (next_path.empty()) { next_path = "/"; }
 
+  auto path = detail::decode_url(next_path, true) + next_query;
+
   if (next_scheme == scheme && next_host == host_ && next_port == port_) {
-    return detail::redirect(*this, req, res, next_path, location, error);
+    return detail::redirect(*this, req, res, path, location, error);
   } else {
     if (next_scheme == "https") {
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
       SSLClient cli(next_host.c_str(), next_port);
       cli.copy_settings(*this);
       if (ca_cert_store_) { cli.set_ca_cert_store(ca_cert_store_); }
-      return detail::redirect(cli, req, res, next_path, location, error);
+      return detail::redirect(cli, req, res, path, location, error);
 #else
       return false;
 #endif
     } else {
       ClientImpl cli(next_host.c_str(), next_port);
       cli.copy_settings(*this);
-      return detail::redirect(cli, req, res, next_path, location, error);
+      return detail::redirect(cli, req, res, path, location, error);
     }
   }
 }

+ 31 - 0
test/test.cc

@@ -5978,3 +5978,34 @@ TEST(TaskQueueTest, IncreaseAtomicInteger) {
   EXPECT_NO_THROW(task_queue->shutdown());
   EXPECT_EQ(number_of_task, count.load());
 }
+
+TEST(RedirectTest, RedirectToUrlWithQueryParameters) {
+  Server svr;
+
+  svr.Get("/", [](const Request & /*req*/, Response &res) {
+    res.set_redirect(R"(/hello?key=val%26key2%3Dval2)");
+  });
+
+  svr.Get("/hello", [](const Request &req, Response &res) {
+    res.set_content(req.get_param_value("key"), "text/plain");
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+
+  {
+    Client cli(HOST, PORT);
+    cli.set_follow_location(true);
+
+    auto res = cli.Get("/");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(200, res->status);
+    EXPECT_EQ("val&key2=val2", res->body);
+  }
+
+  svr.stop();
+  thread.join();
+  ASSERT_FALSE(svr.is_running());
+}
+