Browse Source

Merge pull request #1482 from paullouisageneau/fix-ws-onerror-reset-crash

Fix crash if WebSocket is deleted in onError callback
Paul-Louis Ageneau 2 weeks ago
parent
commit
8d51a090f4

+ 4 - 4
.github/workflows/build-mbedtls.yml

@@ -12,11 +12,11 @@ jobs:
     - name: Set up Homebrew
       uses: Homebrew/actions/setup-homebrew@master
     - name: Install Mbed TLS
-      run: brew update && brew install mbedtls
+      run: brew update && brew install mbedtls@3
     - name: submodules
       run: git submodule update --init --recursive --depth 1
     - name: cmake
-      run: cmake -B build -DUSE_MBEDTLS=1 -DWARNINGS_AS_ERRORS=1  -DCMAKE_PREFIX_PATH=$(brew --prefix mbedtls)
+      run: cmake -B build -DUSE_MBEDTLS=1 -DWARNINGS_AS_ERRORS=1  -DCMAKE_PREFIX_PATH=$(brew --prefix mbedtls@3)
     - name: make
       run: (cd build; make -j2)
     - name: test
@@ -26,11 +26,11 @@ jobs:
     steps:
     - uses: actions/checkout@v4
     - name: Install Mbed TLS
-      run: brew update && brew install mbedtls
+      run: brew update && brew install mbedtls@3
     - name: submodules
       run: git submodule update --init --recursive --depth 1
     - name: cmake
-      run: cmake -B build -DUSE_MBEDTLS=1 -DWARNINGS_AS_ERRORS=1 -DENABLE_LOCAL_ADDRESS_TRANSLATION=1  -DCMAKE_PREFIX_PATH=$(brew --prefix mbedtls)
+      run: cmake -B build -DUSE_MBEDTLS=1 -DWARNINGS_AS_ERRORS=1 -DENABLE_LOCAL_ADDRESS_TRANSLATION=1 -DCMAKE_PREFIX_PATH=$(brew --prefix mbedtls@3)
     - name: make
       run: (cd build; make -j2)
     - name: test

+ 1 - 0
include/rtc/pacinghandler.hpp

@@ -40,6 +40,7 @@ private:
 	std::queue<message_ptr> mRtpBuffer;
 
 	void schedule(const message_callback &send);
+	void run(const message_callback &send);
 };
 
 } // namespace rtc

+ 93 - 94
src/impl/peerconnection.cpp

@@ -162,52 +162,52 @@ shared_ptr<IceTransport> PeerConnection::initIceTransport() {
 		auto transport = std::make_shared<IceTransport>(
 		    config, weak_bind(&PeerConnection::processLocalCandidate, this, _1),
 		    [this, weak_this = weak_from_this()](IceTransport::State transportState) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (transportState) {
-			    case IceTransport::State::Connecting:
-				    changeIceState(IceState::Checking);
-				    changeState(State::Connecting);
-				    break;
-			    case IceTransport::State::Connected:
-				    changeIceState(IceState::Connected);
-				    initDtlsTransport();
-				    break;
-			    case IceTransport::State::Completed:
-				    changeIceState(IceState::Completed);
-				    break;
-			    case IceTransport::State::Failed:
-				    changeIceState(IceState::Failed);
-				    changeState(State::Failed);
-				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
-				    break;
-			    case IceTransport::State::Disconnected:
-				    changeIceState(IceState::Disconnected);
-				    changeState(State::Disconnected);
-				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
+			    if (auto locked = weak_this.lock())
+				    std::invoke([=]() {
+					    switch (transportState) {
+					    case IceTransport::State::Connecting:
+						    changeIceState(IceState::Checking);
+						    changeState(State::Connecting);
+						    break;
+					    case IceTransport::State::Connected:
+						    changeIceState(IceState::Connected);
+						    initDtlsTransport();
+						    break;
+					    case IceTransport::State::Completed:
+						    changeIceState(IceState::Completed);
+						    break;
+					    case IceTransport::State::Failed:
+						    changeIceState(IceState::Failed);
+						    changeState(State::Failed);
+						    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
+						    break;
+					    case IceTransport::State::Disconnected:
+						    changeIceState(IceState::Disconnected);
+						    changeState(State::Disconnected);
+						    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
+						    break;
+					    default:
+						    // Ignore
+						    break;
+					    }
+				    });
 		    },
 		    [this, weak_this = weak_from_this()](IceTransport::GatheringState gatheringState) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (gatheringState) {
-			    case IceTransport::GatheringState::InProgress:
-				    changeGatheringState(GatheringState::InProgress);
-				    break;
-			    case IceTransport::GatheringState::Complete:
-				    endLocalCandidates();
-				    changeGatheringState(GatheringState::Complete);
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
+			    if (auto locked = weak_this.lock())
+				    std::invoke([=]() {
+					    switch (gatheringState) {
+					    case IceTransport::GatheringState::InProgress:
+						    changeGatheringState(GatheringState::InProgress);
+						    break;
+					    case IceTransport::GatheringState::Complete:
+						    endLocalCandidates();
+						    changeGatheringState(GatheringState::Complete);
+						    break;
+					    default:
+						    // Ignore
+						    break;
+					    }
+				    });
 		    });
 
 		return emplaceTransport(this, &mIceTransport, std::move(transport));
@@ -241,34 +241,33 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 
 		auto certificate = mCertificate.get();
 		auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1);
-		auto dtlsStateChangeCallback =
-		    [this, weak_this = weak_from_this()](DtlsTransport::State transportState) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-
-			    switch (transportState) {
-			    case DtlsTransport::State::Connected:
-				    if (auto remote = remoteDescription(); remote && remote->hasApplication())
-					    initSctpTransport();
-				    else
-					    changeState(State::Connected);
-
-				    mProcessor.enqueue(&PeerConnection::openTracks, shared_from_this());
-				    break;
-			    case DtlsTransport::State::Failed:
-				    changeState(State::Failed);
-				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
-				    break;
-			    case DtlsTransport::State::Disconnected:
-				    changeState(State::Disconnected);
-				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
-		    };
+		auto dtlsStateChangeCallback = [this, weak_this = weak_from_this()](
+		                                   DtlsTransport::State transportState) {
+			if (auto locked = weak_this.lock())
+				std::invoke([=]() {
+					switch (transportState) {
+					case DtlsTransport::State::Connected:
+						if (auto remote = remoteDescription(); remote && remote->hasApplication())
+							initSctpTransport();
+						else
+							changeState(State::Connected);
+
+						mProcessor.enqueue(&PeerConnection::openTracks, shared_from_this());
+						break;
+					case DtlsTransport::State::Failed:
+						changeState(State::Failed);
+						mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
+						break;
+					case DtlsTransport::State::Disconnected:
+						changeState(State::Disconnected);
+						mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
+						break;
+					default:
+						// Ignore
+						break;
+					}
+				});
+		};
 
 		shared_ptr<DtlsTransport> transport;
 		auto local = localDescription();
@@ -329,28 +328,28 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 		    lower, config, std::move(ports), weak_bind(&PeerConnection::forwardMessage, this, _1),
 		    weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),
 		    [this, weak_this = weak_from_this()](SctpTransport::State transportState) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-
-			    switch (transportState) {
-			    case SctpTransport::State::Connected:
-				    changeState(State::Connected);
-				    assignDataChannels();
-				    mProcessor.enqueue(&PeerConnection::openDataChannels, shared_from_this());
-				    break;
-			    case SctpTransport::State::Failed:
-				    changeState(State::Failed);
-				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
-				    break;
-			    case SctpTransport::State::Disconnected:
-				    changeState(State::Disconnected);
-				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
+			    if (auto locked = weak_this.lock())
+				    std::invoke([=]() {
+					    switch (transportState) {
+					    case SctpTransport::State::Connected:
+						    changeState(State::Connected);
+						    assignDataChannels();
+						    mProcessor.enqueue(&PeerConnection::openDataChannels,
+						                       shared_from_this());
+						    break;
+					    case SctpTransport::State::Failed:
+						    changeState(State::Failed);
+						    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
+						    break;
+					    case SctpTransport::State::Disconnected:
+						    changeState(State::Disconnected);
+						    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
+						    break;
+					    default:
+						    // Ignore
+						    break;
+					    }
+				    });
 		    });
 
 		return emplaceTransport(this, &mSctpTransport, std::move(transport));

+ 5 - 5
src/impl/track.cpp

@@ -73,7 +73,7 @@ void Track::close() {
 		triggerClosed();
 		setMediaHandler(nullptr);
 		resetCallbacks();
-	}		
+	}
 }
 
 message_variant Track::trackMessageToVariant(message_ptr message) {
@@ -144,9 +144,9 @@ void Track::incoming(message_ptr message) {
 	message_vector messages{std::move(message)};
 	if (auto handler = getMediaHandler()) {
 		try {
-			handler->incomingChain(messages, [this, weak_this = weak_from_this()](message_ptr m) {
+			handler->incomingChain(messages, [weak_this = weak_from_this()](message_ptr m) {
 				if (auto locked = weak_this.lock()) {
-					transportSend(m);
+					locked->transportSend(m);
 				}
 			});
 		} catch (const std::exception &e) {
@@ -186,9 +186,9 @@ bool Track::outgoing(message_ptr message) {
 
 	if (handler) {
 		message_vector messages{std::move(message)};
-		handler->outgoingChain(messages, [this, weak_this = weak_from_this()](message_ptr m) {
+		handler->outgoingChain(messages, [weak_this = weak_from_this()](message_ptr m) {
 			if (auto locked = weak_this.lock()) {
-				transportSend(m);
+				locked->transportSend(m);
 			}
 		});
 

+ 87 - 87
src/impl/websocket.cpp

@@ -235,30 +235,30 @@ shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> tra
 		transport->onBufferedAmount(weak_bind(&WebSocket::triggerBufferedAmount, this, _1));
 
 		transport->onStateChange([this, weak_this = weak_from_this()](State transportState) {
-			auto shared_this = weak_this.lock();
-			if (!shared_this)
-				return;
-			switch (transportState) {
-			case State::Connected:
-				if (config.proxyServer)
-					initProxyTransport();
-				else if (mIsSecure)
-					initTlsTransport();
-				else
-					initWsTransport();
-				break;
-			case State::Failed:
-				triggerError("TCP connection failed");
-				remoteClose();
-				break;
-			case State::Disconnected:
-				if(state == WebSocket::State::Connecting)
-					remoteClose();
-				break;
-			default:
-				// Ignore
-				break;
-			}
+			if(auto locked = weak_this.lock())
+				std::invoke([=]() {
+					switch (transportState) {
+					case State::Connected:
+						if (config.proxyServer)
+							initProxyTransport();
+						else if (mIsSecure)
+							initTlsTransport();
+						else
+							initWsTransport();
+						break;
+					case State::Failed:
+						triggerError("TCP connection failed");
+						remoteClose();
+						break;
+					case State::Disconnected:
+						if(state == WebSocket::State::Connecting)
+							remoteClose();
+						break;
+					default:
+						// Ignore
+						break;
+					}
+				});
 		});
 
 		// WS transport sends a ping on read timeout
@@ -289,28 +289,28 @@ shared_ptr<HttpProxyTransport> WebSocket::initProxyTransport() {
 			throw std::logic_error("No underlying TCP transport for Proxy transport");
 
 		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
-			auto shared_this = weak_this.lock();
-			if (!shared_this)
-				return;
-			switch (transportState) {
-			case State::Connected:
-				if (mIsSecure)
-					initTlsTransport();
-				else
-					initWsTransport();
-				break;
-			case State::Failed:
-				triggerError("Proxy connection failed");
-				remoteClose();
-				break;
-			case State::Disconnected:
-				if(state == WebSocket::State::Connecting)
-					remoteClose();
-				break;
-			default:
-				// Ignore
-				break;
-			}
+			if(auto locked = weak_this.lock())
+				std::invoke([=]() {
+					switch (transportState) {
+					case State::Connected:
+						if (mIsSecure)
+							initTlsTransport();
+						else
+							initWsTransport();
+						break;
+					case State::Failed:
+						triggerError("Proxy connection failed");
+						remoteClose();
+						break;
+					case State::Disconnected:
+						if(state == WebSocket::State::Connecting)
+							remoteClose();
+						break;
+					default:
+						// Ignore
+						break;
+					}
+				});
 		};
 
 		auto transport = std::make_shared<HttpProxyTransport>(
@@ -348,25 +348,25 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 		}
 
 		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
-			auto shared_this = weak_this.lock();
-			if (!shared_this)
-				return;
-			switch (transportState) {
-			case State::Connected:
-				initWsTransport();
-				break;
-			case State::Failed:
-				triggerError("TLS connection failed");
-				remoteClose();
-				break;
-			case State::Disconnected:
-				if(state == WebSocket::State::Connecting)
-					remoteClose();
-				break;
-			default:
-				// Ignore
-				break;
-			}
+			if(auto locked = weak_this.lock())
+				std::invoke([=]() {
+					switch (transportState) {
+					case State::Connected:
+						initWsTransport();
+						break;
+					case State::Failed:
+						triggerError("TLS connection failed");
+						remoteClose();
+						break;
+					case State::Disconnected:
+						if(state == WebSocket::State::Connecting)
+							remoteClose();
+						break;
+					default:
+						// Ignore
+						break;
+					}
+				});
 		};
 
 		bool verify = mHostname.has_value() && !config.disableTlsVerification;
@@ -428,28 +428,28 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 			atomic_store(&mWsHandshake, std::make_shared<WsHandshake>());
 
 		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
-			auto shared_this = weak_this.lock();
-			if (!shared_this)
-				return;
-			switch (transportState) {
-			case State::Connected:
-				if (state == WebSocket::State::Connecting) {
-					PLOG_DEBUG << "WebSocket open";
-					if (changeState(WebSocket::State::Open))
-						triggerOpen();
-				}
-				break;
-			case State::Failed:
-				triggerError("WebSocket connection failed");
-				remoteClose();
-				break;
-			case State::Disconnected:
-				remoteClose();
-				break;
-			default:
-				// Ignore
-				break;
-			}
+			if(auto locked = weak_this.lock())
+				std::invoke([=]() {
+					switch (transportState) {
+					case State::Connected:
+						if (state == WebSocket::State::Connecting) {
+							PLOG_DEBUG << "WebSocket open";
+							if (changeState(WebSocket::State::Open))
+								triggerOpen();
+						}
+						break;
+					case State::Failed:
+						triggerError("WebSocket connection failed");
+						remoteClose();
+						break;
+					case State::Disconnected:
+						remoteClose();
+						break;
+					default:
+						// Ignore
+						break;
+					}
+				});
 		};
 
 		auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, config,

+ 6 - 7
src/impl/wstransport.cpp

@@ -102,13 +102,12 @@ void WsTransport::close() {
 		return;
 	}
 
-	ThreadPool::Instance().schedule(std::chrono::seconds(10),
-	                                [this, weak_this = weak_from_this()]() {
-		                                if (auto shared_this = weak_this.lock()) {
-			                                PLOG_DEBUG << "WebSocket close timeout";
-			                                changeState(State::Disconnected);
-		                                }
-	                                });
+	ThreadPool::Instance().schedule(std::chrono::seconds(10), [weak_this = weak_from_this()]() {
+		if (auto locked = weak_this.lock()) {
+			PLOG_DEBUG << "WebSocket close timeout";
+			locked->changeState(State::Disconnected);
+		}
+	});
 }
 
 void WsTransport::incoming(message_ptr message) {

+ 26 - 31
src/pacinghandler.cpp

@@ -18,41 +18,36 @@
 namespace rtc {
 
 PacingHandler::PacingHandler(double bitsPerSecond, std::chrono::milliseconds sendInterval)
-    : mBytesPerSecond(bitsPerSecond / 8), mBudget(0.), mSendInterval(sendInterval){};
+    : mBytesPerSecond(bitsPerSecond / 8), mBudget(0.), mSendInterval(sendInterval) {};
 
 void PacingHandler::schedule(const message_callback &send) {
-	if (mHaveScheduled.exchange(true)) {
-		return;
+	if (!mHaveScheduled.exchange(true))
+		impl::ThreadPool::Instance().schedule(mSendInterval,
+		                                      weak_bind(&PacingHandler::run, this, send));
+}
+
+void PacingHandler::run(const message_callback &send) {
+	const std::lock_guard<std::mutex> lock(mMutex);
+	mHaveScheduled.store(false);
+
+	// Update the budget and cap it
+	auto now = std::chrono::high_resolution_clock::now();
+	auto newBudget = std::chrono::duration<double>(now - mLastRun).count() * mBytesPerSecond;
+	auto maxBudget = std::chrono::duration<double>(mSendInterval).count() * mBytesPerSecond;
+	mBudget = std::min(mBudget + newBudget, maxBudget);
+	mLastRun = std::chrono::high_resolution_clock::now();
+
+	// Send packets while there is budget, allow a single partial packet over budget
+	while (!mRtpBuffer.empty() && mBudget > 0) {
+		auto size = int(mRtpBuffer.front()->size());
+		send(std::move(mRtpBuffer.front()));
+		mRtpBuffer.pop();
+		mBudget -= size;
 	}
 
-	impl::ThreadPool::Instance().schedule(mSendInterval, [this, weak_this = weak_from_this(),
-	                                                      send]() {
-		if (auto locked = weak_this.lock()) {
-			const std::lock_guard<std::mutex> lock(mMutex);
-			mHaveScheduled.store(false);
-
-			// Update the budget and cap it
-			auto newBudget =
-			    std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - mLastRun)
-			        .count() *
-			    mBytesPerSecond;
-			auto maxBudget = std::chrono::duration<double>(mSendInterval).count() * mBytesPerSecond;
-			mBudget = std::min(mBudget + newBudget, maxBudget);
-			mLastRun = std::chrono::high_resolution_clock::now();
-
-			// Send packets while there is budget, allow a single partial packet over budget
-			while (!mRtpBuffer.empty() && mBudget > 0) {
-				auto size = int(mRtpBuffer.front()->size());
-				send(std::move(mRtpBuffer.front()));
-				mRtpBuffer.pop();
-				mBudget -= size;
-			}
-
-			if (!mRtpBuffer.empty()) {
-				schedule(send);
-			}
-		}
-	});
+	if (!mRtpBuffer.empty()) {
+		schedule(send);
+	}
 }
 
 void PacingHandler::outgoing(message_vector &messages, const message_callback &send) {