|
@@ -43,11 +43,9 @@ using std::to_integer;
|
|
|
using std::to_string;
|
|
|
using std::chrono::system_clock;
|
|
|
|
|
|
-WsTransport::WsTransport(
|
|
|
- variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
|
|
|
- lower,
|
|
|
- shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
|
|
|
- state_callback stateCallback)
|
|
|
+WsTransport::WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
|
|
|
+ const WebSocketConfiguration &config, message_callback recvCallback,
|
|
|
+ state_callback stateCallback)
|
|
|
: Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
|
|
|
std::move(stateCallback)),
|
|
|
mHandshake(std::move(handshake)),
|
|
@@ -55,7 +53,8 @@ WsTransport::WsTransport(
|
|
|
std::visit(rtc::overloaded{[](auto l) { return l->isActive(); },
|
|
|
[](shared_ptr<TlsTransport> l) { return l->isClient(); }},
|
|
|
lower)),
|
|
|
- mMaxOutstandingPings(maxOutstandingPings) {
|
|
|
+ mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE)),
|
|
|
+ mMaxOutstandingPings(config.maxOutstandingPings.value_or(0)) {
|
|
|
|
|
|
onRecv(std::move(recvCallback));
|
|
|
|
|
@@ -75,7 +74,10 @@ void WsTransport::start() {
|
|
|
void WsTransport::stop() { close(); }
|
|
|
|
|
|
bool WsTransport::send(message_ptr message) {
|
|
|
- if (!message || state() != State::Connected)
|
|
|
+ if (state() != State::Connected)
|
|
|
+ throw std::runtime_error("WebSocket is not open");
|
|
|
+
|
|
|
+ if (!message)
|
|
|
return false;
|
|
|
|
|
|
PLOG_VERBOSE << "Send size=" << message->size();
|
|
@@ -146,10 +148,22 @@ void WsTransport::incoming(message_ptr message) {
|
|
|
sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
|
|
|
addOutstandingPing();
|
|
|
} else {
|
|
|
- Frame frame;
|
|
|
- while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
|
|
|
- recvFrame(frame);
|
|
|
+ if (mIgnoreLength > 0) {
|
|
|
+ size_t len = std::min(mIgnoreLength, mBuffer.size());
|
|
|
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
|
|
|
+ mIgnoreLength -= len;
|
|
|
+ }
|
|
|
+ if (mIgnoreLength == 0) {
|
|
|
+ Frame frame;
|
|
|
+ while (size_t len = parseFrame(mBuffer.data(), mBuffer.size(), frame)) {
|
|
|
+ recvFrame(frame);
|
|
|
+ if (len > mBuffer.size()) {
|
|
|
+ mIgnoreLength = len - mBuffer.size();
|
|
|
+ mBuffer.clear();
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -229,7 +243,7 @@ bool WsTransport::sendHttpError(int code) {
|
|
|
// | Payload Data continued ... |
|
|
|
// +---------------------------------------------------------------+
|
|
|
|
|
|
-size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
|
|
|
+size_t WsTransport::parseFrame(byte *buffer, size_t size, Frame &frame) {
|
|
|
const byte *end = buffer + size;
|
|
|
if (end - buffer < 2)
|
|
|
return 0;
|
|
@@ -263,16 +277,25 @@ size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
|
|
|
cur += 4;
|
|
|
}
|
|
|
|
|
|
- if (size_t(end - cur) < frame.length)
|
|
|
+ const size_t maxControlFrameLength = 125;
|
|
|
+ const size_t maxFrameLength = std::max(maxControlFrameLength, mMaxMessageSize);
|
|
|
+ if (size_t(end - cur) < std::min(frame.length, maxFrameLength))
|
|
|
return 0;
|
|
|
|
|
|
+ size_t length = frame.length;
|
|
|
+ if (frame.length > maxFrameLength) {
|
|
|
+ PLOG_WARNING << "WebSocket frame is too large (length=" << frame.length
|
|
|
+ << "), truncating it";
|
|
|
+ frame.length = maxFrameLength;
|
|
|
+ }
|
|
|
+
|
|
|
frame.payload = cur;
|
|
|
+
|
|
|
if (maskingKey)
|
|
|
for (size_t i = 0; i < frame.length; ++i)
|
|
|
frame.payload[i] ^= maskingKey[i % 4];
|
|
|
- cur += frame.length;
|
|
|
|
|
|
- return size_t(cur - buffer);
|
|
|
+ return frame.payload + length - buffer; // can be more than buffer size
|
|
|
}
|
|
|
|
|
|
void WsTransport::recvFrame(const Frame &frame) {
|
|
@@ -282,10 +305,15 @@ void WsTransport::recvFrame(const Frame &frame) {
|
|
|
switch (frame.opcode) {
|
|
|
case TEXT_FRAME:
|
|
|
case BINARY_FRAME: {
|
|
|
+ size_t size = frame.length;
|
|
|
+ if (size > mMaxMessageSize) {
|
|
|
+ PLOG_WARNING << "WebSocket message is too large, truncating it";
|
|
|
+ size = mMaxMessageSize;
|
|
|
+ }
|
|
|
if (!mPartial.empty()) {
|
|
|
PLOG_WARNING << "WebSocket unfinished message: type="
|
|
|
<< (mPartialOpcode == TEXT_FRAME ? "text" : "binary")
|
|
|
- << ", length=" << mPartial.size();
|
|
|
+ << ", size=" << mPartial.size();
|
|
|
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
|
|
|
recv(make_message(mPartial.begin(), mPartial.end(), type));
|
|
|
mPartial.clear();
|
|
@@ -293,21 +321,24 @@ void WsTransport::recvFrame(const Frame &frame) {
|
|
|
mPartialOpcode = frame.opcode;
|
|
|
if (frame.fin) {
|
|
|
PLOG_DEBUG << "WebSocket finished message: type="
|
|
|
- << (frame.opcode == TEXT_FRAME ? "text" : "binary")
|
|
|
- << ", length=" << frame.length;
|
|
|
+ << (frame.opcode == TEXT_FRAME ? "text" : "binary") << ", size=" << size;
|
|
|
auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
|
|
|
- recv(make_message(frame.payload, frame.payload + frame.length, type));
|
|
|
+ recv(make_message(frame.payload, frame.payload + size, type));
|
|
|
} else {
|
|
|
- mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
|
|
|
+ mPartial.insert(mPartial.end(), frame.payload, frame.payload + size);
|
|
|
}
|
|
|
break;
|
|
|
}
|
|
|
case CONTINUATION: {
|
|
|
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
|
|
|
+ if (mPartial.size() > mMaxMessageSize) {
|
|
|
+ PLOG_WARNING << "WebSocket message is too large, truncating it";
|
|
|
+ mPartial.resize(mMaxMessageSize);
|
|
|
+ }
|
|
|
if (frame.fin) {
|
|
|
PLOG_DEBUG << "WebSocket finished message: type="
|
|
|
<< (frame.opcode == TEXT_FRAME ? "text" : "binary")
|
|
|
- << ", length=" << mPartial.size();
|
|
|
+ << ", size=" << mPartial.size();
|
|
|
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
|
|
|
recv(make_message(mPartial.begin(), mPartial.end(), type));
|
|
|
mPartial.clear();
|