Parcourir la source

Added synchronization to DataChannel

Paul-Louis Ageneau il y a 4 ans
Parent
commit
90e59435c0
2 fichiers modifiés avec 47 ajouts et 20 suppressions
  1. 3 0
      include/rtc/datachannel.hpp
  2. 44 20
      src/datachannel.cpp

+ 3 - 0
include/rtc/datachannel.hpp

@@ -30,6 +30,7 @@
 #include <functional>
 #include <type_traits>
 #include <variant>
+#include <shared_mutex>
 
 namespace rtc {
 
@@ -79,6 +80,8 @@ protected:
 	string mProtocol;
 	std::shared_ptr<Reliability> mReliability;
 
+	mutable std::shared_mutex mMutex;
+
 	std::atomic<bool> mIsOpen = false;
 	std::atomic<bool> mIsClosed = false;
 

+ 44 - 20
src/datachannel.cpp

@@ -87,21 +87,34 @@ DataChannel::~DataChannel() { close(); }
 
 uint16_t DataChannel::stream() const { return mStream; }
 
-uint16_t DataChannel::id() const { return uint16_t(mStream); }
+uint16_t DataChannel::id() const { return mStream; }
 
-string DataChannel::label() const { return mLabel; }
+string DataChannel::label() const {
+	std::shared_lock lock(mMutex);
+	return mLabel;
+}
 
-string DataChannel::protocol() const { return mProtocol; }
+string DataChannel::protocol() const {
+	std::shared_lock lock(mMutex);
+	return mProtocol;
+}
 
-Reliability DataChannel::reliability() const { return *mReliability; }
+Reliability DataChannel::reliability() const {
+	std::shared_lock lock(mMutex);
+	return *mReliability;
+}
 
 void DataChannel::close() {
+	std::shared_ptr<SctpTransport> transport;
+	{
+		std::shared_lock lock(mMutex);
+		transport = mSctpTransport.lock();
+	}
+
 	mIsClosed = true;
-	if (mIsOpen.exchange(false))
-		if (auto transport = mSctpTransport.lock())
-			transport->closeStream(mStream);
+	if (mIsOpen.exchange(false) && transport)
+		transport->closeStream(mStream);
 
-	mSctpTransport.reset();
 	resetCallbacks();
 }
 
@@ -110,7 +123,6 @@ void DataChannel::remoteClose() {
 		triggerClosed();
 
 	mIsOpen = false;
-	mSctpTransport.reset();
 }
 
 bool DataChannel::send(message_variant data) { return outgoing(make_message(std::move(data))); }
@@ -167,7 +179,10 @@ size_t DataChannel::maxMessageSize() const {
 size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 
 void DataChannel::open(shared_ptr<SctpTransport> transport) {
-	mSctpTransport = transport;
+	{
+		std::unique_lock lock(mMutex);
+		mSctpTransport = transport;
+	}
 
 	if (!mIsOpen.exchange(true))
 		triggerOpen();
@@ -179,19 +194,22 @@ void DataChannel::processOpenMessage(message_ptr) {
 }
 
 bool DataChannel::outgoing(message_ptr message) {
-	if (mIsClosed)
-		throw std::runtime_error("DataChannel is closed");
+	std::shared_ptr<SctpTransport> transport;
+	{
+		std::shared_lock lock(mMutex);
+		transport = mSctpTransport.lock();
 
-	if (message->size() > maxMessageSize())
-		throw std::runtime_error("Message size exceeds limit");
+		if (!transport || mIsClosed)
+			throw std::runtime_error("DataChannel is closed");
 
-	auto transport = mSctpTransport.lock();
-	if (!transport)
-		throw std::runtime_error("DataChannel transport is not open");
+		if (message->size() > maxMessageSize())
+			throw std::runtime_error("Message size exceeds limit");
+
+		// Before the ACK has been received on a DataChannel, all messages must be sent ordered
+		message->reliability = mIsOpen ? mReliability : nullptr;
+		message->stream = mStream;
+	}
 
-	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
-	message->reliability = mIsOpen ? mReliability : nullptr;
-	message->stream = mStream;
 	return transport->send(message);
 }
 
@@ -249,6 +267,7 @@ NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr<PeerConnection> pc,
 NegotiatedDataChannel::~NegotiatedDataChannel() {}
 
 void NegotiatedDataChannel::open(shared_ptr<SctpTransport> transport) {
+	std::unique_lock lock(mMutex);
 	mSctpTransport = transport;
 
 	uint8_t channelType;
@@ -287,10 +306,13 @@ void NegotiatedDataChannel::open(shared_ptr<SctpTransport> transport) {
 	std::copy(mLabel.begin(), mLabel.end(), end);
 	std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
 
+	lock.unlock();
+
 	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 }
 
 void NegotiatedDataChannel::processOpenMessage(message_ptr message) {
+	std::unique_lock lock(mMutex);
 	auto transport = mSctpTransport.lock();
 	if (!transport)
 		throw std::runtime_error("DataChannel has no transport");
@@ -326,6 +348,8 @@ void NegotiatedDataChannel::processOpenMessage(message_ptr message) {
 		mReliability->rexmit = int(0);
 	}
 
+	lock.unlock();
+
 	binary buffer(sizeof(AckMessage), byte(0));
 	auto &ack = *reinterpret_cast<AckMessage *>(buffer.data());
 	ack.type = MESSAGE_ACK;