Browse Source

Merge pull request #627 from deblauwetom/websocketPingTimeouts

Verify pong replies on ping messages
Paul-Louis Ageneau 3 years ago
parent
commit
989c84dd25

+ 3 - 0
include/rtc/websocket.hpp

@@ -46,6 +46,9 @@ public:
 		bool disableTlsVerification = false; // if true, don't verify the TLS certificate
 		optional<ProxyServer> proxyServer;   // unsupported for now
 		std::vector<string> protocols;
+		optional<int> maxOutstandingPings;
+		std::chrono::milliseconds pingInterval =
+			std::chrono::seconds(10); // interval at which to send pings
 	};
 
 	WebSocket();

+ 3 - 1
src/impl/tcptransport.cpp

@@ -123,6 +123,8 @@ bool TcpTransport::outgoing(message_ptr message) {
 
 string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
 
+void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) { mReadTimeout = readTimeout; }
+
 void TcpTransport::connect() {
 	PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
 	changeState(State::Connecting);
@@ -247,7 +249,7 @@ void TcpTransport::prepare(const sockaddr *addr, socklen_t addrlen) {
 }
 
 void TcpTransport::setPoll(PollService::Direction direction) {
-	const auto timeout = 10s;
+	const auto timeout = mReadTimeout;
 	PollService::Instance().add(
 	    mSock,
 	    {direction, direction == PollService::Direction::In ? make_optional(timeout) : nullopt,

+ 2 - 0
src/impl/tcptransport.hpp

@@ -47,6 +47,7 @@ public:
 	bool isActive() const { return mIsActive; }
 
 	string remoteAddress() const;
+	void setReadTimeout(std::chrono::milliseconds readTimeout);
 
 private:
 	void connect();
@@ -61,6 +62,7 @@ private:
 
 	const bool mIsActive;
 	string mHostname, mService;
+	std::chrono::milliseconds mReadTimeout = std::chrono::seconds(10);
 
 	socket_t mSock;
 	Queue<message_ptr> mSendQueue;

+ 5 - 2
src/impl/websocket.cpp

@@ -113,7 +113,9 @@ void WebSocket::open(const string &url) {
 	std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
 
 	changeState(State::Connecting);
-	setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
+	auto tcpTransport = std::make_shared<TcpTransport>(hostname, service, nullptr);
+	tcpTransport->setReadTimeout(config.pingInterval);
+	setTcpTransport(tcpTransport);
 }
 
 void WebSocket::close() {
@@ -360,7 +362,8 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		};
 
 		auto transport = std::make_shared<WsTransport>(
-		    lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback);
+			lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback,
+			config.maxOutstandingPings);
 
 		return emplaceTransport(this, &mWsTransport, std::move(transport));
 

+ 12 - 3
src/impl/wstransport.cpp

@@ -54,14 +54,15 @@ using random_bytes_engine =
 
 WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
                          shared_ptr<WsHandshake> handshake, message_callback recvCallback,
-                         state_callback stateCallback)
+                         state_callback stateCallback, std::optional<int> maxOutstandingPings)
     : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
                 std::move(stateCallback)),
       mHandshake(std::move(handshake)),
       mIsClient(
           std::visit(rtc::overloaded{[](shared_ptr<TcpTransport> l) { return l->isActive(); },
                                      [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
-                     lower)) {
+                     lower)),
+      mMaxPongsMissed(maxOutstandingPings) {
 
 	onRecv(std::move(recvCallback));
 
@@ -132,7 +133,7 @@ void WsTransport::incoming(message_ptr message) {
 					PLOG_DEBUG << "WebSocket sending ping";
 					uint32_t dummy = 0;
 					sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
-
+					addOutstandingPing();
 				} else {
 					Frame frame;
 					while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
@@ -317,6 +318,7 @@ void WsTransport::recvFrame(const Frame &frame) {
 	}
 	case PONG: {
 		PLOG_DEBUG << "WebSocket received pong";
+		mPingsOutstanding = 0;
 		break;
 	}
 	case CLOSE: {
@@ -371,6 +373,13 @@ bool WsTransport::sendFrame(const Frame &frame) {
 	return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
 }
 
+void WsTransport::addOutstandingPing() {
+	++mPingsOutstanding;
+	if (mMaxPongsMissed && *mMaxPongsMissed > 0 && mPingsOutstanding > *mMaxPongsMissed) {
+		changeState(State::Failed);
+	}
+}
+
 } // namespace rtc::impl
 
 #endif

+ 6 - 2
src/impl/wstransport.hpp

@@ -33,8 +33,8 @@ class TlsTransport;
 class WsTransport final : public Transport {
 public:
 	WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
-	            shared_ptr<WsHandshake> handshake, message_callback recvCallback,
-	            state_callback stateCallback);
+				shared_ptr<WsHandshake> handshake, message_callback recvCallback,
+				state_callback stateCallback, optional<int> maxOutstandingPings);
 	~WsTransport();
 
 	void start() override;
@@ -71,12 +71,16 @@ private:
 	void recvFrame(const Frame &frame);
 	bool sendFrame(const Frame &frame);
 
+	void addOutstandingPing();
+
 	const shared_ptr<WsHandshake> mHandshake;
 	const bool mIsClient;
+	const optional<int> mMaxPongsMissed;
 
 	binary mBuffer;
 	binary mPartial;
 	Opcode mPartialOpcode;
+	int mPingsOutstanding = 0;
 };
 
 } // namespace rtc::impl