yhirose 11 months ago
parent
commit
ba6845925d
2 changed files with 67 additions and 11 deletions
  1. 26 4
      httplib.h
  2. 41 7
      test/test.cc

+ 26 - 4
httplib.h

@@ -2012,18 +2012,34 @@ inline void duration_to_sec_and_usec(const T &duration, U callback) {
   callback(static_cast<time_t>(sec), static_cast<time_t>(usec));
 }
 
+inline bool is_numeric(const std::string &str) {
+  return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit);
+}
+
 inline uint64_t get_header_value_u64(const Headers &headers,
                                      const std::string &key, uint64_t def,
-                                     size_t id) {
+                                     size_t id, bool &is_invalid_value) {
+  is_invalid_value = false;
   auto rng = headers.equal_range(key);
   auto it = rng.first;
   std::advance(it, static_cast<ssize_t>(id));
   if (it != rng.second) {
-    return std::strtoull(it->second.data(), nullptr, 10);
+    if (is_numeric(it->second)) {
+      return std::strtoull(it->second.data(), nullptr, 10);
+    } else {
+      is_invalid_value = true;
+    }
   }
   return def;
 }
 
+inline uint64_t get_header_value_u64(const Headers &headers,
+                                     const std::string &key, uint64_t def,
+                                     size_t id) {
+  bool dummy = false;
+  return get_header_value_u64(headers, key, def, id, dummy);
+}
+
 } // namespace detail
 
 inline uint64_t Request::get_header_value_u64(const std::string &key,
@@ -4433,8 +4449,14 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
         } else if (!has_header(x.headers, "Content-Length")) {
           ret = read_content_without_length(strm, out);
         } else {
-          auto len = get_header_value_u64(x.headers, "Content-Length", 0, 0);
-          if (len > payload_max_length) {
+          auto is_invalid_value = false;
+          auto len = get_header_value_u64(x.headers, "Content-Length",
+                                          std::numeric_limits<uint64_t>::max(),
+                                          0, is_invalid_value);
+
+          if (is_invalid_value) {
+            ret = false;
+          } else if (len > payload_max_length) {
             exceed_payload_max_length = true;
             skip_content_with_length(strm, len);
             ret = false;

+ 41 - 7
test/test.cc

@@ -511,6 +511,15 @@ TEST(GetHeaderValueTest, RegularValueInt) {
   EXPECT_EQ(100ull, val);
 }
 
+TEST(GetHeaderValueTest, RegularInvalidValueInt) {
+  Headers headers = {{"Content-Length", "x"}};
+  auto is_invalid_value = false;
+  auto val = detail::get_header_value_u64(headers, "Content-Length", 0, 0,
+                                          is_invalid_value);
+  EXPECT_EQ(0ull, val);
+  EXPECT_TRUE(is_invalid_value);
+}
+
 TEST(GetHeaderValueTest, Range) {
   {
     Headers headers = {make_range_header({{1, -1}})};
@@ -7496,9 +7505,9 @@ TEST(MultipartFormDataTest, CloseDelimiterWithoutCRLF) {
              "text2"
              "\r\n------------";
 
-  std::string resonse;
-  ASSERT_TRUE(send_request(1, req, &resonse));
-  ASSERT_EQ("200", resonse.substr(9, 3));
+  std::string response;
+  ASSERT_TRUE(send_request(1, req, &response));
+  ASSERT_EQ("200", response.substr(9, 3));
 }
 
 TEST(MultipartFormDataTest, ContentLength) {
@@ -7543,11 +7552,10 @@ TEST(MultipartFormDataTest, ContentLength) {
              "text2"
              "\r\n------------\r\n";
 
-  std::string resonse;
-  ASSERT_TRUE(send_request(1, req, &resonse));
-  ASSERT_EQ("200", resonse.substr(9, 3));
+  std::string response;
+  ASSERT_TRUE(send_request(1, req, &response));
+  ASSERT_EQ("200", response.substr(9, 3));
 }
-
 #endif
 
 TEST(TaskQueueTest, IncreaseAtomicInteger) {
@@ -8007,6 +8015,32 @@ TEST(InvalidHeaderCharsTest, OnServer) {
   }
 }
 
+TEST(InvalidHeaderValueTest, InvalidContentLength) {
+  auto handled = false;
+
+  Server svr;
+  svr.Post("/test", [&](const Request &, Response &) { handled = true; });
+
+  thread t = thread([&] { svr.listen(HOST, PORT); });
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    t.join();
+    ASSERT_FALSE(svr.is_running());
+    ASSERT_FALSE(handled);
+  });
+
+  svr.wait_until_ready();
+
+  auto req = "POST /test HTTP/1.1\r\n"
+             "Content-Length: x\r\n"
+             "\r\n";
+
+  std::string response;
+  ASSERT_TRUE(send_request(1, req, &response));
+  ASSERT_EQ("HTTP/1.1 400 Bad Request",
+            response.substr(0, response.find("\r\n")));
+}
+
 #ifndef _WIN32
 TEST(Expect100ContinueTest, ServerClosesConnection) {
   static constexpr char reject[] = "Unauthorized";