Browse Source

Brotli suport on server. Fix #578

yhirose 5 years ago
parent
commit
a5b4cfadb9
2 changed files with 219 additions and 105 deletions
  1. 164 91
      httplib.h
  2. 55 14
      test/test.cc

+ 164 - 91
httplib.h

@@ -225,6 +225,7 @@ inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) {
 
 #ifdef CPPHTTPLIB_BROTLI_SUPPORT
 #include <brotli/decode.h>
+#include <brotli/encode.h>
 #endif
 
 /*
@@ -2157,8 +2158,39 @@ inline EncodingType encoding_type(const Request &req, const Response &res) {
   return EncodingType::None;
 }
 
+class compressor {
+public:
+  virtual ~compressor(){};
+
+  typedef std::function<bool(const char *data, size_t data_len)> Callback;
+  virtual bool compress(const char *data, size_t data_length, bool last,
+                        Callback callback) = 0;
+};
+
+class decompressor {
+public:
+  virtual ~decompressor() {}
+
+  virtual bool is_valid() const = 0;
+
+  typedef std::function<bool(const char *data, size_t data_len)> Callback;
+  virtual bool decompress(const char *data, size_t data_length,
+                          Callback callback) = 0;
+};
+
+class nocompressor : public compressor {
+public:
+  ~nocompressor(){};
+
+  bool compress(const char *data, size_t data_length, bool /*last*/,
+                Callback callback) override {
+    if (!data_length) { return true; }
+    return callback(data, data_length);
+  }
+};
+
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
-class gzip_compressor {
+class gzip_compressor : public compressor {
 public:
   gzip_compressor() {
     std::memset(&strm_, 0, sizeof(strm_));
@@ -2172,8 +2204,8 @@ public:
 
   ~gzip_compressor() { deflateEnd(&strm_); }
 
-  template <typename T>
-  bool compress(const char *data, size_t data_length, bool last, T callback) {
+  bool compress(const char *data, size_t data_length, bool last,
+                Callback callback) override {
     assert(is_valid_);
 
     auto flush = last ? Z_FINISH : Z_NO_FLUSH;
@@ -2206,7 +2238,7 @@ private:
   z_stream strm_;
 };
 
-class gzip_decompressor {
+class gzip_decompressor : public decompressor {
 public:
   gzip_decompressor() {
     std::memset(&strm_, 0, sizeof(strm_));
@@ -2223,10 +2255,10 @@ public:
 
   ~gzip_decompressor() { inflateEnd(&strm_); }
 
-  bool is_valid() const { return is_valid_; }
+  bool is_valid() const override { return is_valid_; }
 
-  template <typename T>
-  bool decompress(const char *data, size_t data_length, T callback) {
+  bool decompress(const char *data, size_t data_length,
+                  Callback callback) override {
     assert(is_valid_);
 
     int ret = Z_OK;
@@ -2262,7 +2294,52 @@ private:
 #endif
 
 #ifdef CPPHTTPLIB_BROTLI_SUPPORT
-class brotli_decompressor {
+class brotli_compressor : public compressor {
+public:
+  brotli_compressor() {
+    state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr);
+  }
+
+  ~brotli_compressor() { BrotliEncoderDestroyInstance(state_); }
+
+  bool compress(const char *data, size_t data_length, bool last,
+                Callback callback) override {
+    std::array<uint8_t, 16384> buff{};
+
+    auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS;
+    auto available_in = data_length;
+    auto next_in = reinterpret_cast<const uint8_t *>(data);
+
+    for (;;) {
+      if (last) {
+        if (BrotliEncoderIsFinished(state_)) { break; }
+      } else {
+        if (!available_in) { break; }
+      }
+
+      auto available_out = buff.size();
+      auto next_out = buff.data();
+
+      if (!BrotliEncoderCompressStream(state_, operation, &available_in,
+                                       &next_in, &available_out, &next_out,
+                                       nullptr)) {
+        return false;
+      }
+
+      auto output_bytes = buff.size() - available_out;
+      if (output_bytes) {
+        callback(reinterpret_cast<const char *>(buff.data()), output_bytes);
+      }
+    }
+
+    return true;
+  }
+
+private:
+  BrotliEncoderState *state_ = nullptr;
+};
+
+class brotli_decompressor : public decompressor {
 public:
   brotli_decompressor() {
     decoder_s = BrotliDecoderCreateInstance(0, 0, 0);
@@ -2274,13 +2351,14 @@ public:
     if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); }
   }
 
-  bool is_valid() const { return decoder_s; }
+  bool is_valid() const override { return decoder_s; }
 
-  template <typename T>
-  bool decompress(const char *data, size_t data_length, T callback) {
+  bool decompress(const char *data, size_t data_length,
+                  Callback callback) override {
     if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS ||
-        decoder_r == BROTLI_DECODER_RESULT_ERROR)
+        decoder_r == BROTLI_DECODER_RESULT_ERROR) {
       return 0;
+    }
 
     const uint8_t *next_in = (const uint8_t *)data;
     size_t avail_in = data_length;
@@ -2491,32 +2569,29 @@ bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver,
                               bool decompress, U callback) {
   if (decompress) {
     std::string encoding = x.get_header_value("Content-Encoding");
+    std::shared_ptr<decompressor> decompressor;
 
     if (encoding.find("gzip") != std::string::npos ||
         encoding.find("deflate") != std::string::npos) {
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
-      gzip_decompressor decompressor;
-      if (decompressor.is_valid()) {
-        ContentReceiver out = [&](const char *buf, size_t n) {
-          return decompressor.decompress(
-              buf, n,
-              [&](const char *buf, size_t n) { return receiver(buf, n); });
-        };
-        return callback(out);
-      } else {
-        status = 500;
-        return false;
-      }
+      decompressor = std::make_shared<gzip_decompressor>();
 #else
       status = 415;
       return false;
 #endif
     } else if (encoding.find("br") != std::string::npos) {
 #ifdef CPPHTTPLIB_BROTLI_SUPPORT
-      brotli_decompressor decompressor;
-      if (decompressor.is_valid()) {
+      decompressor = std::make_shared<brotli_decompressor>();
+#else
+      status = 415;
+      return false;
+#endif
+    }
+
+    if (decompressor) {
+      if (decompressor->is_valid()) {
         ContentReceiver out = [&](const char *buf, size_t n) {
-          return decompressor.decompress(
+          return decompressor->decompress(
               buf, n,
               [&](const char *buf, size_t n) { return receiver(buf, n); });
         };
@@ -2525,17 +2600,12 @@ bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver,
         status = 500;
         return false;
       }
-#else
-      status = 415;
-      return false;
-#endif
     }
   }
 
   ContentReceiver out = [&](const char *buf, size_t n) {
     return receiver(buf, n);
   };
-
   return callback(out);
 }
 
@@ -2628,10 +2698,10 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
   return static_cast<ssize_t>(offset - begin_offset);
 }
 
-template <typename T>
+template <typename T, typename U>
 inline ssize_t write_content_chunked(Stream &strm,
                                      ContentProvider content_provider,
-                                     T is_shutting_down, EncodingType type) {
+                                     T is_shutting_down, U &compressor) {
   size_t offset = 0;
   auto data_available = true;
   ssize_t total_written_length = 0;
@@ -2639,10 +2709,6 @@ inline ssize_t write_content_chunked(Stream &strm,
   auto ok = true;
   DataSink data_sink;
 
-#ifdef CPPHTTPLIB_ZLIB_SUPPORT
-  detail::gzip_compressor compressor;
-#endif
-
   data_sink.write = [&](const char *d, size_t l) {
     if (!ok) { return; }
 
@@ -2650,22 +2716,13 @@ inline ssize_t write_content_chunked(Stream &strm,
     offset += l;
 
     std::string payload;
-    if (type == EncodingType::Gzip) {
-#ifdef CPPHTTPLIB_ZLIB_SUPPORT
-      if (!compressor.compress(d, l, false,
-                               [&](const char *data, size_t data_len) {
-                                 payload.append(data, data_len);
-                                 return true;
-                               })) {
-        ok = false;
-        return;
-      }
-#endif
-    } else if (type == EncodingType::Brotli) {
-#ifdef CPPHTTPLIB_BROTLI_SUPPORT
-#endif
-    } else {
-      payload = std::string(d, l);
+    if (!compressor.compress(d, l, false,
+                             [&](const char *data, size_t data_len) {
+                               payload.append(data, data_len);
+                               return true;
+                             })) {
+      ok = false;
+      return;
     }
 
     if (!payload.empty()) {
@@ -2685,32 +2742,25 @@ inline ssize_t write_content_chunked(Stream &strm,
 
     data_available = false;
 
-    if (type == EncodingType::Gzip) {
-#ifdef CPPHTTPLIB_ZLIB_SUPPORT
-      std::string payload;
-      if (!compressor.compress(nullptr, 0, true,
-                               [&](const char *data, size_t data_len) {
-                                 payload.append(data, data_len);
-                                 return true;
-                               })) {
+    std::string payload;
+    if (!compressor.compress(nullptr, 0, true,
+                             [&](const char *data, size_t data_len) {
+                               payload.append(data, data_len);
+                               return true;
+                             })) {
+      ok = false;
+      return;
+    }
+
+    if (!payload.empty()) {
+      // Emit chunked response header and footer for each chunk
+      auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
+      if (write_data(strm, chunk.data(), chunk.size())) {
+        total_written_length += chunk.size();
+      } else {
         ok = false;
         return;
       }
-
-      if (!payload.empty()) {
-        // Emit chunked response header and footer for each chunk
-        auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
-        if (write_data(strm, chunk.data(), chunk.size())) {
-          total_written_length += chunk.size();
-        } else {
-          ok = false;
-          return;
-        }
-      }
-#endif
-    } else if (type == EncodingType::Brotli) {
-#ifdef CPPHTTPLIB_BROTLI_SUPPORT
-#endif
     }
 
     static const std::string done_marker("0\r\n\r\n");
@@ -3918,25 +3968,33 @@ inline bool Server::write_response(Stream &strm, bool close_connection,
     }
 
     if (type != detail::EncodingType::None) {
-#ifdef CPPHTTPLIB_ZLIB_SUPPORT
-      std::string compressed;
+      std::shared_ptr<detail::compressor> compressor;
 
       if (type == detail::EncodingType::Gzip) {
-        detail::gzip_compressor compressor;
-        if (!compressor.compress(res.body.data(), res.body.size(), true,
-                                 [&](const char *data, size_t data_len) {
-                                   compressed.append(data, data_len);
-                                   return true;
-                                 })) {
-          return false;
-        }
+#ifdef CPPHTTPLIB_ZLIB_SUPPORT
+        compressor = std::make_shared<detail::gzip_compressor>();
         res.set_header("Content-Encoding", "gzip");
+#endif
       } else if (type == detail::EncodingType::Brotli) {
-        // TODO:
+#ifdef CPPHTTPLIB_BROTLI_SUPPORT
+        compressor = std::make_shared<detail::brotli_compressor>();
+        res.set_header("Content-Encoding", "brotli");
+#endif
       }
 
-      res.body.swap(compressed);
-#endif
+      if (compressor) {
+        std::string compressed;
+
+        if (!compressor->compress(res.body.data(), res.body.size(), true,
+                                  [&](const char *data, size_t data_len) {
+                                    compressed.append(data, data_len);
+                                    return true;
+                                  })) {
+          return false;
+        }
+
+        res.body.swap(compressed);
+      }
     }
 
     auto length = std::to_string(res.body.size());
@@ -3999,8 +4057,23 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
     }
   } else {
     auto type = detail::encoding_type(req, res);
+
+    std::shared_ptr<detail::compressor> compressor;
+    if (type == detail::EncodingType::Gzip) {
+#ifdef CPPHTTPLIB_ZLIB_SUPPORT
+      compressor = std::make_shared<detail::gzip_compressor>();
+#endif
+    } else if (type == detail::EncodingType::Brotli) {
+#ifdef CPPHTTPLIB_BROTLI_SUPPORT
+      compressor = std::make_shared<detail::brotli_compressor>();
+#endif
+    } else {
+      compressor = std::make_shared<detail::nocompressor>();
+    }
+    assert(compressor != nullptr);
+
     if (detail::write_content_chunked(strm, res.content_provider_,
-                                      is_shutting_down, type) < 0) {
+                                      is_shutting_down, *compressor) < 0) {
       return false;
     }
   }

+ 55 - 14
test/test.cc

@@ -1227,22 +1227,22 @@ protected:
              [&](const Request &req, Response & /*res*/) {
                EXPECT_EQ("close", req.get_header_value("Connection"));
              })
-#ifdef CPPHTTPLIB_ZLIB_SUPPORT
-        .Get("/gzip",
+#if defined(CPPHTTPLIB_ZLIB_SUPPORT) || defined(CPPHTTPLIB_BROTLI_SUPPORT)
+        .Get("/compress",
              [&](const Request & /*req*/, Response &res) {
                res.set_content(
                    "12345678901234567890123456789012345678901234567890123456789"
                    "01234567890123456789012345678901234567890",
                    "text/plain");
              })
-        .Get("/nogzip",
+        .Get("/nocompress",
              [&](const Request & /*req*/, Response &res) {
                res.set_content(
                    "12345678901234567890123456789012345678901234567890123456789"
                    "01234567890123456789012345678901234567890",
                    "application/octet-stream");
              })
-        .Post("/gzipmultipart",
+        .Post("/compress-multipart",
               [&](const Request &req, Response & /*res*/) {
                 EXPECT_EQ(2u, req.files.size());
                 ASSERT_TRUE(!req.has_file("???"));
@@ -2123,6 +2123,28 @@ TEST_F(ServerTest, GetStreamedChunkedWithGzip2) {
 }
 #endif
 
+#ifdef CPPHTTPLIB_BROTLI_SUPPORT
+TEST_F(ServerTest, GetStreamedChunkedWithBrotli) {
+  httplib::Headers headers;
+  headers.emplace("Accept-Encoding", "brotli");
+
+  auto res = cli_.Get("/streamed-chunked", headers);
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ(std::string("123456789"), res->body);
+}
+
+TEST_F(ServerTest, GetStreamedChunkedWithBrotli2) {
+  httplib::Headers headers;
+  headers.emplace("Accept-Encoding", "brotli");
+
+  auto res = cli_.Get("/streamed-chunked2", headers);
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+  EXPECT_EQ(std::string("123456789"), res->body);
+}
+#endif
+
 TEST_F(ServerTest, Patch) {
   auto res = cli_.Patch("/patch", "PATCH", "text/plain");
   ASSERT_TRUE(res != nullptr);
@@ -2285,7 +2307,7 @@ TEST_F(ServerTest, KeepAlive) {
 TEST_F(ServerTest, Gzip) {
   Headers headers;
   headers.emplace("Accept-Encoding", "gzip, deflate");
-  auto res = cli_.Get("/gzip", headers);
+  auto res = cli_.Get("/compress", headers);
 
   ASSERT_TRUE(res != nullptr);
   EXPECT_EQ("gzip", res->get_header_value("Content-Encoding"));
@@ -2299,7 +2321,7 @@ TEST_F(ServerTest, Gzip) {
 
 TEST_F(ServerTest, GzipWithoutAcceptEncoding) {
   Headers headers;
-  auto res = cli_.Get("/gzip", headers);
+  auto res = cli_.Get("/compress", headers);
 
   ASSERT_TRUE(res != nullptr);
   EXPECT_TRUE(res->get_header_value("Content-Encoding").empty());
@@ -2316,7 +2338,7 @@ TEST_F(ServerTest, GzipWithContentReceiver) {
   headers.emplace("Accept-Encoding", "gzip, deflate");
   std::string body;
   auto res =
-      cli_.Get("/gzip", headers, [&](const char *data, uint64_t data_length) {
+      cli_.Get("/compress", headers, [&](const char *data, uint64_t data_length) {
         EXPECT_EQ(data_length, 100);
         body.append(data, data_length);
         return true;
@@ -2337,7 +2359,7 @@ TEST_F(ServerTest, GzipWithoutDecompressing) {
   headers.emplace("Accept-Encoding", "gzip, deflate");
 
   cli_.set_decompress(false);
-  auto res = cli_.Get("/gzip", headers);
+  auto res = cli_.Get("/compress", headers);
 
   ASSERT_TRUE(res != nullptr);
   EXPECT_EQ("gzip", res->get_header_value("Content-Encoding"));
@@ -2351,7 +2373,7 @@ TEST_F(ServerTest, GzipWithContentReceiverWithoutAcceptEncoding) {
   Headers headers;
   std::string body;
   auto res =
-      cli_.Get("/gzip", headers, [&](const char *data, uint64_t data_length) {
+      cli_.Get("/compress", headers, [&](const char *data, uint64_t data_length) {
         EXPECT_EQ(data_length, 100);
         body.append(data, data_length);
         return true;
@@ -2370,7 +2392,7 @@ TEST_F(ServerTest, GzipWithContentReceiverWithoutAcceptEncoding) {
 TEST_F(ServerTest, NoGzip) {
   Headers headers;
   headers.emplace("Accept-Encoding", "gzip, deflate");
-  auto res = cli_.Get("/nogzip", headers);
+  auto res = cli_.Get("/nocompress", headers);
 
   ASSERT_TRUE(res != nullptr);
   EXPECT_EQ(false, res->has_header("Content-Encoding"));
@@ -2387,7 +2409,7 @@ TEST_F(ServerTest, NoGzipWithContentReceiver) {
   headers.emplace("Accept-Encoding", "gzip, deflate");
   std::string body;
   auto res =
-      cli_.Get("/nogzip", headers, [&](const char *data, uint64_t data_length) {
+      cli_.Get("/nocompress", headers, [&](const char *data, uint64_t data_length) {
         EXPECT_EQ(data_length, 100);
         body.append(data, data_length);
         return true;
@@ -2410,13 +2432,30 @@ TEST_F(ServerTest, MultipartFormDataGzip) {
   };
 
   cli_.set_compress(true);
-  auto res = cli_.Post("/gzipmultipart", items);
+  auto res = cli_.Post("/compress-multipart", items);
 
   ASSERT_TRUE(res != nullptr);
   EXPECT_EQ(200, res->status);
 }
 #endif
 
+#ifdef CPPHTTPLIB_BROTLI_SUPPORT
+TEST_F(ServerTest, Brotli) {
+  Headers headers;
+  headers.emplace("Accept-Encoding", "br");
+  auto res = cli_.Get("/compress", headers);
+
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ("brotli", res->get_header_value("Content-Encoding"));
+  EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
+  EXPECT_EQ("19", res->get_header_value("Content-Length"));
+  EXPECT_EQ("123456789012345678901234567890123456789012345678901234567890123456"
+            "7890123456789012345678901234567890",
+            res->body);
+  EXPECT_EQ(200, res->status);
+}
+#endif
+
 // Sends a raw request to a server listening at HOST:PORT.
 static bool send_request(time_t read_timeout_sec, const std::string &req,
                          std::string *resp = nullptr) {
@@ -3149,12 +3188,14 @@ TEST(YahooRedirectTest3, SimpleInterface) {
 #ifdef CPPHTTPLIB_BROTLI_SUPPORT
 TEST(DecodeWithChunkedEncoding, BrotliEncoding) {
   httplib::Client cli("https://cdnjs.cloudflare.com");
-  auto res = cli.Get("/ajax/libs/jquery/3.5.1/jquery.js", {{"Accept-Encoding", "brotli"}});
+  auto res = cli.Get("/ajax/libs/jquery/3.5.1/jquery.js",
+                     {{"Accept-Encoding", "brotli"}});
 
   ASSERT_TRUE(res != nullptr);
   EXPECT_EQ(200, res->status);
   EXPECT_EQ(287630, res->body.size());
-  EXPECT_EQ("application/javascript; charset=utf-8", res->get_header_value("Content-Type"));
+  EXPECT_EQ("application/javascript; charset=utf-8",
+            res->get_header_value("Content-Type"));
 }
 #endif