Browse Source

Fix problem with invalid range

yhirose 5 years ago
parent
commit
041122908c
2 changed files with 171 additions and 146 deletions
  1. 152 120
      httplib.h
  2. 19 26
      test/test.cc

+ 152 - 120
httplib.h

@@ -676,8 +676,14 @@ private:
                                       const HandlersForContentReader &handlers);
                                       const HandlersForContentReader &handlers);
 
 
   bool parse_request_line(const char *s, Request &req);
   bool parse_request_line(const char *s, Request &req);
+  void apply_ranges(const Request &req, Response &res,
+                    std::string &content_type, std::string &boundary);
   bool write_response(Stream &strm, bool close_connection, const Request &req,
   bool write_response(Stream &strm, bool close_connection, const Request &req,
                       Response &res);
                       Response &res);
+  bool write_response_with_content(Stream &strm, bool close_connection,
+                                   const Request &req, Response &res,
+                                   std::string &content_type,
+                                   std::string &boundary);
   bool write_content_with_provider(Stream &strm, const Request &req,
   bool write_content_with_provider(Stream &strm, const Request &req,
                                    Response &res, const std::string &boundary,
                                    Response &res, const std::string &boundary,
                                    const std::string &content_type);
                                    const std::string &content_type);
@@ -3171,9 +3177,7 @@ get_range_offset_and_length(const Request &req, size_t content_length,
     r.second = slen - 1;
     r.second = slen - 1;
   }
   }
 
 
-  if (r.second == -1) {
-    r.second = slen - 1;
-  }
+  if (r.second == -1) { r.second = slen - 1; }
   return std::make_pair(r.first, static_cast<size_t>(r.second - r.first) + 1);
   return std::make_pair(r.first, static_cast<size_t>(r.second - r.first) + 1);
 }
 }
 
 
@@ -3223,21 +3227,21 @@ bool process_multipart_ranges_data(const Request &req, Response &res,
   return true;
   return true;
 }
 }
 
 
-inline std::string make_multipart_ranges_data(const Request &req, Response &res,
-                                              const std::string &boundary,
-                                              const std::string &content_type) {
-  std::string data;
-
-  process_multipart_ranges_data(
+inline bool make_multipart_ranges_data(const Request &req, Response &res,
+                                       const std::string &boundary,
+                                       const std::string &content_type,
+                                       std::string &data) {
+  return process_multipart_ranges_data(
       req, res, boundary, content_type,
       req, res, boundary, content_type,
       [&](const std::string &token) { data += token; },
       [&](const std::string &token) { data += token; },
       [&](const char *token) { data += token; },
       [&](const char *token) { data += token; },
       [&](size_t offset, size_t length) {
       [&](size_t offset, size_t length) {
-        data += res.body.substr(offset, length);
-        return true;
+        if (offset < res.body.size()) {
+          data += res.body.substr(offset, length);
+          return true;
+        }
+        return false;
       });
       });
-
-  return data;
 }
 }
 
 
 inline size_t
 inline size_t
@@ -4006,18 +4010,19 @@ inline bool Server::parse_request_line(const char *s, Request &req) {
 
 
 inline bool Server::write_response(Stream &strm, bool close_connection,
 inline bool Server::write_response(Stream &strm, bool close_connection,
                                    const Request &req, Response &res) {
                                    const Request &req, Response &res) {
+  std::string content_type;
+  std::string boundary;
+  return write_response_with_content(strm, close_connection, req, res,
+                                     content_type, boundary);
+}
+
+inline bool Server::write_response_with_content(
+    Stream &strm, bool close_connection, const Request &req, Response &res,
+    std::string &content_type, std::string &boundary) {
   assert(res.status != -1);
   assert(res.status != -1);
 
 
   if (400 <= res.status && error_handler_) { error_handler_(req, res); }
   if (400 <= res.status && error_handler_) { error_handler_(req, res); }
 
 
-  detail::BufferStream bstrm;
-
-  // Response line
-  if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status,
-                          detail::status_message(res.status))) {
-    return false;
-  }
-
   // Headers
   // Headers
   if (close_connection || req.get_header_value("Connection") == "close") {
   if (close_connection || req.get_header_value("Connection") == "close") {
     res.set_header("Connection", "close");
     res.set_header("Connection", "close");
@@ -4033,109 +4038,21 @@ inline bool Server::write_response(Stream &strm, bool close_connection,
     res.set_header("Content-Type", "text/plain");
     res.set_header("Content-Type", "text/plain");
   }
   }
 
 
-  if (!res.has_header("Accept-Ranges") && req.method == "HEAD") {
-    res.set_header("Accept-Ranges", "bytes");
+  if (!res.has_header("Content-Length") && res.body.empty() &&
+      !res.content_length_ && !res.content_provider_) {
+    res.set_header("Content-Length", "0");
   }
   }
 
 
-  std::string content_type;
-  std::string boundary;
-
-  if (req.ranges.size() > 1) {
-    boundary = detail::make_multipart_data_boundary();
-
-    auto it = res.headers.find("Content-Type");
-    if (it != res.headers.end()) {
-      content_type = it->second;
-      res.headers.erase(it);
-    }
-
-    res.headers.emplace("Content-Type",
-                        "multipart/byteranges; boundary=" + boundary);
+  if (!res.has_header("Accept-Ranges") && req.method == "HEAD") {
+    res.set_header("Accept-Ranges", "bytes");
   }
   }
 
 
-  auto type = detail::encoding_type(req, res);
-
-  if (res.body.empty()) {
-    if (res.content_length_ > 0) {
-      size_t length = 0;
-      if (req.ranges.empty()) {
-        length = res.content_length_;
-      } else if (req.ranges.size() == 1) {
-        auto offsets =
-            detail::get_range_offset_and_length(req, res.content_length_, 0);
-        auto offset = offsets.first;
-        length = offsets.second;
-        auto content_range = detail::make_content_range_header_field(
-            offset, length, res.content_length_);
-        res.set_header("Content-Range", content_range);
-      } else {
-        length = detail::get_multipart_ranges_data_length(req, res, boundary,
-                                                          content_type);
-      }
-      res.set_header("Content-Length", std::to_string(length));
-    } else {
-      if (res.content_provider_) {
-        if (res.is_chunked_content_provider) {
-          res.set_header("Transfer-Encoding", "chunked");
-          if (type == detail::EncodingType::Gzip) {
-            res.set_header("Content-Encoding", "gzip");
-          } else if (type == detail::EncodingType::Brotli) {
-            res.set_header("Content-Encoding", "br");
-          }
-        }
-      } else {
-        res.set_header("Content-Length", "0");
-      }
-    }
-  } else {
-    if (req.ranges.empty()) {
-      ;
-    } else if (req.ranges.size() == 1) {
-      auto offsets =
-          detail::get_range_offset_and_length(req, res.body.size(), 0);
-      auto offset = offsets.first;
-      auto length = offsets.second;
-      auto content_range = detail::make_content_range_header_field(
-          offset, length, res.body.size());
-      res.set_header("Content-Range", content_range);
-      res.body = res.body.substr(offset, length);
-    } else {
-      res.body =
-          detail::make_multipart_ranges_data(req, res, boundary, content_type);
-    }
-
-    if (type != detail::EncodingType::None) {
-      std::unique_ptr<detail::compressor> compressor;
-
-      if (type == detail::EncodingType::Gzip) {
-#ifdef CPPHTTPLIB_ZLIB_SUPPORT
-        compressor = detail::make_unique<detail::gzip_compressor>();
-        res.set_header("Content-Encoding", "gzip");
-#endif
-      } else if (type == detail::EncodingType::Brotli) {
-#ifdef CPPHTTPLIB_BROTLI_SUPPORT
-        compressor = detail::make_unique<detail::brotli_compressor>();
-        res.set_header("Content-Encoding", "brotli");
-#endif
-      }
-
-      if (compressor) {
-        std::string compressed;
-
-        if (!compressor->compress(res.body.data(), res.body.size(), true,
-                                  [&](const char *data, size_t data_len) {
-                                    compressed.append(data, data_len);
-                                    return true;
-                                  })) {
-          return false;
-        }
-
-        res.body.swap(compressed);
-      }
-    }
+  detail::BufferStream bstrm;
 
 
-    auto length = std::to_string(res.body.size());
-    res.set_header("Content-Length", length);
+  // Response line
+  if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status,
+                          detail::status_message(res.status))) {
+    return false;
   }
   }
 
 
   if (!detail::write_headers(bstrm, res, Headers())) { return false; }
   if (!detail::write_headers(bstrm, res, Headers())) { return false; }
@@ -4535,6 +4452,116 @@ inline bool Server::dispatch_request(Request &req, Response &res,
   return false;
   return false;
 }
 }
 
 
+inline void Server::apply_ranges(const Request &req, Response &res,
+                                 std::string &content_type,
+                                 std::string &boundary) {
+  if (req.ranges.size() > 1) {
+    boundary = detail::make_multipart_data_boundary();
+
+    auto it = res.headers.find("Content-Type");
+    if (it != res.headers.end()) {
+      content_type = it->second;
+      res.headers.erase(it);
+    }
+
+    res.headers.emplace("Content-Type",
+                        "multipart/byteranges; boundary=" + boundary);
+  }
+
+  auto type = detail::encoding_type(req, res);
+
+  if (res.body.empty()) {
+    if (res.content_length_ > 0) {
+      size_t length = 0;
+      if (req.ranges.empty()) {
+        length = res.content_length_;
+      } else if (req.ranges.size() == 1) {
+        auto offsets =
+            detail::get_range_offset_and_length(req, res.content_length_, 0);
+        auto offset = offsets.first;
+        length = offsets.second;
+        auto content_range = detail::make_content_range_header_field(
+            offset, length, res.content_length_);
+        res.set_header("Content-Range", content_range);
+      } else {
+        length = detail::get_multipart_ranges_data_length(req, res, boundary,
+                                                          content_type);
+      }
+      res.set_header("Content-Length", std::to_string(length));
+    } else {
+      if (res.content_provider_) {
+        if (res.is_chunked_content_provider) {
+          res.set_header("Transfer-Encoding", "chunked");
+          if (type == detail::EncodingType::Gzip) {
+            res.set_header("Content-Encoding", "gzip");
+          } else if (type == detail::EncodingType::Brotli) {
+            res.set_header("Content-Encoding", "br");
+          }
+        }
+      }
+    }
+  } else {
+    if (req.ranges.empty()) {
+      ;
+    } else if (req.ranges.size() == 1) {
+      auto offsets =
+          detail::get_range_offset_and_length(req, res.body.size(), 0);
+      auto offset = offsets.first;
+      auto length = offsets.second;
+      auto content_range = detail::make_content_range_header_field(
+          offset, length, res.body.size());
+      res.set_header("Content-Range", content_range);
+      if (offset < res.body.size()) {
+        res.body = res.body.substr(offset, length);
+      } else {
+        res.body.clear();
+        res.status = 416;
+      }
+    } else {
+      std::string data;
+      if (detail::make_multipart_ranges_data(req, res, boundary, content_type,
+                                             data)) {
+        res.body.swap(data);
+      } else {
+        res.body.clear();
+        res.status = 416;
+      }
+    }
+
+    if (type != detail::EncodingType::None) {
+      std::unique_ptr<detail::compressor> compressor;
+      std::string content_encoding;
+
+      if (type == detail::EncodingType::Gzip) {
+#ifdef CPPHTTPLIB_ZLIB_SUPPORT
+        compressor = detail::make_unique<detail::gzip_compressor>();
+        content_encoding = "gzip";
+#endif
+      } else if (type == detail::EncodingType::Brotli) {
+#ifdef CPPHTTPLIB_BROTLI_SUPPORT
+        compressor = detail::make_unique<detail::brotli_compressor>();
+        content_encoding = "brotli";
+#endif
+      }
+
+      if (compressor) {
+        std::string compressed;
+        if (compressor->compress(res.body.data(), res.body.size(), true,
+                                 [&](const char *data, size_t data_len) {
+                                   compressed.append(data, data_len);
+                                   return true;
+                                 })) {
+          res.body.swap(compressed);
+          res.set_header("Content-Encoding", content_encoding);
+        }
+      }
+    }
+
+    auto length = std::to_string(res.body.size());
+    res.set_header("Content-Length", length);
+  }
+}
+
 inline bool Server::dispatch_request_for_content_reader(
 inline bool Server::dispatch_request_for_content_reader(
     Request &req, Response &res, ContentReader content_reader,
     Request &req, Response &res, ContentReader content_reader,
     const HandlersForContentReader &handlers) {
     const HandlersForContentReader &handlers) {
@@ -4626,7 +4653,12 @@ Server::process_request(Stream &strm, bool close_connection,
     if (res.status == -1) { res.status = 404; }
     if (res.status == -1) { res.status = 404; }
   }
   }
 
 
-  return write_response(strm, close_connection, req, res);
+  std::string content_type;
+  std::string boundary;
+  apply_ranges(req, res, content_type, boundary);
+
+  return write_response_with_content(strm, close_connection, req, res,
+                                     content_type, boundary);
 }
 }
 
 
 inline bool Server::is_valid() const { return true; }
 inline bool Server::is_valid() const { return true; }

+ 19 - 26
test/test.cc

@@ -1899,9 +1899,7 @@ TEST_F(ServerTest, GetStreamedWithRange2) {
 }
 }
 
 
 TEST_F(ServerTest, GetStreamedWithRangeSuffix1) {
 TEST_F(ServerTest, GetStreamedWithRangeSuffix1) {
-  auto res = cli_.Get("/streamed-with-range", {
-    {"Range", "bytes=-3"}
-  });
+  auto res = cli_.Get("/streamed-with-range", {{"Range", "bytes=-3"}});
   ASSERT_TRUE(res);
   ASSERT_TRUE(res);
   EXPECT_EQ(206, res->status);
   EXPECT_EQ(206, res->status);
   EXPECT_EQ("3", res->get_header_value("Content-Length"));
   EXPECT_EQ("3", res->get_header_value("Content-Length"));
@@ -1909,11 +1907,8 @@ TEST_F(ServerTest, GetStreamedWithRangeSuffix1) {
   EXPECT_EQ(std::string("efg"), res->body);
   EXPECT_EQ(std::string("efg"), res->body);
 }
 }
 
 
-
 TEST_F(ServerTest, GetStreamedWithRangeSuffix2) {
 TEST_F(ServerTest, GetStreamedWithRangeSuffix2) {
-  auto res = cli_.Get("/streamed-with-range", {
-    {"Range", "bytes=-9999"}
-  });
+  auto res = cli_.Get("/streamed-with-range", {{"Range", "bytes=-9999"}});
   ASSERT_TRUE(res);
   ASSERT_TRUE(res);
   EXPECT_EQ(206, res->status);
   EXPECT_EQ(206, res->status);
   EXPECT_EQ("7", res->get_header_value("Content-Length"));
   EXPECT_EQ("7", res->get_header_value("Content-Length"));
@@ -1921,18 +1916,17 @@ TEST_F(ServerTest, GetStreamedWithRangeSuffix2) {
   EXPECT_EQ(std::string("abcdefg"), res->body);
   EXPECT_EQ(std::string("abcdefg"), res->body);
 }
 }
 
 
-
 TEST_F(ServerTest, GetStreamedWithRangeError) {
 TEST_F(ServerTest, GetStreamedWithRangeError) {
-  auto res = cli_.Get("/streamed-with-range", {
-    {"Range", "bytes=92233720368547758079223372036854775806-92233720368547758079223372036854775807"}
-  });
+  auto res = cli_.Get("/streamed-with-range",
+                      {{"Range", "bytes=92233720368547758079223372036854775806-"
+                                 "92233720368547758079223372036854775807"}});
   ASSERT_TRUE(res);
   ASSERT_TRUE(res);
   EXPECT_EQ(416, res->status);
   EXPECT_EQ(416, res->status);
 }
 }
 
 
-//Tests long long overflow.
 TEST_F(ServerTest, GetRangeWithMaxLongLength) {
 TEST_F(ServerTest, GetRangeWithMaxLongLength) {
-  auto res = cli_.Get("/with-range",{{"Range", "bytes=0-9223372036854775807"}});
+  auto res =
+      cli_.Get("/with-range", {{"Range", "bytes=0-9223372036854775807"}});
   EXPECT_EQ(206, res->status);
   EXPECT_EQ(206, res->status);
   EXPECT_EQ("7", res->get_header_value("Content-Length"));
   EXPECT_EQ("7", res->get_header_value("Content-Length"));
   EXPECT_EQ(true, res->has_header("Content-Range"));
   EXPECT_EQ(true, res->has_header("Content-Range"));
@@ -2020,11 +2014,11 @@ TEST_F(ServerTest, GetWithRange4) {
   EXPECT_EQ(std::string("fg"), res->body);
   EXPECT_EQ(std::string("fg"), res->body);
 }
 }
 
 
-//TEST_F(ServerTest, GetWithRangeOffsetGreaterThanContent) {
-//  auto res = cli_.Get("/with-range", {{make_range_header({{10000, 20000}})}});
-//  ASSERT_TRUE(res);
-//  EXPECT_EQ(416, res->status);
-//}
+TEST_F(ServerTest, GetWithRangeOffsetGreaterThanContent) {
+  auto res = cli_.Get("/with-range", {{make_range_header({{10000, 20000}})}});
+  ASSERT_TRUE(res);
+  EXPECT_EQ(416, res->status);
+}
 
 
 TEST_F(ServerTest, GetWithRangeMultipart) {
 TEST_F(ServerTest, GetWithRangeMultipart) {
   auto res = cli_.Get("/with-range", {{make_range_header({{1, 2}, {4, 5}})}});
   auto res = cli_.Get("/with-range", {{make_range_header({{1, 2}, {4, 5}})}});
@@ -2035,11 +2029,12 @@ TEST_F(ServerTest, GetWithRangeMultipart) {
   EXPECT_EQ(269, res->body.size());
   EXPECT_EQ(269, res->body.size());
 }
 }
 
 
-//TEST_F(ServerTest, GetWithRangeMultipartOffsetGreaterThanContent) {
-//  auto res = cli_.Get("/with-range", {{make_range_header({{-1, 2}, {10000, 30000}})}});
-//  ASSERT_TRUE(res);
-//  EXPECT_EQ(416, res->status);
-//}
+TEST_F(ServerTest, GetWithRangeMultipartOffsetGreaterThanContent) {
+  auto res =
+      cli_.Get("/with-range", {{make_range_header({{-1, 2}, {10000, 30000}})}});
+  ASSERT_TRUE(res);
+  EXPECT_EQ(416, res->status);
+}
 
 
 TEST_F(ServerTest, GetStreamedChunked) {
 TEST_F(ServerTest, GetStreamedChunked) {
   auto res = cli_.Get("/streamed-chunked");
   auto res = cli_.Get("/streamed-chunked");
@@ -3058,9 +3053,7 @@ TEST(KeepAliveTest, ReadTimeoutSSL) {
     res.set_content("b", "text/plain");
     res.set_content("b", "text/plain");
   });
   });
 
 
-  auto listen_thread = std::thread([&svr]() {
-    svr.listen("localhost", PORT);
-  });
+  auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
   while (!svr.is_running()) {
   while (!svr.is_running()) {
     std::this_thread::sleep_for(std::chrono::milliseconds(1));
     std::this_thread::sleep_for(std::chrono::milliseconds(1));
   }
   }