yhirose 8 years ago
parent
commit
9bc2883090
2 changed files with 132 additions and 36 deletions
  1. 105 36
      httplib.h
  2. 27 0
      test/test.cc

+ 105 - 36
httplib.h

@@ -67,6 +67,8 @@ typedef int socket_t;
 namespace httplib
 namespace httplib
 {
 {
 
 
+enum class HttpVersion { v1_0 = 0, v1_1 };
+
 typedef std::map<std::string, std::string>      Map;
 typedef std::map<std::string, std::string>      Map;
 typedef std::multimap<std::string, std::string> MultiMap;
 typedef std::multimap<std::string, std::string> MultiMap;
 typedef std::smatch                             Match;
 typedef std::smatch                             Match;
@@ -169,7 +171,7 @@ private:
 
 
 class Client {
 class Client {
 public:
 public:
-    Client(const char* host, int port);
+    Client(const char* host, int port, HttpVersion http_version = HttpVersion::v1_0);
     virtual ~Client();
     virtual ~Client();
 
 
     std::shared_ptr<Response> get(const char* path, Progress callback = [](int64_t,int64_t){});
     std::shared_ptr<Response> get(const char* path, Progress callback = [](int64_t,int64_t){});
@@ -184,6 +186,7 @@ protected:
 
 
     const std::string host_;
     const std::string host_;
     const int         port_;
     const int         port_;
+    const HttpVersion http_version_;
     const std::string host_and_port_;
     const std::string host_and_port_;
 
 
 private:
 private:
@@ -220,7 +223,7 @@ private:
 
 
 class SSLClient : public Client {
 class SSLClient : public Client {
 public:
 public:
-    SSLClient(const char* host, int port);
+    SSLClient(const char* host, int port, HttpVersion http_version = HttpVersion:v1_0);
     virtual ~SSLClient();
     virtual ~SSLClient();
 
 
 private:
 private:
@@ -235,6 +238,8 @@ private:
  */
  */
 namespace detail {
 namespace detail {
 
 
+static std::vector<const char*> http_version_strings = { "HTTP/1.0", "HTTP/1.1" };
+
 template <class Fn>
 template <class Fn>
 void split(const char* b, const char* e, char d, Fn fn)
 void split(const char* b, const char* e, char d, Fn fn)
 {
 {
@@ -567,36 +572,98 @@ inline bool read_headers(Stream& strm, MultiMap& headers)
 }
 }
 
 
 template <typename T>
 template <typename T>
-bool read_content(Stream& strm, T& x, bool allow_no_content_length, Progress progress = [](int64_t,int64_t){})
+bool read_content_with_length(Stream& strm, T& x, size_t len, Progress progress)
+{
+    x.body.assign(len, 0);
+    size_t r = 0;
+    while (r < len){
+        auto r_incr = strm.read(&x.body[r], len - r);
+        if (r_incr <= 0) {
+            return false;
+        }
+        r += r_incr;
+        if (progress) {
+            progress(r, len);
+        }
+    }
+
+    return true;
+}
+
+template <typename T>
+bool read_content_without_length(Stream& strm, T& x)
+{
+    for (;;) {
+        char byte;
+        auto n = strm.read(&byte, 1);
+        if (n < 0) {
+            return false;
+        } else if (n == 0) {
+            break;
+        }
+        x.body += byte;
+    }
+
+    return true;
+}
+
+template <typename T>
+bool read_content_chunked(Stream& strm, T& x)
+{
+    const auto BUFSIZ_CHUNK_LEN = 16;
+    char buf[BUFSIZ_CHUNK_LEN];
+
+    if (!socket_gets(strm, buf, BUFSIZ_CHUNK_LEN)) {
+        return false;
+    }
+
+    auto chunk_len = std::stoi(buf, 0, 16);
+
+    while (chunk_len > 0){
+        std::string chunk(chunk_len, 0);
+
+        auto n = strm.read(&chunk[0], chunk_len);
+        if (n <= 0) {
+            return false;
+        }
+
+        if (!socket_gets(strm, buf, BUFSIZ_CHUNK_LEN)) {
+            return false;
+        }
+
+        if (strcmp(buf, "\r\n")) {
+            break;
+        }
+
+        x.body += chunk;
+
+        if (!socket_gets(strm, buf, BUFSIZ_CHUNK_LEN)) {
+            return false;
+        }
+
+        chunk_len = std::stoi(buf, 0, 16);
+    }
+
+    return true;
+}
+
+template <typename T>
+bool read_content(Stream& strm, T& x, Progress progress = [](int64_t,int64_t){})
 {
 {
     auto len = get_header_value_int(x.headers, "Content-Length", 0);
     auto len = get_header_value_int(x.headers, "Content-Length", 0);
+
     if (len) {
     if (len) {
-        x.body.assign(len, 0);
-        auto r = 0;
-        while (r < len){
-            auto r_incr = strm.read(&x.body[r], len - r);
-            if (r_incr <= 0) {
-                return false;
-            }
-            r += r_incr;
-            if (progress) {
-                progress(r, len);
-            }
-        }
-    } else if (allow_no_content_length) {
-        for (;;) {
-            char byte;
-            auto n = strm.read(&byte, 1);
-            if (n < 1) {
-                if (x.body.size() == 0) {
-                    return true; // no body
-                } else {
-                    break;
-                }
-            }
-            x.body += byte;
+        return read_content_with_length(strm, x, len, progress);
+    } else {
+        auto encoding = get_header_value(x.headers, "Transfer-Encoding", "");
+
+        if (!strcmp(encoding, "chunked")) {
+            return read_content_chunked(strm, x);
+        } else {
+            return read_content_without_length(strm, x);
         }
         }
     }
     }
+
     return true;
     return true;
 }
 }
 
 
@@ -759,10 +826,10 @@ inline std::string decode_url(const std::string& s)
     return result;
     return result;
 }
 }
 
 
-inline void write_request(Stream& strm, const Request& req)
+inline void write_request(Stream& strm, const Request& req, const char* ver)
 {
 {
     auto path = encode_url(req.path);
     auto path = encode_url(req.path);
-    socket_printf(strm, "%s %s HTTP/1.0\r\n", req.method.c_str(), path.c_str());
+    socket_printf(strm, "%s %s %s\r\n", req.method.c_str(), path.c_str(), ver);
 
 
     write_headers(strm, req);
     write_headers(strm, req);
 
 
@@ -1074,7 +1141,7 @@ inline void Server::process_request(Stream& strm)
     }
     }
 
 
     if (req.method == "POST") {
     if (req.method == "POST") {
-        if (!detail::read_content(strm, req, false)) {
+        if (!detail::read_content(strm, req)) {
             res.status = 400;
             res.status = 400;
             write_response(strm, req, res);
             write_response(strm, req, res);
             return;
             return;
@@ -1106,9 +1173,10 @@ inline bool Server::read_and_close_socket(socket_t sock)
 }
 }
 
 
 // HTTP client implementation
 // HTTP client implementation
-inline Client::Client(const char* host, int port)
+inline Client::Client(const char* host, int port, HttpVersion http_version)
     : host_(host)
     : host_(host)
     , port_(port)
     , port_(port)
+    , http_version_(http_version)
     , host_and_port_(host_ + ":" + std::to_string(port_))
     , host_and_port_(host_ + ":" + std::to_string(port_))
 {
 {
 }
 }
@@ -1148,7 +1216,8 @@ inline bool Client::send(const Request& req, Response& res)
 inline bool Client::process_request(Stream& strm, const Request& req, Response& res)
 inline bool Client::process_request(Stream& strm, const Request& req, Response& res)
 {
 {
     // Send request
     // Send request
-    detail::write_request(strm, req);
+    auto ver = detail::http_version_strings[static_cast<size_t>(http_version_)];
+    detail::write_request(strm, req, ver);
 
 
     // Receive response
     // Receive response
     if (!read_response_line(strm, res) ||
     if (!read_response_line(strm, res) ||
@@ -1157,7 +1226,7 @@ inline bool Client::process_request(Stream& strm, const Request& req, Response&
     }
     }
 
 
     if (req.method != "HEAD") {
     if (req.method != "HEAD") {
-        if (!detail::read_content(strm, res, true, req.progress)) {
+        if (!detail::read_content(strm, res, req.progress)) {
             return false;
             return false;
         }
         }
     }
     }
@@ -1334,7 +1403,7 @@ inline bool SSLServer::read_and_close_socket(socket_t sock)
     return detail::read_and_close_socket_ssl(
     return detail::read_and_close_socket_ssl(
         sock, ctx_,
         sock, ctx_,
         SSL_accept,
         SSL_accept,
-        [](SSL* ssl) {},
+        [](SSL* /*ssl*/) {},
         [this](Stream& strm) {
         [this](Stream& strm) {
             process_request(strm);
             process_request(strm);
             return true;
             return true;
@@ -1342,8 +1411,8 @@ inline bool SSLServer::read_and_close_socket(socket_t sock)
 }
 }
 
 
 // SSL HTTP client implementation
 // SSL HTTP client implementation
-inline SSLClient::SSLClient(const char* host, int port)
-    : Client(host, port)
+inline SSLClient::SSLClient(const char* host, int port, HttpVersion http_version)
+    : Client(host, port, http_version)
 {
 {
     ctx_ = SSL_CTX_new(SSLv23_client_method());
     ctx_ = SSL_CTX_new(SSLv23_client_method());
 }
 }

+ 27 - 0
test/test.cc

@@ -109,6 +109,33 @@ TEST(GetHeaderValueTest, RegularValueInt)
     EXPECT_EQ(100, val);
     EXPECT_EQ(100, val);
 }
 }
 
 
+void testChunkedEncoding(httplib::HttpVersion ver)
+{
+    auto host = "www.httpwatch.com";
+    auto port = 80;
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+    httplib::SSLClient cli(host, port, ver);
+#else
+    httplib::Client cli(host, port, ver);
+#endif
+
+    auto res = cli.get("/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137");
+    ASSERT_TRUE(res != nullptr);
+
+    std::string out;
+    httplib::detail::read_file("./image.jpg", out);
+
+    EXPECT_EQ(200, res->status);
+    EXPECT_EQ(out, res->body);
+}
+
+TEST(ChunkedEncodingTest, FromHTTPWatch)
+{
+    testChunkedEncoding(httplib::HttpVersion::v1_0);
+    testChunkedEncoding(httplib::HttpVersion::v1_1);
+}
+
 class ServerTest : public ::testing::Test {
 class ServerTest : public ::testing::Test {
 protected:
 protected:
     ServerTest()
     ServerTest()