Browse Source

Made pingInterval optional and refactored outstanding pings

Paul-Louis Ageneau 3 years ago
parent
commit
95c558f13a

+ 2 - 3
include/rtc/websocket.hpp

@@ -23,7 +23,7 @@
 
 #include "channel.hpp"
 #include "common.hpp"
-#include "configuration.hpp"
+#include "configuration.hpp" // for ProxyServer
 
 namespace rtc {
 
@@ -46,9 +46,8 @@ public:
 		bool disableTlsVerification = false; // if true, don't verify the TLS certificate
 		optional<ProxyServer> proxyServer;   // unsupported for now
 		std::vector<string> protocols;
+		optional<std::chrono::milliseconds> pingInterval; // zero to disable
 		optional<int> maxOutstandingPings;
-		std::chrono::milliseconds pingInterval =
-			std::chrono::seconds(10); // interval at which to send pings
 	};
 
 	WebSocket();

+ 2 - 4
src/impl/tcptransport.cpp

@@ -249,11 +249,9 @@ void TcpTransport::prepare(const sockaddr *addr, socklen_t addrlen) {
 }
 
 void TcpTransport::setPoll(PollService::Direction direction) {
-	const auto timeout = mReadTimeout;
 	PollService::Instance().add(
-	    mSock,
-	    {direction, direction == PollService::Direction::In ? make_optional(timeout) : nullopt,
-	     std::bind(&TcpTransport::process, this, _1)});
+	    mSock, {direction, direction == PollService::Direction::In ? mReadTimeout : nullopt,
+	            std::bind(&TcpTransport::process, this, _1)});
 }
 
 void TcpTransport::close() {

+ 3 - 2
src/impl/tcptransport.hpp

@@ -28,6 +28,7 @@
 #if RTC_ENABLE_WEBSOCKET
 
 #include <mutex>
+#include <chrono>
 
 namespace rtc::impl {
 
@@ -45,8 +46,8 @@ public:
 	bool outgoing(message_ptr message) override;
 
 	bool isActive() const { return mIsActive; }
-
 	string remoteAddress() const;
+
 	void setReadTimeout(std::chrono::milliseconds readTimeout);
 
 private:
@@ -62,7 +63,7 @@ private:
 
 	const bool mIsActive;
 	string mHostname, mService;
-	std::chrono::milliseconds mReadTimeout = std::chrono::seconds(10);
+	optional<std::chrono::milliseconds> mReadTimeout;
 
 	socket_t mSock;
 	Queue<message_ptr> mSendQueue;

+ 12 - 6
src/impl/websocket.cpp

@@ -30,6 +30,7 @@
 #include "wstransport.hpp"
 
 #include <array>
+#include <chrono>
 #include <regex>
 
 #ifdef _WIN32
@@ -39,6 +40,7 @@
 namespace rtc::impl {
 
 using namespace std::placeholders;
+using namespace std::chrono_literals;
 
 WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certificate)
     : config(optConfig ? std::move(*optConfig) : Configuration()),
@@ -113,9 +115,7 @@ void WebSocket::open(const string &url) {
 	std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
 
 	changeState(State::Connecting);
-	auto tcpTransport = std::make_shared<TcpTransport>(hostname, service, nullptr);
-	tcpTransport->setReadTimeout(config.pingInterval);
-	setTcpTransport(tcpTransport);
+	setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
 }
 
 void WebSocket::close() {
@@ -245,6 +245,11 @@ shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> tra
 			}
 		});
 
+		// WS transport sends a ping on read timeout
+		auto pingInterval = config.pingInterval.value_or(10000ms);
+		if (pingInterval > std::chrono::milliseconds::zero())
+			transport->setReadTimeout(pingInterval);
+
 		return emplaceTransport(this, &mTcpTransport, std::move(transport));
 
 	} catch (const std::exception &e) {
@@ -361,9 +366,10 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 			}
 		};
 
-		auto transport = std::make_shared<WsTransport>(
-			lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback,
-			config.maxOutstandingPings);
+		auto maxOutstandingPings = config.maxOutstandingPings.value_or(0);
+		auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, maxOutstandingPings,
+		                                               weak_bind(&WebSocket::incoming, this, _1),
+		                                               stateChangeCallback);
 
 		return emplaceTransport(this, &mWsTransport, std::move(transport));
 

+ 6 - 6
src/impl/wstransport.cpp

@@ -53,8 +53,8 @@ using random_bytes_engine =
     std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
 
 WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
-                         shared_ptr<WsHandshake> handshake, message_callback recvCallback,
-                         state_callback stateCallback, std::optional<int> maxOutstandingPings)
+                         shared_ptr<WsHandshake> handshake, int maxOutstandingPings,
+                         message_callback recvCallback, state_callback stateCallback)
     : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
                 std::move(stateCallback)),
       mHandshake(std::move(handshake)),
@@ -62,7 +62,7 @@ WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTranspo
           std::visit(rtc::overloaded{[](shared_ptr<TcpTransport> l) { return l->isActive(); },
                                      [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
                      lower)),
-      mMaxOutstandingPings(std::move(maxOutstandingPings)) {
+      mMaxOutstandingPings(maxOutstandingPings) {
 
 	onRecv(std::move(recvCallback));
 
@@ -318,7 +318,7 @@ void WsTransport::recvFrame(const Frame &frame) {
 	}
 	case PONG: {
 		PLOG_DEBUG << "WebSocket received pong";
-		mPingsOutstanding = 0;
+		mOutstandingPings = 0;
 		break;
 	}
 	case CLOSE: {
@@ -374,8 +374,8 @@ bool WsTransport::sendFrame(const Frame &frame) {
 }
 
 void WsTransport::addOutstandingPing() {
-	++mPingsOutstanding;
-	if (mMaxOutstandingPings && *mMaxOutstandingPings > 0 && mPingsOutstanding > *mMaxOutstandingPings) {
+	++mOutstandingPings;
+	if (mMaxOutstandingPings > 0 && mOutstandingPings > mMaxOutstandingPings) {
 		changeState(State::Failed);
 	}
 }

+ 4 - 4
src/impl/wstransport.hpp

@@ -33,8 +33,8 @@ class TlsTransport;
 class WsTransport final : public Transport {
 public:
 	WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
-				shared_ptr<WsHandshake> handshake, message_callback recvCallback,
-				state_callback stateCallback, optional<int> maxOutstandingPings);
+	            shared_ptr<WsHandshake> handshake, int maxOutstandingPings,
+	            message_callback recvCallback, state_callback stateCallback);
 	~WsTransport();
 
 	void start() override;
@@ -75,12 +75,12 @@ private:
 
 	const shared_ptr<WsHandshake> mHandshake;
 	const bool mIsClient;
-	const optional<int> mMaxOutstandingPings;
+	const int mMaxOutstandingPings;
 
 	binary mBuffer;
 	binary mPartial;
 	Opcode mPartialOpcode;
-	int mPingsOutstanding = 0;
+	int mOutstandingPings = 0;
 };
 
 } // namespace rtc::impl