Browse Source

Changed to return 416 for a request with an invalid range

yhirose 1 year ago
parent
commit
fceada9ef4
2 changed files with 93 additions and 77 deletions
  1. 78 63
      httplib.h
  2. 15 14
      test/test.cc

+ 78 - 63
httplib.h

@@ -4720,29 +4720,47 @@ serialize_multipart_formdata(const MultipartFormDataItems &items,
   return body;
 }
 
-inline std::pair<size_t, size_t>
-get_range_offset_and_length(Range range, size_t content_length) {
-  if (range.first == -1 && range.second == -1) {
-    return std::make_pair(0, content_length);
-  }
+inline bool normalize_ranges(Request &req, Response &res) {
+  ssize_t len = static_cast<ssize_t>(res.content_length_ ? res.content_length_
+                                                         : res.body.size());
+
+  if (!req.ranges.empty()) {
+    for (auto &r : req.ranges) {
+      auto &st = r.first;
+      auto &ed = r.second;
+
+      if (st == -1 && ed == -1) {
+        st = 0;
+        ed = len;
+      }
+
+      if (st == -1) {
+        st = len - ed;
+        ed = len - 1;
+      }
 
-  auto slen = static_cast<ssize_t>(content_length);
+      if (ed == -1) { ed = len - 1; }
 
-  if (range.first == -1) {
-    range.first = (std::max)(static_cast<ssize_t>(0), slen - range.second);
-    range.second = slen - 1;
+      if (!(0 <= st && st <= ed && ed <= len - 1)) { return false; }
+    }
   }
+  return true;
+}
 
-  if (range.second == -1) { range.second = slen - 1; }
-  return std::make_pair(range.first,
-                        static_cast<size_t>(range.second - range.first) + 1);
+inline std::pair<size_t, size_t>
+get_range_offset_and_length(Range r, size_t content_length) {
+  assert(r.first != -1 && r.second != -1);
+  assert(0 <= r.first && r.first < static_cast<ssize_t>(content_length));
+  assert(r.first <= r.second &&
+         r.second < static_cast<ssize_t>(content_length));
+
+  return std::make_pair(r.first, static_cast<size_t>(r.second - r.first) + 1);
 }
 
 inline std::string make_content_range_header_field(
     const std::pair<size_t, size_t> &offset_and_length, size_t content_length) {
-
   auto st = offset_and_length.first;
-  auto ed = (std::min)(st + offset_and_length.second - 1, content_length - 1);
+  auto ed = st + offset_and_length.second - 1;
 
   std::string field = "bytes ";
   field += std::to_string(st);
@@ -4754,11 +4772,11 @@ inline std::string make_content_range_header_field(
 }
 
 template <typename SToken, typename CToken, typename Content>
-bool process_multipart_ranges_data(const Request &req, Response &res,
+bool process_multipart_ranges_data(const Request &req,
                                    const std::string &boundary,
                                    const std::string &content_type,
-                                   SToken stoken, CToken ctoken,
-                                   Content content) {
+                                   size_t content_length, SToken stoken,
+                                   CToken ctoken, Content content) {
   for (size_t i = 0; i < req.ranges.size(); i++) {
     ctoken("--");
     stoken(boundary);
@@ -4770,11 +4788,10 @@ bool process_multipart_ranges_data(const Request &req, Response &res,
     }
 
     auto offset_and_length =
-        get_range_offset_and_length(req.ranges[i], res.content_length_);
+        get_range_offset_and_length(req.ranges[i], content_length);
 
     ctoken("Content-Range: ");
-    stoken(make_content_range_header_field(offset_and_length,
-                                           res.content_length_));
+    stoken(make_content_range_header_field(offset_and_length, content_length));
     ctoken("\r\n");
     ctoken("\r\n");
 
@@ -4791,31 +4808,30 @@ bool process_multipart_ranges_data(const Request &req, Response &res,
   return true;
 }
 
-inline bool make_multipart_ranges_data(const Request &req, Response &res,
+inline void make_multipart_ranges_data(const Request &req, Response &res,
                                        const std::string &boundary,
                                        const std::string &content_type,
+                                       size_t content_length,
                                        std::string &data) {
-  return process_multipart_ranges_data(
-      req, res, boundary, content_type,
+  process_multipart_ranges_data(
+      req, boundary, content_type, content_length,
       [&](const std::string &token) { data += token; },
       [&](const std::string &token) { data += token; },
       [&](size_t offset, size_t length) {
-        if (offset < res.body.size()) {
-          data += res.body.substr(offset, length);
-          return true;
-        }
-        return false;
+        assert(offset + length <= content_length);
+        data += res.body.substr(offset, length);
+        return true;
       });
 }
 
-inline size_t
-get_multipart_ranges_data_length(const Request &req, Response &res,
-                                 const std::string &boundary,
-                                 const std::string &content_type) {
+inline size_t get_multipart_ranges_data_length(const Request &req,
+                                               const std::string &boundary,
+                                               const std::string &content_type,
+                                               size_t content_length) {
   size_t data_length = 0;
 
   process_multipart_ranges_data(
-      req, res, boundary, content_type,
+      req, boundary, content_type, content_length,
       [&](const std::string &token) { data_length += token.size(); },
       [&](const std::string &token) { data_length += token.size(); },
       [&](size_t /*offset*/, size_t length) {
@@ -4827,13 +4843,13 @@ get_multipart_ranges_data_length(const Request &req, Response &res,
 }
 
 template <typename T>
-inline bool write_multipart_ranges_data(Stream &strm, const Request &req,
-                                        Response &res,
-                                        const std::string &boundary,
-                                        const std::string &content_type,
-                                        const T &is_shutting_down) {
+inline bool
+write_multipart_ranges_data(Stream &strm, const Request &req, Response &res,
+                            const std::string &boundary,
+                            const std::string &content_type,
+                            size_t content_length, const T &is_shutting_down) {
   return process_multipart_ranges_data(
-      req, res, boundary, content_type,
+      req, boundary, content_type, content_length,
       [&](const std::string &token) { strm.write(token); },
       [&](const std::string &token) { strm.write(token); },
       [&](size_t offset, size_t length) {
@@ -6012,7 +6028,6 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
       if (write_content_with_provider(strm, req, res, boundary, content_type)) {
         res.content_provider_success_ = true;
       } else {
-        res.content_provider_success_ = false;
         ret = false;
       }
     }
@@ -6045,7 +6060,8 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
                                    offset_and_length.second, is_shutting_down);
     } else {
       return detail::write_multipart_ranges_data(
-          strm, req, res, boundary, content_type, is_shutting_down);
+          strm, req, res, boundary, content_type, res.content_length_,
+          is_shutting_down);
     }
   } else {
     if (res.is_chunked_content_provider_) {
@@ -6437,14 +6453,14 @@ inline void Server::apply_ranges(const Request &req, Response &res,
                                  std::string &content_type,
                                  std::string &boundary) const {
   if (req.ranges.size() > 1) {
-    boundary = detail::make_multipart_data_boundary();
-
     auto it = res.headers.find("Content-Type");
     if (it != res.headers.end()) {
       content_type = it->second;
       res.headers.erase(it);
     }
 
+    boundary = detail::make_multipart_data_boundary();
+
     res.set_header("Content-Type",
                    "multipart/byteranges; boundary=" + boundary);
   }
@@ -6466,8 +6482,8 @@ inline void Server::apply_ranges(const Request &req, Response &res,
             offset_and_length, res.content_length_);
         res.set_header("Content-Range", content_range);
       } else {
-        length = detail::get_multipart_ranges_data_length(req, res, boundary,
-                                                          content_type);
+        length = detail::get_multipart_ranges_data_length(
+            req, boundary, content_type, res.content_length_);
       }
       res.set_header("Content-Length", std::to_string(length));
     } else {
@@ -6495,21 +6511,13 @@ inline void Server::apply_ranges(const Request &req, Response &res,
           offset_and_length, res.body.size());
       res.set_header("Content-Range", content_range);
 
-      if (offset < res.body.size()) {
-        res.body = res.body.substr(offset, length);
-      } else {
-        res.body.clear();
-        res.status = StatusCode::RangeNotSatisfiable_416;
-      }
+      assert(offset + length <= res.body.size());
+      res.body = res.body.substr(offset, length);
     } else {
       std::string data;
-      if (detail::make_multipart_ranges_data(req, res, boundary, content_type,
-                                             data)) {
-        res.body.swap(data);
-      } else {
-        res.body.clear();
-        res.status = StatusCode::RangeNotSatisfiable_416;
-      }
+      detail::make_multipart_ranges_data(req, res, boundary, content_type,
+                                         res.body.size(), data);
+      res.body.swap(data);
     }
 
     if (type != detail::EncodingType::None) {
@@ -6685,13 +6693,20 @@ Server::process_request(Stream &strm, bool close_connection,
     }
   }
 #endif
-
   if (routed) {
-    if (res.status == -1) {
-      res.status = req.ranges.empty() ? StatusCode::OK_200
-                                      : StatusCode::PartialContent_206;
+    if (detail::normalize_ranges(req, res)) {
+      if (res.status == -1) {
+        res.status = req.ranges.empty() ? StatusCode::OK_200
+                                        : StatusCode::PartialContent_206;
+      }
+      return write_response_with_content(strm, close_connection, req, res);
+    } else {
+      res.body.clear();
+      res.content_length_ = 0;
+      res.content_provider_ = nullptr;
+      res.status = StatusCode::RangeNotSatisfiable_416;
+      return write_response(strm, close_connection, req, res);
     }
-    return write_response_with_content(strm, close_connection, req, res);
   } else {
     if (res.status == -1) { res.status = StatusCode::NotFound_404; }
     return write_response(strm, close_connection, req, res);

+ 15 - 14
test/test.cc

@@ -1831,7 +1831,7 @@ protected:
                    });
              })
         .Get("/streamed-with-range",
-             [&](const Request & /*req*/, Response &res) {
+             [&](const Request &req, Response &res) {
                auto data = new std::string("abcdefg");
                res.set_content_provider(
                    data->size(), "text/plain",
@@ -1845,8 +1845,8 @@ protected:
                      EXPECT_TRUE(ret);
                      return true;
                    },
-                   [data](bool success) {
-                     EXPECT_TRUE(success);
+                   [data, &req](bool success) {
+                     EXPECT_EQ(success, !req.has_param("error"));
                      delete data;
                    });
              })
@@ -2957,13 +2957,12 @@ TEST_F(ServerTest, GetStreamedWithRangeSuffix1) {
 }
 
 TEST_F(ServerTest, GetStreamedWithRangeSuffix2) {
-  auto res = cli_.Get("/streamed-with-range", {{"Range", "bytes=-9999"}});
+  auto res = cli_.Get("/streamed-with-range?error", {{"Range", "bytes=-9999"}});
   ASSERT_TRUE(res);
-  EXPECT_EQ(StatusCode::PartialContent_206, res->status);
-  EXPECT_EQ("7", res->get_header_value("Content-Length"));
-  EXPECT_EQ(true, res->has_header("Content-Range"));
-  EXPECT_EQ("bytes 0-6/7", res->get_header_value("Content-Range"));
-  EXPECT_EQ(std::string("abcdefg"), res->body);
+  EXPECT_EQ(StatusCode::RangeNotSatisfiable_416, res->status);
+  EXPECT_EQ("0", res->get_header_value("Content-Length"));
+  EXPECT_EQ(false, res->has_header("Content-Range"));
+  EXPECT_EQ(0, res->body.size());
 }
 
 TEST_F(ServerTest, GetStreamedWithRangeError) {
@@ -2972,16 +2971,18 @@ TEST_F(ServerTest, GetStreamedWithRangeError) {
                                  "92233720368547758079223372036854775807"}});
   ASSERT_TRUE(res);
   EXPECT_EQ(StatusCode::RangeNotSatisfiable_416, res->status);
+  EXPECT_EQ("0", res->get_header_value("Content-Length"));
+  EXPECT_EQ(false, res->has_header("Content-Range"));
+  EXPECT_EQ(0, res->body.size());
 }
 
 TEST_F(ServerTest, GetRangeWithMaxLongLength) {
   auto res =
       cli_.Get("/with-range", {{"Range", "bytes=0-9223372036854775807"}});
-  EXPECT_EQ(StatusCode::PartialContent_206, res->status);
-  EXPECT_EQ("7", res->get_header_value("Content-Length"));
-  EXPECT_EQ("bytes 0-6/7", res->get_header_value("Content-Range"));
-  EXPECT_EQ(true, res->has_header("Content-Range"));
-  EXPECT_EQ(std::string("abcdefg"), res->body);
+  EXPECT_EQ(StatusCode::RangeNotSatisfiable_416, res->status);
+  EXPECT_EQ("0", res->get_header_value("Content-Length"));
+  EXPECT_EQ(false, res->has_header("Content-Range"));
+  EXPECT_EQ(0, res->body.size());
 }
 
 TEST_F(ServerTest, GetStreamedWithRangeMultipart) {