Browse Source

Merge pull request #252 from paullouisageneau/update-usrsctp

Update usrsctp and enhance threading
Paul-Louis Ageneau 4 years ago
parent
commit
86c3f914fb

+ 3 - 5
CMakeLists.txt

@@ -115,17 +115,15 @@ set(CMAKE_POLICY_DEFAULT_CMP0048 NEW)
 add_subdirectory(deps/plog)
 add_subdirectory(deps/plog)
 
 
 option(sctp_build_programs 0)
 option(sctp_build_programs 0)
+option(sctp_build_shared_lib 0)
 add_subdirectory(deps/usrsctp EXCLUDE_FROM_ALL)
 add_subdirectory(deps/usrsctp EXCLUDE_FROM_ALL)
 if (MSYS OR MINGW)
 if (MSYS OR MINGW)
 	target_compile_definitions(usrsctp PUBLIC -DSCTP_STDINT_INCLUDE=<stdint.h>)
 	target_compile_definitions(usrsctp PUBLIC -DSCTP_STDINT_INCLUDE=<stdint.h>)
-	target_compile_definitions(usrsctp-static PUBLIC -DSCTP_STDINT_INCLUDE=<stdint.h>)
 endif()
 endif()
 if (CMAKE_CXX_COMPILER_ID MATCHES "GNU")
 if (CMAKE_CXX_COMPILER_ID MATCHES "GNU")
     target_compile_options(usrsctp PRIVATE -Wno-error=format-truncation)
     target_compile_options(usrsctp PRIVATE -Wno-error=format-truncation)
-	target_compile_options(usrsctp-static PRIVATE -Wno-error=format-truncation)
 endif()
 endif()
 add_library(Usrsctp::Usrsctp ALIAS usrsctp)
 add_library(Usrsctp::Usrsctp ALIAS usrsctp)
-add_library(Usrsctp::UsrsctpStatic ALIAS usrsctp-static)
 
 
 if (NO_WEBSOCKET)
 if (NO_WEBSOCKET)
 	add_library(datachannel SHARED
 	add_library(datachannel SHARED
@@ -156,13 +154,13 @@ target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/includ
 target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
 target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
 target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
 target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
 target_link_libraries(datachannel PUBLIC Threads::Threads plog::plog)
 target_link_libraries(datachannel PUBLIC Threads::Threads plog::plog)
-target_link_libraries(datachannel PRIVATE Usrsctp::UsrsctpStatic)
+target_link_libraries(datachannel PRIVATE Usrsctp::Usrsctp)
 
 
 target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
 target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
 target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
 target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
 target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
 target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
 target_link_libraries(datachannel-static PUBLIC Threads::Threads plog::plog)
 target_link_libraries(datachannel-static PUBLIC Threads::Threads plog::plog)
-target_link_libraries(datachannel-static PRIVATE Usrsctp::UsrsctpStatic)
+target_link_libraries(datachannel-static PRIVATE Usrsctp::Usrsctp)
 
 
 if(WIN32)
 if(WIN32)
 	target_link_libraries(datachannel PRIVATE ws2_32) # winsock2
 	target_link_libraries(datachannel PRIVATE ws2_32) # winsock2

+ 1 - 1
deps/usrsctp

@@ -1 +1 @@
-Subproject commit 0db969100094422d9ea74a08ae5e5d9a4cfdb06b
+Subproject commit 2e754d58227f76b8d8c7358ee5f5770b78cc239a

+ 10 - 6
include/rtc/include.hpp

@@ -102,12 +102,12 @@ private:
 	std::function<void()> function;
 	std::function<void()> function;
 };
 };
 
 
-template <typename... P> class synchronized_callback {
+template <typename... Args> class synchronized_callback {
 public:
 public:
 	synchronized_callback() = default;
 	synchronized_callback() = default;
 	synchronized_callback(synchronized_callback &&cb) { *this = std::move(cb); }
 	synchronized_callback(synchronized_callback &&cb) { *this = std::move(cb); }
 	synchronized_callback(const synchronized_callback &cb) { *this = cb; }
 	synchronized_callback(const synchronized_callback &cb) { *this = cb; }
-	synchronized_callback(std::function<void(P...)> func) { *this = std::move(func); }
+	synchronized_callback(std::function<void(Args...)> func) { *this = std::move(func); }
 	~synchronized_callback() { *this = nullptr; }
 	~synchronized_callback() { *this = nullptr; }
 
 
 	synchronized_callback &operator=(synchronized_callback &&cb) {
 	synchronized_callback &operator=(synchronized_callback &&cb) {
@@ -123,16 +123,16 @@ public:
 		return *this;
 		return *this;
 	}
 	}
 
 
-	synchronized_callback &operator=(std::function<void(P...)> func) {
+	synchronized_callback &operator=(std::function<void(Args...)> func) {
 		std::lock_guard lock(mutex);
 		std::lock_guard lock(mutex);
 		callback = std::move(func);
 		callback = std::move(func);
 		return *this;
 		return *this;
 	}
 	}
 
 
-	void operator()(P... args) const {
+	void operator()(Args... args) const {
 		std::lock_guard lock(mutex);
 		std::lock_guard lock(mutex);
 		if (callback)
 		if (callback)
-			callback(args...);
+			callback(std::move(args)...);
 	}
 	}
 
 
 	operator bool() const {
 	operator bool() const {
@@ -140,8 +140,12 @@ public:
 		return callback ? true : false;
 		return callback ? true : false;
 	}
 	}
 
 
+	std::function<void(Args...)> wrap() const {
+		return [this](Args... args) { (*this)(std::move(args)...); };
+	}
+
 private:
 private:
-	std::function<void(P...)> callback;
+	std::function<void(Args...)> callback;
 	mutable std::recursive_mutex mutex;
 	mutable std::recursive_mutex mutex;
 };
 };
 } // namespace rtc
 } // namespace rtc

+ 1 - 1
src/description.cpp

@@ -106,7 +106,7 @@ Description::Description(const string &sdp, Type type, Role role)
 					               mFingerprint->begin(),
 					               mFingerprint->begin(),
 					               [](char c) { return char(std::toupper(c)); });
 					               [](char c) { return char(std::toupper(c)); });
 				} else {
 				} else {
-					PLOG_WARNING << "Unknown SDP fingerprint type: " << value;
+					PLOG_WARNING << "Unknown SDP fingerprint format: " << value;
 				}
 				}
 			} else if (key == "ice-ufrag") {
 			} else if (key == "ice-ufrag") {
 				mIceUfrag = value;
 				mIceUfrag = value;

+ 16 - 21
src/peerconnection.cpp

@@ -77,7 +77,7 @@ void PeerConnection::close() {
 	mNegotiationNeeded = false;
 	mNegotiationNeeded = false;
 
 
 	// Close data channels asynchronously
 	// Close data channels asynchronously
-	mProcessor->enqueue(std::bind(&PeerConnection::closeDataChannels, this));
+	mProcessor->enqueue(&PeerConnection::closeDataChannels, this);
 
 
 	closeTransports();
 	closeTransports();
 }
 }
@@ -490,7 +490,7 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 				else
 				else
 					changeState(State::Connected);
 					changeState(State::Connected);
 
 
-				mProcessor->enqueue(std::bind(&PeerConnection::openTracks, this));
+				mProcessor->enqueue(&PeerConnection::openTracks, this);
 				break;
 				break;
 			case DtlsTransport::State::Failed:
 			case DtlsTransport::State::Failed:
 				changeState(State::Failed);
 				changeState(State::Failed);
@@ -561,16 +561,16 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 			    switch (state) {
 			    switch (state) {
 			    case SctpTransport::State::Connected:
 			    case SctpTransport::State::Connected:
 				    changeState(State::Connected);
 				    changeState(State::Connected);
-				    mProcessor->enqueue(std::bind(&PeerConnection::openDataChannels, this));
+				    mProcessor->enqueue(&PeerConnection::openDataChannels, this);
 				    break;
 				    break;
 			    case SctpTransport::State::Failed:
 			    case SctpTransport::State::Failed:
 				    LOG_WARNING << "SCTP transport failed";
 				    LOG_WARNING << "SCTP transport failed";
 				    changeState(State::Failed);
 				    changeState(State::Failed);
-				    mProcessor->enqueue(std::bind(&PeerConnection::remoteCloseDataChannels, this));
+				    mProcessor->enqueue(&PeerConnection::remoteCloseDataChannels, this);
 				    break;
 				    break;
 			    case SctpTransport::State::Disconnected:
 			    case SctpTransport::State::Disconnected:
 				    changeState(State::Disconnected);
 				    changeState(State::Disconnected);
-				    mProcessor->enqueue(std::bind(&PeerConnection::remoteCloseDataChannels, this));
+				    mProcessor->enqueue(&PeerConnection::remoteCloseDataChannels, this);
 				    break;
 				    break;
 			    default:
 			    default:
 				    // Ignore
 				    // Ignore
@@ -1069,19 +1069,17 @@ void PeerConnection::processLocalDescription(Description description) {
 			mCurrentLocalDescription.emplace(std::move(*mLocalDescription));
 			mCurrentLocalDescription.emplace(std::move(*mLocalDescription));
 		}
 		}
 
 
-		mLocalDescription.emplace(std::move(description));
+		mLocalDescription.emplace(description);
 		mLocalDescription->addCandidates(std::move(existingCandidates));
 		mLocalDescription->addCandidates(std::move(existingCandidates));
 	}
 	}
 
 
-	mProcessor->enqueue([this, description = *mLocalDescription]() {
-		PLOG_VERBOSE << "Issuing local description: " << description;
-		mLocalDescriptionCallback(std::move(description));
-	});
+	PLOG_VERBOSE << "Issuing local description: " << description;
+	mProcessor->enqueue(mLocalDescriptionCallback.wrap(), std::move(description));
 
 
 	// Reciprocated tracks might need to be open
 	// Reciprocated tracks might need to be open
 	if (auto dtlsTransport = std::atomic_load(&mDtlsTransport);
 	if (auto dtlsTransport = std::atomic_load(&mDtlsTransport);
 	    dtlsTransport && dtlsTransport->state() == Transport::State::Connected)
 	    dtlsTransport && dtlsTransport->state() == Transport::State::Connected)
-		mProcessor->enqueue(std::bind(&PeerConnection::openTracks, this));
+		mProcessor->enqueue(&PeerConnection::openTracks, this);
 }
 }
 
 
 void PeerConnection::processLocalCandidate(Candidate candidate) {
 void PeerConnection::processLocalCandidate(Candidate candidate) {
@@ -1092,10 +1090,8 @@ void PeerConnection::processLocalCandidate(Candidate candidate) {
 	candidate.resolve(Candidate::ResolveMode::Simple); // for proper SDP generation later
 	candidate.resolve(Candidate::ResolveMode::Simple); // for proper SDP generation later
 	mLocalDescription->addCandidate(candidate);
 	mLocalDescription->addCandidate(candidate);
 
 
-	mProcessor->enqueue([this, candidate = std::move(candidate)]() {
-		PLOG_VERBOSE << "Issuing local candidate: " << candidate;
-		mLocalCandidateCallback(std::move(candidate));
-	});
+	PLOG_VERBOSE << "Issuing local candidate: " << candidate;
+	mProcessor->enqueue(mLocalCandidateCallback.wrap(), std::move(candidate));
 }
 }
 
 
 void PeerConnection::processRemoteDescription(Description description) {
 void PeerConnection::processRemoteDescription(Description description) {
@@ -1150,12 +1146,11 @@ void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
 	if (!dataChannel)
 	if (!dataChannel)
 		return;
 		return;
 
 
-	mProcessor->enqueue(
-	    [this, dataChannel = std::move(dataChannel)]() { mDataChannelCallback(dataChannel); });
+	mProcessor->enqueue(mDataChannelCallback.wrap(), std::move(dataChannel));
 }
 }
 
 
 void PeerConnection::triggerTrack(std::shared_ptr<Track> track) {
 void PeerConnection::triggerTrack(std::shared_ptr<Track> track) {
-	mProcessor->enqueue([this, track = std::move(track)]() { mTrackCallback(track); });
+	mProcessor->enqueue(mTrackCallback.wrap(), std::move(track));
 }
 }
 
 
 bool PeerConnection::changeState(State state) {
 bool PeerConnection::changeState(State state) {
@@ -1177,7 +1172,7 @@ bool PeerConnection::changeState(State state) {
 		// This is the last state change, so we may steal the callback
 		// This is the last state change, so we may steal the callback
 		mProcessor->enqueue([cb = std::move(mStateChangeCallback)]() { cb(State::Closed); });
 		mProcessor->enqueue([cb = std::move(mStateChangeCallback)]() { cb(State::Closed); });
 	else
 	else
-		mProcessor->enqueue([this, state]() { mStateChangeCallback(state); });
+		mProcessor->enqueue(mStateChangeCallback.wrap(), state);
 
 
 	return true;
 	return true;
 }
 }
@@ -1189,7 +1184,7 @@ bool PeerConnection::changeGatheringState(GatheringState state) {
 	std::ostringstream s;
 	std::ostringstream s;
 	s << state;
 	s << state;
 	PLOG_INFO << "Changed gathering state to " << s.str();
 	PLOG_INFO << "Changed gathering state to " << s.str();
-	mProcessor->enqueue([this, state] { mGatheringStateChangeCallback(state); });
+	mProcessor->enqueue(mGatheringStateChangeCallback.wrap(), state);
 	return true;
 	return true;
 }
 }
 
 
@@ -1200,7 +1195,7 @@ bool PeerConnection::changeSignalingState(SignalingState state) {
 	std::ostringstream s;
 	std::ostringstream s;
 	s << state;
 	s << state;
 	PLOG_INFO << "Changed signaling state to " << s.str();
 	PLOG_INFO << "Changed signaling state to " << s.str();
-	mProcessor->enqueue([this, state] { mSignalingStateChangeCallback(state); });
+	mProcessor->enqueue(mSignalingStateChangeCallback.wrap(), state);
 	return true;
 	return true;
 }
 }
 
 

+ 5 - 5
src/processor.cpp

@@ -20,6 +20,8 @@
 
 
 namespace rtc {
 namespace rtc {
 
 
+Processor::Processor(size_t limit) : mTasks(limit) {}
+
 Processor::~Processor() { join(); }
 Processor::~Processor() { join(); }
 
 
 void Processor::join() {
 void Processor::join() {
@@ -29,15 +31,13 @@ void Processor::join() {
 
 
 void Processor::schedule() {
 void Processor::schedule() {
 	std::unique_lock lock(mMutex);
 	std::unique_lock lock(mMutex);
-	if (mTasks.empty()) {
+	if (auto next = mTasks.tryPop()) {
+		ThreadPool::Instance().enqueue(std::move(*next));
+	} else {
 		// No more tasks
 		// No more tasks
 		mPending = false;
 		mPending = false;
 		mCondition.notify_all();
 		mCondition.notify_all();
-		return;
 	}
 	}
-
-	ThreadPool::Instance().enqueue(std::move(mTasks.front()));
-	mTasks.pop();
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc

+ 4 - 3
src/processor.hpp

@@ -22,6 +22,7 @@
 #include "include.hpp"
 #include "include.hpp"
 #include "init.hpp"
 #include "init.hpp"
 #include "threadpool.hpp"
 #include "threadpool.hpp"
+#include "queue.hpp"
 
 
 #include <condition_variable>
 #include <condition_variable>
 #include <future>
 #include <future>
@@ -34,7 +35,7 @@ namespace rtc {
 // Processed tasks in order by delegating them to the thread pool
 // Processed tasks in order by delegating them to the thread pool
 class Processor final {
 class Processor final {
 public:
 public:
-	Processor() = default;
+	Processor(size_t limit = 0);
 	~Processor();
 	~Processor();
 
 
 	Processor(const Processor &) = delete;
 	Processor(const Processor &) = delete;
@@ -52,7 +53,7 @@ protected:
 	// Keep an init token
 	// Keep an init token
 	const init_token mInitToken = Init::Token();
 	const init_token mInitToken = Init::Token();
 
 
-	std::queue<std::function<void()>> mTasks;
+	Queue<std::function<void()>> mTasks;
 	bool mPending = false; // true iff a task is pending in the thread pool
 	bool mPending = false; // true iff a task is pending in the thread pool
 
 
 	mutable std::mutex mMutex;
 	mutable std::mutex mMutex;
@@ -71,7 +72,7 @@ template <class F, class... Args> void Processor::enqueue(F &&f, Args &&... args
 		ThreadPool::Instance().enqueue(std::move(task));
 		ThreadPool::Instance().enqueue(std::move(task));
 		mPending = true;
 		mPending = true;
 	} else {
 	} else {
-		mTasks.emplace(std::move(task));
+		mTasks.push(std::move(task));
 	}
 	}
 }
 }
 
 

+ 11 - 1
src/sctptransport.cpp

@@ -88,7 +88,7 @@ void SctpTransport::Cleanup() {
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
                              message_callback recvCallback, amount_callback bufferedAmountCallback,
                              message_callback recvCallback, amount_callback bufferedAmountCallback,
                              state_callback stateChangeCallback)
                              state_callback stateChangeCallback)
-    : Transport(lower, std::move(stateChangeCallback)), mPort(port),
+    : Transport(lower, std::move(stateChangeCallback)), mPort(port), mProcessor(16),
       mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
       mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
 	onRecv(recvCallback);
 	onRecv(recvCallback);
 
 
@@ -230,6 +230,16 @@ void SctpTransport::close() {
 	}
 	}
 }
 }
 
 
+void SctpTransport::recv(message_ptr message) {
+	// Delegate to processor to release SCTP thread
+	mProcessor.enqueue([this, message = std::move(message)]() { Transport::recv(message); });
+}
+
+void SctpTransport::changeState(State state) {
+	// Delegate to processor to release SCTP thread
+	mProcessor.enqueue([this, state]() { Transport::changeState(state); });
+}
+
 void SctpTransport::connect() {
 void SctpTransport::connect() {
 	if (!mSock)
 	if (!mSock)
 		throw std::logic_error("Attempted SCTP connect with closed socket");
 		throw std::logic_error("Attempted SCTP connect with closed socket");

+ 6 - 1
src/sctptransport.hpp

@@ -21,6 +21,7 @@
 
 
 #include "include.hpp"
 #include "include.hpp"
 #include "peerconnection.hpp"
 #include "peerconnection.hpp"
+#include "processor.hpp"
 #include "queue.hpp"
 #include "queue.hpp"
 #include "transport.hpp"
 #include "transport.hpp"
 
 
@@ -35,7 +36,7 @@
 
 
 namespace rtc {
 namespace rtc {
 
 
-class SctpTransport : public Transport {
+class SctpTransport final : public Transport {
 public:
 public:
 	static void Init();
 	static void Init();
 	static void Cleanup();
 	static void Cleanup();
@@ -71,6 +72,9 @@ private:
 		PPID_BINARY_EMPTY = 57
 		PPID_BINARY_EMPTY = 57
 	};
 	};
 
 
+	void recv(message_ptr message) override;
+	void changeState(State state) override;
+
 	void connect();
 	void connect();
 	void shutdown();
 	void shutdown();
 	void close();
 	void close();
@@ -93,6 +97,7 @@ private:
 	const uint16_t mPort;
 	const uint16_t mPort;
 	struct socket *mSock;
 	struct socket *mSock;
 
 
+	Processor mProcessor;
 	std::mutex mSendMutex;
 	std::mutex mSendMutex;
 	Queue<message_ptr> mSendQueue;
 	Queue<message_ptr> mSendQueue;
 	std::map<uint16_t, size_t> mBufferedAmount;
 	std::map<uint16_t, size_t> mBufferedAmount;

+ 2 - 2
src/transport.hpp

@@ -67,14 +67,14 @@ public:
 	virtual bool send(message_ptr message) { return outgoing(message); }
 	virtual bool send(message_ptr message) { return outgoing(message); }
 
 
 protected:
 protected:
-	void recv(message_ptr message) {
+	virtual void recv(message_ptr message) {
 		try {
 		try {
 			mRecvCallback(message);
 			mRecvCallback(message);
 		} catch (const std::exception &e) {
 		} catch (const std::exception &e) {
 			PLOG_WARNING << e.what();
 			PLOG_WARNING << e.what();
 		}
 		}
 	}
 	}
-	void changeState(State state) {
+	virtual void changeState(State state) {
 		try {
 		try {
 			if (mState.exchange(state) != state)
 			if (mState.exchange(state) != state)
 				mStateChangeCallback(state);
 				mStateChangeCallback(state);