소스 검색

Merge pull request #735 from paullouisageneau/proper-teardown

Fix connection teardown process to wait for remote close
Paul-Louis Ageneau 2 년 전
부모
커밋
648f717277

+ 2 - 0
CMakeLists.txt

@@ -124,6 +124,7 @@ set(LIBDATACHANNEL_IMPL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/transport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocketserver.cpp
@@ -155,6 +156,7 @@ set(LIBDATACHANNEL_IMPL_HEADERS
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/transport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocketserver.hpp

+ 19 - 16
src/impl/dtlstransport.cpp

@@ -50,7 +50,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
                              state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate),
       mVerifierCallback(std::move(verifierCallback)),
-      mIsClient(lower->role() == Description::Role::Active), mCurrentDscp(0) {
+      mIsClient(lower->role() == Description::Role::Active) {
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
@@ -99,26 +99,27 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 DtlsTransport::~DtlsTransport() {
 	stop();
 
+	PLOG_DEBUG << "Destroying DTLS transport";
 	gnutls_deinit(mSession);
 }
 
 void DtlsTransport::start() {
-	Transport::start();
-
-	registerIncoming();
+	if(mStarted.exchange(true))
+		return;
 
 	PLOG_DEBUG << "Starting DTLS recv thread";
+	registerIncoming();
 	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
 }
 
-bool DtlsTransport::stop() {
-	if (!Transport::stop())
-		return false;
+void DtlsTransport::stop() {
+	if(!mStarted.exchange(false))
+		return;
 
 	PLOG_DEBUG << "Stopping DTLS recv thread";
+	unregisterIncoming();
 	mIncomingQueue.stop();
 	mRecvThread.join();
-	return true;
 }
 
 bool DtlsTransport::send(message_ptr message) {
@@ -127,6 +128,7 @@ bool DtlsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
+
 	ssize_t ret;
 	do {
 		std::lock_guard lock(mSendMutex);
@@ -385,7 +387,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
                              state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate),
       mVerifierCallback(std::move(verifierCallback)),
-      mIsClient(lower->role() == Description::Role::Active), mCurrentDscp(0) {
+      mIsClient(lower->role() == Description::Role::Active) {
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
 	if (!mCertificate)
@@ -466,28 +468,29 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 DtlsTransport::~DtlsTransport() {
 	stop();
 
+	PLOG_DEBUG << "Destroying DTLS transport";
 	SSL_free(mSsl);
 	SSL_CTX_free(mCtx);
 }
 
 void DtlsTransport::start() {
-	Transport::start();
-
-	registerIncoming();
+	if(mStarted.exchange(true))
+		return;
 
 	PLOG_DEBUG << "Starting DTLS recv thread";
+	registerIncoming();
 	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
 }
 
-bool DtlsTransport::stop() {
-	if (!Transport::stop())
-		return false;
+void DtlsTransport::stop() {
+	if(!mStarted.exchange(false))
+		return;
 
 	PLOG_DEBUG << "Stopping DTLS recv thread";
+	unregisterIncoming();
 	mIncomingQueue.stop();
 	mRecvThread.join();
 	SSL_shutdown(mSsl);
-	return true;
 }
 
 bool DtlsTransport::send(message_ptr message) {

+ 4 - 2
src/impl/dtlstransport.hpp

@@ -47,7 +47,7 @@ public:
 	~DtlsTransport();
 
 	virtual void start() override;
-	virtual bool stop() override;
+	virtual void stop() override;
 	virtual bool send(message_ptr message) override; // false if dropped
 
 	bool isClient() const { return mIsClient; }
@@ -57,6 +57,7 @@ protected:
 	virtual bool outgoing(message_ptr message) override;
 	virtual bool demuxMessage(message_ptr message);
 	virtual void postHandshake();
+
 	void runRecvLoop();
 
 	const optional<size_t> mMtu;
@@ -66,7 +67,8 @@ protected:
 
 	Queue<message_ptr> mIncomingQueue;
 	std::thread mRecvThread;
-	std::atomic<unsigned int> mCurrentDscp;
+	std::atomic<bool> mStarted = false;
+	std::atomic<unsigned int> mCurrentDscp = 0;
 	std::atomic<bool> mOutgoingResult = true;
 
 #if USE_GNUTLS

+ 4 - 11
src/impl/icetransport.cpp

@@ -158,12 +158,10 @@ IceTransport::IceTransport(const Configuration &config, candidate_callback candi
 }
 
 IceTransport::~IceTransport() {
-	stop();
+	PLOG_DEBUG << "Destroying ICE transport";
 	mAgent.reset();
 }
 
-bool IceTransport::stop() { return Transport::stop(); }
-
 Description::Role IceTransport::role() const { return mRole; }
 
 Description IceTransport::getLocalDescription(Description::Type type) const {
@@ -572,24 +570,19 @@ IceTransport::IceTransport(const Configuration &config, candidate_callback candi
 	                       RecvCallback, this);
 }
 
-IceTransport::~IceTransport() { stop(); }
-
-bool IceTransport::stop() {
+IceTransport::~IceTransport() {
 	if (mTimeoutId) {
 		g_source_remove(mTimeoutId);
 		mTimeoutId = 0;
 	}
 
-	if (!Transport::stop())
-		return false;
-
-	PLOG_DEBUG << "Stopping ICE thread";
+	PLOG_DEBUG << "Destroying ICE transport";
 	nice_agent_attach_recv(mNiceAgent.get(), mStreamId, 1, g_main_loop_get_context(mMainLoop.get()),
 	                       NULL, NULL);
 	nice_agent_remove_stream(mNiceAgent.get(), mStreamId);
 	g_main_loop_quit(mMainLoop.get());
 	mMainLoopThread.join();
-	return true;
+	mNiceAgent.reset();
 }
 
 Description::Role IceTransport::role() const { return mRole; }

+ 0 - 1
src/impl/icetransport.hpp

@@ -62,7 +62,6 @@ public:
 	optional<string> getLocalAddress() const;
 	optional<string> getRemoteAddress() const;
 
-	bool stop() override;
 	bool send(message_ptr message) override; // false if dropped
 
 	bool getSelectedCandidatePair(Candidate *local, Candidate *remote);

+ 38 - 29
src/impl/peerconnection.cpp

@@ -79,15 +79,23 @@ PeerConnection::~PeerConnection() {
 }
 
 void PeerConnection::close() {
-	PLOG_VERBOSE << "Closing PeerConnection";
-
 	negotiationNeeded = false;
+	if (!closing.exchange(true)) {
+		PLOG_VERBOSE << "Closing PeerConnection";
+		if (auto transport = std::atomic_load(&mSctpTransport))
+			transport->stop();
+		else
+			remoteClose();
+	}
+}
 
-	// Close data channels and tracks asynchronously
-	mProcessor.enqueue(&PeerConnection::closeDataChannels, shared_from_this());
-	mProcessor.enqueue(&PeerConnection::closeTracks, shared_from_this());
-
-	closeTransports();
+void PeerConnection::remoteClose() {
+	close();
+	if (state.load() != State::Closed) {
+		closeDataChannels();
+		closeTracks();
+		closeTransports();
+	}
 }
 
 optional<Description> PeerConnection::localDescription() const {
@@ -125,11 +133,10 @@ shared_ptr<T> emplaceTransport(PeerConnection *pc, shared_ptr<T> *member, shared
 		transport->start();
 	} catch (...) {
 		std::atomic_store(member, decltype(transport)(nullptr));
-		transport->stop();
 		throw;
 	}
 
-	if (pc->state.load() == PeerConnection::State::Closed) {
+	if (pc->closing.load() || pc->state.load() == PeerConnection::State::Closed) {
 		std::atomic_store(member, decltype(transport)(nullptr));
 		transport->stop();
 		return nullptr;
@@ -155,14 +162,16 @@ shared_ptr<IceTransport> PeerConnection::initIceTransport() {
 			    case IceTransport::State::Connecting:
 				    changeState(State::Connecting);
 				    break;
-			    case IceTransport::State::Failed:
-				    changeState(State::Failed);
-				    break;
 			    case IceTransport::State::Connected:
 				    initDtlsTransport();
 				    break;
+			    case IceTransport::State::Failed:
+				    changeState(State::Failed);
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
+				    break;
 			    case IceTransport::State::Disconnected:
 				    changeState(State::Disconnected);
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    default:
 				    // Ignore
@@ -226,11 +235,11 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 				    break;
 			    case DtlsTransport::State::Failed:
 				    changeState(State::Failed);
-				    mProcessor.enqueue(&PeerConnection::closeTracks, shared_from_this());
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    case DtlsTransport::State::Disconnected:
 				    changeState(State::Disconnected);
-				    mProcessor.enqueue(&PeerConnection::closeTracks, shared_from_this());
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    default:
 				    // Ignore
@@ -299,6 +308,7 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 			    auto shared_this = weak_this.lock();
 			    if (!shared_this)
 				    return;
+
 			    switch (transportState) {
 			    case SctpTransport::State::Connected:
 				    changeState(State::Connected);
@@ -308,13 +318,12 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 			    case SctpTransport::State::Failed:
 				    LOG_WARNING << "SCTP transport failed";
 				    changeState(State::Failed);
-				    mProcessor.enqueue(&PeerConnection::remoteCloseDataChannels,
-				                       shared_from_this());
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    case SctpTransport::State::Disconnected:
+				    LOG_INFO << "SCTP transport disconnected";
 				    changeState(State::Disconnected);
-				    mProcessor.enqueue(&PeerConnection::remoteCloseDataChannels,
-				                       shared_from_this());
+				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 			    default:
 				    // Ignore
@@ -370,18 +379,18 @@ void PeerConnection::closeTransports() {
 		if (t)
 			t->onStateChange(nullptr);
 
-	// Initiate transport stop on the processor after closing the data channels
-	mProcessor.enqueue([self = shared_from_this(), transports = std::move(transports)]() {
-		TearDownProcessor::Instance().enqueue(
-		    [transports = std::move(transports), token = Init::Instance().token()]() mutable {
-			    for (const auto &t : transports)
-				    if (t)
-					    t->stop();
+	TearDownProcessor::Instance().enqueue(
+	    [transports = std::move(transports), token = Init::Instance().token()]() mutable {
+		    for (const auto &t : transports) {
+			    if (t) {
+				    t->stop();
+				    break;
+			    }
+		    }
 
-			    for (auto &t : transports)
-				    t.reset();
-		    });
-	});
+		    for (auto &t : transports)
+			    t.reset();
+	    });
 }
 
 void PeerConnection::endLocalCandidates() {

+ 2 - 0
src/impl/peerconnection.hpp

@@ -46,6 +46,7 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	~PeerConnection();
 
 	void close();
+	void remoteClose();
 
 	optional<Description> localDescription() const;
 	optional<Description> remoteDescription() const;
@@ -113,6 +114,7 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	std::atomic<GatheringState> gatheringState = GatheringState::New;
 	std::atomic<SignalingState> signalingState = SignalingState::Stable;
 	std::atomic<bool> negotiationNeeded = false;
+	std::atomic<bool> closing = false;
 	std::mutex signalingMutex;
 
 	synchronized_callback<shared_ptr<rtc::DataChannel>> dataChannelCallback;

+ 28 - 37
src/impl/sctptransport.cpp

@@ -327,8 +327,20 @@ SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &c
 }
 
 SctpTransport::~SctpTransport() {
-	stop();
-	close();
+	PLOG_DEBUG << "Destroying SCTP transport";
+
+	// Before unregistering incoming() from the lower layer, we need to make sure the thread from
+	// lower layers is not blocked in incoming() by the WrittenOnce condition.
+	mWrittenOnce = true;
+	mWrittenCondition.notify_all();
+
+	unregisterIncoming();
+
+	mProcessor.join();
+	usrsctp_close(mSock);
+
+	usrsctp_deregister_address(this);
+	Instances->erase(this);
 }
 
 void SctpTransport::onBufferedAmount(amount_callback callback) {
@@ -336,26 +348,11 @@ void SctpTransport::onBufferedAmount(amount_callback callback) {
 }
 
 void SctpTransport::start() {
-	Transport::start();
-
 	registerIncoming();
 	connect();
 }
 
-bool SctpTransport::stop() {
-	// Transport::stop() will unregister incoming() from the lower layer, therefore we need to make
-	// sure the thread from lower layers is not blocked in incoming() by the WrittenOnce condition.
-	mWrittenOnce = true;
-	mWrittenCondition.notify_all();
-
-	if (!Transport::stop())
-		return false;
-
-	mSendQueue.stop();
-	flush();
-	shutdown();
-	return true;
-}
+void SctpTransport::stop() { close(); }
 
 struct sockaddr_conn SctpTransport::getSockAddrConn(uint16_t port) {
 	struct sockaddr_conn sconn = {};
@@ -398,24 +395,6 @@ void SctpTransport::shutdown() {
 	if (usrsctp_shutdown(mSock, SHUT_RDWR) != 0 && errno != ENOTCONN) {
 		PLOG_WARNING << "SCTP shutdown failed, errno=" << errno;
 	}
-
-	close();
-
-	PLOG_INFO << "SCTP disconnected";
-	changeState(State::Disconnected);
-	mWrittenCondition.notify_all();
-}
-
-void SctpTransport::close() {
-	if (!mSock)
-		return;
-
-	mProcessor.join();
-	usrsctp_close(mSock);
-	mSock = nullptr;
-
-	usrsctp_deregister_address(this);
-	Instances->erase(this);
 }
 
 bool SctpTransport::send(message_ptr message) {
@@ -459,6 +438,11 @@ void SctpTransport::closeStream(unsigned int stream) {
 	mProcessor.enqueue(&SctpTransport::flush, shared_from_this());
 }
 
+void SctpTransport::close() {
+	mSendQueue.stop();
+	mProcessor.enqueue(&SctpTransport::flush, shared_from_this());
+}
+
 unsigned int SctpTransport::maxStream() const {
 	unsigned int streamsCount = mNegotiatedStreamsCount.value_or(MAX_SCTP_STREAMS_COUNT);
 	return streamsCount > 0 ? streamsCount - 1 : 0;
@@ -511,6 +495,8 @@ void SctpTransport::doRecv() {
 					break;
 				else
 					throw std::runtime_error("SCTP recv failed, errno=" + std::to_string(errno));
+			} else if (len == 0) {
+				break;
 			}
 
 			PLOG_VERBOSE << "SCTP recv, len=" << len;
@@ -566,6 +552,12 @@ bool SctpTransport::trySendQueue() {
 		mSendQueue.pop();
 		updateBufferedAmount(to_uint16(message->stream), -ptrdiff_t(message_size_func(message)));
 	}
+
+	if (!mSendQueue.running()) {
+		shutdown();
+		return false;
+	}
+
 	return true;
 }
 
@@ -918,7 +910,6 @@ optional<milliseconds> SctpTransport::rtt() {
 	socklen_t len = sizeof(status);
 	if (usrsctp_getsockopt(mSock, IPPROTO_SCTP, SCTP_STATUS, &status, &len)) {
 		COUNTER_BAD_SCTP_STATUS++;
-
 		return nullopt;
 	}
 	return milliseconds(status.sstat_primary.spinfo_srtt);

+ 2 - 2
src/impl/sctptransport.hpp

@@ -56,10 +56,11 @@ public:
 	void onBufferedAmount(amount_callback callback);
 
 	void start() override;
-	bool stop() override;
+	void stop() override;
 	bool send(message_ptr message) override; // false if buffered
 	bool flush();
 	void closeStream(unsigned int stream);
+	void close();
 
 	unsigned int maxStream() const;
 
@@ -86,7 +87,6 @@ private:
 
 	void connect();
 	void shutdown();
-	void close();
 	void incoming(message_ptr message) override;
 	bool outgoing(message_ptr message) override;
 

+ 0 - 11
src/impl/tcptransport.cpp

@@ -69,7 +69,6 @@ TcpTransport::TcpTransport(socket_t sock, state_callback callback)
 }
 
 TcpTransport::~TcpTransport() {
-	stop();
 	close();
 }
 
@@ -82,8 +81,6 @@ void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) {
 }
 
 void TcpTransport::start() {
-	Transport::start();
-
 	if (mSock == INVALID_SOCKET) {
 		connect();
 	} else {
@@ -92,14 +89,6 @@ void TcpTransport::start() {
 	}
 }
 
-bool TcpTransport::stop() {
-	if (!Transport::stop())
-		return false;
-
-	close();
-	return true;
-}
-
 bool TcpTransport::send(message_ptr message) {
 	std::lock_guard lock(mSendMutex);
 

+ 0 - 1
src/impl/tcptransport.hpp

@@ -46,7 +46,6 @@ public:
 	void setReadTimeout(std::chrono::milliseconds readTimeout);
 
 	void start() override;
-	bool stop() override;
 	bool send(message_ptr message) override;
 
 	void incoming(message_ptr message) override;

+ 14 - 16
src/impl/tlstransport.cpp

@@ -95,27 +95,26 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host
 
 TlsTransport::~TlsTransport() {
 	stop();
-
 	gnutls_deinit(mSession);
 }
 
 void TlsTransport::start() {
-	Transport::start();
-
-	registerIncoming();
+	if (mStarted.exchange(true))
+		return;
 
 	PLOG_DEBUG << "Starting TLS recv thread";
+	registerIncoming();
 	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
 }
 
-bool TlsTransport::stop() {
-	if (!Transport::stop())
-		return false;
+void TlsTransport::stop() {
+	if (!mStarted.exchange(false))
+		return;
 
 	PLOG_DEBUG << "Stopping TLS recv thread";
+	unregisterIncoming();
 	mIncomingQueue.stop();
 	mRecvThread.join();
-	return true;
 }
 
 bool TlsTransport::send(message_ptr message) {
@@ -375,29 +374,28 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host
 
 TlsTransport::~TlsTransport() {
 	stop();
-
 	SSL_free(mSsl);
 	SSL_CTX_free(mCtx);
 }
 
 void TlsTransport::start() {
-	Transport::start();
-
-	registerIncoming();
+	if (mStarted.exchange(true))
+		return;
 
 	PLOG_DEBUG << "Starting TLS recv thread";
+	registerIncoming();
 	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
 }
 
-bool TlsTransport::stop() {
-	if (!Transport::stop())
-		return false;
+void TlsTransport::stop() {
+	if (!mStarted.exchange(false))
+		return;
 
 	PLOG_DEBUG << "Stopping TLS recv thread";
+	unregisterIncoming();
 	mIncomingQueue.stop();
 	mRecvThread.join();
 	SSL_shutdown(mSsl);
-	return true;
 }
 
 bool TlsTransport::send(message_ptr message) {

+ 3 - 1
src/impl/tlstransport.hpp

@@ -44,7 +44,7 @@ public:
 	virtual ~TlsTransport();
 
 	void start() override;
-	bool stop() override;
+	void stop() override;
 	bool send(message_ptr message) override;
 
 	bool isClient() const { return mIsClient; }
@@ -53,6 +53,7 @@ protected:
 	virtual void incoming(message_ptr message) override;
 	virtual bool outgoing(message_ptr message) override;
 	virtual void postHandshake();
+
 	void runRecvLoop();
 
 	const optional<string> mHost;
@@ -60,6 +61,7 @@ protected:
 
 	Queue<message_ptr> mIncomingQueue;
 	std::thread mRecvThread;
+	std::atomic<bool> mStarted = false;
 
 #if USE_GNUTLS
 	gnutls_session_t mSession;

+ 89 - 0
src/impl/transport.cpp

@@ -0,0 +1,89 @@
+/**
+ * Copyright (c) 2019-2022 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "transport.hpp"
+
+namespace rtc::impl {
+
+Transport::Transport(shared_ptr<Transport> lower, state_callback callback)
+    : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {}
+
+Transport::~Transport() {
+	unregisterIncoming();
+
+	if (mLower) {
+		mLower->stop();
+		mLower.reset();
+	}
+}
+
+void Transport::registerIncoming() {
+	if (mLower) {
+		PLOG_VERBOSE << "Registering incoming callback";
+		mLower->onRecv(std::bind(&Transport::incoming, this, std::placeholders::_1));
+	}
+}
+
+void Transport::unregisterIncoming() {
+	if (mLower) {
+		PLOG_VERBOSE << "Unregistering incoming callback";
+		mLower->onRecv(nullptr);
+	}
+}
+
+Transport::State Transport::state() const { return mState; }
+
+void Transport::onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
+
+void Transport::onStateChange(state_callback callback) {
+	mStateChangeCallback = std::move(callback);
+}
+
+void Transport::start() { registerIncoming(); }
+
+void Transport::stop() { unregisterIncoming(); }
+
+bool Transport::send(message_ptr message) { return outgoing(message); }
+
+void Transport::recv(message_ptr message) {
+	try {
+		mRecvCallback(message);
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+	}
+}
+
+void Transport::changeState(State state) {
+	try {
+		if (mState.exchange(state) != state)
+			mStateChangeCallback(state);
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+	}
+}
+
+void Transport::incoming(message_ptr message) { recv(message); }
+
+bool Transport::outgoing(message_ptr message) {
+	if (mLower)
+		return mLower->send(message);
+	else
+		return false;
+}
+
+} // namespace rtc::impl

+ 19 - 54
src/impl/transport.hpp

@@ -1,5 +1,5 @@
 /**
- * Copyright (c) 2019 Paul-Louis Ageneau
+ * Copyright (c) 2019-2022 Paul-Louis Ageneau
  *
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
@@ -20,6 +20,7 @@
 #define RTC_IMPL_TRANSPORT_H
 
 #include "common.hpp"
+#include "init.hpp"
 #include "internals.hpp"
 #include "message.hpp"
 
@@ -34,70 +35,34 @@ public:
 	enum class State { Disconnected, Connecting, Connected, Completed, Failed };
 	using state_callback = std::function<void(State state)>;
 
-	Transport(shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr)
-	    : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {}
+	Transport(shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr);
+	virtual ~Transport();
 
-	virtual ~Transport() { stop(); }
+	void registerIncoming();
+	void unregisterIncoming();
+	State state() const;
 
-	virtual void start() { mStopped = false; }
+	void onRecv(message_callback callback);
+	void onStateChange(state_callback callback);
 
-	virtual bool stop() {
-		if (mStopped.exchange(true))
-			return false;
-
-		// We don't want incoming() to be called by the lower layer anymore
-		if (mLower) {
-			PLOG_VERBOSE << "Unregistering incoming callback";
-			mLower->onRecv(nullptr);
-		}
-		return true;
-	}
-
-	void registerIncoming() {
-		if (mLower) {
-			PLOG_VERBOSE << "Registering incoming callback";
-			mLower->onRecv(std::bind(&Transport::incoming, this, std::placeholders::_1));
-		}
-	}
-
-	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
-	void onStateChange(state_callback callback) { mStateChangeCallback = std::move(callback); }
-	State state() const { return mState; }
-
-	virtual bool send(message_ptr message) { return outgoing(message); }
+	virtual void start();
+	virtual void stop();
+	virtual bool send(message_ptr message);
 
 protected:
-	void recv(message_ptr message) {
-		try {
-			mRecvCallback(message);
-		} catch (const std::exception &e) {
-			PLOG_WARNING << e.what();
-		}
-	}
-	void changeState(State state) {
-		try {
-			if (mState.exchange(state) != state)
-				mStateChangeCallback(state);
-		} catch (const std::exception &e) {
-			PLOG_WARNING << e.what();
-		}
-	}
-
-	virtual void incoming(message_ptr message) { recv(message); }
-	virtual bool outgoing(message_ptr message) {
-		if (mLower)
-			return mLower->send(message);
-		else
-			return false;
-	}
+	void recv(message_ptr message);
+	void changeState(State state);
+	virtual void incoming(message_ptr message);
+	virtual bool outgoing(message_ptr message);
 
 private:
-	const shared_ptr<Transport> mLower;
+	const init_token mInitToken = Init::Instance().token();
+
+	shared_ptr<Transport> mLower;
 	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<message_ptr> mRecvCallback;
 
 	std::atomic<State> mState = State::Disconnected;
-	std::atomic<bool> mStopped = true;
 };
 
 } // namespace rtc::impl

+ 1 - 1
src/impl/verifiedtlstransport.cpp

@@ -37,7 +37,7 @@ VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, strin
 #endif
 }
 
-VerifiedTlsTransport::~VerifiedTlsTransport() {}
+VerifiedTlsTransport::~VerifiedTlsTransport() { stop(); }
 
 } // namespace rtc::impl
 

+ 8 - 6
src/impl/websocket.cpp

@@ -124,17 +124,16 @@ void WebSocket::close() {
 		PLOG_VERBOSE << "Closing WebSocket";
 		changeState(State::Closing);
 		if (auto transport = std::atomic_load(&mWsTransport))
-			transport->close();
+			transport->stop();
 		else
 			remoteClose();
 	}
 }
 
 void WebSocket::remoteClose() {
-	if (state != State::Closed) {
-		close();
+	close();
+	if (state.load() != State::Closed)
 		closeTransports();
-	}
 }
 
 bool WebSocket::isOpen() const { return state == State::Open; }
@@ -424,9 +423,12 @@ void WebSocket::closeTransports() {
 
 	TearDownProcessor::Instance().enqueue(
 	    [transports = std::move(transports), token = Init::Instance().token()]() mutable {
-		    for (const auto &t : transports)
-			    if (t)
+		    for (const auto &t : transports) {
+			    if (t) {
 				    t->stop();
+				    break;
+			    }
+		    }
 
 		    for (auto &t : transports)
 			    t.reset();

+ 1 - 1
src/impl/websocket.hpp

@@ -46,6 +46,7 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 
 	void open(const string &url);
 	void close();
+	void remoteClose();
 	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
 
@@ -58,7 +59,6 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	size_t maxMessageSize() const;
 
 	bool changeState(State state);
-	void remoteClose();
 
 	shared_ptr<TcpTransport> setTcpTransport(shared_ptr<TcpTransport> transport);
 	shared_ptr<TlsTransport> initTlsTransport();

+ 26 - 23
src/impl/wstransport.cpp

@@ -18,6 +18,7 @@
 
 #include "wstransport.hpp"
 #include "tcptransport.hpp"
+#include "threadpool.hpp"
 #include "tlstransport.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
@@ -69,11 +70,9 @@ WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTranspo
 	PLOG_DEBUG << "Initializing WebSocket transport";
 }
 
-WsTransport::~WsTransport() { stop(); }
+WsTransport::~WsTransport() { unregisterIncoming(); }
 
 void WsTransport::start() {
-	Transport::start();
-
 	registerIncoming();
 
 	changeState(State::Connecting);
@@ -81,13 +80,7 @@ void WsTransport::start() {
 		sendHttpRequest();
 }
 
-bool WsTransport::stop() {
-	if (!Transport::stop())
-		return false;
-
-	close();
-	return true;
-}
+void WsTransport::stop() { close(); }
 
 bool WsTransport::send(message_ptr message) {
 	if (!message || state() != State::Connected)
@@ -98,6 +91,29 @@ bool WsTransport::send(message_ptr message) {
 	                  message->size(), true, mIsClient});
 }
 
+void WsTransport::close() {
+	if (state() != State::Connected)
+		return;
+
+	PLOG_INFO << "WebSocket closing";
+	try {
+		sendFrame({CLOSE, NULL, 0, true, mIsClient});
+	} catch (const std::exception &e) {
+		// The connection might not be open anymore
+		PLOG_DEBUG << "Unable to send WebSocket close frame: " << e.what();
+		changeState(State::Disconnected);
+		return;
+	}
+
+	ThreadPool::Instance().schedule(std::chrono::milliseconds(10),
+	                                [this, weak_this = weak_from_this()]() {
+		                                if (auto shared_this = weak_this.lock()) {
+			                                PLOG_DEBUG << "WebSocket close timeout";
+			                                changeState(State::Disconnected);
+		                                }
+	                                });
+}
+
 void WsTransport::incoming(message_ptr message) {
 	auto s = state();
 	if (s != State::Connecting && s != State::Connected)
@@ -172,19 +188,6 @@ void WsTransport::incoming(message_ptr message) {
 	}
 }
 
-void WsTransport::close() {
-	if (state() == State::Connected) {
-		PLOG_INFO << "WebSocket closing";
-		try {
-			sendFrame({CLOSE, NULL, 0, true, mIsClient});
-		} catch (const std::exception &e) {
-			// Ignore error as the connection might not be open anymore
-			PLOG_DEBUG << "Unable to send WebSocket close frame: " << e.what();
-		}
-		changeState(State::Disconnected);
-	}
-}
-
 bool WsTransport::sendHttpRequest() {
 	PLOG_DEBUG << "Sending WebSocket HTTP request";
 

+ 3 - 3
src/impl/wstransport.hpp

@@ -30,7 +30,7 @@ namespace rtc::impl {
 class TcpTransport;
 class TlsTransport;
 
-class WsTransport final : public Transport {
+class WsTransport final : public Transport, public std::enable_shared_from_this<WsTransport> {
 public:
 	WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
 	            shared_ptr<WsHandshake> handshake, int maxOutstandingPings,
@@ -38,10 +38,10 @@ public:
 	~WsTransport();
 
 	void start() override;
-	bool stop() override;
+	void stop() override;
 	bool send(message_ptr message) override;
-	void incoming(message_ptr message) override;
 	void close();
+	void incoming(message_ptr message) override;
 
 	bool isClient() const { return mIsClient; }
 

+ 1 - 1
src/peerconnection.cpp

@@ -49,7 +49,7 @@ PeerConnection::PeerConnection(Configuration config)
 
 PeerConnection::~PeerConnection() {
 	try {
-		close();
+		impl()->remoteClose();
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 	}