Browse Source

Merge pull request #763 from paullouisageneau/refactor-sctp-shutdown

Refactor SCTP shutdown
Paul-Louis Ageneau 2 years ago
parent
commit
5bca9bb21c

+ 0 - 2
src/impl/peerconnection.cpp

@@ -318,12 +318,10 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 				    mProcessor.enqueue(&PeerConnection::openDataChannels, shared_from_this());
 				    mProcessor.enqueue(&PeerConnection::openDataChannels, shared_from_this());
 				    break;
 				    break;
 			    case SctpTransport::State::Failed:
 			    case SctpTransport::State::Failed:
-				    LOG_WARNING << "SCTP transport failed";
 				    changeState(State::Failed);
 				    changeState(State::Failed);
 				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 				    break;
 			    case SctpTransport::State::Disconnected:
 			    case SctpTransport::State::Disconnected:
-				    LOG_INFO << "SCTP transport disconnected";
 				    changeState(State::Disconnected);
 				    changeState(State::Disconnected);
 				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this());
 				    break;
 				    break;

+ 2 - 2
src/impl/processor.hpp

@@ -44,7 +44,7 @@ public:
 
 
 	void join();
 	void join();
 
 
-	template <class F, class... Args> void enqueue(F &&f, Args &&...args);
+	template <class F, class... Args> void enqueue(F &&f, Args &&...args) noexcept;
 
 
 private:
 private:
 	void schedule();
 	void schedule();
@@ -65,7 +65,7 @@ private:
 	~TearDownProcessor();
 	~TearDownProcessor();
 };
 };
 
 
-template <class F, class... Args> void Processor::enqueue(F &&f, Args &&...args) {
+template <class F, class... Args> void Processor::enqueue(F &&f, Args &&...args) noexcept {
 	std::unique_lock lock(mMutex);
 	std::unique_lock lock(mMutex);
 	auto bound = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
 	auto bound = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
 	auto task = [this, bound = std::move(bound)]() mutable {
 	auto task = [this, bound = std::move(bound)]() mutable {

+ 48 - 42
src/impl/sctptransport.cpp

@@ -84,11 +84,6 @@ namespace rtc::impl {
 
 
 static LogCounter COUNTER_UNKNOWN_PPID(plog::warning,
 static LogCounter COUNTER_UNKNOWN_PPID(plog::warning,
                                        "Number of SCTP packets received with an unknown PPID");
                                        "Number of SCTP packets received with an unknown PPID");
-static LogCounter
-    COUNTER_BAD_NOTIF_LEN(plog::warning,
-                          "Number of SCTP packets received with an bad notification length");
-static LogCounter COUNTER_BAD_SCTP_STATUS(plog::warning,
-                                          "Number of SCTP packets received with a bad status");
 
 
 class SctpTransport::InstancesSet {
 class SctpTransport::InstancesSet {
 public:
 public:
@@ -103,7 +98,7 @@ public:
 	}
 	}
 
 
 	using shared_lock = std::shared_lock<std::shared_mutex>;
 	using shared_lock = std::shared_lock<std::shared_mutex>;
-	optional<shared_lock> lock(SctpTransport *instance) {
+	optional<shared_lock> lock(SctpTransport *instance) noexcept {
 		shared_lock lock(mMutex);
 		shared_lock lock(mMutex);
 		return mSet.find(instance) != mSet.end() ? std::make_optional(std::move(lock)) : nullopt;
 		return mSet.find(instance) != mSet.end() ? std::make_optional(std::move(lock)) : nullopt;
 	}
 	}
@@ -175,7 +170,7 @@ void SctpTransport::SetSettings(const SctpSettings &s) {
 }
 }
 
 
 void SctpTransport::Cleanup() {
 void SctpTransport::Cleanup() {
-	while (usrsctp_finish() != 0)
+	while (usrsctp_finish())
 		std::this_thread::sleep_for(100ms);
 		std::this_thread::sleep_for(100ms);
 }
 }
 
 
@@ -329,6 +324,8 @@ SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &c
 SctpTransport::~SctpTransport() {
 SctpTransport::~SctpTransport() {
 	PLOG_DEBUG << "Destroying SCTP transport";
 	PLOG_DEBUG << "Destroying SCTP transport";
 
 
+	mProcessor.join(); // if we are here, the processor must be empty
+
 	// Before unregistering incoming() from the lower layer, we need to make sure the thread from
 	// Before unregistering incoming() from the lower layer, we need to make sure the thread from
 	// lower layers is not blocked in incoming() by the WrittenOnce condition.
 	// lower layers is not blocked in incoming() by the WrittenOnce condition.
 	mWrittenOnce = true;
 	mWrittenOnce = true;
@@ -336,7 +333,6 @@ SctpTransport::~SctpTransport() {
 
 
 	unregisterIncoming();
 	unregisterIncoming();
 
 
-	mProcessor.join();
 	usrsctp_close(mSock);
 	usrsctp_close(mSock);
 
 
 	usrsctp_deregister_address(this);
 	usrsctp_deregister_address(this);
@@ -366,9 +362,6 @@ struct sockaddr_conn SctpTransport::getSockAddrConn(uint16_t port) {
 }
 }
 
 
 void SctpTransport::connect() {
 void SctpTransport::connect() {
-	if (!mSock)
-		throw std::logic_error("Attempted SCTP connect with closed socket");
-
 	PLOG_DEBUG << "SCTP connecting (local port=" << mPorts.local
 	PLOG_DEBUG << "SCTP connecting (local port=" << mPorts.local
 	           << ", remote port=" << mPorts.remote << ")";
 	           << ", remote port=" << mPorts.remote << ")";
 	changeState(State::Connecting);
 	changeState(State::Connecting);
@@ -386,17 +379,6 @@ void SctpTransport::connect() {
 		throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno));
 		throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno));
 }
 }
 
 
-void SctpTransport::shutdown() {
-	if (!mSock)
-		return;
-
-	PLOG_DEBUG << "SCTP shutdown";
-
-	if (usrsctp_shutdown(mSock, SHUT_RDWR) != 0 && errno != ENOTCONN) {
-		PLOG_WARNING << "SCTP shutdown failed, errno=" << errno;
-	}
-}
-
 bool SctpTransport::send(message_ptr message) {
 bool SctpTransport::send(message_ptr message) {
 	std::lock_guard lock(mSendMutex);
 	std::lock_guard lock(mSendMutex);
 
 
@@ -542,6 +524,28 @@ void SctpTransport::doFlush() {
 	}
 	}
 }
 }
 
 
+void SctpTransport::enqueueRecv() {
+	if (mPendingRecvCount > 0)
+		return;
+
+	if (auto shared_this = weak_from_this().lock()) {
+		// This is called from the upcall callback, we must not release the shared ptr here
+		++mPendingRecvCount;
+		mProcessor.enqueue(&SctpTransport::doRecv, std::move(shared_this));
+	}
+}
+
+void SctpTransport::enqueueFlush() {
+	if (mPendingFlushCount > 0)
+		return;
+
+	if (auto shared_this = weak_from_this().lock()) {
+		// This is called from the upcall callback, we must not release the shared ptr here
+		++mPendingFlushCount;
+		mProcessor.enqueue(&SctpTransport::doFlush, std::move(shared_this));
+	}
+}
+
 bool SctpTransport::trySendQueue() {
 bool SctpTransport::trySendQueue() {
 	// Requires mSendMutex to be locked
 	// Requires mSendMutex to be locked
 	while (auto next = mSendQueue.peek()) {
 	while (auto next = mSendQueue.peek()) {
@@ -553,9 +557,17 @@ bool SctpTransport::trySendQueue() {
 		updateBufferedAmount(to_uint16(message->stream), -ptrdiff_t(message_size_func(message)));
 		updateBufferedAmount(to_uint16(message->stream), -ptrdiff_t(message_size_func(message)));
 	}
 	}
 
 
-	if (!mSendQueue.running()) {
-		shutdown();
-		return false;
+	if (!mSendQueue.running() && !std::exchange(mSendShutdown, true)) {
+		PLOG_DEBUG << "SCTP shutdown";
+		if (usrsctp_shutdown(mSock, SHUT_WR)) {
+			if (errno == ENOTCONN) {
+				PLOG_VERBOSE << "SCTP already shut down";
+			} else {
+				PLOG_WARNING << "SCTP shutdown failed, errno=" << errno;
+				changeState(State::Disconnected);
+				recv(nullptr);
+			}
+		}
 	}
 	}
 
 
 	return true;
 	return true;
@@ -698,31 +710,25 @@ void SctpTransport::sendReset(uint16_t streamId) {
 	}
 	}
 }
 }
 
 
-void SctpTransport::handleUpcall() {
+void SctpTransport::handleUpcall() noexcept {
 	try {
 	try {
-		if (!mSock)
-			return;
-
 		PLOG_VERBOSE << "Handle upcall";
 		PLOG_VERBOSE << "Handle upcall";
 
 
 		int events = usrsctp_get_events(mSock);
 		int events = usrsctp_get_events(mSock);
 
 
-		if (events & SCTP_EVENT_READ && mPendingRecvCount == 0) {
-			++mPendingRecvCount;
-			mProcessor.enqueue(&SctpTransport::doRecv, shared_from_this());
-		}
+		if (events & SCTP_EVENT_READ)
+			enqueueRecv();
 
 
-		if (events & SCTP_EVENT_WRITE && mPendingFlushCount == 0) {
-			++mPendingFlushCount;
-			mProcessor.enqueue(&SctpTransport::doFlush, shared_from_this());
-		}
+		if (events & SCTP_EVENT_WRITE)
+			enqueueFlush();
 
 
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "SCTP upcall: " << e.what();
 		PLOG_ERROR << "SCTP upcall: " << e.what();
 	}
 	}
 }
 }
 
 
-int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/, uint8_t /*set_df*/) {
+int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/,
+                               uint8_t /*set_df*/) noexcept {
 	try {
 	try {
 		std::unique_lock lock(mWriteMutex);
 		std::unique_lock lock(mWriteMutex);
 		PLOG_VERBOSE << "Handle write, len=" << len;
 		PLOG_VERBOSE << "Handle write, len=" << len;
@@ -806,7 +812,8 @@ void SctpTransport::processData(binary &&data, uint16_t sid, PayloadId ppid) {
 
 
 void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) {
 void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) {
 	if (len != size_t(notify->sn_header.sn_length)) {
 	if (len != size_t(notify->sn_header.sn_length)) {
-		COUNTER_BAD_NOTIF_LEN++;
+		PLOG_WARNING << "Unexpected notification length, expected=" << notify->sn_header.sn_length
+		             << ", actual=" << len;
 		return;
 		return;
 	}
 	}
 
 
@@ -908,10 +915,9 @@ optional<milliseconds> SctpTransport::rtt() {
 
 
 	struct sctp_status status = {};
 	struct sctp_status status = {};
 	socklen_t len = sizeof(status);
 	socklen_t len = sizeof(status);
-	if (usrsctp_getsockopt(mSock, IPPROTO_SCTP, SCTP_STATUS, &status, &len)) {
-		COUNTER_BAD_SCTP_STATUS++;
+	if (usrsctp_getsockopt(mSock, IPPROTO_SCTP, SCTP_STATUS, &status, &len))
 		return nullopt;
 		return nullopt;
-	}
+
 	return milliseconds(status.sstat_primary.spinfo_srtt);
 	return milliseconds(status.sstat_primary.spinfo_srtt);
 }
 }
 
 

+ 5 - 2
src/impl/sctptransport.hpp

@@ -92,14 +92,16 @@ private:
 
 
 	void doRecv();
 	void doRecv();
 	void doFlush();
 	void doFlush();
+	void enqueueRecv();
+	void enqueueFlush();
 	bool trySendQueue();
 	bool trySendQueue();
 	bool trySendMessage(message_ptr message);
 	bool trySendMessage(message_ptr message);
 	void updateBufferedAmount(uint16_t streamId, ptrdiff_t delta);
 	void updateBufferedAmount(uint16_t streamId, ptrdiff_t delta);
 	void triggerBufferedAmount(uint16_t streamId, size_t amount);
 	void triggerBufferedAmount(uint16_t streamId, size_t amount);
 	void sendReset(uint16_t streamId);
 	void sendReset(uint16_t streamId);
 
 
-	void handleUpcall();
-	int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df);
+	void handleUpcall() noexcept;
+	int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) noexcept;
 
 
 	void processData(binary &&data, uint16_t streamId, PayloadId ppid);
 	void processData(binary &&data, uint16_t streamId, PayloadId ppid);
 	void processNotification(const union sctp_notification *notify, size_t len);
 	void processNotification(const union sctp_notification *notify, size_t len);
@@ -114,6 +116,7 @@ private:
 	std::mutex mRecvMutex;
 	std::mutex mRecvMutex;
 	std::recursive_mutex mSendMutex; // buffered amount callback is synchronous
 	std::recursive_mutex mSendMutex; // buffered amount callback is synchronous
 	Queue<message_ptr> mSendQueue;
 	Queue<message_ptr> mSendQueue;
+	bool mSendShutdown = false;
 	std::map<uint16_t, size_t> mBufferedAmount;
 	std::map<uint16_t, size_t> mBufferedAmount;
 	amount_callback mBufferedAmountCallback;
 	amount_callback mBufferedAmountCallback;
 
 

+ 8 - 6
src/impl/threadpool.hpp

@@ -59,13 +59,15 @@ public:
 	bool runOne();
 	bool runOne();
 
 
 	template <class F, class... Args>
 	template <class F, class... Args>
-	auto enqueue(F &&f, Args &&...args) -> invoke_future_t<F, Args...>;
+	auto enqueue(F &&f, Args &&...args) noexcept -> invoke_future_t<F, Args...>;
 
 
 	template <class F, class... Args>
 	template <class F, class... Args>
-	auto schedule(clock::duration delay, F &&f, Args &&...args) -> invoke_future_t<F, Args...>;
+	auto schedule(clock::duration delay, F &&f, Args &&...args) noexcept
+	    -> invoke_future_t<F, Args...>;
 
 
 	template <class F, class... Args>
 	template <class F, class... Args>
-	auto schedule(clock::time_point time, F &&f, Args &&...args) -> invoke_future_t<F, Args...>;
+	auto schedule(clock::time_point time, F &&f, Args &&...args) noexcept
+	    -> invoke_future_t<F, Args...>;
 
 
 private:
 private:
 	ThreadPool();
 	ThreadPool();
@@ -90,18 +92,18 @@ private:
 };
 };
 
 
 template <class F, class... Args>
 template <class F, class... Args>
-auto ThreadPool::enqueue(F &&f, Args &&...args) -> invoke_future_t<F, Args...> {
+auto ThreadPool::enqueue(F &&f, Args &&...args) noexcept -> invoke_future_t<F, Args...> {
 	return schedule(clock::now(), std::forward<F>(f), std::forward<Args>(args)...);
 	return schedule(clock::now(), std::forward<F>(f), std::forward<Args>(args)...);
 }
 }
 
 
 template <class F, class... Args>
 template <class F, class... Args>
-auto ThreadPool::schedule(clock::duration delay, F &&f, Args &&...args)
+auto ThreadPool::schedule(clock::duration delay, F &&f, Args &&...args) noexcept
     -> invoke_future_t<F, Args...> {
     -> invoke_future_t<F, Args...> {
 	return schedule(clock::now() + delay, std::forward<F>(f), std::forward<Args>(args)...);
 	return schedule(clock::now() + delay, std::forward<F>(f), std::forward<Args>(args)...);
 }
 }
 
 
 template <class F, class... Args>
 template <class F, class... Args>
-auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args)
+auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args) noexcept
     -> invoke_future_t<F, Args...> {
     -> invoke_future_t<F, Args...> {
 	std::unique_lock lock(mMutex);
 	std::unique_lock lock(mMutex);
 	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
 	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;