Jelajahi Sumber

Implemented proper teardown waiting for remote close

Paul-Louis Ageneau 3 tahun lalu
induk
melakukan
cfccb10008

+ 19 - 16
src/impl/dtlstransport.cpp

@@ -50,7 +50,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
                              state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate),
       mVerifierCallback(std::move(verifierCallback)),
-      mIsClient(lower->role() == Description::Role::Active), mCurrentDscp(0) {
+      mIsClient(lower->role() == Description::Role::Active) {
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
@@ -99,26 +99,27 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 DtlsTransport::~DtlsTransport() {
 	stop();
 
+	PLOG_DEBUG << "Destroying DTLS transport";
 	gnutls_deinit(mSession);
 }
 
 void DtlsTransport::start() {
-	Transport::start();
-
-	registerIncoming();
+	if(mStarted.exchange(true))
+		return;
 
 	PLOG_DEBUG << "Starting DTLS recv thread";
+	registerIncoming();
 	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
 }
 
-bool DtlsTransport::stop() {
-	if (!Transport::stop())
-		return false;
+void DtlsTransport::stop() {
+	if(!mStarted.exchange(false))
+		return;
 
 	PLOG_DEBUG << "Stopping DTLS recv thread";
+	unregisterIncoming();
 	mIncomingQueue.stop();
 	mRecvThread.join();
-	return true;
 }
 
 bool DtlsTransport::send(message_ptr message) {
@@ -127,6 +128,7 @@ bool DtlsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
+
 	ssize_t ret;
 	do {
 		std::lock_guard lock(mSendMutex);
@@ -385,7 +387,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
                              state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate),
       mVerifierCallback(std::move(verifierCallback)),
-      mIsClient(lower->role() == Description::Role::Active), mCurrentDscp(0) {
+      mIsClient(lower->role() == Description::Role::Active) {
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
 	if (!mCertificate)
@@ -466,28 +468,29 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 DtlsTransport::~DtlsTransport() {
 	stop();
 
+	PLOG_DEBUG << "Destroying DTLS transport";
 	SSL_free(mSsl);
 	SSL_CTX_free(mCtx);
 }
 
 void DtlsTransport::start() {
-	Transport::start();
-
-	registerIncoming();
+	if(mStarted.exchange(true))
+		return;
 
 	PLOG_DEBUG << "Starting DTLS recv thread";
+	registerIncoming();
 	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
 }
 
-bool DtlsTransport::stop() {
-	if (!Transport::stop())
-		return false;
+void DtlsTransport::stop() {
+	if(!mStarted.exchange(false))
+		return;
 
 	PLOG_DEBUG << "Stopping DTLS recv thread";
+	unregisterIncoming();
 	mIncomingQueue.stop();
 	mRecvThread.join();
 	SSL_shutdown(mSsl);
-	return true;
 }
 
 bool DtlsTransport::send(message_ptr message) {

+ 4 - 2
src/impl/dtlstransport.hpp

@@ -47,7 +47,7 @@ public:
 	~DtlsTransport();
 
 	virtual void start() override;
-	virtual bool stop() override;
+	virtual void stop() override;
 	virtual bool send(message_ptr message) override; // false if dropped
 
 	bool isClient() const { return mIsClient; }
@@ -57,6 +57,7 @@ protected:
 	virtual bool outgoing(message_ptr message) override;
 	virtual bool demuxMessage(message_ptr message);
 	virtual void postHandshake();
+
 	void runRecvLoop();
 
 	const optional<size_t> mMtu;
@@ -66,7 +67,8 @@ protected:
 
 	Queue<message_ptr> mIncomingQueue;
 	std::thread mRecvThread;
-	std::atomic<unsigned int> mCurrentDscp;
+	std::atomic<bool> mStarted = false;
+	std::atomic<unsigned int> mCurrentDscp = 0;
 	std::atomic<bool> mOutgoingResult = true;
 
 #if USE_GNUTLS

+ 4 - 11
src/impl/icetransport.cpp

@@ -158,12 +158,10 @@ IceTransport::IceTransport(const Configuration &config, candidate_callback candi
 }
 
 IceTransport::~IceTransport() {
-	stop();
+	PLOG_DEBUG << "Destroying ICE transport";
 	mAgent.reset();
 }
 
-bool IceTransport::stop() { return Transport::stop(); }
-
 Description::Role IceTransport::role() const { return mRole; }
 
 Description IceTransport::getLocalDescription(Description::Type type) const {
@@ -572,24 +570,19 @@ IceTransport::IceTransport(const Configuration &config, candidate_callback candi
 	                       RecvCallback, this);
 }
 
-IceTransport::~IceTransport() { stop(); }
-
-bool IceTransport::stop() {
+IceTransport::~IceTransport() {
 	if (mTimeoutId) {
 		g_source_remove(mTimeoutId);
 		mTimeoutId = 0;
 	}
 
-	if (!Transport::stop())
-		return false;
-
-	PLOG_DEBUG << "Stopping ICE thread";
+	PLOG_DEBUG << "Destroying ICE transport";
 	nice_agent_attach_recv(mNiceAgent.get(), mStreamId, 1, g_main_loop_get_context(mMainLoop.get()),
 	                       NULL, NULL);
 	nice_agent_remove_stream(mNiceAgent.get(), mStreamId);
 	g_main_loop_quit(mMainLoop.get());
 	mMainLoopThread.join();
-	return true;
+	mNiceAgent.reset();
 }
 
 Description::Role IceTransport::role() const { return mRole; }

+ 0 - 1
src/impl/icetransport.hpp

@@ -62,7 +62,6 @@ public:
 	optional<string> getLocalAddress() const;
 	optional<string> getRemoteAddress() const;
 
-	bool stop() override;
 	bool send(message_ptr message) override; // false if dropped
 
 	bool getSelectedCandidatePair(Candidate *local, Candidate *remote);

+ 33 - 26
src/impl/peerconnection.cpp

@@ -79,15 +79,23 @@ PeerConnection::~PeerConnection() {
 }
 
 void PeerConnection::close() {
-	PLOG_VERBOSE << "Closing PeerConnection";
-
 	negotiationNeeded = false;
+	if (!closing.exchange(true)) {
+		PLOG_VERBOSE << "Closing PeerConnection";
+		if (auto transport = std::atomic_load(&mSctpTransport))
+			transport->stop();
+		else
+			remoteClose();
+	}
+}
 
-	// Close data channels and tracks asynchronously
-	mProcessor.enqueue(&PeerConnection::closeDataChannels, shared_from_this());
-	mProcessor.enqueue(&PeerConnection::closeTracks, shared_from_this());
-
-	closeTransports();
+void PeerConnection::remoteClose() {
+	close();
+	if (state.load() != State::Closed) {
+		closeDataChannels();
+		closeTracks();
+		closeTransports();
+	}
 }
 
 optional<Description> PeerConnection::localDescription() const {
@@ -125,11 +133,10 @@ shared_ptr<T> emplaceTransport(PeerConnection *pc, shared_ptr<T> *member, shared
 		transport->start();
 	} catch (...) {
 		std::atomic_store(member, decltype(transport)(nullptr));
-		transport->stop();
 		throw;
 	}
 
-	if (pc->state.load() == PeerConnection::State::Closed) {
+	if (pc->closing.load() || pc->state.load() == PeerConnection::State::Closed) {
 		std::atomic_store(member, decltype(transport)(nullptr));
 		transport->stop();
 		return nullptr;
@@ -226,11 +233,11 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 				    break;
 			    case DtlsTransport::State::Failed:
 				    changeState(State::Failed);
-				    mProcessor.enqueue(&PeerConnection::closeTracks, shared_from_this());
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    case DtlsTransport::State::Disconnected:
 				    changeState(State::Disconnected);
-				    mProcessor.enqueue(&PeerConnection::closeTracks, shared_from_this());
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    default:
 				    // Ignore
@@ -299,6 +306,7 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 			    auto shared_this = weak_this.lock();
 			    if (!shared_this)
 				    return;
+
 			    switch (transportState) {
 			    case SctpTransport::State::Connected:
 				    changeState(State::Connected);
@@ -308,13 +316,12 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 			    case SctpTransport::State::Failed:
 				    LOG_WARNING << "SCTP transport failed";
 				    changeState(State::Failed);
-				    mProcessor.enqueue(&PeerConnection::remoteCloseDataChannels,
-				                       shared_from_this());
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    case SctpTransport::State::Disconnected:
+				    LOG_INFO << "SCTP transport disconnected";
 				    changeState(State::Disconnected);
-				    mProcessor.enqueue(&PeerConnection::remoteCloseDataChannels,
-				                       shared_from_this());
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    default:
 				    // Ignore
@@ -370,18 +377,18 @@ void PeerConnection::closeTransports() {
 		if (t)
 			t->onStateChange(nullptr);
 
-	// Initiate transport stop on the processor after closing the data channels
-	mProcessor.enqueue([self = shared_from_this(), transports = std::move(transports)]() {
-		TearDownProcessor::Instance().enqueue(
-		    [transports = std::move(transports), token = Init::Instance().token()]() mutable {
-			    for (const auto &t : transports)
-				    if (t)
-					    t->stop();
+	TearDownProcessor::Instance().enqueue(
+	    [transports = std::move(transports), token = Init::Instance().token()]() mutable {
+		    for (const auto &t : transports) {
+			    if (t) {
+				    t->stop();
+				    break;
+			    }
+		    }
 
-			    for (auto &t : transports)
-				    t.reset();
-		    });
-	});
+		    for (auto &t : transports)
+			    t.reset();
+	    });
 }
 
 void PeerConnection::endLocalCandidates() {

+ 2 - 0
src/impl/peerconnection.hpp

@@ -46,6 +46,7 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	~PeerConnection();
 
 	void close();
+	void remoteClose();
 
 	optional<Description> localDescription() const;
 	optional<Description> remoteDescription() const;
@@ -113,6 +114,7 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	std::atomic<GatheringState> gatheringState = GatheringState::New;
 	std::atomic<SignalingState> signalingState = SignalingState::Stable;
 	std::atomic<bool> negotiationNeeded = false;
+	std::atomic<bool> closing = false;
 	std::mutex signalingMutex;
 
 	synchronized_callback<shared_ptr<rtc::DataChannel>> dataChannelCallback;

+ 28 - 37
src/impl/sctptransport.cpp

@@ -327,8 +327,20 @@ SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &c
 }
 
 SctpTransport::~SctpTransport() {
-	stop();
-	close();
+	PLOG_DEBUG << "Destroying SCTP transport";
+
+	// Before unregistering incoming() from the lower layer, we need to make sure the thread from
+	// lower layers is not blocked in incoming() by the WrittenOnce condition.
+	mWrittenOnce = true;
+	mWrittenCondition.notify_all();
+
+	unregisterIncoming();
+
+	mProcessor.join();
+	usrsctp_close(mSock);
+
+	usrsctp_deregister_address(this);
+	Instances->erase(this);
 }
 
 void SctpTransport::onBufferedAmount(amount_callback callback) {
@@ -336,26 +348,11 @@ void SctpTransport::onBufferedAmount(amount_callback callback) {
 }
 
 void SctpTransport::start() {
-	Transport::start();
-
 	registerIncoming();
 	connect();
 }
 
-bool SctpTransport::stop() {
-	// Transport::stop() will unregister incoming() from the lower layer, therefore we need to make
-	// sure the thread from lower layers is not blocked in incoming() by the WrittenOnce condition.
-	mWrittenOnce = true;
-	mWrittenCondition.notify_all();
-
-	if (!Transport::stop())
-		return false;
-
-	mSendQueue.stop();
-	flush();
-	shutdown();
-	return true;
-}
+void SctpTransport::stop() { close(); }
 
 struct sockaddr_conn SctpTransport::getSockAddrConn(uint16_t port) {
 	struct sockaddr_conn sconn = {};
@@ -398,24 +395,6 @@ void SctpTransport::shutdown() {
 	if (usrsctp_shutdown(mSock, SHUT_RDWR) != 0 && errno != ENOTCONN) {
 		PLOG_WARNING << "SCTP shutdown failed, errno=" << errno;
 	}
-
-	close();
-
-	PLOG_INFO << "SCTP disconnected";
-	changeState(State::Disconnected);
-	mWrittenCondition.notify_all();
-}
-
-void SctpTransport::close() {
-	if (!mSock)
-		return;
-
-	mProcessor.join();
-	usrsctp_close(mSock);
-	mSock = nullptr;
-
-	usrsctp_deregister_address(this);
-	Instances->erase(this);
 }
 
 bool SctpTransport::send(message_ptr message) {
@@ -459,6 +438,11 @@ void SctpTransport::closeStream(unsigned int stream) {
 	mProcessor.enqueue(&SctpTransport::flush, shared_from_this());
 }
 
+void SctpTransport::close() {
+	mSendQueue.stop();
+	mProcessor.enqueue(&SctpTransport::flush, shared_from_this());
+}
+
 unsigned int SctpTransport::maxStream() const {
 	unsigned int streamsCount = mNegotiatedStreamsCount.value_or(MAX_SCTP_STREAMS_COUNT);
 	return streamsCount > 0 ? streamsCount - 1 : 0;
@@ -511,6 +495,8 @@ void SctpTransport::doRecv() {
 					break;
 				else
 					throw std::runtime_error("SCTP recv failed, errno=" + std::to_string(errno));
+			} else if (len == 0) {
+				break;
 			}
 
 			PLOG_VERBOSE << "SCTP recv, len=" << len;
@@ -566,6 +552,12 @@ bool SctpTransport::trySendQueue() {
 		mSendQueue.pop();
 		updateBufferedAmount(to_uint16(message->stream), -ptrdiff_t(message_size_func(message)));
 	}
+
+	if (!mSendQueue.running()) {
+		shutdown();
+		return false;
+	}
+
 	return true;
 }
 
@@ -918,7 +910,6 @@ optional<milliseconds> SctpTransport::rtt() {
 	socklen_t len = sizeof(status);
 	if (usrsctp_getsockopt(mSock, IPPROTO_SCTP, SCTP_STATUS, &status, &len)) {
 		COUNTER_BAD_SCTP_STATUS++;
-
 		return nullopt;
 	}
 	return milliseconds(status.sstat_primary.spinfo_srtt);

+ 2 - 2
src/impl/sctptransport.hpp

@@ -56,10 +56,11 @@ public:
 	void onBufferedAmount(amount_callback callback);
 
 	void start() override;
-	bool stop() override;
+	void stop() override;
 	bool send(message_ptr message) override; // false if buffered
 	bool flush();
 	void closeStream(unsigned int stream);
+	void close();
 
 	unsigned int maxStream() const;
 
@@ -86,7 +87,6 @@ private:
 
 	void connect();
 	void shutdown();
-	void close();
 	void incoming(message_ptr message) override;
 	bool outgoing(message_ptr message) override;
 

+ 0 - 11
src/impl/tcptransport.cpp

@@ -69,7 +69,6 @@ TcpTransport::TcpTransport(socket_t sock, state_callback callback)
 }
 
 TcpTransport::~TcpTransport() {
-	stop();
 	close();
 }
 
@@ -82,8 +81,6 @@ void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) {
 }
 
 void TcpTransport::start() {
-	Transport::start();
-
 	if (mSock == INVALID_SOCKET) {
 		connect();
 	} else {
@@ -92,14 +89,6 @@ void TcpTransport::start() {
 	}
 }
 
-bool TcpTransport::stop() {
-	if (!Transport::stop())
-		return false;
-
-	close();
-	return true;
-}
-
 bool TcpTransport::send(message_ptr message) {
 	std::lock_guard lock(mSendMutex);
 

+ 0 - 1
src/impl/tcptransport.hpp

@@ -46,7 +46,6 @@ public:
 	void setReadTimeout(std::chrono::milliseconds readTimeout);
 
 	void start() override;
-	bool stop() override;
 	bool send(message_ptr message) override;
 
 	void incoming(message_ptr message) override;

+ 14 - 16
src/impl/tlstransport.cpp

@@ -95,27 +95,26 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host
 
 TlsTransport::~TlsTransport() {
 	stop();
-
 	gnutls_deinit(mSession);
 }
 
 void TlsTransport::start() {
-	Transport::start();
-
-	registerIncoming();
+	if (mStarted.exchange(true))
+		return;
 
 	PLOG_DEBUG << "Starting TLS recv thread";
+	registerIncoming();
 	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
 }
 
-bool TlsTransport::stop() {
-	if (!Transport::stop())
-		return false;
+void TlsTransport::stop() {
+	if (!mStarted.exchange(false))
+		return;
 
 	PLOG_DEBUG << "Stopping TLS recv thread";
+	unregisterIncoming();
 	mIncomingQueue.stop();
 	mRecvThread.join();
-	return true;
 }
 
 bool TlsTransport::send(message_ptr message) {
@@ -375,29 +374,28 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host
 
 TlsTransport::~TlsTransport() {
 	stop();
-
 	SSL_free(mSsl);
 	SSL_CTX_free(mCtx);
 }
 
 void TlsTransport::start() {
-	Transport::start();
-
-	registerIncoming();
+	if (mStarted.exchange(true))
+		return;
 
 	PLOG_DEBUG << "Starting TLS recv thread";
+	registerIncoming();
 	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
 }
 
-bool TlsTransport::stop() {
-	if (!Transport::stop())
-		return false;
+void TlsTransport::stop() {
+	if (!mStarted.exchange(false))
+		return;
 
 	PLOG_DEBUG << "Stopping TLS recv thread";
+	unregisterIncoming();
 	mIncomingQueue.stop();
 	mRecvThread.join();
 	SSL_shutdown(mSsl);
-	return true;
 }
 
 bool TlsTransport::send(message_ptr message) {

+ 3 - 1
src/impl/tlstransport.hpp

@@ -44,7 +44,7 @@ public:
 	virtual ~TlsTransport();
 
 	void start() override;
-	bool stop() override;
+	void stop() override;
 	bool send(message_ptr message) override;
 
 	bool isClient() const { return mIsClient; }
@@ -53,6 +53,7 @@ protected:
 	virtual void incoming(message_ptr message) override;
 	virtual bool outgoing(message_ptr message) override;
 	virtual void postHandshake();
+
 	void runRecvLoop();
 
 	const optional<string> mHost;
@@ -60,6 +61,7 @@ protected:
 
 	Queue<message_ptr> mIncomingQueue;
 	std::thread mRecvThread;
+	std::atomic<bool> mStarted = false;
 
 #if USE_GNUTLS
 	gnutls_session_t mSession;

+ 15 - 12
src/impl/transport.cpp

@@ -23,20 +23,13 @@ namespace rtc::impl {
 Transport::Transport(shared_ptr<Transport> lower, state_callback callback)
     : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {}
 
-Transport::~Transport() { stop(); }
+Transport::~Transport() {
+	unregisterIncoming();
 
-void Transport::start() { mStopped = false; }
-
-bool Transport::stop() {
-	if (mStopped.exchange(true))
-		return false;
-
-	// We don't want incoming() to be called by the lower layer anymore
 	if (mLower) {
-		PLOG_VERBOSE << "Unregistering incoming callback";
-		mLower->onRecv(nullptr);
+		mLower->stop();
+		mLower.reset();
 	}
-	return true;
 }
 
 void Transport::registerIncoming() {
@@ -46,6 +39,13 @@ void Transport::registerIncoming() {
 	}
 }
 
+void Transport::unregisterIncoming() {
+	if (mLower) {
+		PLOG_VERBOSE << "Unregistering incoming callback";
+		mLower->onRecv(nullptr);
+	}
+}
+
 Transport::State Transport::state() const { return mState; }
 
 void Transport::onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
@@ -54,6 +54,10 @@ void Transport::onStateChange(state_callback callback) {
 	mStateChangeCallback = std::move(callback);
 }
 
+void Transport::start() { registerIncoming(); }
+
+void Transport::stop() { unregisterIncoming(); }
+
 bool Transport::send(message_ptr message) { return outgoing(message); }
 
 void Transport::recv(message_ptr message) {
@@ -83,4 +87,3 @@ bool Transport::outgoing(message_ptr message) {
 }
 
 } // namespace rtc::impl
-

+ 7 - 5
src/impl/transport.hpp

@@ -20,6 +20,7 @@
 #define RTC_IMPL_TRANSPORT_H
 
 #include "common.hpp"
+#include "init.hpp"
 #include "internals.hpp"
 #include "message.hpp"
 
@@ -37,15 +38,15 @@ public:
 	Transport(shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr);
 	virtual ~Transport();
 
-	virtual void start();
-	virtual bool stop();
-
 	void registerIncoming();
+	void unregisterIncoming();
 	State state() const;
 
 	void onRecv(message_callback callback);
 	void onStateChange(state_callback callback);
 
+	virtual void start();
+	virtual void stop();
 	virtual bool send(message_ptr message);
 
 protected:
@@ -55,12 +56,13 @@ protected:
 	virtual bool outgoing(message_ptr message);
 
 private:
-	const shared_ptr<Transport> mLower;
+	const init_token mInitToken = Init::Instance().token();
+
+	shared_ptr<Transport> mLower;
 	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<message_ptr> mRecvCallback;
 
 	std::atomic<State> mState = State::Disconnected;
-	std::atomic<bool> mStopped = true;
 };
 
 } // namespace rtc::impl

+ 1 - 1
src/impl/verifiedtlstransport.cpp

@@ -37,7 +37,7 @@ VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, strin
 #endif
 }
 
-VerifiedTlsTransport::~VerifiedTlsTransport() {}
+VerifiedTlsTransport::~VerifiedTlsTransport() { stop(); }
 
 } // namespace rtc::impl
 

+ 8 - 6
src/impl/websocket.cpp

@@ -124,17 +124,16 @@ void WebSocket::close() {
 		PLOG_VERBOSE << "Closing WebSocket";
 		changeState(State::Closing);
 		if (auto transport = std::atomic_load(&mWsTransport))
-			transport->close();
+			transport->stop();
 		else
 			remoteClose();
 	}
 }
 
 void WebSocket::remoteClose() {
-	if (state != State::Closed) {
-		close();
+	close();
+	if (state.load() != State::Closed)
 		closeTransports();
-	}
 }
 
 bool WebSocket::isOpen() const { return state == State::Open; }
@@ -424,9 +423,12 @@ void WebSocket::closeTransports() {
 
 	TearDownProcessor::Instance().enqueue(
 	    [transports = std::move(transports), token = Init::Instance().token()]() mutable {
-		    for (const auto &t : transports)
-			    if (t)
+		    for (const auto &t : transports) {
+			    if (t) {
 				    t->stop();
+				    break;
+			    }
+		    }
 
 		    for (auto &t : transports)
 			    t.reset();

+ 1 - 1
src/impl/websocket.hpp

@@ -46,6 +46,7 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 
 	void open(const string &url);
 	void close();
+	void remoteClose();
 	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
 
@@ -58,7 +59,6 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	size_t maxMessageSize() const;
 
 	bool changeState(State state);
-	void remoteClose();
 
 	shared_ptr<TcpTransport> setTcpTransport(shared_ptr<TcpTransport> transport);
 	shared_ptr<TlsTransport> initTlsTransport();

+ 26 - 23
src/impl/wstransport.cpp

@@ -18,6 +18,7 @@
 
 #include "wstransport.hpp"
 #include "tcptransport.hpp"
+#include "threadpool.hpp"
 #include "tlstransport.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
@@ -69,11 +70,9 @@ WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTranspo
 	PLOG_DEBUG << "Initializing WebSocket transport";
 }
 
-WsTransport::~WsTransport() { stop(); }
+WsTransport::~WsTransport() { unregisterIncoming(); }
 
 void WsTransport::start() {
-	Transport::start();
-
 	registerIncoming();
 
 	changeState(State::Connecting);
@@ -81,13 +80,7 @@ void WsTransport::start() {
 		sendHttpRequest();
 }
 
-bool WsTransport::stop() {
-	if (!Transport::stop())
-		return false;
-
-	close();
-	return true;
-}
+void WsTransport::stop() { close(); }
 
 bool WsTransport::send(message_ptr message) {
 	if (!message || state() != State::Connected)
@@ -98,6 +91,29 @@ bool WsTransport::send(message_ptr message) {
 	                  message->size(), true, mIsClient});
 }
 
+void WsTransport::close() {
+	if (state() != State::Connected)
+		return;
+
+	PLOG_INFO << "WebSocket closing";
+	try {
+		sendFrame({CLOSE, NULL, 0, true, mIsClient});
+	} catch (const std::exception &e) {
+		// The connection might not be open anymore
+		PLOG_DEBUG << "Unable to send WebSocket close frame: " << e.what();
+		changeState(State::Disconnected);
+		return;
+	}
+
+	ThreadPool::Instance().schedule(std::chrono::milliseconds(10),
+	                                [this, weak_this = weak_from_this()]() {
+		                                if (auto shared_this = weak_this.lock()) {
+			                                PLOG_DEBUG << "WebSocket close timeout";
+			                                changeState(State::Disconnected);
+		                                }
+	                                });
+}
+
 void WsTransport::incoming(message_ptr message) {
 	auto s = state();
 	if (s != State::Connecting && s != State::Connected)
@@ -172,19 +188,6 @@ void WsTransport::incoming(message_ptr message) {
 	}
 }
 
-void WsTransport::close() {
-	if (state() == State::Connected) {
-		PLOG_INFO << "WebSocket closing";
-		try {
-			sendFrame({CLOSE, NULL, 0, true, mIsClient});
-		} catch (const std::exception &e) {
-			// Ignore error as the connection might not be open anymore
-			PLOG_DEBUG << "Unable to send WebSocket close frame: " << e.what();
-		}
-		changeState(State::Disconnected);
-	}
-}
-
 bool WsTransport::sendHttpRequest() {
 	PLOG_DEBUG << "Sending WebSocket HTTP request";
 

+ 3 - 3
src/impl/wstransport.hpp

@@ -30,7 +30,7 @@ namespace rtc::impl {
 class TcpTransport;
 class TlsTransport;
 
-class WsTransport final : public Transport {
+class WsTransport final : public Transport, public std::enable_shared_from_this<WsTransport> {
 public:
 	WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
 	            shared_ptr<WsHandshake> handshake, int maxOutstandingPings,
@@ -38,10 +38,10 @@ public:
 	~WsTransport();
 
 	void start() override;
-	bool stop() override;
+	void stop() override;
 	bool send(message_ptr message) override;
-	void incoming(message_ptr message) override;
 	void close();
+	void incoming(message_ptr message) override;
 
 	bool isClient() const { return mIsClient; }
 

+ 1 - 1
src/peerconnection.cpp

@@ -49,7 +49,7 @@ PeerConnection::PeerConnection(Configuration config)
 
 PeerConnection::~PeerConnection() {
 	try {
-		close();
+		impl()->remoteClose();
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 	}