Browse Source

Verify pong replies on ping messages

Allow the ping interval to be configured for a websocket.
Allow to fail the connection when a configured amount of
pings are not answered with pongs.
Tom Deblauwe 3 years ago
parent
commit
5ef8fc8f45

+ 2 - 0
include/rtc/websocket.hpp

@@ -46,6 +46,8 @@ public:
 		bool disableTlsVerification = false; // if true, don't verify the TLS certificate
 		optional<ProxyServer> proxyServer;   // unsupported for now
 		std::vector<string> protocols;
+		int maxMissedPongsAllowed = 2;                   // -1 is no check
+		std::chrono::milliseconds sendPingInterval = std::chrono::seconds(10); // interval at which to send pings
 	};
 
 	WebSocket();

+ 3 - 1
src/impl/tcptransport.cpp

@@ -123,6 +123,8 @@ bool TcpTransport::outgoing(message_ptr message) {
 
 string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
 
+void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) { mReadTimeout = readTimeout; }
+
 void TcpTransport::connect() {
 	PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
 	changeState(State::Connecting);
@@ -247,7 +249,7 @@ void TcpTransport::prepare(const sockaddr *addr, socklen_t addrlen) {
 }
 
 void TcpTransport::setPoll(PollService::Direction direction) {
-	const auto timeout = 10s;
+	const auto timeout = mReadTimeout;
 	PollService::Instance().add(
 	    mSock,
 	    {direction, direction == PollService::Direction::In ? make_optional(timeout) : nullopt,

+ 2 - 0
src/impl/tcptransport.hpp

@@ -47,6 +47,7 @@ public:
 	bool isActive() const { return mIsActive; }
 
 	string remoteAddress() const;
+	void setReadTimeout(std::chrono::milliseconds readTimeout);
 
 private:
 	void connect();
@@ -61,6 +62,7 @@ private:
 
 	const bool mIsActive;
 	string mHostname, mService;
+	std::chrono::milliseconds mReadTimeout = std::chrono::seconds(10);
 
 	socket_t mSock;
 	Queue<message_ptr> mSendQueue;

+ 5 - 2
src/impl/websocket.cpp

@@ -113,7 +113,9 @@ void WebSocket::open(const string &url) {
 	std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
 
 	changeState(State::Connecting);
-	setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
+	auto tcpTransport = std::make_shared<TcpTransport>(hostname, service, nullptr);
+	tcpTransport->setReadTimeout(config.sendPingInterval);
+	setTcpTransport(tcpTransport);
 }
 
 void WebSocket::close() {
@@ -360,7 +362,8 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		};
 
 		auto transport = std::make_shared<WsTransport>(
-		    lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback);
+			lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback,
+			config.maxMissedPongsAllowed);
 
 		return emplaceTransport(this, &mWsTransport, std::move(transport));
 

+ 15 - 4
src/impl/wstransport.cpp

@@ -54,14 +54,15 @@ using random_bytes_engine =
 
 WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
                          shared_ptr<WsHandshake> handshake, message_callback recvCallback,
-                         state_callback stateCallback)
+                         state_callback stateCallback, int maxPongsMissed)
     : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
                 std::move(stateCallback)),
       mHandshake(std::move(handshake)),
       mIsClient(
           std::visit(rtc::overloaded{[](shared_ptr<TcpTransport> l) { return l->isActive(); },
                                      [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
-                     lower)) {
+                     lower)),
+      mMaxPongsMissed(maxPongsMissed) {
 
 	onRecv(std::move(recvCallback));
 
@@ -132,7 +133,7 @@ void WsTransport::incoming(message_ptr message) {
 					PLOG_DEBUG << "WebSocket sending ping";
 					uint32_t dummy = 0;
 					sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
-
+					addOpenPing();
 				} else {
 					Frame frame;
 					while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
@@ -312,11 +313,14 @@ void WsTransport::recvFrame(const Frame &frame) {
 	}
 	case PING: {
 		PLOG_DEBUG << "WebSocket received ping, sending pong";
-		sendFrame({PONG, frame.payload, frame.length, true, mIsClient});
+		if (!sendFrame({PONG, frame.payload, frame.length, true, mIsClient})) {
+			PLOG_ERROR << "WebSocket could not send ping";
+		}
 		break;
 	}
 	case PONG: {
 		PLOG_DEBUG << "WebSocket received pong";
+		mPingsOpen = 0;
 		break;
 	}
 	case CLOSE: {
@@ -371,6 +375,13 @@ bool WsTransport::sendFrame(const Frame &frame) {
 	return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
 }
 
+void WsTransport::addOpenPing() {
+	++mPingsOpen;
+	if (mMaxPongsMissed > 0 && mPingsOpen > mMaxPongsMissed) {
+		changeState(State::Failed);
+	}
+}
+
 } // namespace rtc::impl
 
 #endif

+ 5 - 1
src/impl/wstransport.hpp

@@ -34,7 +34,7 @@ class WsTransport final : public Transport {
 public:
 	WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
 	            shared_ptr<WsHandshake> handshake, message_callback recvCallback,
-	            state_callback stateCallback);
+				state_callback stateCallback, int maxPongsMissed);
 	~WsTransport();
 
 	void start() override;
@@ -71,12 +71,16 @@ private:
 	void recvFrame(const Frame &frame);
 	bool sendFrame(const Frame &frame);
 
+	void addOpenPing();
+
 	const shared_ptr<WsHandshake> mHandshake;
 	const bool mIsClient;
+	const int mMaxPongsMissed;
 
 	binary mBuffer;
 	binary mPartial;
 	Opcode mPartialOpcode;
+	int mPingsOpen = 0;
 };
 
 } // namespace rtc::impl