Browse Source

Don't allow invalid status code format (It sould be a three-digit code.)

yhirose 5 years ago
parent
commit
7c1c952f5a
2 changed files with 34 additions and 8 deletions
  1. 9 8
      httplib.h
  2. 25 0
      test/test.cc

+ 9 - 8
httplib.h

@@ -1025,7 +1025,7 @@ protected:
 
 private:
   socket_t create_client_socket(Error &error) const;
-  bool read_response_line(Stream &strm, Response &res);
+  bool read_response_line(Stream &strm, const Request &req, Response &res);
   bool write_request(Stream &strm, const Request &req, bool close_connection,
                      Error &error);
   bool redirect(const Request &req, Response &res, Error &error);
@@ -4947,17 +4947,20 @@ inline void ClientImpl::lock_socket_and_shutdown_and_close() {
   close_socket(socket_);
 }
 
-inline bool ClientImpl::read_response_line(Stream &strm, Response &res) {
+inline bool ClientImpl::read_response_line(Stream &strm, const Request &req,
+                                           Response &res) {
   std::array<char, 2048> buf;
 
   detail::stream_line_reader line_reader(strm, buf.data(), buf.size());
 
   if (!line_reader.getline()) { return false; }
 
-  const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n");
+  const static std::regex re("(HTTP/1\\.[01]) (\\d{3}) (.*?)\r\n");
 
   std::cmatch m;
-  if (!std::regex_match(line_reader.ptr(), m, re)) { return true; }
+  if (!std::regex_match(line_reader.ptr(), m, re)) {
+    return req.method == "CONNECT";
+  }
   res.version = std::string(m[1]);
   res.status = std::stoi(std::string(m[2]));
   res.reason = std::string(m[3]);
@@ -5404,7 +5407,7 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
   if (!write_request(strm, req, close_connection, error)) { return false; }
 
   // Receive response and headers
-  if (!read_response_line(strm, res) ||
+  if (!read_response_line(strm, req, res) ||
       !detail::read_headers(strm, res.headers)) {
     error = Error::Read;
     return false;
@@ -5448,9 +5451,7 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
     if (!detail::read_content(strm, res, (std::numeric_limits<size_t>::max)(),
                               dummy_status, std::move(progress), std::move(out),
                               decompress_)) {
-      if (error != Error::Canceled) {
-        error = Error::Read;
-      }
+      if (error != Error::Canceled) { error = Error::Read; }
       return false;
     }
   }

+ 25 - 0
test/test.cc

@@ -930,6 +930,31 @@ TEST(ErrorHandlerTest, ContentLength) {
   ASSERT_FALSE(svr.is_running());
 }
 
+TEST(InvalidFormatTest, StatusCode) {
+  Server svr;
+
+  svr.Get("/hi", [](const Request & /*req*/, Response &res) {
+    res.set_content("Hello World!\n", "text/plain");
+    res.status = 9999; // Status should be a three-digit code...
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+
+  // Give GET time to get a few messages.
+  std::this_thread::sleep_for(std::chrono::seconds(1));
+
+  {
+    Client cli(HOST, PORT);
+
+    auto res = cli.Get("/hi");
+    ASSERT_FALSE(res);
+  }
+
+  svr.stop();
+  thread.join();
+  ASSERT_FALSE(svr.is_running());
+}
+
 class ServerTest : public ::testing::Test {
 protected:
   ServerTest()