Browse Source

Fix #708 (#713)

* Fix #708

* Rename ContentReceiver2 to ContentReceiverWithProgress
yhirose 5 years ago
parent
commit
109b624dfe
1 changed files with 53 additions and 30 deletions
  1. 53 30
      httplib.h

+ 53 - 30
httplib.h

@@ -326,6 +326,10 @@ using ContentProvider =
 using ContentProviderWithoutLength =
     std::function<bool(size_t offset, DataSink &sink)>;
 
+using ContentReceiverWithProgress =
+    std::function<bool(const char *data, size_t data_length, uint64_t offset,
+                       uint64_t total_length)>;
+
 using ContentReceiver =
     std::function<bool(const char *data, size_t data_length)>;
 
@@ -378,7 +382,7 @@ struct Request {
   // for client
   size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT;
   ResponseHandler response_handler;
-  ContentReceiver content_receiver;
+  ContentReceiverWithProgress content_receiver;
   size_t content_length = 0;
   ContentProvider content_provider;
   Progress progress;
@@ -2508,7 +2512,8 @@ inline bool read_headers(Stream &strm, Headers &headers) {
 }
 
 inline bool read_content_with_length(Stream &strm, uint64_t len,
-                                     Progress progress, ContentReceiver out) {
+                                     Progress progress,
+                                     ContentReceiverWithProgress out) {
   char buf[CPPHTTPLIB_RECV_BUFSIZ];
 
   uint64_t r = 0;
@@ -2517,8 +2522,7 @@ inline bool read_content_with_length(Stream &strm, uint64_t len,
     auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ));
     if (n <= 0) { return false; }
 
-    if (!out(buf, static_cast<size_t>(n))) { return false; }
-
+    if (!out(buf, static_cast<size_t>(n), r, len)) { return false; }
     r += static_cast<uint64_t>(n);
 
     if (progress) {
@@ -2540,8 +2544,10 @@ inline void skip_content_with_length(Stream &strm, uint64_t len) {
   }
 }
 
-inline bool read_content_without_length(Stream &strm, ContentReceiver out) {
+inline bool read_content_without_length(Stream &strm,
+                                        ContentReceiverWithProgress out) {
   char buf[CPPHTTPLIB_RECV_BUFSIZ];
+  uint64_t r = 0;
   for (;;) {
     auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ);
     if (n < 0) {
@@ -2549,13 +2555,16 @@ inline bool read_content_without_length(Stream &strm, ContentReceiver out) {
     } else if (n == 0) {
       return true;
     }
-    if (!out(buf, static_cast<size_t>(n))) { return false; }
+
+    if (!out(buf, static_cast<size_t>(n), r, 0)) { return false; }
+    r += static_cast<uint64_t>(n);
   }
 
   return true;
 }
 
-inline bool read_content_chunked(Stream &strm, ContentReceiver out) {
+inline bool read_content_chunked(Stream &strm,
+                                 ContentReceiverWithProgress out) {
   const auto bufsiz = 16;
   char buf[bufsiz];
 
@@ -2600,7 +2609,8 @@ inline bool is_chunked_transfer_encoding(const Headers &headers) {
 }
 
 template <typename T, typename U>
-bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver,
+bool prepare_content_receiver(T &x, int &status,
+                              ContentReceiverWithProgress receiver,
                               bool decompress, U callback) {
   if (decompress) {
     std::string encoding = x.get_header_value("Content-Encoding");
@@ -2625,10 +2635,12 @@ bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver,
 
     if (decompressor) {
       if (decompressor->is_valid()) {
-        ContentReceiver out = [&](const char *buf, size_t n) {
-          return decompressor->decompress(
-              buf, n,
-              [&](const char *buf, size_t n) { return receiver(buf, n); });
+        ContentReceiverWithProgress out = [&](const char *buf, size_t n,
+                                              uint64_t off, uint64_t len) {
+          return decompressor->decompress(buf, n,
+                                          [&](const char *buf, size_t n) {
+                                            return receiver(buf, n, off, len);
+                                          });
         };
         return callback(std::move(out));
       } else {
@@ -2638,19 +2650,20 @@ bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver,
     }
   }
 
-  ContentReceiver out = [&](const char *buf, size_t n) {
-    return receiver(buf, n);
+  ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off,
+                                        uint64_t len) {
+    return receiver(buf, n, off, len);
   };
   return callback(std::move(out));
 }
 
 template <typename T>
 bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
-                  Progress progress, ContentReceiver receiver,
+                  Progress progress, ContentReceiverWithProgress receiver,
                   bool decompress) {
   return prepare_content_receiver(
       x, status, std::move(receiver), decompress,
-      [&](const ContentReceiver &out) {
+      [&](const ContentReceiverWithProgress &out) {
         auto ret = true;
         auto exceed_payload_max_length = false;
 
@@ -4251,7 +4264,7 @@ inline bool Server::read_content_core(Stream &strm, Request &req, Response &res,
                                       MultipartContentHeader mulitpart_header,
                                       ContentReceiver multipart_receiver) {
   detail::MultipartFormDataParser multipart_form_data_parser;
-  ContentReceiver out;
+  ContentReceiverWithProgress out;
 
   if (req.is_multipart_form_data()) {
     const auto &content_type = req.get_header_value("Content-Type");
@@ -4262,7 +4275,7 @@ inline bool Server::read_content_core(Stream &strm, Request &req, Response &res,
     }
 
     multipart_form_data_parser.set_boundary(std::move(boundary));
-    out = [&](const char *buf, size_t n) {
+    out = [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) {
       /* For debug
       size_t pos = 0;
       while (pos < n) {
@@ -4278,7 +4291,8 @@ inline bool Server::read_content_core(Stream &strm, Request &req, Response &res,
                                               mulitpart_header);
     };
   } else {
-    out = std::move(receiver);
+    out = [receiver](const char *buf, size_t n, uint64_t /*off*/,
+                     uint64_t /*len*/) { return receiver(buf, n); };
   }
 
   if (req.method == "DELETE" && !req.has_header("Content-Length")) {
@@ -5119,16 +5133,21 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
   if (req.method != "HEAD" && req.method != "CONNECT") {
     auto out =
         req.content_receiver
-            ? static_cast<ContentReceiver>([&](const char *buf, size_t n) {
-                auto ret = req.content_receiver(buf, n);
-                if (!ret) { error_ = Error::Canceled; }
-                return ret;
-              })
-            : static_cast<ContentReceiver>([&](const char *buf, size_t n) {
-                if (res.body.size() + n > res.body.max_size()) { return false; }
-                res.body.append(buf, n);
-                return true;
-              });
+            ? static_cast<ContentReceiverWithProgress>(
+                  [&](const char *buf, size_t n, uint64_t off, uint64_t len) {
+                    auto ret = req.content_receiver(buf, n, off, len);
+                    if (!ret) { error_ = Error::Canceled; }
+                    return ret;
+                  })
+            : static_cast<ContentReceiverWithProgress>(
+                  [&](const char *buf, size_t n, uint64_t /*off*/,
+                      uint64_t /*len*/) {
+                    if (res.body.size() + n > res.body.max_size()) {
+                      return false;
+                    }
+                    res.body.append(buf, n);
+                    return true;
+                  });
 
     auto progress = [&](uint64_t current, uint64_t total) {
       if (!req.progress) { return true; }
@@ -5249,7 +5268,11 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers,
   req.headers = default_headers_;
   req.headers.insert(headers.begin(), headers.end());
   req.response_handler = std::move(response_handler);
-  req.content_receiver = std::move(content_receiver);
+  req.content_receiver =
+      [content_receiver](const char *data, size_t data_length,
+                         uint64_t /*offset*/, uint64_t /*total_length*/) {
+        return content_receiver(data, data_length);
+      };
   req.progress = std::move(progress);
 
   auto res = detail::make_unique<Response>();