Browse Source

Fixed race condition with transport start() and stop()

Paul-Louis Ageneau 3 years ago
parent
commit
262a22928c
2 changed files with 60 additions and 68 deletions
  1. 31 35
      src/impl/peerconnection.cpp
  2. 29 33
      src/impl/websocket.cpp

+ 31 - 35
src/impl/peerconnection.cpp

@@ -115,6 +115,19 @@ size_t PeerConnection::remoteMaxMessageSize() const {
 	return std::min(remoteMax, localMax);
 	return std::min(remoteMax, localMax);
 }
 }
 
 
+// Helper for PeerConnection::initXTransport methods: start and emplace the transport
+template <typename T>
+shared_ptr<T> emplaceTransport(PeerConnection *pc, shared_ptr<T> *member, shared_ptr<T> transport) {
+	transport->start();
+	std::atomic_store(member, transport);
+	if (pc->state.load() == PeerConnection::State::Closed) {
+		std::atomic_store(member, decltype(transport)(nullptr));
+		transport->stop();
+		throw std::runtime_error("Connection is closed");
+	}
+	return transport;
+}
+
 shared_ptr<IceTransport> PeerConnection::initIceTransport() {
 shared_ptr<IceTransport> PeerConnection::initIceTransport() {
 	try {
 	try {
 		if (auto transport = std::atomic_load(&mIceTransport))
 		if (auto transport = std::atomic_load(&mIceTransport))
@@ -164,13 +177,7 @@ shared_ptr<IceTransport> PeerConnection::initIceTransport() {
 			    }
 			    }
 		    });
 		    });
 
 
-		std::atomic_store(&mIceTransport, transport);
-		if (state.load() == State::Closed) {
-			std::atomic_store(&mIceTransport, decltype(mIceTransport)(nullptr));
-			throw std::runtime_error("Connection is closed");
-		}
-		transport->start();
-		return transport;
+		return emplaceTransport(this, &mIceTransport, std::move(transport));
 
 
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		PLOG_ERROR << e.what();
@@ -239,13 +246,7 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 			                                            verifierCallback, dtlsStateChangeCallback);
 			                                            verifierCallback, dtlsStateChangeCallback);
 		}
 		}
 
 
-		std::atomic_store(&mDtlsTransport, transport);
-		if (state.load() == State::Closed) {
-			std::atomic_store(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
-			throw std::runtime_error("Connection is closed");
-		}
-		transport->start();
-		return transport;
+		return emplaceTransport(this, &mDtlsTransport, std::move(transport));
 
 
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		PLOG_ERROR << e.what();
@@ -301,13 +302,7 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 			    }
 			    }
 		    });
 		    });
 
 
-		std::atomic_store(&mSctpTransport, transport);
-		if (state.load() == State::Closed) {
-			std::atomic_store(&mSctpTransport, decltype(mSctpTransport)(nullptr));
-			throw std::runtime_error("Connection is closed");
-		}
-		transport->start();
-		return transport;
+		return emplaceTransport(this, &mSctpTransport, std::move(transport));
 
 
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		PLOG_ERROR << e.what();
@@ -344,18 +339,19 @@ void PeerConnection::closeTransports() {
 		auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr));
 		auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr));
 		auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
 		auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
 		auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr));
 		auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr));
-		ThreadPool::Instance().enqueue([sctp, dtls, ice]() mutable {
-			if (sctp)
-				sctp->stop();
-			if (dtls)
-				dtls->stop();
-			if (ice)
-				ice->stop();
-
-			sctp.reset();
-			dtls.reset();
-			ice.reset();
-		});
+		ThreadPool::Instance().enqueue(
+		    [sctp = std::move(sctp), dtls = std::move(dtls), ice = std::move(ice)]() mutable {
+			    if (sctp)
+				    sctp->stop();
+			    if (dtls)
+				    dtls->stop();
+			    if (ice)
+				    ice->stop();
+
+			    sctp.reset();
+			    dtls.reset();
+			    ice.reset();
+		    });
 	});
 	});
 }
 }
 
 
@@ -1037,11 +1033,11 @@ void PeerConnection::triggerPendingTracks() {
 }
 }
 
 
 void PeerConnection::flushPendingDataChannels() {
 void PeerConnection::flushPendingDataChannels() {
-	mProcessor->enqueue(std::bind(&PeerConnection::triggerPendingDataChannels, this));
+	mProcessor->enqueue(&PeerConnection::triggerPendingDataChannels, this);
 }
 }
 
 
 void PeerConnection::flushPendingTracks() {
 void PeerConnection::flushPendingTracks() {
-	mProcessor->enqueue(std::bind(&PeerConnection::triggerPendingTracks, this));
+	mProcessor->enqueue(&PeerConnection::triggerPendingTracks, this);
 }
 }
 
 
 bool PeerConnection::changeState(State newState) {
 bool PeerConnection::changeState(State newState) {

+ 29 - 33
src/impl/websocket.cpp

@@ -175,6 +175,19 @@ void WebSocket::incoming(message_ptr message) {
 	}
 	}
 }
 }
 
 
+// Helper for WebSocket::initXTransport methods: start and emplace the transport
+template <typename T>
+shared_ptr<T> emplaceTransport(WebSocket *ws, shared_ptr<T> *member, shared_ptr<T> transport) {
+	transport->start();
+	std::atomic_store(member, transport);
+	if (ws->state.load() == WebSocket::State::Closed) {
+		std::atomic_store(member, decltype(transport)(nullptr));
+		transport->stop();
+		throw std::runtime_error("Connection is closed");
+	}
+	return transport;
+}
+
 shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> transport) {
 shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> transport) {
 	PLOG_VERBOSE << "Starting TCP transport";
 	PLOG_VERBOSE << "Starting TCP transport";
 
 
@@ -210,13 +223,7 @@ shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> tra
 			}
 			}
 		});
 		});
 
 
-		std::atomic_store(&mTcpTransport, transport);
-		if (state == WebSocket::State::Closed) {
-			std::atomic_store(&mTcpTransport, decltype(mTcpTransport)(nullptr));
-			throw std::runtime_error("Connection is closed");
-		}
-		transport->start();
-		return transport;
+		return emplaceTransport(this, &mTcpTransport, std::move(transport));
 
 
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		PLOG_ERROR << e.what();
@@ -273,13 +280,7 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 			transport =
 			transport =
 			    std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);
 			    std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);
 
 
-		std::atomic_store(&mTlsTransport, transport);
-		if (state == WebSocket::State::Closed) {
-			std::atomic_store(&mTlsTransport, decltype(mTlsTransport)(nullptr));
-			throw std::runtime_error("Connection is closed");
-		}
-		transport->start();
-		return transport;
+		return emplaceTransport(this, &mTlsTransport, std::move(transport));
 
 
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		PLOG_ERROR << e.what();
@@ -341,13 +342,7 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		auto transport = std::make_shared<WsTransport>(
 		auto transport = std::make_shared<WsTransport>(
 		    lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback);
 		    lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback);
 
 
-		std::atomic_store(&mWsTransport, transport);
-		if (state == WebSocket::State::Closed) {
-			std::atomic_store(&mWsTransport, decltype(mWsTransport)(nullptr));
-			throw std::runtime_error("Connection is closed");
-		}
-		transport->start();
-		return transport;
+		return emplaceTransport(this, &mWsTransport, std::move(transport));
 
 
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		PLOG_ERROR << e.what();
@@ -387,18 +382,19 @@ void WebSocket::closeTransports() {
 	auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
 	auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
 	auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
 	auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
 	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
 	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
-	ThreadPool::Instance().enqueue([ws, tls, tcp]() mutable {
-		if (ws)
-			ws->stop();
-		if (tls)
-			tls->stop();
-		if (tcp)
-			tcp->stop();
-
-		ws.reset();
-		tls.reset();
-		tcp.reset();
-	});
+	ThreadPool::Instance().enqueue(
+	    [ws = std::move(ws), tls = std::move(tls), tcp = std::move(tcp)]() mutable {
+		    if (ws)
+			    ws->stop();
+		    if (tls)
+			    tls->stop();
+		    if (tcp)
+			    tcp->stop();
+
+		    ws.reset();
+		    tls.reset();
+		    tcp.reset();
+	    });
 }
 }
 
 
 } // namespace rtc::impl
 } // namespace rtc::impl