Browse Source

Added WebSocket connection timeout

Paul-Louis Ageneau 2 years ago
parent
commit
203add1058

+ 1 - 0
include/rtc/websocket.hpp

@@ -36,6 +36,7 @@ public:
 		bool disableTlsVerification = false; // if true, don't verify the TLS certificate
 		optional<ProxyServer> proxyServer;   // only non-authenticated http supported for now
 		std::vector<string> protocols;
+		optional<std::chrono::milliseconds> connectionTimeout; // zero to disable
 		optional<std::chrono::milliseconds> pingInterval; // zero to disable
 		optional<int> maxOutstandingPings;
 	};

+ 1 - 0
include/rtc/websocketserver.hpp

@@ -31,6 +31,7 @@ public:
 		optional<string> keyPemFile;
 		optional<string> keyPemPass;
 		optional<string> bindAddress;
+		optional<std::chrono::milliseconds> connectionTimeout;
 	};
 
 	WebSocketServer();

+ 24 - 4
src/impl/websocket.cpp

@@ -32,15 +32,17 @@ namespace rtc::impl {
 
 using namespace std::placeholders;
 using namespace std::chrono_literals;
+using std::chrono::milliseconds;
 
 WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certificate)
     : config(optConfig ? std::move(*optConfig) : Configuration()),
       mCertificate(std::move(certificate)), mIsSecure(mCertificate != nullptr),
       mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
 	PLOG_VERBOSE << "Creating WebSocket";
-	if (config.proxyServer) {		
-		if( config.proxyServer->type == ProxyServer::Type::Socks5)
-			throw std::invalid_argument("Proxy server support for WebSocket is not implemented for Socks5");
+	if (config.proxyServer) {
+		if (config.proxyServer->type == ProxyServer::Type::Socks5)
+			throw std::invalid_argument(
+			    "Proxy server support for WebSocket is not implemented for Socks5");
 		if (config.proxyServer->username || config.proxyServer->password) {
 			PLOG_WARNING << "HTTP authentication support for proxy is not implemented";
 		}
@@ -251,9 +253,11 @@ shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> tra
 
 		// WS transport sends a ping on read timeout
 		auto pingInterval = config.pingInterval.value_or(10000ms);
-		if (pingInterval > std::chrono::milliseconds::zero())
+		if (pingInterval > milliseconds::zero())
 			transport->setReadTimeout(pingInterval);
 
+		scheduleConnectionTimeout();
+
 		return emplaceTransport(this, &mTcpTransport, std::move(transport));
 
 	} catch (const std::exception &e) {
@@ -505,6 +509,22 @@ void WebSocket::closeTransports() {
 	triggerClosed();
 }
 
+void WebSocket::scheduleConnectionTimeout() {
+	auto defaultTimeout = 30s;
+	auto timeout = config.connectionTimeout.value_or(milliseconds(defaultTimeout));
+	if (timeout > milliseconds::zero()) {
+		ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
+			if (auto locked = weak_this.lock()) {
+				if (locked->state == WebSocket::State::Connecting) {
+					PLOG_WARNING << "WebSocket connection timed out";
+					locked->triggerError("Connection timed out");
+					locked->remoteClose();
+				}
+			}
+		});
+	}
+}
+
 } // namespace rtc::impl
 
 #endif

+ 2 - 0
src/impl/websocket.hpp

@@ -67,6 +67,8 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	std::atomic<State> state = State::Closed;
 
 private:
+	void scheduleConnectionTimeout();
+
 	const init_token mInitToken = Init::Instance().token();
 
 	const certificate_ptr mCertificate;

+ 5 - 2
src/impl/websocketserver.cpp

@@ -40,7 +40,7 @@ WebSocketServer::WebSocketServer(Configuration config_)
 			    "Either none or both certificate and key PEM files must be specified");
 		}
 	}
-	
+
 	const char* bindAddress = nullptr;
 	if(config.bindAddress){
 		bindAddress = config.bindAddress->c_str();
@@ -75,7 +75,10 @@ void WebSocketServer::runLoop() {
 				if (!clientCallback)
 					continue;
 
-				auto impl = std::make_shared<WebSocket>(nullopt, mCertificate);
+				WebSocket::Configuration clientConfig;
+				clientConfig.connectionTimeout = config.connectionTimeout;
+
+				auto impl = std::make_shared<WebSocket>(std::move(clientConfig), mCertificate);
 				impl->changeState(WebSocket::State::Connecting);
 				impl->setTcpTransport(incoming);
 				clientCallback(std::make_shared<rtc::WebSocket>(impl));

+ 2 - 0
test/websocket.cpp

@@ -35,6 +35,8 @@ void test_websocket() {
 		ws.send(myMessage);
 	});
 
+	ws.onError([](string error) { cout << "WebSocket: Error: " << error << endl; });
+
 	ws.onClosed([]() { cout << "WebSocket: Closed" << endl; });
 
 	std::atomic<bool> received = false;