Browse Source

Refactored SCTP upcall callback and prevent bad_weak_ptr errors

Paul-Louis Ageneau 2 years ago
parent
commit
adffd7b37c
4 changed files with 44 additions and 21 deletions
  1. 2 2
      src/impl/processor.hpp
  2. 30 11
      src/impl/sctptransport.cpp
  3. 4 2
      src/impl/sctptransport.hpp
  4. 8 6
      src/impl/threadpool.hpp

+ 2 - 2
src/impl/processor.hpp

@@ -44,7 +44,7 @@ public:
 
 	void join();
 
-	template <class F, class... Args> void enqueue(F &&f, Args &&...args);
+	template <class F, class... Args> void enqueue(F &&f, Args &&...args) noexcept;
 
 private:
 	void schedule();
@@ -65,7 +65,7 @@ private:
 	~TearDownProcessor();
 };
 
-template <class F, class... Args> void Processor::enqueue(F &&f, Args &&...args) {
+template <class F, class... Args> void Processor::enqueue(F &&f, Args &&...args) noexcept {
 	std::unique_lock lock(mMutex);
 	auto bound = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
 	auto task = [this, bound = std::move(bound)]() mutable {

+ 30 - 11
src/impl/sctptransport.cpp

@@ -103,7 +103,7 @@ public:
 	}
 
 	using shared_lock = std::shared_lock<std::shared_mutex>;
-	optional<shared_lock> lock(SctpTransport *instance) {
+	optional<shared_lock> lock(SctpTransport *instance) noexcept {
 		shared_lock lock(mMutex);
 		return mSet.find(instance) != mSet.end() ? std::make_optional(std::move(lock)) : nullopt;
 	}
@@ -529,6 +529,28 @@ void SctpTransport::doFlush() {
 	}
 }
 
+void SctpTransport::enqueueRecv() {
+	if (mPendingRecvCount > 0)
+		return;
+
+	if (auto shared_this = weak_from_this().lock()) {
+		// This is called from the upcall callback, we must not release the shared ptr here
+		++mPendingRecvCount;
+		mProcessor.enqueue(&SctpTransport::doRecv, std::move(shared_this));
+	}
+}
+
+void SctpTransport::enqueueFlush() {
+	if (mPendingFlushCount > 0)
+		return;
+
+	if (auto shared_this = weak_from_this().lock()) {
+		// This is called from the upcall callback, we must not release the shared ptr here
+		++mPendingFlushCount;
+		mProcessor.enqueue(&SctpTransport::doFlush, std::move(shared_this));
+	}
+}
+
 bool SctpTransport::trySendQueue() {
 	// Requires mSendMutex to be locked
 	while (auto next = mSendQueue.peek()) {
@@ -693,28 +715,25 @@ void SctpTransport::sendReset(uint16_t streamId) {
 	}
 }
 
-void SctpTransport::handleUpcall() {
+void SctpTransport::handleUpcall() noexcept {
 	try {
 		PLOG_VERBOSE << "Handle upcall";
 
 		int events = usrsctp_get_events(mSock);
 
-		if (events & SCTP_EVENT_READ && mPendingRecvCount == 0) {
-			++mPendingRecvCount;
-			mProcessor.enqueue(&SctpTransport::doRecv, shared_from_this());
-		}
+		if (events & SCTP_EVENT_READ)
+			enqueueRecv();
 
-		if (events & SCTP_EVENT_WRITE && mPendingFlushCount == 0) {
-			++mPendingFlushCount;
-			mProcessor.enqueue(&SctpTransport::doFlush, shared_from_this());
-		}
+		if (events & SCTP_EVENT_WRITE)
+			enqueueFlush();
 
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "SCTP upcall: " << e.what();
 	}
 }
 
-int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/, uint8_t /*set_df*/) {
+int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/,
+                               uint8_t /*set_df*/) noexcept {
 	try {
 		std::unique_lock lock(mWriteMutex);
 		PLOG_VERBOSE << "Handle write, len=" << len;

+ 4 - 2
src/impl/sctptransport.hpp

@@ -92,14 +92,16 @@ private:
 
 	void doRecv();
 	void doFlush();
+	void enqueueRecv();
+	void enqueueFlush();
 	bool trySendQueue();
 	bool trySendMessage(message_ptr message);
 	void updateBufferedAmount(uint16_t streamId, ptrdiff_t delta);
 	void triggerBufferedAmount(uint16_t streamId, size_t amount);
 	void sendReset(uint16_t streamId);
 
-	void handleUpcall();
-	int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df);
+	void handleUpcall() noexcept;
+	int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) noexcept;
 
 	void processData(binary &&data, uint16_t streamId, PayloadId ppid);
 	void processNotification(const union sctp_notification *notify, size_t len);

+ 8 - 6
src/impl/threadpool.hpp

@@ -59,13 +59,15 @@ public:
 	bool runOne();
 
 	template <class F, class... Args>
-	auto enqueue(F &&f, Args &&...args) -> invoke_future_t<F, Args...>;
+	auto enqueue(F &&f, Args &&...args) noexcept -> invoke_future_t<F, Args...>;
 
 	template <class F, class... Args>
-	auto schedule(clock::duration delay, F &&f, Args &&...args) -> invoke_future_t<F, Args...>;
+	auto schedule(clock::duration delay, F &&f, Args &&...args) noexcept
+	    -> invoke_future_t<F, Args...>;
 
 	template <class F, class... Args>
-	auto schedule(clock::time_point time, F &&f, Args &&...args) -> invoke_future_t<F, Args...>;
+	auto schedule(clock::time_point time, F &&f, Args &&...args) noexcept
+	    -> invoke_future_t<F, Args...>;
 
 private:
 	ThreadPool();
@@ -90,18 +92,18 @@ private:
 };
 
 template <class F, class... Args>
-auto ThreadPool::enqueue(F &&f, Args &&...args) -> invoke_future_t<F, Args...> {
+auto ThreadPool::enqueue(F &&f, Args &&...args) noexcept -> invoke_future_t<F, Args...> {
 	return schedule(clock::now(), std::forward<F>(f), std::forward<Args>(args)...);
 }
 
 template <class F, class... Args>
-auto ThreadPool::schedule(clock::duration delay, F &&f, Args &&...args)
+auto ThreadPool::schedule(clock::duration delay, F &&f, Args &&...args) noexcept
     -> invoke_future_t<F, Args...> {
 	return schedule(clock::now() + delay, std::forward<F>(f), std::forward<Args>(args)...);
 }
 
 template <class F, class... Args>
-auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args)
+auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args) noexcept
     -> invoke_future_t<F, Args...> {
 	std::unique_lock lock(mMutex);
 	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;