Forráskód Böngészése

Refactor streams: rename is_* to wait_* for clarity (#2069)

- Replace is_readable() with wait_readable() and is_writable() with
  wait_writable() in the Stream interface.
- Implement a new is_readable() function with semantics that more
  closely reflect its name. It returns immediately whether data is
  available for reading, without waiting.
- Update call sites of is_writable(), removing redundant checks.
Florian Albrechtskirchinger 9 hónapja
szülő
commit
550f728165
3 módosított fájl, 45 hozzáadás és 33 törlés
  1. 37 27
      httplib.h
  2. 3 1
      test/fuzzing/server_fuzzer.cc
  3. 5 5
      test/test.cc

+ 37 - 27
httplib.h

@@ -751,7 +751,8 @@ public:
   virtual ~Stream() = default;
 
   virtual bool is_readable() const = 0;
-  virtual bool is_writable() const = 0;
+  virtual bool wait_readable() const = 0;
+  virtual bool wait_writable() const = 0;
 
   virtual ssize_t read(char *ptr, size_t size) = 0;
   virtual ssize_t write(const char *ptr, size_t size) = 0;
@@ -2466,7 +2467,8 @@ public:
   ~BufferStream() override = default;
 
   bool is_readable() const override;
-  bool is_writable() const override;
+  bool wait_readable() const override;
+  bool wait_writable() const override;
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
   void get_remote_ip_and_port(std::string &ip, int &port) const override;
@@ -3380,7 +3382,8 @@ public:
   ~SocketStream() override;
 
   bool is_readable() const override;
-  bool is_writable() const override;
+  bool wait_readable() const override;
+  bool wait_writable() const override;
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
   void get_remote_ip_and_port(std::string &ip, int &port) const override;
@@ -3416,7 +3419,8 @@ public:
   ~SSLSocketStream() override;
 
   bool is_readable() const override;
-  bool is_writable() const override;
+  bool wait_readable() const override;
+  bool wait_writable() const override;
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
   void get_remote_ip_and_port(std::string &ip, int &port) const override;
@@ -4578,7 +4582,7 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,
 
   data_sink.write = [&](const char *d, size_t l) -> bool {
     if (ok) {
-      if (strm.is_writable() && write_data(strm, d, l)) {
+      if (write_data(strm, d, l)) {
         offset += l;
       } else {
         ok = false;
@@ -4587,10 +4591,10 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider,
     return ok;
   };
 
-  data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
+  data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
 
   while (offset < end_offset && !is_shutting_down()) {
-    if (!strm.is_writable()) {
+    if (!strm.wait_writable()) {
       error = Error::Write;
       return false;
     } else if (!content_provider(offset, end_offset - offset, data_sink)) {
@@ -4628,17 +4632,17 @@ write_content_without_length(Stream &strm,
   data_sink.write = [&](const char *d, size_t l) -> bool {
     if (ok) {
       offset += l;
-      if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; }
+      if (!write_data(strm, d, l)) { ok = false; }
     }
     return ok;
   };
 
-  data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
+  data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
 
   data_sink.done = [&](void) { data_available = false; };
 
   while (data_available && !is_shutting_down()) {
-    if (!strm.is_writable()) {
+    if (!strm.wait_writable()) {
       return false;
     } else if (!content_provider(offset, 0, data_sink)) {
       return false;
@@ -4673,10 +4677,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
           // Emit chunked response header and footer for each chunk
           auto chunk =
               from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
-          if (!strm.is_writable() ||
-              !write_data(strm, chunk.data(), chunk.size())) {
-            ok = false;
-          }
+          if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; }
         }
       } else {
         ok = false;
@@ -4685,7 +4686,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
     return ok;
   };
 
-  data_sink.is_writable = [&]() -> bool { return strm.is_writable(); };
+  data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
 
   auto done_with_trailer = [&](const Headers *trailer) {
     if (!ok) { return; }
@@ -4705,8 +4706,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
     if (!payload.empty()) {
       // Emit chunked response header and footer for each chunk
       auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
-      if (!strm.is_writable() ||
-          !write_data(strm, chunk.data(), chunk.size())) {
+      if (!write_data(strm, chunk.data(), chunk.size())) {
         ok = false;
         return;
       }
@@ -4738,7 +4738,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
   };
 
   while (data_available && !is_shutting_down()) {
-    if (!strm.is_writable()) {
+    if (!strm.wait_writable()) {
       error = Error::Write;
       return false;
     } else if (!content_provider(offset, 0, data_sink)) {
@@ -6029,6 +6029,10 @@ inline SocketStream::SocketStream(
 inline SocketStream::~SocketStream() = default;
 
 inline bool SocketStream::is_readable() const {
+  return read_buff_off_ < read_buff_content_size_;
+}
+
+inline bool SocketStream::wait_readable() const {
   if (max_timeout_msec_ <= 0) {
     return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
   }
@@ -6041,7 +6045,7 @@ inline bool SocketStream::is_readable() const {
   return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0;
 }
 
-inline bool SocketStream::is_writable() const {
+inline bool SocketStream::wait_writable() const {
   return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
          is_socket_alive(sock_);
 }
@@ -6068,7 +6072,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
     }
   }
 
-  if (!is_readable()) { return -1; }
+  if (!wait_readable()) { return -1; }
 
   read_buff_off_ = 0;
   read_buff_content_size_ = 0;
@@ -6093,7 +6097,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
 }
 
 inline ssize_t SocketStream::write(const char *ptr, size_t size) {
-  if (!is_writable()) { return -1; }
+  if (!wait_writable()) { return -1; }
 
 #if defined(_WIN32) && !defined(_WIN64)
   size =
@@ -6124,7 +6128,9 @@ inline time_t SocketStream::duration() const {
 // Buffer stream implementation
 inline bool BufferStream::is_readable() const { return true; }
 
-inline bool BufferStream::is_writable() const { return true; }
+inline bool BufferStream::wait_readable() const { return true; }
+
+inline bool BufferStream::wait_writable() const { return true; }
 
 inline ssize_t BufferStream::read(char *ptr, size_t size) {
 #if defined(_MSC_VER) && _MSC_VER < 1910
@@ -9161,6 +9167,10 @@ inline SSLSocketStream::SSLSocketStream(
 inline SSLSocketStream::~SSLSocketStream() = default;
 
 inline bool SSLSocketStream::is_readable() const {
+  return SSL_pending(ssl_) > 0;
+}
+
+inline bool SSLSocketStream::wait_readable() const {
   if (max_timeout_msec_ <= 0) {
     return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
   }
@@ -9173,7 +9183,7 @@ inline bool SSLSocketStream::is_readable() const {
   return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0;
 }
 
-inline bool SSLSocketStream::is_writable() const {
+inline bool SSLSocketStream::wait_writable() const {
   return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
          is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_);
 }
@@ -9181,7 +9191,7 @@ inline bool SSLSocketStream::is_writable() const {
 inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
   if (SSL_pending(ssl_) > 0) {
     return SSL_read(ssl_, ptr, static_cast<int>(size));
-  } else if (is_readable()) {
+  } else if (wait_readable()) {
     auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
     if (ret < 0) {
       auto err = SSL_get_error(ssl_, ret);
@@ -9195,7 +9205,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
 #endif
         if (SSL_pending(ssl_) > 0) {
           return SSL_read(ssl_, ptr, static_cast<int>(size));
-        } else if (is_readable()) {
+        } else if (wait_readable()) {
           std::this_thread::sleep_for(std::chrono::microseconds{10});
           ret = SSL_read(ssl_, ptr, static_cast<int>(size));
           if (ret >= 0) { return ret; }
@@ -9212,7 +9222,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
 }
 
 inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
-  if (is_writable()) {
+  if (wait_writable()) {
     auto handle_size = static_cast<int>(
         std::min<size_t>(size, (std::numeric_limits<int>::max)()));
 
@@ -9227,7 +9237,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
 #else
       while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
 #endif
-        if (is_writable()) {
+        if (wait_writable()) {
           std::this_thread::sleep_for(std::chrono::microseconds{10});
           ret = SSL_write(ssl_, ptr, static_cast<int>(handle_size));
           if (ret >= 0) { return ret; }

+ 3 - 1
test/fuzzing/server_fuzzer.cc

@@ -25,7 +25,9 @@ public:
 
   bool is_readable() const override { return true; }
 
-  bool is_writable() const override { return true; }
+  bool wait_readable() const override { return true; }
+
+  bool wait_writable() const override { return true; }
 
   void get_remote_ip_and_port(std::string &ip, int &port) const override {
     ip = "127.0.0.1";

+ 5 - 5
test/test.cc

@@ -156,7 +156,7 @@ TEST_F(UnixSocketTest, abstract) {
 }
 #endif
 
-TEST(SocketStream, is_writable_UNIX) {
+TEST(SocketStream, wait_writable_UNIX) {
   int fds[2];
   ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
 
@@ -167,17 +167,17 @@ TEST(SocketStream, is_writable_UNIX) {
   };
   asSocketStream(fds[0], [&](Stream &s0) {
     EXPECT_EQ(s0.socket(), fds[0]);
-    EXPECT_TRUE(s0.is_writable());
+    EXPECT_TRUE(s0.wait_writable());
 
     EXPECT_EQ(0, close(fds[1]));
-    EXPECT_FALSE(s0.is_writable());
+    EXPECT_FALSE(s0.wait_writable());
 
     return true;
   });
   EXPECT_EQ(0, close(fds[0]));
 }
 
-TEST(SocketStream, is_writable_INET) {
+TEST(SocketStream, wait_writable_INET) {
   sockaddr_in addr;
   memset(&addr, 0, sizeof(addr));
   addr.sin_family = AF_INET;
@@ -212,7 +212,7 @@ TEST(SocketStream, is_writable_INET) {
   };
   asSocketStream(disconnected_svr_sock, [&](Stream &ss) {
     EXPECT_EQ(ss.socket(), disconnected_svr_sock);
-    EXPECT_FALSE(ss.is_writable());
+    EXPECT_FALSE(ss.wait_writable());
 
     return true;
   });