Browse Source

Enforce WebSocket message size limit at reception

Paul-Louis Ageneau 1 năm trước cách đây
mục cha
commit
e492a19d9b
3 tập tin đã thay đổi với 62 bổ sung28 xóa
  1. 1 2
      src/impl/websocket.cpp
  2. 51 20
      src/impl/wstransport.cpp
  3. 10 6
      src/impl/wstransport.hpp

+ 1 - 2
src/impl/websocket.cpp

@@ -443,8 +443,7 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 			}
 		};
 
-		auto maxOutstandingPings = config.maxOutstandingPings.value_or(0);
-		auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, maxOutstandingPings,
+		auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, config,
 		                                               weak_bind(&WebSocket::incoming, this, _1),
 		                                               stateChangeCallback);
 

+ 51 - 20
src/impl/wstransport.cpp

@@ -43,11 +43,9 @@ using std::to_integer;
 using std::to_string;
 using std::chrono::system_clock;
 
-WsTransport::WsTransport(
-    variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
-        lower,
-    shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
-    state_callback stateCallback)
+WsTransport::WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
+                         const WebSocketConfiguration &config, message_callback recvCallback,
+                         state_callback stateCallback)
     : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
                 std::move(stateCallback)),
       mHandshake(std::move(handshake)),
@@ -55,7 +53,8 @@ WsTransport::WsTransport(
           std::visit(rtc::overloaded{[](auto l) { return l->isActive(); },
                                      [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
                      lower)),
-      mMaxOutstandingPings(maxOutstandingPings) {
+      mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_MAX_MESSAGE_SIZE)),
+      mMaxOutstandingPings(config.maxOutstandingPings.value_or(0)) {
 
 	onRecv(std::move(recvCallback));
 
@@ -75,7 +74,10 @@ void WsTransport::start() {
 void WsTransport::stop() { close(); }
 
 bool WsTransport::send(message_ptr message) {
-	if (!message || state() != State::Connected)
+	if (state() != State::Connected)
+		throw std::runtime_error("WebSocket is not open");
+
+	if (!message)
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -146,10 +148,22 @@ void WsTransport::incoming(message_ptr message) {
 					sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
 					addOutstandingPing();
 				} else {
-					Frame frame;
-					while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
-						recvFrame(frame);
+					if (mIgnoreLength > 0) {
+						size_t len = std::min(mIgnoreLength, mBuffer.size());
 						mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+						mIgnoreLength -= len;
+					}
+					if (mIgnoreLength == 0) {
+						Frame frame;
+						while (size_t len = parseFrame(mBuffer.data(), mBuffer.size(), frame)) {
+							recvFrame(frame);
+							if (len > mBuffer.size()) {
+								mIgnoreLength = len - mBuffer.size();
+								mBuffer.clear();
+								break;
+							}
+							mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+						}
 					}
 				}
 			}
@@ -229,7 +243,7 @@ bool WsTransport::sendHttpError(int code) {
 // |                     Payload Data continued ...                |
 // +---------------------------------------------------------------+
 
-size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
+size_t WsTransport::parseFrame(byte *buffer, size_t size, Frame &frame) {
 	const byte *end = buffer + size;
 	if (end - buffer < 2)
 		return 0;
@@ -263,16 +277,25 @@ size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
 		cur += 4;
 	}
 
-	if (size_t(end - cur) < frame.length)
+	const size_t maxControlFrameLength = 125;
+	const size_t maxFrameLength = std::max(maxControlFrameLength, mMaxMessageSize);
+	if (size_t(end - cur) < std::min(frame.length, maxFrameLength))
 		return 0;
 
+	size_t length = frame.length;
+	if (frame.length > maxFrameLength) {
+		PLOG_WARNING << "WebSocket frame is too large (length=" << frame.length
+		             << "), truncating it";
+		frame.length = maxFrameLength;
+	}
+
 	frame.payload = cur;
+
 	if (maskingKey)
 		for (size_t i = 0; i < frame.length; ++i)
 			frame.payload[i] ^= maskingKey[i % 4];
-	cur += frame.length;
 
-	return size_t(cur - buffer);
+	return frame.payload + length - buffer; // can be more than buffer size
 }
 
 void WsTransport::recvFrame(const Frame &frame) {
@@ -282,10 +305,15 @@ void WsTransport::recvFrame(const Frame &frame) {
 	switch (frame.opcode) {
 	case TEXT_FRAME:
 	case BINARY_FRAME: {
+		size_t size = frame.length;
+		if (size > mMaxMessageSize) {
+			PLOG_WARNING << "WebSocket message is too large, truncating it";
+			size = mMaxMessageSize;
+		}
 		if (!mPartial.empty()) {
 			PLOG_WARNING << "WebSocket unfinished message: type="
 			             << (mPartialOpcode == TEXT_FRAME ? "text" : "binary")
-			             << ", length=" << mPartial.size();
+			             << ", size=" << mPartial.size();
 			auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
 			recv(make_message(mPartial.begin(), mPartial.end(), type));
 			mPartial.clear();
@@ -293,21 +321,24 @@ void WsTransport::recvFrame(const Frame &frame) {
 		mPartialOpcode = frame.opcode;
 		if (frame.fin) {
 			PLOG_DEBUG << "WebSocket finished message: type="
-			           << (frame.opcode == TEXT_FRAME ? "text" : "binary")
-			           << ", length=" << frame.length;
+			           << (frame.opcode == TEXT_FRAME ? "text" : "binary") << ", size=" << size;
 			auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
-			recv(make_message(frame.payload, frame.payload + frame.length, type));
+			recv(make_message(frame.payload, frame.payload + size, type));
 		} else {
-			mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
+			mPartial.insert(mPartial.end(), frame.payload, frame.payload + size);
 		}
 		break;
 	}
 	case CONTINUATION: {
 		mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
+		if (mPartial.size() > mMaxMessageSize) {
+			PLOG_WARNING << "WebSocket message is too large, truncating it";
+			mPartial.resize(mMaxMessageSize);
+		}
 		if (frame.fin) {
 			PLOG_DEBUG << "WebSocket finished message: type="
 			           << (frame.opcode == TEXT_FRAME ? "text" : "binary")
-			           << ", length=" << mPartial.size();
+			           << ", size=" << mPartial.size();
 			auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
 			recv(make_message(mPartial.begin(), mPartial.end(), type));
 			mPartial.clear();

+ 10 - 6
src/impl/wstransport.hpp

@@ -11,6 +11,7 @@
 
 #include "common.hpp"
 #include "transport.hpp"
+#include "configuration.hpp"
 #include "wshandshake.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
@@ -25,11 +26,12 @@ class TlsTransport;
 
 class WsTransport final : public Transport, public std::enable_shared_from_this<WsTransport> {
 public:
-	WsTransport(
-	    variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
-	        lower,
-	    shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
-	    state_callback stateCallback);
+	using LowerTransport =
+	    variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>;
+
+	WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
+	            const WebSocketConfiguration &config, message_callback recvCallback,
+	            state_callback stateCallback);
 	~WsTransport();
 
 	void start() override;
@@ -62,7 +64,7 @@ private:
 	bool sendHttpError(int code);
 	bool sendHttpResponse();
 
-	size_t readFrame(byte *buffer, size_t size, Frame &frame);
+	size_t parseFrame(byte *buffer, size_t size, Frame &frame);
 	void recvFrame(const Frame &frame);
 	bool sendFrame(const Frame &frame);
 
@@ -70,11 +72,13 @@ private:
 
 	const shared_ptr<WsHandshake> mHandshake;
 	const bool mIsClient;
+	const size_t mMaxMessageSize;
 	const int mMaxOutstandingPings;
 
 	binary mBuffer;
 	binary mPartial;
 	Opcode mPartialOpcode;
+	size_t mIgnoreLength = 0;
 	std::mutex mSendMutex;
 	int mOutstandingPings = 0;
 	std::atomic<bool> mCloseSent = false;