Browse Source

Write error handling

yhirose 6 years ago
parent
commit
7267b3f3e2
1 changed files with 58 additions and 37 deletions
  1. 58 37
      httplib.h

+ 58 - 37
httplib.h

@@ -197,7 +197,7 @@ public:
   virtual std::string get_remote_addr() const = 0;
 
   template <typename... Args>
-  void write_format(const char *fmt, const Args &... args);
+  int write_format(const char *fmt, const Args &... args);
 };
 
 class SocketStream : public Stream {
@@ -286,7 +286,7 @@ private:
   bool dispatch_request(Request &req, Response &res, Handlers &handlers);
 
   bool parse_request_line(const char *s, Request &req);
-  void write_response(Stream &strm, bool last_connection, const Request &req,
+  bool write_response(Stream &strm, bool last_connection, const Request &req,
                       Response &res);
 
   virtual bool read_and_close_socket(socket_t sock);
@@ -1228,18 +1228,29 @@ bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status,
   return ret;
 }
 
-template <typename T> inline void write_headers(Stream &strm, const T &info) {
+template <typename T> inline int write_headers(Stream &strm, const T &info) {
+  auto write_len = 0;
   for (const auto &x : info.headers) {
-    strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str());
+    auto len = strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str());
+    if (len < 0) {
+      return len;
+    }
+    write_len += len;
+  }
+  auto len = strm.write("\r\n");
+  if (len < 0) {
+    return len;
   }
-  strm.write("\r\n");
+  write_len += len;
+  return write_len;
 }
 
 template <typename T>
-inline void write_content_chunked(Stream &strm, const T &x) {
+inline int write_content_chunked(Stream &strm, const T &x) {
   auto chunked_response = !x.has_header("Content-Length");
   uint64_t offset = 0;
   auto data_available = true;
+  auto write_len = 0;
   while (data_available) {
     auto chunk = x.content_producer(offset);
     offset += chunk.size();
@@ -1250,10 +1261,13 @@ inline void write_content_chunked(Stream &strm, const T &x) {
       chunk = from_i_to_hex(chunk.size()) + "\r\n" + chunk + "\r\n";
     }
 
-    if (strm.write(chunk.c_str(), chunk.size()) < 0) {
-      break; // Stop on error
+    auto len = strm.write(chunk.c_str(), chunk.size());
+    if (len < 0) {
+      return len;
     }
+    write_len += len;
   }
+  return write_len;
 }
 
 inline std::string encode_url(const std::string &s) {
@@ -1560,7 +1574,7 @@ inline void Response::set_content(const std::string &s,
 
 // Rstream implementation
 template <typename... Args>
-inline void Stream::write_format(const char *fmt, const Args &... args) {
+inline int Stream::write_format(const char *fmt, const Args &... args) {
   const auto bufsiz = 2048;
   char buf[bufsiz];
 
@@ -1569,23 +1583,25 @@ inline void Stream::write_format(const char *fmt, const Args &... args) {
 #else
   auto n = snprintf(buf, bufsiz - 1, fmt, args...);
 #endif
-  if (n > 0) {
-    if (n >= bufsiz - 1) {
-      std::vector<char> glowable_buf(bufsiz);
+  if (n <= 0) {
+    return n;
+  }
 
-      while (n >= static_cast<int>(glowable_buf.size() - 1)) {
-        glowable_buf.resize(glowable_buf.size() * 2);
+  if (n >= bufsiz - 1) {
+    std::vector<char> glowable_buf(bufsiz);
+
+    while (n >= static_cast<int>(glowable_buf.size() - 1)) {
+      glowable_buf.resize(glowable_buf.size() * 2);
 #if defined(_MSC_VER) && _MSC_VER < 1900
-        n = _snprintf_s(&glowable_buf[0], glowable_buf.size(),
-                        glowable_buf.size() - 1, fmt, args...);
+      n = _snprintf_s(&glowable_buf[0], glowable_buf.size(),
+                      glowable_buf.size() - 1, fmt, args...);
 #else
-        n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...);
+      n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...);
 #endif
-      }
-      write(&glowable_buf[0], n);
-    } else {
-      write(buf, n);
     }
+    return write(&glowable_buf[0], n);
+  } else {
+    return write(buf, n);
   }
 }
 
@@ -1745,15 +1761,17 @@ inline bool Server::parse_request_line(const char *s, Request &req) {
   return false;
 }
 
-inline void Server::write_response(Stream &strm, bool last_connection,
+inline bool Server::write_response(Stream &strm, bool last_connection,
                                    const Request &req, Response &res) {
   assert(res.status != -1);
 
   if (400 <= res.status && error_handler_) { error_handler_(req, res); }
 
   // Response line
-  strm.write_format("HTTP/1.1 %d %s\r\n", res.status,
-                    detail::status_message(res.status));
+  if (!strm.write_format("HTTP/1.1 %d %s\r\n", res.status,
+                    detail::status_message(res.status))) {
+    return false;
+  }
 
   // Headers
   if (last_connection || req.get_header_value("Connection") == "close") {
@@ -1793,19 +1811,27 @@ inline void Server::write_response(Stream &strm, bool last_connection,
     res.set_header("Content-Length", length.c_str());
   }
 
-  detail::write_headers(strm, res);
+  if (!detail::write_headers(strm, res)) {
+    return false;
+  }
 
   // Body
   if (req.method != "HEAD") {
     if (!res.body.empty()) {
-      strm.write(res.body.c_str(), res.body.size());
+      if (!strm.write(res.body.c_str(), res.body.size())) {
+        return false;
+      }
     } else if (res.content_producer) {
-      detail::write_content_chunked(strm, res);
+      if (!detail::write_content_chunked(strm, res)) {
+        return false;
+      }
     }
   }
 
   // Log
   if (logger_) { logger_(req, res); }
+
+  return true;
 }
 
 inline bool Server::handle_file_request(Request &req, Response &res) {
@@ -1978,16 +2004,14 @@ Server::process_request(Stream &strm, bool last_connection,
   // Check if the request URI doesn't exceed the limit
   if (reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
     res.status = 414;
-    write_response(strm, last_connection, req, res);
-    return true;
+    return write_response(strm, last_connection, req, res);
   }
 
   // Request line and headers
   if (!parse_request_line(reader.ptr(), req) ||
       !detail::read_headers(strm, req.headers)) {
     res.status = 400;
-    write_response(strm, last_connection, req, res);
-    return true;
+    return write_response(strm, last_connection, req, res);
   }
 
   if (req.get_header_value("Connection") == "close") {
@@ -2001,8 +2025,7 @@ Server::process_request(Stream &strm, bool last_connection,
     if (!detail::read_content(
             strm, req, payload_max_length_, res.status, Progress(),
             [&](const char *buf, size_t n) { req.body.append(buf, n); })) {
-      write_response(strm, last_connection, req, res);
-      return true;
+      return write_response(strm, last_connection, req, res);
     }
 
     const auto &content_type = req.get_header_value("Content-Type");
@@ -2014,8 +2037,7 @@ Server::process_request(Stream &strm, bool last_connection,
       if (!detail::parse_multipart_boundary(content_type, boundary) ||
           !detail::parse_multipart_formdata(boundary, req.body, req.files)) {
         res.status = 400;
-        write_response(strm, last_connection, req, res);
-        return true;
+        return write_response(strm, last_connection, req, res);
       }
     }
   }
@@ -2029,8 +2051,7 @@ Server::process_request(Stream &strm, bool last_connection,
     res.status = 404;
   }
 
-  write_response(strm, last_connection, req, res);
-  return true;
+  return write_response(strm, last_connection, req, res);
 }
 
 inline bool Server::is_valid() const { return true; }