Browse Source

Added compressor class

yhirose 5 years ago
parent
commit
457a5a7501
1 changed files with 96 additions and 52 deletions
  1. 96 52
      httplib.h

+ 96 - 52
httplib.h

@@ -143,9 +143,9 @@ using ssize_t = int;
 #endif // NOMINMAX
 #endif // NOMINMAX
 
 
 #include <io.h>
 #include <io.h>
+#include <wincrypt.h>
 #include <winsock2.h>
 #include <winsock2.h>
 #include <ws2tcpip.h>
 #include <ws2tcpip.h>
-#include <wincrypt.h>
 
 
 #ifndef WSA_FLAG_NO_HANDLE_INHERIT
 #ifndef WSA_FLAG_NO_HANDLE_INHERIT
 #define WSA_FLAG_NO_HANDLE_INHERIT 0x80
 #define WSA_FLAG_NO_HANDLE_INHERIT 0x80
@@ -2271,90 +2271,106 @@ inline bool can_compress(const std::string &content_type) {
          content_type == "application/xhtml+xml";
          content_type == "application/xhtml+xml";
 }
 }
 
 
-inline bool compress(std::string &content) {
-  z_stream strm;
-  strm.zalloc = Z_NULL;
-  strm.zfree = Z_NULL;
-  strm.opaque = Z_NULL;
+class compressor {
+public:
+  compressor() {
+    std::memset(&strm_, 0, sizeof(strm_));
+    strm_.zalloc = Z_NULL;
+    strm_.zfree = Z_NULL;
+    strm_.opaque = Z_NULL;
 
 
-  auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8,
-                          Z_DEFAULT_STRATEGY);
-  if (ret != Z_OK) { return false; }
+    is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8,
+                             Z_DEFAULT_STRATEGY) == Z_OK;
+  }
 
 
-  strm.avail_in = static_cast<decltype(strm.avail_in)>(content.size());
-  strm.next_in =
-      const_cast<Bytef *>(reinterpret_cast<const Bytef *>(content.data()));
+  ~compressor() { deflateEnd(&strm_); }
 
 
-  std::string compressed;
+  template <typename T>
+  bool compress(const char *data, size_t data_length, bool last, T callback) {
+    assert(is_valid_);
 
 
-  std::array<char, 16384> buff{};
-  do {
-    strm.avail_out = buff.size();
-    strm.next_out = reinterpret_cast<Bytef *>(buff.data());
-    ret = deflate(&strm, Z_FINISH);
-    assert(ret != Z_STREAM_ERROR);
-    compressed.append(buff.data(), buff.size() - strm.avail_out);
-  } while (strm.avail_out == 0);
+    auto flush = last ? Z_FINISH : Z_NO_FLUSH;
 
 
-  assert(ret == Z_STREAM_END);
-  assert(strm.avail_in == 0);
+    strm_.avail_in = static_cast<decltype(strm_.avail_in)>(data_length);
+    strm_.next_in = const_cast<Bytef *>(reinterpret_cast<const Bytef *>(data));
 
 
-  content.swap(compressed);
+    int ret = Z_OK;
 
 
-  deflateEnd(&strm);
-  return true;
-}
+    std::array<char, 16384> buff{};
+    do {
+      strm_.avail_out = buff.size();
+      strm_.next_out = reinterpret_cast<Bytef *>(buff.data());
+
+      ret = deflate(&strm_, flush);
+      assert(ret != Z_STREAM_ERROR);
+
+      if (!callback(buff.data(), buff.size() - strm_.avail_out)) {
+        return false;
+      }
+    } while (strm_.avail_out == 0);
+
+    assert(ret == Z_STREAM_END);
+    assert(strm_.avail_in == 0);
+    return true;
+  }
+
+private:
+  bool is_valid_ = false;
+  z_stream strm_;
+};
 
 
 class decompressor {
 class decompressor {
 public:
 public:
   decompressor() {
   decompressor() {
-    std::memset(&strm, 0, sizeof(strm));
-    strm.zalloc = Z_NULL;
-    strm.zfree = Z_NULL;
-    strm.opaque = Z_NULL;
+    std::memset(&strm_, 0, sizeof(strm_));
+    strm_.zalloc = Z_NULL;
+    strm_.zfree = Z_NULL;
+    strm_.opaque = Z_NULL;
 
 
     // 15 is the value of wbits, which should be at the maximum possible value
     // 15 is the value of wbits, which should be at the maximum possible value
     // to ensure that any gzip stream can be decoded. The offset of 32 specifies
     // to ensure that any gzip stream can be decoded. The offset of 32 specifies
     // that the stream type should be automatically detected either gzip or
     // that the stream type should be automatically detected either gzip or
     // deflate.
     // deflate.
-    is_valid_ = inflateInit2(&strm, 32 + 15) == Z_OK;
+    is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK;
   }
   }
 
 
-  ~decompressor() { inflateEnd(&strm); }
+  ~decompressor() { inflateEnd(&strm_); }
 
 
   bool is_valid() const { return is_valid_; }
   bool is_valid() const { return is_valid_; }
 
 
   template <typename T>
   template <typename T>
   bool decompress(const char *data, size_t data_length, T callback) {
   bool decompress(const char *data, size_t data_length, T callback) {
+    assert(is_valid_);
+
     int ret = Z_OK;
     int ret = Z_OK;
 
 
-    strm.avail_in = static_cast<decltype(strm.avail_in)>(data_length);
-    strm.next_in = const_cast<Bytef *>(reinterpret_cast<const Bytef *>(data));
+    strm_.avail_in = static_cast<decltype(strm_.avail_in)>(data_length);
+    strm_.next_in = const_cast<Bytef *>(reinterpret_cast<const Bytef *>(data));
 
 
     std::array<char, 16384> buff{};
     std::array<char, 16384> buff{};
     do {
     do {
-      strm.avail_out = buff.size();
-      strm.next_out = reinterpret_cast<Bytef *>(buff.data());
+      strm_.avail_out = buff.size();
+      strm_.next_out = reinterpret_cast<Bytef *>(buff.data());
 
 
-      ret = inflate(&strm, Z_NO_FLUSH);
+      ret = inflate(&strm_, Z_NO_FLUSH);
       assert(ret != Z_STREAM_ERROR);
       assert(ret != Z_STREAM_ERROR);
       switch (ret) {
       switch (ret) {
       case Z_NEED_DICT:
       case Z_NEED_DICT:
       case Z_DATA_ERROR:
       case Z_DATA_ERROR:
-      case Z_MEM_ERROR: inflateEnd(&strm); return false;
+      case Z_MEM_ERROR: inflateEnd(&strm_); return false;
       }
       }
 
 
-      if (!callback(buff.data(), buff.size() - strm.avail_out)) {
+      if (!callback(buff.data(), buff.size() - strm_.avail_out)) {
         return false;
         return false;
       }
       }
-    } while (strm.avail_out == 0);
+    } while (strm_.avail_out == 0);
 
 
     return ret == Z_OK || ret == Z_STREAM_END;
     return ret == Z_OK || ret == Z_STREAM_END;
   }
   }
 
 
 private:
 private:
-  bool is_valid_;
-  z_stream strm;
+  bool is_valid_ = false;
+  z_stream strm_;
 };
 };
 #endif
 #endif
 
 
@@ -3924,9 +3940,17 @@ inline bool Server::write_response(Stream &strm, bool close_connection,
     const auto &encodings = req.get_header_value("Accept-Encoding");
     const auto &encodings = req.get_header_value("Accept-Encoding");
     if (encodings.find("gzip") != std::string::npos &&
     if (encodings.find("gzip") != std::string::npos &&
         detail::can_compress(res.get_header_value("Content-Type"))) {
         detail::can_compress(res.get_header_value("Content-Type"))) {
-      if (detail::compress(res.body)) {
-        res.set_header("Content-Encoding", "gzip");
+      std::string compressed;
+      detail::compressor compressor;
+      if (!compressor.compress(res.body.data(), res.body.size(), true,
+                               [&](const char *data, size_t data_len) {
+                                 compressed.append(data, data_len);
+                                 return true;
+                               })) {
+        return false;
       }
       }
+      res.body.swap(compressed);
+      res.set_header("Content-Encoding", "gzip");
     }
     }
 #endif
 #endif
 
 
@@ -4730,26 +4754,47 @@ inline std::shared_ptr<Response> Client::send_with_content_provider(
 
 
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
   if (compress_) {
   if (compress_) {
+    detail::compressor compressor;
+
     if (content_provider) {
     if (content_provider) {
+      auto ok = true;
       size_t offset = 0;
       size_t offset = 0;
 
 
       DataSink data_sink;
       DataSink data_sink;
       data_sink.write = [&](const char *data, size_t data_len) {
       data_sink.write = [&](const char *data, size_t data_len) {
-        req.body.append(data, data_len);
-        offset += data_len;
+        if (ok) {
+          auto last = offset + data_len == content_length;
+
+          auto ret = compressor.compress(
+              data, data_len, last, [&](const char *data, size_t data_len) {
+                req.body.append(data, data_len);
+                return true;
+              });
+
+          if (ret) {
+            offset += data_len;
+          } else {
+            ok = false;
+          }
+        }
       };
       };
-      data_sink.is_writable = [&](void) { return true; };
+      data_sink.is_writable = [&](void) { return ok && true; };
 
 
-      while (offset < content_length) {
+      while (ok && offset < content_length) {
         if (!content_provider(offset, content_length - offset, data_sink)) {
         if (!content_provider(offset, content_length - offset, data_sink)) {
           return nullptr;
           return nullptr;
         }
         }
       }
       }
     } else {
     } else {
-      req.body = body;
+      if (!compressor.compress(body.data(), body.size(), true,
+                               [&](const char *data, size_t data_len) {
+                                 req.body.append(data, data_len);
+                                 return true;
+                               })) {
+        return nullptr;
+      }
     }
     }
 
 
-    if (!detail::compress(req.body)) { return nullptr; }
     req.headers.emplace("Content-Encoding", "gzip");
     req.headers.emplace("Content-Encoding", "gzip");
   } else
   } else
 #endif
 #endif
@@ -5821,4 +5866,3 @@ inline bool SSLClient::check_host_name(const char *pattern,
 } // namespace httplib
 } // namespace httplib
 
 
 #endif // CPPHTTPLIB_HTTPLIB_H
 #endif // CPPHTTPLIB_HTTPLIB_H
-