Browse Source

Improved Stream interface

yhirose 6 years ago
parent
commit
2e360f9dd6

+ 1 - 0
.gitignore

@@ -6,6 +6,7 @@ example/hello
 example/simplesvr
 example/benchmark
 example/redirect
+example/sse
 example/upload
 example/*.pem
 test/test

+ 5 - 2
example/Makefile

@@ -5,7 +5,7 @@ OPENSSL_DIR = /usr/local/opt/openssl
 OPENSSL_SUPPORT = -DCPPHTTPLIB_OPENSSL_SUPPORT -I$(OPENSSL_DIR)/include -L$(OPENSSL_DIR)/lib -lssl -lcrypto
 ZLIB_SUPPORT = -DCPPHTTPLIB_ZLIB_SUPPORT -lz
 
-all: server client hello simplesvr upload redirect benchmark
+all: server client hello simplesvr upload redirect sse benchmark
 
 server : server.cc ../httplib.h Makefile
 	$(CXX) -o server $(CXXFLAGS) server.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT)
@@ -25,6 +25,9 @@ upload : upload.cc ../httplib.h Makefile
 redirect : redirect.cc ../httplib.h Makefile
 	$(CXX) -o redirect $(CXXFLAGS) redirect.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT)
 
+sse : sse.cc ../httplib.h Makefile
+	$(CXX) -o sse $(CXXFLAGS) sse.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT)
+
 benchmark : benchmark.cc ../httplib.h Makefile
 	$(CXX) -o benchmark $(CXXFLAGS) benchmark.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT)
 
@@ -33,4 +36,4 @@ pem:
 	openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem
 
 clean:
-	rm server client hello simplesvr upload redirect benchmark *.pem
+	rm server client hello simplesvr upload redirect sse benchmark *.pem

+ 106 - 0
example/sse.cc

@@ -0,0 +1,106 @@
+//
+//  sse.cc
+//
+//  Copyright (c) 2020 Yuji Hirose. All rights reserved.
+//  MIT License
+//
+
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <httplib.h>
+#include <iostream>
+#include <mutex>
+#include <thread>
+
+using namespace httplib;
+using namespace std;
+
+class EventDispatcher {
+public:
+  EventDispatcher() {
+    id_ = 0;
+    cid_ = -1;
+  }
+
+  void add_sink(DataSink *sink) {
+    unique_lock<mutex> lk(m_);
+    int id = id_;
+    cv_.wait(lk, [&] { return cid_ == id; });
+    if (sink->is_writable()) { sink->write(message_.data(), message_.size()); }
+  }
+
+  void send_event(const string &message) {
+    lock_guard<mutex> lk(m_);
+    cid_ = id_++;
+    message_ = message;
+    cv_.notify_all();
+  }
+
+private:
+  mutex m_;
+  condition_variable cv_;
+  atomic_int id_;
+  atomic_int cid_;
+  string message_;
+};
+
+const auto html = R"(
+<!DOCTYPE html>
+<html lang="en">
+<head>
+<meta charset="UTF-8">
+<title>SSE demo</title>
+</head>
+<body>
+<script>
+const ev1 = new EventSource("event1");
+ev1.onmessage = function(e) {
+  console.log('ev1', e.data);
+}
+const ev2 = new EventSource("event2");
+ev2.onmessage = function(e) {
+  console.log('ev2', e.data);
+}
+</script>
+</body>
+</html>
+)";
+
+int main(void) {
+  EventDispatcher ed;
+
+  Server svr;
+
+  svr.Get("/", [&](const Request & /*req*/, Response &res) {
+    res.set_content(html, "text/html");
+  });
+
+  svr.Get("/event1", [&](const Request & /*req*/, Response &res) {
+    cout << "connected to event1..." << endl;
+    res.set_header("Content-Type", "text/event-stream");
+    res.set_chunked_content_provider(
+        [&](uint64_t /*offset*/, DataSink &sink) { ed.add_sink(&sink); });
+  });
+
+  svr.Get("/event2", [&](const Request & /*req*/, Response &res) {
+    cout << "connected to event2..." << endl;
+    res.set_header("Content-Type", "text/event-stream");
+    res.set_chunked_content_provider(
+        [&](uint64_t /*offset*/, DataSink &sink) { ed.add_sink(&sink); });
+  });
+
+  thread t([&] {
+    int id = 0;
+    while (true) {
+      this_thread::sleep_for(chrono::seconds(1));
+      cout << "send event: " << id << std::endl;
+      std::stringstream ss;
+      ss << "data: " << id << "\n\n";
+      ed.send_event(ss.str());
+      id++;
+    }
+  });
+
+  svr.listen("localhost", 1234);
+}

+ 81 - 50
httplib.h

@@ -224,7 +224,7 @@ public:
 
   std::function<void(const char *data, size_t data_len)> write;
   std::function<void()> done;
-  // TODO: std::function<bool()> is_alive;
+  std::function<bool()> is_writable;
 };
 
 using ContentProvider =
@@ -349,14 +349,18 @@ struct Response {
 class Stream {
 public:
   virtual ~Stream() = default;
+
+  virtual bool is_readable() const = 0;
+  virtual bool is_writable() const = 0;
+
   virtual int read(char *ptr, size_t size) = 0;
-  virtual int write(const char *ptr, size_t size1) = 0;
-  virtual int write(const char *ptr) = 0;
-  virtual int write(const std::string &s) = 0;
+  virtual int write(const char *ptr, size_t size) = 0;
   virtual std::string get_remote_addr() const = 0;
 
   template <typename... Args>
   int write_format(const char *fmt, const Args &... args);
+  int write(const char *ptr);
+  int write(const std::string &s);
 };
 
 class TaskQueue {
@@ -496,7 +500,7 @@ protected:
 
 private:
   using Handlers = std::vector<std::pair<std::regex, Handler>>;
-  using HandersForContentReader =
+  using HandlersForContentReader =
       std::vector<std::pair<std::regex, HandlerWithContentReader>>;
 
   socket_t create_server_socket(const char *host, int port,
@@ -509,7 +513,7 @@ private:
   bool dispatch_request(Request &req, Response &res, Handlers &handlers);
   bool dispatch_request_for_content_reader(Request &req, Response &res,
                                            ContentReader content_reader,
-                                           HandersForContentReader &handlers);
+                                           HandlersForContentReader &handlers);
 
   bool parse_request_line(const char *s, Request &req);
   bool write_response(Stream &strm, bool last_connection, const Request &req,
@@ -537,11 +541,11 @@ private:
   Handler file_request_handler_;
   Handlers get_handlers_;
   Handlers post_handlers_;
-  HandersForContentReader post_handlers_for_content_reader_;
+  HandlersForContentReader post_handlers_for_content_reader_;
   Handlers put_handlers_;
-  HandersForContentReader put_handlers_for_content_reader_;
+  HandlersForContentReader put_handlers_for_content_reader_;
   Handlers patch_handlers_;
-  HandersForContentReader patch_handlers_for_content_reader_;
+  HandlersForContentReader patch_handlers_for_content_reader_;
   Handlers delete_handlers_;
   Handlers options_handlers_;
   Handler error_handler_;
@@ -1186,6 +1190,28 @@ inline int select_read(socket_t sock, time_t sec, time_t usec) {
 #endif
 }
 
+inline int select_write(socket_t sock, time_t sec, time_t usec) {
+#ifdef CPPHTTPLIB_USE_POLL
+  struct pollfd pfd_read;
+  pfd_read.fd = sock;
+  pfd_read.events = POLLOUT;
+
+  auto timeout = static_cast<int>(sec * 1000 + usec / 1000);
+
+  return poll(&pfd_read, 1, timeout);
+#else
+  fd_set fds;
+  FD_ZERO(&fds);
+  FD_SET(sock, &fds);
+
+  timeval tv;
+  tv.tv_sec = static_cast<long>(sec);
+  tv.tv_usec = static_cast<long>(usec);
+
+  return select(static_cast<int>(sock + 1), nullptr, &fds, nullptr, &tv);
+#endif
+}
+
 inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) {
 #ifdef CPPHTTPLIB_USE_POLL
   struct pollfd pfd_read;
@@ -1233,10 +1259,10 @@ public:
                time_t read_timeout_usec);
   ~SocketStream() override;
 
+  bool is_readable() const override;
+  bool is_writable() const override;
   int read(char *ptr, size_t size) override;
   int write(const char *ptr, size_t size) override;
-  int write(const char *ptr) override;
-  int write(const std::string &s) override;
   std::string get_remote_addr() const override;
 
 private:
@@ -1252,11 +1278,11 @@ public:
                   time_t read_timeout_usec);
   virtual ~SSLSocketStream();
 
-  virtual int read(char *ptr, size_t size);
-  virtual int write(const char *ptr, size_t size);
-  virtual int write(const char *ptr);
-  virtual int write(const std::string &s);
-  virtual std::string get_remote_addr() const;
+  bool is_readable() const override;
+  bool is_writable() const override;
+  int read(char *ptr, size_t size) override;
+  int write(const char *ptr, size_t size) override;
+  std::string get_remote_addr() const override;
 
 private:
   socket_t sock_;
@@ -1271,10 +1297,10 @@ public:
   BufferStream() = default;
   ~BufferStream() override = default;
 
+  bool is_readable() const override;
+  bool is_writable() const override;
   int read(char *ptr, size_t size) override;
   int write(const char *ptr, size_t size) override;
-  int write(const char *ptr) override;
-  int write(const std::string &s) override;
   std::string get_remote_addr() const override;
 
   const std::string &get_buffer() const;
@@ -1914,6 +1940,7 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
       written_length = strm.write(d, l);
     };
     data_sink.done = [&](void) { written_length = -1; };
+    data_sink.is_writable = [&](void) { return strm.is_writable(); };
 
     content_provider(offset, end_offset - offset, data_sink);
     if (written_length < 0) { return written_length; }
@@ -1944,6 +1971,7 @@ inline ssize_t write_content_chunked(Stream &strm,
       data_available = false;
       written_length = strm.write("0\r\n\r\n");
     };
+    data_sink.is_writable = [&](void) { return strm.is_writable(); };
 
     content_provider(offset, 0, data_sink);
 
@@ -2701,6 +2729,12 @@ inline void Response::set_chunked_content_provider(
 }
 
 // Rstream implementation
+inline int Stream::write(const char *ptr) { return write(ptr, strlen(ptr)); }
+
+inline int Stream::write(const std::string &s) {
+  return write(s.data(), s.size());
+}
+
 template <typename... Args>
 inline int Stream::write_format(const char *fmt, const Args &... args) {
   std::array<char, 2048> buf;
@@ -2740,23 +2774,22 @@ inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec,
 
 inline SocketStream::~SocketStream() {}
 
-inline int SocketStream::read(char *ptr, size_t size) {
-  if (detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) {
-    return recv(sock_, ptr, static_cast<int>(size), 0);
-  }
-  return -1;
+inline bool SocketStream::is_readable() const {
+  return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
 }
 
-inline int SocketStream::write(const char *ptr, size_t size) {
-  return send(sock_, ptr, static_cast<int>(size), 0);
+inline bool SocketStream::is_writable() const {
+  return detail::select_write(sock_, 0, 0) > 0;
 }
 
-inline int SocketStream::write(const char *ptr) {
-  return write(ptr, strlen(ptr));
+inline int SocketStream::read(char *ptr, size_t size) {
+  if (is_readable()) { return recv(sock_, ptr, static_cast<int>(size), 0); }
+  return -1;
 }
 
-inline int SocketStream::write(const std::string &s) {
-  return write(s.data(), s.size());
+inline int SocketStream::write(const char *ptr, size_t size) {
+  if (is_writable()) { return send(sock_, ptr, static_cast<int>(size), 0); }
+  return -1;
 }
 
 inline std::string SocketStream::get_remote_addr() const {
@@ -2764,6 +2797,10 @@ inline std::string SocketStream::get_remote_addr() const {
 }
 
 // Buffer stream implementation
+inline bool BufferStream::is_readable() const { return true; }
+
+inline bool BufferStream::is_writable() const { return true; }
+
 inline int BufferStream::read(char *ptr, size_t size) {
 #if defined(_MSC_VER) && _MSC_VER < 1900
   int len_read = static_cast<int>(buffer._Copy_s(ptr, size, size, position));
@@ -2779,14 +2816,6 @@ inline int BufferStream::write(const char *ptr, size_t size) {
   return static_cast<int>(size);
 }
 
-inline int BufferStream::write(const char *ptr) {
-  return write(ptr, strlen(ptr));
-}
-
-inline int BufferStream::write(const std::string &s) {
-  return write(s.data(), s.size());
-}
-
 inline std::string BufferStream::get_remote_addr() const { return ""; }
 
 inline const std::string &BufferStream::get_buffer() const { return buffer; }
@@ -3372,10 +3401,9 @@ inline bool Server::dispatch_request(Request &req, Response &res,
   return false;
 }
 
-inline bool
-Server::dispatch_request_for_content_reader(Request &req, Response &res,
-                                            ContentReader content_reader,
-                                            HandersForContentReader &handlers) {
+inline bool Server::dispatch_request_for_content_reader(
+    Request &req, Response &res, ContentReader content_reader,
+    HandlersForContentReader &handlers) {
   for (const auto &x : handlers) {
     const auto &pattern = x.first;
     const auto &handler = x.second;
@@ -3777,6 +3805,7 @@ inline bool Client::write_request(Stream &strm, const Request &req,
         auto written_length = strm.write(d, l);
         offset += written_length;
       };
+      data_sink.is_writable = [&](void) { return strm.is_writable(); };
 
       while (offset < end_offset) {
         req.content_provider(offset, end_offset - offset, data_sink);
@@ -3810,6 +3839,7 @@ inline std::shared_ptr<Response> Client::send_with_content_provider(
         req.body.append(data, data_len);
         offset += data_len;
       };
+      data_sink.is_writable = [&](void) { return true; };
 
       while (offset < content_length) {
         content_provider(offset, content_length - offset, data_sink);
@@ -4380,6 +4410,14 @@ inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl,
 
 inline SSLSocketStream::~SSLSocketStream() {}
 
+inline bool SSLSocketStream::is_readable() const {
+  return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
+}
+
+inline bool SSLSocketStream::is_writable() const {
+  return detail::select_write(sock_, 0, 0) > 0;
+}
+
 inline int SSLSocketStream::read(char *ptr, size_t size) {
   if (SSL_pending(ssl_) > 0 ||
       select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) {
@@ -4389,15 +4427,8 @@ inline int SSLSocketStream::read(char *ptr, size_t size) {
 }
 
 inline int SSLSocketStream::write(const char *ptr, size_t size) {
-  return SSL_write(ssl_, ptr, static_cast<int>(size));
-}
-
-inline int SSLSocketStream::write(const char *ptr) {
-  return write(ptr, strlen(ptr));
-}
-
-inline int SSLSocketStream::write(const std::string &s) {
-  return write(s.data(), s.size());
+  if (is_writable()) { return SSL_write(ssl_, ptr, static_cast<int>(size)); }
+  return -1;
 }
 
 inline std::string SSLSocketStream::get_remote_addr() const {

+ 0 - 0
test/test_proxy_docker/Dockerfile → test/proxy/Dockerfile


+ 0 - 0
test/test_proxy_docker/basic_passwd → test/proxy/basic_passwd


+ 0 - 0
test/test_proxy_docker/basic_squid.conf → test/proxy/basic_squid.conf


+ 0 - 0
test/test_proxy_docker/digest_passwd → test/proxy/digest_passwd


+ 0 - 0
test/test_proxy_docker/digest_squid.conf → test/proxy/digest_squid.conf


+ 0 - 0
test/test_proxy_docker/docker-compose.yml → test/proxy/docker-compose.yml


+ 0 - 0
test/test_proxy_docker/down.sh → test/proxy/down.sh


+ 0 - 0
test/test_proxy_docker/up.sh → test/proxy/up.sh


+ 8 - 1
test/test.cc

@@ -200,7 +200,8 @@ TEST(ParseHeaderValueTest, Range) {
 }
 
 TEST(BufferStreamTest, read) {
-  detail::BufferStream strm;
+  detail::BufferStream strm1;
+  Stream& strm = strm1;
 
   EXPECT_EQ(5, strm.write("hello"));
 
@@ -724,6 +725,7 @@ protected:
              [&](const Request & /*req*/, Response &res) {
                res.set_chunked_content_provider(
                    [](uint64_t /*offset*/, DataSink &sink) {
+                     ASSERT_TRUE(sink.is_writable());
                      sink.write("123", 3);
                      sink.write("456", 3);
                      sink.write("789", 3);
@@ -735,6 +737,7 @@ protected:
                auto i = new int(0);
                res.set_chunked_content_provider(
                    [i](uint64_t /*offset*/, DataSink &sink) {
+                     ASSERT_TRUE(sink.is_writable());
                      switch (*i) {
                      case 0: sink.write("123", 3); break;
                      case 1: sink.write("456", 3); break;
@@ -758,6 +761,7 @@ protected:
                res.set_content_provider(
                    data->size(),
                    [data](uint64_t offset, uint64_t length, DataSink &sink) {
+                     ASSERT_TRUE(sink.is_writable());
                      size_t DATA_CHUNK_SIZE = 4;
                      const auto &d = *data;
                      auto out_len =
@@ -771,6 +775,7 @@ protected:
                res.set_content_provider(size_t(-1), [](uint64_t /*offset*/,
                                                        uint64_t /*length*/,
                                                        DataSink &sink) {
+                 ASSERT_TRUE(sink.is_writable());
                  std::string data = "data_chunk";
                  sink.write(data.data(), data.size());
                });
@@ -1636,6 +1641,7 @@ TEST_F(ServerTest, PutWithContentProvider) {
   auto res = cli_.Put(
       "/put", 3,
       [](size_t /*offset*/, size_t /*length*/, DataSink &sink) {
+        ASSERT_TRUE(sink.is_writable());
         sink.write("PUT", 3);
       },
       "text/plain");
@@ -1651,6 +1657,7 @@ TEST_F(ServerTest, PutWithContentProviderWithGzip) {
   auto res = cli_.Put(
       "/put", 3,
       [](size_t /*offset*/, size_t /*length*/, DataSink &sink) {
+        ASSERT_TRUE(sink.is_writable());
         sink.write("PUT", 3);
       },
       "text/plain");