Browse Source

Fix #139. Content receiver support

yhirose 6 years ago
parent
commit
6f663028e9
3 changed files with 204 additions and 32 deletions
  1. 9 0
      README.md
  2. 132 30
      httplib.h
  3. 63 2
      test/test.cc

+ 9 - 0
README.md

@@ -114,6 +114,15 @@ int main(void)
 }
 ```
 
+### GET with Content Receiver
+
+```c++
+  std::string body;
+  auto res = cli.Get("/large-data", [&](const char *data, size_t len) {
+    body.append(data, len);
+  });
+```
+
 ### POST
 
 ```c++

+ 132 - 30
httplib.h

@@ -124,6 +124,9 @@ std::pair<std::string, std::string> make_range_header(uint64_t value,
 
 typedef std::multimap<std::string, std::string> Params;
 typedef std::smatch Match;
+
+typedef std::function<std::string(uint64_t offset)> ContentProducer;
+typedef std::function<void(const char *data, size_t len)> ContentReceiver;
 typedef std::function<bool(uint64_t current, uint64_t total)> Progress;
 
 struct MultipartFile {
@@ -145,8 +148,6 @@ struct Request {
   MultipartFiles files;
   Match matches;
 
-  Progress progress;
-
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
   const SSL *ssl;
 #endif
@@ -169,7 +170,10 @@ struct Response {
   int status;
   Headers headers;
   std::string body;
-  std::function<std::string(uint64_t offset)> streamcb;
+
+  ContentProducer content_producer;
+  ContentReceiver content_receiver;
+  Progress progress;
 
   bool has_header(const char *key) const;
   std::string get_header_value(const char *key, size_t id = 0) const;
@@ -315,6 +319,13 @@ public:
   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
                                 Progress progress = nullptr);
 
+  std::shared_ptr<Response> Get(const char *path,
+                                ContentReceiver content_receiver,
+                                Progress progress = nullptr);
+  std::shared_ptr<Response> Get(const char *path, const Headers &headers,
+                                ContentReceiver content_receiver,
+                                Progress progress = nullptr);
+
   std::shared_ptr<Response> Head(const char *path);
   std::shared_ptr<Response> Head(const char *path, const Headers &headers);
 
@@ -942,6 +953,63 @@ inline bool compress(std::string &content) {
   return true;
 }
 
+class decompressor {
+public:
+  decompressor() {
+    strm.zalloc = Z_NULL;
+    strm.zfree = Z_NULL;
+    strm.opaque = Z_NULL;
+
+    // 15 is the value of wbits, which should be at the maximum possible value
+    // to ensure that any gzip stream can be decoded. The offset of 16 specifies
+    // that the stream to decompress will be formatted with a gzip wrapper.
+    is_valid_ = inflateInit2(&strm, 16 + 15) == Z_OK;
+  }
+
+  ~decompressor() { inflateEnd(&strm); }
+
+  bool is_valid() const { return is_valid_; }
+
+  template <typename T>
+  bool decompress(const char *data, size_t data_len, T callback) {
+    int ret = Z_OK;
+    std::string decompressed;
+
+    // strm.avail_in = content.size();
+    // strm.next_in = (Bytef *)content.data();
+    strm.avail_in = data_len;
+    strm.next_in = (Bytef *)data;
+
+    const auto bufsiz = 16384;
+    char buff[bufsiz];
+    do {
+      strm.avail_out = bufsiz;
+      strm.next_out = (Bytef *)buff;
+
+      ret = inflate(&strm, Z_NO_FLUSH);
+      assert(ret != Z_STREAM_ERROR);
+      switch (ret) {
+      case Z_NEED_DICT:
+      case Z_DATA_ERROR:
+      case Z_MEM_ERROR: inflateEnd(&strm); return false;
+      }
+
+      decompressed.append(buff, bufsiz - strm.avail_out);
+    } while (strm.avail_out == 0);
+
+    if (ret == Z_STREAM_END) {
+      callback(decompressed.data(), decompressed.size());
+      return true;
+    }
+
+    return false;
+  }
+
+private:
+  bool is_valid_;
+  z_stream strm;
+};
+
 inline bool decompress(std::string &content) {
   z_stream strm;
   strm.zalloc = Z_NULL;
@@ -1112,26 +1180,40 @@ inline bool is_chunked_transfer_encoding(const Headers &headers) {
                      "chunked");
 }
 
-template <typename T>
+template <typename T, typename U>
 bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status,
-                  Progress progress) {
+                  Progress progress, U callback) {
+
+  ContentReceiver out = [&](const char *buf, size_t n) { callback(buf, n); };
+
+#ifdef CPPHTTPLIB_ZLIB_SUPPORT
+  detail::decompressor decompressor;
+
+  if (!decompressor.is_valid()) {
+    status = 500;
+    return false;
+  }
 
-#ifndef CPPHTTPLIB_ZLIB_SUPPORT
+  if (x.get_header_value("Content-Encoding") == "gzip") {
+    out = [&](const char *buf, size_t n) {
+      decompressor.decompress(
+          buf, n, [&](const char *buf, size_t n) { callback(buf, n); });
+    };
+  }
+#else
   if (x.get_header_value("Content-Encoding") == "gzip") {
     status = 415;
     return false;
   }
 #endif
 
-  auto callback = [&](const char *buf, size_t n) { x.body.append(buf, n); };
-
   auto ret = true;
   auto exceed_payload_max_length = false;
 
   if (is_chunked_transfer_encoding(x.headers)) {
-    ret = read_content_chunked(strm, callback);
+    ret = read_content_chunked(strm, out);
   } else if (!has_header(x.headers, "Content-Length")) {
-    ret = read_content_without_length(strm, callback);
+    ret = read_content_without_length(strm, out);
   } else {
     auto len = get_header_value_uint64(x.headers, "Content-Length", 0);
     if (len > 0) {
@@ -1143,23 +1225,12 @@ bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status,
         skip_content_with_length(strm, len);
         ret = false;
       } else {
-        // NOTE: We can remove it if it doesn't give us enough better
-        // performance.
-        x.body.reserve(len);
-        ret = read_content_with_length(strm, len, progress, callback);
+        ret = read_content_with_length(strm, len, progress, out);
       }
     }
   }
 
-  if (ret) {
-#ifdef CPPHTTPLIB_ZLIB_SUPPORT
-    if (x.get_header_value("Content-Encoding") == "gzip") {
-      ret = detail::decompress(x.body);
-    }
-#endif
-  } else {
-    status = exceed_payload_max_length ? 413 : 400;
-  }
+  if (!ret) { status = exceed_payload_max_length ? 413 : 400; }
 
   return ret;
 }
@@ -1177,7 +1248,7 @@ inline void write_content_chunked(Stream &strm, const T &x) {
   uint64_t offset = 0;
   auto data_available = true;
   while (data_available) {
-    auto chunk = x.streamcb(offset);
+    auto chunk = x.content_producer(offset);
     offset += chunk.size();
     data_available = !chunk.empty();
 
@@ -1696,7 +1767,7 @@ inline void Server::write_response(Stream &strm, bool last_connection,
 
   if (res.body.empty()) {
     if (!res.has_header("Content-Length")) {
-      if (res.streamcb) {
+      if (res.content_producer) {
         // Streamed response
         res.set_header("Transfer-Encoding", "chunked");
       } else {
@@ -1729,7 +1800,7 @@ inline void Server::write_response(Stream &strm, bool last_connection,
   if (req.method != "HEAD") {
     if (!res.body.empty()) {
       strm.write(res.body.c_str(), res.body.size());
-    } else if (res.streamcb) {
+    } else if (res.content_producer) {
       detail::write_content_chunked(strm, res);
     }
   }
@@ -1928,8 +1999,9 @@ Server::process_request(Stream &strm, bool last_connection,
 
   // Body
   if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") {
-    if (!detail::read_content(strm, req, payload_max_length_, res.status,
-                              Progress())) {
+    if (!detail::read_content(
+            strm, req, payload_max_length_, res.status, Progress(),
+            [&](const char *buf, size_t n) { req.body.append(buf, n); })) {
       write_response(strm, last_connection, req, res);
       return true;
     }
@@ -2107,9 +2179,17 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
 
   // Body
   if (req.method != "HEAD") {
+    ContentReceiver out = [&](const char *buf, size_t n) {
+      res.body.append(buf, n);
+    };
+
+    if (res.content_receiver) {
+      out = [&](const char *buf, size_t n) { res.content_receiver(buf, n); };
+    }
+
     int dummy_status;
     if (!detail::read_content(strm, res, std::numeric_limits<uint64_t>::max(),
-                              dummy_status, req.progress)) {
+                              dummy_status, res.progress, out)) {
       return false;
     }
   }
@@ -2139,9 +2219,31 @@ Client::Get(const char *path, const Headers &headers, Progress progress) {
   req.method = "GET";
   req.path = path;
   req.headers = headers;
-  req.progress = progress;
 
   auto res = std::make_shared<Response>();
+  res->progress = progress;
+
+  return send(req, *res) ? res : nullptr;
+}
+
+inline std::shared_ptr<Response> Client::Get(const char *path,
+                                             ContentReceiver content_receiver,
+                                             Progress progress) {
+  return Get(path, Headers(), content_receiver, progress);
+}
+
+inline std::shared_ptr<Response> Client::Get(const char *path,
+                                             const Headers &headers,
+                                             ContentReceiver content_receiver,
+                                             Progress progress) {
+  Request req;
+  req.method = "GET";
+  req.path = path;
+  req.headers = headers;
+
+  auto res = std::make_shared<Response>();
+  res->content_receiver = content_receiver;
+  res->progress = progress;
 
   return send(req, *res) ? res : nullptr;
 }

+ 63 - 2
test/test.cc

@@ -142,6 +142,31 @@ TEST(ChunkedEncodingTest, FromHTTPWatch) {
   EXPECT_EQ(out, res->body);
 }
 
+TEST(ChunkedEncodingTest, WithContentReceiver) {
+  auto host = "www.httpwatch.com";
+  auto sec = 2;
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+  auto port = 443;
+  httplib::SSLClient cli(host, port, sec);
+#else
+  auto port = 80;
+  httplib::Client cli(host, port, sec);
+#endif
+
+  std::string body;
+  auto res =
+      cli.Get("/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137",
+              [&](const char *data, size_t len) { body.append(data, len); });
+  ASSERT_TRUE(res != nullptr);
+
+  std::string out;
+  httplib::detail::read_file("./image.jpg", out);
+
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ(out, body);
+}
+
 TEST(RangeTest, FromHTTPBin) {
   auto host = "httpbin.org";
   auto sec = 5;
@@ -380,7 +405,7 @@ protected:
               })
         .Get("/streamedchunked",
              [&](const Request & /*req*/, Response &res) {
-               res.streamcb = [](uint64_t offset) {
+               res.content_producer = [](uint64_t offset) {
                  if (offset < 3) return "a";
                  if (offset < 6) return "b";
                  return "";
@@ -389,7 +414,7 @@ protected:
         .Get("/streamed",
              [&](const Request & /*req*/, Response &res) {
                res.set_header("Content-Length", "6");
-               res.streamcb = [](uint64_t offset) {
+               res.content_producer = [](uint64_t offset) {
                  if (offset < 3) return "a";
                  if (offset < 6) return "b";
                  return "";
@@ -1146,6 +1171,24 @@ TEST_F(ServerTest, Gzip) {
   EXPECT_EQ(200, res->status);
 }
 
+TEST_F(ServerTest, GzipWithContentReceiver) {
+  Headers headers;
+  headers.emplace("Accept-Encoding", "gzip, deflate");
+  std::string body;
+  auto res = cli_.Get("/gzip", headers, [&](const char *data, size_t len) {
+    body.append(data, len);
+  });
+
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ("gzip", res->get_header_value("Content-Encoding"));
+  EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
+  EXPECT_EQ("33", res->get_header_value("Content-Length"));
+  EXPECT_EQ("123456789012345678901234567890123456789012345678901234567890123456"
+            "7890123456789012345678901234567890",
+            body);
+  EXPECT_EQ(200, res->status);
+}
+
 TEST_F(ServerTest, NoGzip) {
   Headers headers;
   headers.emplace("Accept-Encoding", "gzip, deflate");
@@ -1161,6 +1204,24 @@ TEST_F(ServerTest, NoGzip) {
   EXPECT_EQ(200, res->status);
 }
 
+TEST_F(ServerTest, NoGzipWithContentReceiver) {
+  Headers headers;
+  headers.emplace("Accept-Encoding", "gzip, deflate");
+  std::string body;
+  auto res = cli_.Get("/nogzip", headers, [&](const char *data, size_t len) {
+    body.append(data, len);
+  });
+
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(false, res->has_header("Content-Encoding"));
+  EXPECT_EQ("application/octet-stream", res->get_header_value("Content-Type"));
+  EXPECT_EQ("100", res->get_header_value("Content-Length"));
+  EXPECT_EQ("123456789012345678901234567890123456789012345678901234567890123456"
+            "7890123456789012345678901234567890",
+            body);
+  EXPECT_EQ(200, res->status);
+}
+
 TEST_F(ServerTest, MultipartFormDataGzip) {
   Request req;
   req.method = "POST";