瀏覽代碼

Merge branch 'v0.11'

Paul-Louis Ageneau 4 年之前
父節點
當前提交
2fd3f6ad45
共有 5 個文件被更改,包括 61 次插入43 次删除
  1. 31 23
      src/impl/sctptransport.cpp
  2. 4 3
      src/impl/sctptransport.hpp
  3. 13 12
      src/impl/threadpool.cpp
  4. 1 1
      src/impl/threadpool.hpp
  5. 12 4
      test/benchmark.cpp

+ 31 - 23
src/impl/sctptransport.cpp

@@ -85,8 +85,8 @@ void SctpTransport::Init() {
 	// Change congestion control from the default TCP Reno (RFC 2581) to H-TCP
 	usrsctp_sysctl_set_sctp_default_cc_module(SCTP_CC_HTCP);
 
-	// Enable Non-Renegable Selective Acknowledgments (NR-SACKs)
-	usrsctp_sysctl_set_sctp_nrsack_enable(1);
+	// Enable Partial Reliability Extension (RFC 3758)
+	usrsctp_sysctl_set_sctp_pr_enable(1);
 
 	// Increase the initial window size to 10 MTUs (RFC 6928)
 	usrsctp_sysctl_set_sctp_initial_cwnd(10);
@@ -104,7 +104,7 @@ SctpTransport::SctpTransport(shared_ptr<Transport> lower, uint16_t port,
                              optional<size_t> mtu, message_callback recvCallback,
                              amount_callback bufferedAmountCallback,
                              state_callback stateChangeCallback)
-    : Transport(lower, std::move(stateChangeCallback)), mPort(port), mPendingRecvCount(0),
+    : Transport(lower, std::move(stateChangeCallback)), mPort(port),
       mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
 	onRecv(recvCallback);
 
@@ -260,7 +260,7 @@ bool SctpTransport::stop() {
 		return false;
 
 	mSendQueue.stop();
-	safeFlush();
+	flush();
 	shutdown();
 	onRecv(nullptr);
 	return true;
@@ -334,13 +334,20 @@ bool SctpTransport::send(message_ptr message) {
 	return false;
 }
 
-void SctpTransport::closeStream(unsigned int stream) {
-	send(make_message(0, Message::Reset, uint16_t(stream)));
+bool SctpTransport::flush() {
+	try {
+		std::lock_guard lock(mSendMutex);
+		trySendQueue();
+		return true;
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << "SCTP flush: " << e.what();
+		return false;
+	}
 }
 
-void SctpTransport::flush() {
-	std::lock_guard lock(mSendMutex);
-	trySendQueue();
+void SctpTransport::closeStream(unsigned int stream) {
+	send(make_message(0, Message::Reset, uint16_t(stream)));
 }
 
 void SctpTransport::incoming(message_ptr message) {
@@ -428,6 +435,16 @@ void SctpTransport::doRecv() {
 	}
 }
 
+void SctpTransport::doFlush() {
+	std::lock_guard lock(mSendMutex);
+	--mPendingFlushCount;
+	try {
+		trySendQueue();
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+	}
+}
+
 bool SctpTransport::trySendQueue() {
 	// Requires mSendMutex to be locked
 	while (auto next = mSendQueue.peek()) {
@@ -573,17 +590,6 @@ void SctpTransport::sendReset(uint16_t streamId) {
 	}
 }
 
-bool SctpTransport::safeFlush() {
-	try {
-		flush();
-		return true;
-
-	} catch (const std::exception &e) {
-		PLOG_WARNING << "SCTP flush: " << e.what();
-		return false;
-	}
-}
-
 void SctpTransport::handleUpcall() {
 	if (!mSock)
 		return;
@@ -597,8 +603,10 @@ void SctpTransport::handleUpcall() {
 		mProcessor.enqueue(&SctpTransport::doRecv, this);
 	}
 
-	if (events & SCTP_EVENT_WRITE)
-		mProcessor.enqueue(&SctpTransport::safeFlush, this);
+	if (events & SCTP_EVENT_WRITE && mPendingFlushCount == 0) {
+		++mPendingFlushCount;
+		mProcessor.enqueue(&SctpTransport::doFlush, this);
+	}
 }
 
 int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/, uint8_t /*set_df*/) {
@@ -713,7 +721,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 		PLOG_VERBOSE << "SCTP dry event";
 		// It should not be necessary since the send callback should have been called already,
 		// but to be sure, let's try to send now.
-		safeFlush();
+		flush();
 		break;
 	}
 

+ 4 - 3
src/impl/sctptransport.hpp

@@ -50,8 +50,8 @@ public:
 	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override; // false if buffered
+	bool flush();
 	void closeStream(unsigned int stream);
-	void flush();
 
 	// Stats
 	void clearStats();
@@ -79,12 +79,12 @@ private:
 	bool outgoing(message_ptr message) override;
 
 	void doRecv();
+	void doFlush();
 	bool trySendQueue();
 	bool trySendMessage(message_ptr message);
 	void updateBufferedAmount(uint16_t streamId, long delta);
 	void triggerBufferedAmount(uint16_t streamId, size_t amount);
 	void sendReset(uint16_t streamId);
-	bool safeFlush();
 
 	void handleUpcall();
 	int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df);
@@ -96,7 +96,8 @@ private:
 	struct socket *mSock;
 
 	Processor mProcessor;
-	std::atomic<int> mPendingRecvCount;
+	std::atomic<int> mPendingRecvCount = 0;
+	std::atomic<int> mPendingFlushCount = 0;
 	std::mutex mRecvMutex;
 	std::recursive_mutex mSendMutex; // buffered amount callback is synchronous
 	Queue<message_ptr> mSendQueue;

+ 13 - 12
src/impl/threadpool.cpp

@@ -51,7 +51,7 @@ void ThreadPool::spawn(int count) {
 void ThreadPool::join() {
 	{
 		std::unique_lock lock(mMutex);
-		mWaitingCondition.wait(lock, [&]() { return mWaitingWorkers == int(mWorkers.size()); });
+		mWaitingCondition.wait(lock, [&]() { return mBusyWorkers == 0; });
 		mJoining = true;
 		mTasksCondition.notify_all();
 	}
@@ -66,6 +66,8 @@ void ThreadPool::join() {
 }
 
 void ThreadPool::run() {
+	++mBusyWorkers;
+	scope_guard([&]() { --mBusyWorkers; });
 	while (runOne()) {
 	}
 }
@@ -81,24 +83,23 @@ bool ThreadPool::runOne() {
 std::function<void()> ThreadPool::dequeue() {
 	std::unique_lock lock(mMutex);
 	while (!mJoining) {
+		std::optional<clock::time_point> time;
 		if (!mTasks.empty()) {
-			if (mTasks.top().time <= clock::now()) {
+			time = mTasks.top().time;
+			if (*time <= clock::now()) {
 				auto func = std::move(mTasks.top().func);
 				mTasks.pop();
 				return func;
 			}
-
-			++mWaitingWorkers;
-			mWaitingCondition.notify_all();
-			mTasksCondition.wait_until(lock, mTasks.top().time);
-
-		} else {
-			++mWaitingWorkers;
-			mWaitingCondition.notify_all();
-			mTasksCondition.wait(lock);
 		}
 
-		--mWaitingWorkers;
+		--mBusyWorkers;
+		scope_guard([&]() { ++mBusyWorkers; });
+		mWaitingCondition.notify_all();
+		if(time)
+			mTasksCondition.wait_until(lock, *time);
+		else
+			mTasksCondition.wait(lock);
 	}
 	return nullptr;
 }

+ 1 - 1
src/impl/threadpool.hpp

@@ -72,7 +72,7 @@ protected:
 	std::function<void()> dequeue(); // returns null function if joining
 
 	std::vector<std::thread> mWorkers;
-	int mWaitingWorkers = 0;
+	int mBusyWorkers = 0;
 	std::atomic<bool> mJoining = false;
 
 	struct Task {

+ 12 - 4
test/benchmark.cpp

@@ -115,8 +115,12 @@ size_t benchmark(milliseconds duration) {
 		openTime = steady_clock::now();
 
 		cout << "DataChannel open, sending data..." << endl;
-		while (dc1->bufferedAmount() == 0) {
-			dc1->send(messageData);
+		try {
+			while (dc1->bufferedAmount() == 0) {
+				dc1->send(messageData);
+			}
+		} catch (const std::exception &e) {
+			std::cout << "Send failed: " << e.what() << std::endl;
 		}
 
 		// When sent data is buffered in the DataChannel,
@@ -129,8 +133,12 @@ size_t benchmark(milliseconds duration) {
 			return;
 
 		// Continue sending
-		while (dc1->bufferedAmount() == 0) {
-			dc1->send(messageData);
+		try {
+			while (dc1->isOpen() && dc1->bufferedAmount() == 0) {
+				dc1->send(messageData);
+			}
+		} catch (const std::exception &e) {
+			std::cout << "Send failed: " << e.what() << std::endl;
 		}
 	});