Browse Source

Merge pull request #881 from paullouisageneau/ws-connection-timeout

Add WebSocket connection timeout
Paul-Louis Ageneau 2 years ago
parent
commit
bf48620648

+ 6 - 2
DOC.md

@@ -812,7 +812,8 @@ typedef struct {
 	bool disableTlsVerification;
 	const char **protocols;
 	int protocolsCount;
-	int pingInterval;
+	int connectionTimeoutMs;
+	int pingIntervalMs;
 	int maxOutstandingPings;
 } rtcWsConfiguration;
 ```
@@ -826,7 +827,8 @@ Arguments:
   - `disableTlsVerification`: if true, don't verify the TLS certificate, else try to verify it if possible
   - `protocols` (optional): an array of pointers on null-terminated protocol names (NULL if unused)
   - `protocolsCount` (optional): number of URLs in the array pointed by `protocols` (0 if unused)
-  - `pingInterval` (optional): ping interval in milliseconds (0 if default, < 0 if disabled)
+  - `connectionTimeoutMs` (optional): connection timeout in milliseconds (0 if default, < 0 if disabled)
+  - `pingIntervalMs` (optional): ping interval in milliseconds (0 if default, < 0 if disabled)
   - `maxOutstandingPings` (optional): number of unanswered pings before declaring failure (0 if default, < 0 if disabled)
 
 Return value: the identifier of the new WebSocket or a negative error code
@@ -894,6 +896,7 @@ typedef struct {
 	const char *certificatePemFile;
 	const char *keyPemFile;
 	const char *keyPemPass;
+	int connectionTimeoutMs;
 } rtcWsServerConfiguration;
 ```
 
@@ -907,6 +910,7 @@ Arguments:
   - `certificatePemFile` (optional): PEM certificate or path of the file containing the PEM certificate (`NULL` for an autogenerated certificate)
   - `keyPemFile` (optional): PEM key or path of the file containing the PEM key (`NULL` for an autogenerated certificate)
   - `keyPemPass` (optional): PEM key file passphrase (NULL if no passphrase)
+  - `connectionTimeoutMs` (optional): connection timeout in milliseconds (0 if default, < 0 if disabled)
 - `cb`: the callback for incoming client WebSocket connections (must not be `NULL`)
 
 `cb` must have the following signature: `void rtcWebSocketClientCallbackFunc(int wsserver, int ws, void *user_ptr)`

+ 20 - 18
include/rtc/rtc.h

@@ -105,8 +105,8 @@ typedef enum {
 
 	// audio
 	RTC_CODEC_OPUS = 128,
-    RTC_CODEC_PCMU = 129,
-    RTC_CODEC_PCMA = 130
+	RTC_CODEC_PCMU = 129,
+	RTC_CODEC_PCMA = 130
 } rtcCodec;
 
 typedef enum {
@@ -196,7 +196,7 @@ RTC_C_EXPORT int rtcGetLocalAddress(int pc, char *buffer, int size);
 RTC_C_EXPORT int rtcGetRemoteAddress(int pc, char *buffer, int size);
 
 RTC_C_EXPORT int rtcGetSelectedCandidatePair(int pc, char *local, int localSize, char *remote,
-                                           int remoteSize);
+                                             int remoteSize);
 
 RTC_C_EXPORT int rtcGetMaxDataChannelStream(int pc);
 
@@ -242,7 +242,7 @@ typedef struct {
 RTC_C_EXPORT int rtcSetDataChannelCallback(int pc, rtcDataChannelCallbackFunc cb);
 RTC_C_EXPORT int rtcCreateDataChannel(int pc, const char *label); // returns dc id
 RTC_C_EXPORT int rtcCreateDataChannelEx(int pc, const char *label,
-                                      const rtcDataChannelInit *init); // returns dc id
+                                        const rtcDataChannelInit *init); // returns dc id
 RTC_C_EXPORT int rtcDeleteDataChannel(int dc);
 
 RTC_C_EXPORT int rtcGetDataChannelStream(int dc);
@@ -308,7 +308,7 @@ typedef struct {
 // Opaque message
 
 // Opaque type used (via rtcMessage*) to reference an rtc::Message
-typedef void* rtcMessage;
+typedef void *rtcMessage;
 
 // Allocate a new opaque message.
 // Must be explicitly freed by rtcDeleteOpaqueMessage() unless
@@ -377,7 +377,8 @@ typedef struct {
 	const char *proxyServer;     // only non-authenticated http supported for now
 	const char **protocols;
 	int protocolsCount;
-	int pingInterval;        // in milliseconds, 0 means default, < 0 means disabled
+	int connectionTimeoutMs; // in milliseconds, 0 means default, < 0 means disabled
+	int pingIntervalMs;      // in milliseconds, 0 means default, < 0 means disabled
 	int maxOutstandingPings; // 0 means default, < 0 means disabled
 } rtcWsConfiguration;
 
@@ -399,10 +400,11 @@ typedef struct {
 	const char *keyPemFile;         // NULL for autogenerated certificate
 	const char *keyPemPass;         // NULL if no pass
 	const char *bindAddress;        // NULL for IP_ANY_ADDR
+	int connectionTimeoutMs;        // in milliseconds, 0 means default, < 0 means disabled
 } rtcWsServerConfiguration;
 
 RTC_C_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config,
-                                        rtcWebSocketClientCallbackFunc cb); // returns wsserver id
+                                          rtcWebSocketClientCallbackFunc cb); // returns wsserver id
 RTC_C_EXPORT int rtcDeleteWebSocketServer(int wsserver);
 
 RTC_C_EXPORT int rtcGetWebSocketServerPort(int wsserver);
@@ -417,18 +419,18 @@ RTC_C_EXPORT void rtcCleanup(void);
 // SCTP global settings
 
 typedef struct {
-	int recvBufferSize;             // in bytes, <= 0 means optimized default
-	int sendBufferSize;             // in bytes, <= 0 means optimized default
-	int maxChunksOnQueue;           // in chunks, <= 0 means optimized default
-	int initialCongestionWindow;    // in MTUs, <= 0 means optimized default
-	int maxBurst;                   // in MTUs, 0 means optimized default, < 0 means disabled
-	int congestionControlModule;    // 0: RFC2581 (default), 1: HSTCP, 2: H-TCP, 3: RTCC
-	int delayedSackTimeMs;          // in msecs, 0 means optimized default, < 0 means disabled
-	int minRetransmitTimeoutMs;     // in msecs, <= 0 means optimized default
-	int maxRetransmitTimeoutMs;     // in msecs, <= 0 means optimized default
-	int initialRetransmitTimeoutMs; // in msecs, <= 0 means optimized default
+	int recvBufferSize;          // in bytes, <= 0 means optimized default
+	int sendBufferSize;          // in bytes, <= 0 means optimized default
+	int maxChunksOnQueue;        // in chunks, <= 0 means optimized default
+	int initialCongestionWindow; // in MTUs, <= 0 means optimized default
+	int maxBurst;                // in MTUs, 0 means optimized default, < 0 means disabled
+	int congestionControlModule; // 0: RFC2581 (default), 1: HSTCP, 2: H-TCP, 3: RTCC
+	int delayedSackTimeMs;       // in milliseconds, 0 means optimized default, < 0 means disabled
+	int minRetransmitTimeoutMs;  // in milliseconds, <= 0 means optimized default
+	int maxRetransmitTimeoutMs;  // in milliseconds, <= 0 means optimized default
+	int initialRetransmitTimeoutMs; // in milliseconds, <= 0 means optimized default
 	int maxRetransmitAttempts;      // number of retransmissions, <= 0 means optimized default
-	int heartbeatIntervalMs;        // in msecs, <= 0 means optimized default
+	int heartbeatIntervalMs;        // in milliseconds, <= 0 means optimized default
 } rtcSctpSettings;
 
 // Note: SCTP settings apply to newly-created PeerConnections only

+ 1 - 0
include/rtc/websocket.hpp

@@ -36,6 +36,7 @@ public:
 		bool disableTlsVerification = false; // if true, don't verify the TLS certificate
 		optional<ProxyServer> proxyServer;   // only non-authenticated http supported for now
 		std::vector<string> protocols;
+		optional<std::chrono::milliseconds> connectionTimeout; // zero to disable
 		optional<std::chrono::milliseconds> pingInterval; // zero to disable
 		optional<int> maxOutstandingPings;
 	};

+ 1 - 0
include/rtc/websocketserver.hpp

@@ -31,6 +31,7 @@ public:
 		optional<string> keyPemFile;
 		optional<string> keyPemPass;
 		optional<string> bindAddress;
+		optional<std::chrono::milliseconds> connectionTimeout;
 	};
 
 	WebSocketServer();

+ 6 - 2
pages/content/pages/reference.md

@@ -815,7 +815,8 @@ typedef struct {
 	bool disableTlsVerification;
 	const char **protocols;
 	int protocolsCount;
-	int pingInterval;
+	int connectionTimeoutMs;
+	int pingIntervalMs;
 	int maxOutstandingPings;
 } rtcWsConfiguration;
 ```
@@ -829,7 +830,8 @@ Arguments:
   - `disableTlsVerification`: if true, don't verify the TLS certificate, else try to verify it if possible
   - `protocols` (optional): an array of pointers on null-terminated protocol names (NULL if unused)
   - `protocolsCount` (optional): number of URLs in the array pointed by `protocols` (0 if unused)
-  - `pingInterval` (optional): ping interval in milliseconds (0 if default, < 0 if disabled)
+  - `connectionTimeoutMs` (optional): connection timeout in milliseconds (0 if default, < 0 if disabled)
+  - `pingIntervalMs` (optional): ping interval in milliseconds (0 if default, < 0 if disabled)
   - `maxOutstandingPings` (optional): number of unanswered pings before declaring failure (0 if default, < 0 if disabled)
 
 Return value: the identifier of the new WebSocket or a negative error code
@@ -897,6 +899,7 @@ typedef struct {
 	const char *certificatePemFile;
 	const char *keyPemFile;
 	const char *keyPemPass;
+	int connectionTimeoutMs;
 } rtcWsServerConfiguration;
 ```
 
@@ -910,6 +913,7 @@ Arguments:
   - `certificatePemFile` (optional): PEM certificate or path of the file containing the PEM certificate (`NULL` for an autogenerated certificate)
   - `keyPemFile` (optional): PEM key or path of the file containing the PEM key (`NULL` for an autogenerated certificate)
   - `keyPemPass` (optional): PEM key file passphrase (NULL if no passphrase)
+  - `connectionTimeoutMs` (optional): connection timeout in milliseconds (0 if default, < 0 if disabled)
 - `cb`: the callback for incoming client WebSocket connections (must not be `NULL`)
 
 `cb` must have the following signature: `void rtcWebSocketClientCallbackFunc(int wsserver, int ws, void *user_ptr)`

+ 27 - 24
src/capi.cpp

@@ -1010,8 +1010,8 @@ int rtcAddTrackEx(int pc, const rtcTrackInit *init) {
 				mid = "video";
 				break;
 			case RTC_CODEC_OPUS:
-            case RTC_CODEC_PCMU:
-            case RTC_CODEC_PCMA:
+			case RTC_CODEC_PCMU:
+			case RTC_CODEC_PCMA:
 				mid = "audio";
 				break;
 			default:
@@ -1044,19 +1044,19 @@ int rtcAddTrackEx(int pc, const rtcTrackInit *init) {
 			break;
 		}
 		case RTC_CODEC_OPUS:
-        case RTC_CODEC_PCMU:
-        case RTC_CODEC_PCMA:{
+		case RTC_CODEC_PCMU:
+		case RTC_CODEC_PCMA: {
 			auto desc = Description::Audio(mid, direction);
 			switch (init->codec) {
 			case RTC_CODEC_OPUS:
 				desc.addOpusCodec(init->payloadType);
 				break;
-            case RTC_CODEC_PCMU:
-                desc.addPCMUCodec(init->payloadType);
-                break;
-            case RTC_CODEC_PCMA:
-                desc.addPCMACodec(init->payloadType);
-                break;
+			case RTC_CODEC_PCMU:
+				desc.addPCMUCodec(init->payloadType);
+				break;
+			case RTC_CODEC_PCMA:
+				desc.addPCMACodec(init->payloadType);
+				break;
 			default:
 				break;
 			}
@@ -1386,12 +1386,16 @@ int rtcCreateWebSocketEx(const char *url, const rtcWsConfiguration *config) {
 		for (int i = 0; i < config->protocolsCount; ++i)
 			c.protocols.emplace_back(string(config->protocols[i]));
 
-		if (config->pingInterval > 0)
-			c.pingInterval = std::chrono::milliseconds(config->pingInterval);
-		else if (config->pingInterval < 0)
-			c.pingInterval = std::chrono::milliseconds::zero(); // setting to 0 disables,
-			                                                    // not setting keeps default
-
+		if (config->connectionTimeoutMs > 0)
+			c.connectionTimeout = milliseconds(config->connectionTimeoutMs);
+		else if (config->connectionTimeoutMs < 0)
+			c.connectionTimeout = milliseconds::zero(); // setting to 0 disables,
+			                                            // not setting keeps default
+		if (config->pingIntervalMs > 0)
+			c.pingInterval = milliseconds(config->pingIntervalMs);
+		else if (config->pingIntervalMs < 0)
+			c.pingInterval = milliseconds::zero(); // setting to 0 disables,
+			                                       // not setting keeps default
 		if (config->maxOutstandingPings > 0)
 			c.maxOutstandingPings = config->maxOutstandingPings;
 		else if (config->maxOutstandingPings < 0)
@@ -1434,7 +1438,7 @@ int rtcGetWebSocketPath(int ws, char *buffer, int size) {
 }
 
 RTC_C_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config,
-                                        rtcWebSocketClientCallbackFunc cb) {
+                                          rtcWebSocketClientCallbackFunc cb) {
 	return wrap([&] {
 		if (!config)
 			throw std::invalid_argument("Unexpected null pointer for config");
@@ -1534,25 +1538,24 @@ int rtcSetSctpSettings(const rtcSctpSettings *settings) {
 			s.congestionControlModule = unsigned(settings->congestionControlModule);
 
 		if (settings->delayedSackTimeMs > 0)
-			s.delayedSackTime = std::chrono::milliseconds(settings->delayedSackTimeMs);
+			s.delayedSackTime = milliseconds(settings->delayedSackTimeMs);
 		else if (settings->delayedSackTimeMs < 0)
-			s.delayedSackTime = std::chrono::milliseconds(0);
+			s.delayedSackTime = milliseconds(0);
 
 		if (settings->minRetransmitTimeoutMs > 0)
-			s.minRetransmitTimeout = std::chrono::milliseconds(settings->minRetransmitTimeoutMs);
+			s.minRetransmitTimeout = milliseconds(settings->minRetransmitTimeoutMs);
 
 		if (settings->maxRetransmitTimeoutMs > 0)
-			s.maxRetransmitTimeout = std::chrono::milliseconds(settings->maxRetransmitTimeoutMs);
+			s.maxRetransmitTimeout = milliseconds(settings->maxRetransmitTimeoutMs);
 
 		if (settings->initialRetransmitTimeoutMs > 0)
-			s.initialRetransmitTimeout =
-			    std::chrono::milliseconds(settings->initialRetransmitTimeoutMs);
+			s.initialRetransmitTimeout = milliseconds(settings->initialRetransmitTimeoutMs);
 
 		if (settings->maxRetransmitAttempts > 0)
 			s.maxRetransmitAttempts = settings->maxRetransmitAttempts;
 
 		if (settings->heartbeatIntervalMs > 0)
-			s.heartbeatInterval = std::chrono::milliseconds(settings->heartbeatIntervalMs);
+			s.heartbeatInterval = milliseconds(settings->heartbeatIntervalMs);
 
 		SetSctpSettings(std::move(s));
 		return RTC_ERR_SUCCESS;

+ 24 - 4
src/impl/websocket.cpp

@@ -32,15 +32,17 @@ namespace rtc::impl {
 
 using namespace std::placeholders;
 using namespace std::chrono_literals;
+using std::chrono::milliseconds;
 
 WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certificate)
     : config(optConfig ? std::move(*optConfig) : Configuration()),
       mCertificate(std::move(certificate)), mIsSecure(mCertificate != nullptr),
       mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
 	PLOG_VERBOSE << "Creating WebSocket";
-	if (config.proxyServer) {		
-		if( config.proxyServer->type == ProxyServer::Type::Socks5)
-			throw std::invalid_argument("Proxy server support for WebSocket is not implemented for Socks5");
+	if (config.proxyServer) {
+		if (config.proxyServer->type == ProxyServer::Type::Socks5)
+			throw std::invalid_argument(
+			    "Proxy server support for WebSocket is not implemented for Socks5");
 		if (config.proxyServer->username || config.proxyServer->password) {
 			PLOG_WARNING << "HTTP authentication support for proxy is not implemented";
 		}
@@ -251,9 +253,11 @@ shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> tra
 
 		// WS transport sends a ping on read timeout
 		auto pingInterval = config.pingInterval.value_or(10000ms);
-		if (pingInterval > std::chrono::milliseconds::zero())
+		if (pingInterval > milliseconds::zero())
 			transport->setReadTimeout(pingInterval);
 
+		scheduleConnectionTimeout();
+
 		return emplaceTransport(this, &mTcpTransport, std::move(transport));
 
 	} catch (const std::exception &e) {
@@ -505,6 +509,22 @@ void WebSocket::closeTransports() {
 	triggerClosed();
 }
 
+void WebSocket::scheduleConnectionTimeout() {
+	auto defaultTimeout = 30s;
+	auto timeout = config.connectionTimeout.value_or(milliseconds(defaultTimeout));
+	if (timeout > milliseconds::zero()) {
+		ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
+			if (auto locked = weak_this.lock()) {
+				if (locked->state == WebSocket::State::Connecting) {
+					PLOG_WARNING << "WebSocket connection timed out";
+					locked->triggerError("Connection timed out");
+					locked->remoteClose();
+				}
+			}
+		});
+	}
+}
+
 } // namespace rtc::impl
 
 #endif

+ 2 - 0
src/impl/websocket.hpp

@@ -67,6 +67,8 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	std::atomic<State> state = State::Closed;
 
 private:
+	void scheduleConnectionTimeout();
+
 	const init_token mInitToken = Init::Instance().token();
 
 	const certificate_ptr mCertificate;

+ 5 - 2
src/impl/websocketserver.cpp

@@ -40,7 +40,7 @@ WebSocketServer::WebSocketServer(Configuration config_)
 			    "Either none or both certificate and key PEM files must be specified");
 		}
 	}
-	
+
 	const char* bindAddress = nullptr;
 	if(config.bindAddress){
 		bindAddress = config.bindAddress->c_str();
@@ -75,7 +75,10 @@ void WebSocketServer::runLoop() {
 				if (!clientCallback)
 					continue;
 
-				auto impl = std::make_shared<WebSocket>(nullopt, mCertificate);
+				WebSocket::Configuration clientConfig;
+				clientConfig.connectionTimeout = config.connectionTimeout;
+
+				auto impl = std::make_shared<WebSocket>(std::move(clientConfig), mCertificate);
 				impl->changeState(WebSocket::State::Connecting);
 				impl->setTcpTransport(incoming);
 				clientCallback(std::make_shared<rtc::WebSocket>(impl));

+ 2 - 0
test/websocket.cpp

@@ -35,6 +35,8 @@ void test_websocket() {
 		ws.send(myMessage);
 	});
 
+	ws.onError([](string error) { cout << "WebSocket: Error: " << error << endl; });
+
 	ws.onClosed([]() { cout << "WebSocket: Closed" << endl; });
 
 	std::atomic<bool> received = false;