2
0
Эх сурвалжийг харах

Merge pull request #1094 from paullouisageneau/websocket-enforce-max-message-size

Enforce WebSocket max message size at reception
Paul-Louis Ageneau 1 жил өмнө
parent
commit
dbdfb49a2b

+ 29 - 0
include/rtc/configuration.hpp

@@ -88,6 +88,35 @@ struct RTC_CPP_EXPORT Configuration {
 	optional<size_t> maxMessageSize;
 };
 
+#ifdef RTC_ENABLE_WEBSOCKET
+
+struct WebSocketConfiguration {
+	bool disableTlsVerification = false; // if true, don't verify the TLS certificate
+	optional<ProxyServer> proxyServer;   // only non-authenticated http supported for now
+	std::vector<string> protocols;
+	optional<std::chrono::milliseconds> connectionTimeout; // zero to disable
+	optional<std::chrono::milliseconds> pingInterval;      // zero to disable
+	optional<int> maxOutstandingPings;
+	optional<string> caCertificatePemFile;
+	optional<string> certificatePemFile;
+	optional<string> keyPemFile;
+	optional<string> keyPemPass;
+	optional<size_t> maxMessageSize;
+};
+
+struct WebSocketServerConfiguration {
+	uint16_t port = 8080;
+	bool enableTls = false;
+	optional<string> certificatePemFile;
+	optional<string> keyPemFile;
+	optional<string> keyPemPass;
+	optional<string> bindAddress;
+	optional<std::chrono::milliseconds> connectionTimeout;
+	optional<size_t> maxMessageSize;
+};
+
+#endif
+
 } // namespace rtc
 
 #endif

+ 3 - 1
include/rtc/rtc.h

@@ -422,6 +422,7 @@ typedef struct {
 	int connectionTimeoutMs; // in milliseconds, 0 means default, < 0 means disabled
 	int pingIntervalMs;      // in milliseconds, 0 means default, < 0 means disabled
 	int maxOutstandingPings; // 0 means default, < 0 means disabled
+	int maxMessageSize;      // <= 0 means default
 } rtcWsConfiguration;
 
 RTC_C_EXPORT int rtcCreateWebSocket(const char *url); // returns ws id
@@ -441,8 +442,9 @@ typedef struct {
 	const char *certificatePemFile; // NULL for autogenerated certificate
 	const char *keyPemFile;         // NULL for autogenerated certificate
 	const char *keyPemPass;         // NULL if no pass
-	const char *bindAddress;        // NULL for IP_ANY_ADDR
+	const char *bindAddress;        // NULL for any
 	int connectionTimeoutMs;        // in milliseconds, 0 means default, < 0 means disabled
+	int maxMessageSize;             // <= 0 means default
 } rtcWsServerConfiguration;
 
 RTC_C_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config,

+ 2 - 14
include/rtc/websocket.hpp

@@ -13,7 +13,7 @@
 
 #include "channel.hpp"
 #include "common.hpp"
-#include "configuration.hpp" // for ProxyServer
+#include "configuration.hpp"
 
 namespace rtc {
 
@@ -32,19 +32,7 @@ public:
 		Closed = 3,
 	};
 
-	struct Configuration {
-		bool disableTlsVerification = false; // if true, don't verify the TLS certificate
-		optional<ProxyServer> proxyServer;   // only non-authenticated http supported for now
-		std::vector<string> protocols;
-		optional<std::chrono::milliseconds> connectionTimeout; // zero to disable
-		optional<std::chrono::milliseconds> pingInterval;      // zero to disable
-		optional<int> maxOutstandingPings;
-		optional<string> caCertificatePemFile;
-		optional<string> certificatePemFile;
-		optional<string> keyPemFile;
-		optional<string> keyPemPass;
-		optional<size_t> maxMessageSize;
-	};
+	using Configuration = WebSocketConfiguration;
 
 	WebSocket();
 	WebSocket(Configuration config);

+ 2 - 9
include/rtc/websocketserver.hpp

@@ -12,6 +12,7 @@
 #if RTC_ENABLE_WEBSOCKET
 
 #include "common.hpp"
+#include "configuration.hpp"
 #include "websocket.hpp"
 
 namespace rtc {
@@ -24,15 +25,7 @@ struct WebSocketServer;
 
 class RTC_CPP_EXPORT WebSocketServer final : private CheshireCat<impl::WebSocketServer> {
 public:
-	struct Configuration {
-		uint16_t port = 8080;
-		bool enableTls = false;
-		optional<string> certificatePemFile;
-		optional<string> keyPemFile;
-		optional<string> keyPemPass;
-		optional<string> bindAddress;
-		optional<std::chrono::milliseconds> connectionTimeout;
-	};
+	using Configuration = WebSocketServerConfiguration;
 
 	WebSocketServer();
 	WebSocketServer(Configuration config);

+ 7 - 0
src/capi.cpp

@@ -1479,6 +1479,9 @@ int rtcCreateWebSocketEx(const char *url, const rtcWsConfiguration *config) {
 		else if (config->maxOutstandingPings < 0)
 			c.maxOutstandingPings = 0; // setting to 0 disables, not setting keeps default
 
+		if(config->maxMessageSize > 0)
+			c.maxMessageSize = size_t(config->maxMessageSize);
+
 		auto webSocket = std::make_shared<WebSocket>(std::move(c));
 		webSocket->open(url);
 		return emplaceWebSocket(webSocket);
@@ -1533,6 +1536,10 @@ RTC_C_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config
 		c.keyPemFile = config->keyPemFile ? make_optional(string(config->keyPemFile)) : nullopt;
 		c.keyPemPass = config->keyPemPass ? make_optional(string(config->keyPemPass)) : nullopt;
 		c.bindAddress = config->bindAddress ? make_optional(string(config->bindAddress)) : nullopt;
+
+		if(config->maxMessageSize > 0)
+			c.maxMessageSize = size_t(config->maxMessageSize);
+
 		auto webSocketServer = std::make_shared<WebSocketServer>(std::move(c));
 		int wsserver = emplaceWebSocketServer(webSocketServer);
 

+ 1 - 1
src/channel.cpp

@@ -17,7 +17,7 @@ Channel::~Channel() { impl()->resetCallbacks(); }
 
 Channel::Channel(impl_ptr<impl::Channel> impl) : CheshireCat<impl::Channel>(std::move(impl)) {}
 
-size_t Channel::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
+size_t Channel::maxMessageSize() const { return 0; }
 
 size_t Channel::bufferedAmount() const { return impl()->bufferedAmount; }
 

+ 1 - 1
src/impl/datachannel.cpp

@@ -153,7 +153,7 @@ bool DataChannel::isClosed(void) const { return mIsClosed; }
 
 size_t DataChannel::maxMessageSize() const {
 	auto pc = mPeerConnection.lock();
-	return pc ? pc->remoteMaxMessageSize() : DEFAULT_MAX_MESSAGE_SIZE;
+	return pc ? pc->remoteMaxMessageSize() : DEFAULT_REMOTE_MAX_MESSAGE_SIZE;
 }
 
 void DataChannel::assignStream(uint16_t stream) {

+ 3 - 1
src/impl/internals.hpp

@@ -39,7 +39,9 @@ const uint16_t MAX_SCTP_STREAMS_COUNT = 1024; // Max number of negotiated SCTP s
                                               // of memory, Chromium historically limits to 1024.
 
 const size_t DEFAULT_LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Default local max message size
-const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP
+const size_t DEFAULT_REMOTE_MAX_MESSAGE_SIZE = 65536;     // Remote max message size if not in SDP
+
+const size_t DEFAULT_WS_MAX_MESSAGE_SIZE = 256 * 1024;   // Default max message size for WebSockets
 
 const size_t RECV_QUEUE_LIMIT = 1024 * 1024; // Max per-channel queue size
 

+ 1 - 1
src/impl/peerconnection.cpp

@@ -103,7 +103,7 @@ optional<Description> PeerConnection::remoteDescription() const {
 size_t PeerConnection::remoteMaxMessageSize() const {
 	const size_t localMax = config.maxMessageSize.value_or(DEFAULT_LOCAL_MAX_MESSAGE_SIZE);
 
-	size_t remoteMax = DEFAULT_MAX_MESSAGE_SIZE;
+	size_t remoteMax = DEFAULT_REMOTE_MAX_MESSAGE_SIZE;
 	std::lock_guard lock(mRemoteDescriptionMutex);
 	if (mRemoteDescription)
 		if (auto *application = mRemoteDescription->application())

+ 2 - 3
src/impl/websocket.cpp

@@ -156,7 +156,7 @@ bool WebSocket::isOpen() const { return state == State::Open; }
 
 bool WebSocket::isClosed() const { return state == State::Closed; }
 
-size_t WebSocket::maxMessageSize() const { return config.maxMessageSize.value_or(DEFAULT_MAX_MESSAGE_SIZE); }
+size_t WebSocket::maxMessageSize() const { return config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE); }
 
 optional<message_variant> WebSocket::receive() {
 	auto next = mRecvQueue.pop();
@@ -443,8 +443,7 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 			}
 		};
 
-		auto maxOutstandingPings = config.maxOutstandingPings.value_or(0);
-		auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, maxOutstandingPings,
+		auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, config,
 		                                               weak_bind(&WebSocket::incoming, this, _1),
 		                                               stateChangeCallback);
 

+ 1 - 0
src/impl/websocketserver.cpp

@@ -79,6 +79,7 @@ void WebSocketServer::runLoop() {
 
 				WebSocket::Configuration clientConfig;
 				clientConfig.connectionTimeout = config.connectionTimeout;
+				clientConfig.maxMessageSize = config.maxMessageSize;
 
 				auto impl = std::make_shared<WebSocket>(std::move(clientConfig), mCertificate);
 				impl->changeState(WebSocket::State::Connecting);

+ 51 - 20
src/impl/wstransport.cpp

@@ -43,11 +43,9 @@ using std::to_integer;
 using std::to_string;
 using std::chrono::system_clock;
 
-WsTransport::WsTransport(
-    variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
-        lower,
-    shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
-    state_callback stateCallback)
+WsTransport::WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
+                         const WebSocketConfiguration &config, message_callback recvCallback,
+                         state_callback stateCallback)
     : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
                 std::move(stateCallback)),
       mHandshake(std::move(handshake)),
@@ -55,7 +53,8 @@ WsTransport::WsTransport(
           std::visit(rtc::overloaded{[](auto l) { return l->isActive(); },
                                      [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
                      lower)),
-      mMaxOutstandingPings(maxOutstandingPings) {
+      mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE)),
+      mMaxOutstandingPings(config.maxOutstandingPings.value_or(0)) {
 
 	onRecv(std::move(recvCallback));
 
@@ -75,7 +74,10 @@ void WsTransport::start() {
 void WsTransport::stop() { close(); }
 
 bool WsTransport::send(message_ptr message) {
-	if (!message || state() != State::Connected)
+	if (state() != State::Connected)
+		throw std::runtime_error("WebSocket is not open");
+
+	if (!message)
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -146,10 +148,22 @@ void WsTransport::incoming(message_ptr message) {
 					sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
 					addOutstandingPing();
 				} else {
-					Frame frame;
-					while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
-						recvFrame(frame);
+					if (mIgnoreLength > 0) {
+						size_t len = std::min(mIgnoreLength, mBuffer.size());
 						mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+						mIgnoreLength -= len;
+					}
+					if (mIgnoreLength == 0) {
+						Frame frame;
+						while (size_t len = parseFrame(mBuffer.data(), mBuffer.size(), frame)) {
+							recvFrame(frame);
+							if (len > mBuffer.size()) {
+								mIgnoreLength = len - mBuffer.size();
+								mBuffer.clear();
+								break;
+							}
+							mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+						}
 					}
 				}
 			}
@@ -229,7 +243,7 @@ bool WsTransport::sendHttpError(int code) {
 // |                     Payload Data continued ...                |
 // +---------------------------------------------------------------+
 
-size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
+size_t WsTransport::parseFrame(byte *buffer, size_t size, Frame &frame) {
 	const byte *end = buffer + size;
 	if (end - buffer < 2)
 		return 0;
@@ -263,16 +277,25 @@ size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
 		cur += 4;
 	}
 
-	if (size_t(end - cur) < frame.length)
+	const size_t maxControlFrameLength = 125;
+	const size_t maxFrameLength = std::max(maxControlFrameLength, mMaxMessageSize);
+	if (size_t(end - cur) < std::min(frame.length, maxFrameLength))
 		return 0;
 
+	size_t length = frame.length;
+	if (frame.length > maxFrameLength) {
+		PLOG_WARNING << "WebSocket frame is too large (length=" << frame.length
+		             << "), truncating it";
+		frame.length = maxFrameLength;
+	}
+
 	frame.payload = cur;
+
 	if (maskingKey)
 		for (size_t i = 0; i < frame.length; ++i)
 			frame.payload[i] ^= maskingKey[i % 4];
-	cur += frame.length;
 
-	return size_t(cur - buffer);
+	return frame.payload + length - buffer; // can be more than buffer size
 }
 
 void WsTransport::recvFrame(const Frame &frame) {
@@ -282,10 +305,15 @@ void WsTransport::recvFrame(const Frame &frame) {
 	switch (frame.opcode) {
 	case TEXT_FRAME:
 	case BINARY_FRAME: {
+		size_t size = frame.length;
+		if (size > mMaxMessageSize) {
+			PLOG_WARNING << "WebSocket message is too large, truncating it";
+			size = mMaxMessageSize;
+		}
 		if (!mPartial.empty()) {
 			PLOG_WARNING << "WebSocket unfinished message: type="
 			             << (mPartialOpcode == TEXT_FRAME ? "text" : "binary")
-			             << ", length=" << mPartial.size();
+			             << ", size=" << mPartial.size();
 			auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
 			recv(make_message(mPartial.begin(), mPartial.end(), type));
 			mPartial.clear();
@@ -293,21 +321,24 @@ void WsTransport::recvFrame(const Frame &frame) {
 		mPartialOpcode = frame.opcode;
 		if (frame.fin) {
 			PLOG_DEBUG << "WebSocket finished message: type="
-			           << (frame.opcode == TEXT_FRAME ? "text" : "binary")
-			           << ", length=" << frame.length;
+			           << (frame.opcode == TEXT_FRAME ? "text" : "binary") << ", size=" << size;
 			auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
-			recv(make_message(frame.payload, frame.payload + frame.length, type));
+			recv(make_message(frame.payload, frame.payload + size, type));
 		} else {
-			mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
+			mPartial.insert(mPartial.end(), frame.payload, frame.payload + size);
 		}
 		break;
 	}
 	case CONTINUATION: {
 		mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
+		if (mPartial.size() > mMaxMessageSize) {
+			PLOG_WARNING << "WebSocket message is too large, truncating it";
+			mPartial.resize(mMaxMessageSize);
+		}
 		if (frame.fin) {
 			PLOG_DEBUG << "WebSocket finished message: type="
 			           << (frame.opcode == TEXT_FRAME ? "text" : "binary")
-			           << ", length=" << mPartial.size();
+			           << ", size=" << mPartial.size();
 			auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
 			recv(make_message(mPartial.begin(), mPartial.end(), type));
 			mPartial.clear();

+ 10 - 6
src/impl/wstransport.hpp

@@ -11,6 +11,7 @@
 
 #include "common.hpp"
 #include "transport.hpp"
+#include "configuration.hpp"
 #include "wshandshake.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
@@ -25,11 +26,12 @@ class TlsTransport;
 
 class WsTransport final : public Transport, public std::enable_shared_from_this<WsTransport> {
 public:
-	WsTransport(
-	    variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
-	        lower,
-	    shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
-	    state_callback stateCallback);
+	using LowerTransport =
+	    variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>;
+
+	WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
+	            const WebSocketConfiguration &config, message_callback recvCallback,
+	            state_callback stateCallback);
 	~WsTransport();
 
 	void start() override;
@@ -62,7 +64,7 @@ private:
 	bool sendHttpError(int code);
 	bool sendHttpResponse();
 
-	size_t readFrame(byte *buffer, size_t size, Frame &frame);
+	size_t parseFrame(byte *buffer, size_t size, Frame &frame);
 	void recvFrame(const Frame &frame);
 	bool sendFrame(const Frame &frame);
 
@@ -70,11 +72,13 @@ private:
 
 	const shared_ptr<WsHandshake> mHandshake;
 	const bool mIsClient;
+	const size_t mMaxMessageSize;
 	const int mMaxOutstandingPings;
 
 	binary mBuffer;
 	binary mPartial;
 	Opcode mPartialOpcode;
+	size_t mIgnoreLength = 0;
 	std::mutex mSendMutex;
 	int mOutstandingPings = 0;
 	std::atomic<bool> mCloseSent = false;

+ 15 - 5
test/websocketserver.cpp

@@ -24,14 +24,13 @@ template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 void test_websocketserver() {
 	InitLogger(LogLevel::Debug);
 
-	const string myMessage = "Hello world from client";
-
 	WebSocketServer::Configuration serverConfig;
 	serverConfig.port = 48080;
 	serverConfig.enableTls = true;
 	// serverConfig.certificatePemFile = ...
 	// serverConfig.keyPemFile = ...
 	serverConfig.bindAddress = "127.0.0.1"; // to test IPv4 fallback
+	serverConfig.maxMessageSize = 1000;     // to test max message size
 	WebSocketServer server(std::move(serverConfig));
 
 	shared_ptr<WebSocket> client;
@@ -63,15 +62,19 @@ void test_websocketserver() {
 	config.disableTlsVerification = true;
 	WebSocket ws(std::move(config));
 
+	const string myMessage = "Hello world from client";
+
 	ws.onOpen([&ws, &myMessage]() {
 		cout << "WebSocket: Open" << endl;
+		ws.send(binary(1001, byte(0))); // test max message size
 		ws.send(myMessage);
 	});
 
 	ws.onClosed([]() { cout << "WebSocket: Closed" << endl; });
 
 	std::atomic<bool> received = false;
-	ws.onMessage([&received, &myMessage](variant<binary, string> message) {
+	std::atomic<bool> maxSizeReceived = false;
+	ws.onMessage([&received, &maxSizeReceived, &myMessage](variant<binary, string> message) {
 		if (holds_alternative<string>(message)) {
 			string str = std::move(get<string>(message));
 			if ((received = (str == myMessage)))
@@ -79,6 +82,13 @@ void test_websocketserver() {
 			else
 				cout << "WebSocket: Received UNEXPECTED message" << endl;
 		}
+		else {
+			binary bin = std::move(get<binary>(message));
+			if ((maxSizeReceived = (bin.size() == 1000)))
+				cout << "WebSocket: Received large message truncated at max size" << endl;
+			else
+				cout << "WebSocket: Received large message NOT TRUNCATED" << endl;
+		}
 	});
 
 	ws.open("wss://localhost:48080/");
@@ -90,8 +100,8 @@ void test_websocketserver() {
 	if (!ws.isOpen())
 		throw runtime_error("WebSocket is not open");
 
-	if (!received)
-		throw runtime_error("Expected message not received");
+	if (!received || !maxSizeReceived)
+		throw runtime_error("Expected messages not received");
 
 	ws.close();
 	this_thread::sleep_for(1s);