Преглед изворни кода

Merge pull request #344 from paullouisageneau/fix-datachannel-data-race

Fix possible data race in DataChannel
Paul-Louis Ageneau пре 4 година
родитељ
комит
d748016446
7 измењених фајлова са 116 додато и 68 уклоњено
  1. 7 4
      include/rtc/datachannel.hpp
  2. 5 3
      include/rtc/track.hpp
  3. 49 25
      src/datachannel.cpp
  4. 2 2
      src/peerconnection.cpp
  5. 49 30
      src/track.cpp
  6. 2 2
      test/connectivity.cpp
  7. 2 2
      test/turn_connectivity.cpp

+ 7 - 4
include/rtc/datachannel.hpp

@@ -30,6 +30,7 @@
 #include <functional>
 #include <type_traits>
 #include <variant>
+#include <shared_mutex>
 
 namespace rtc {
 
@@ -79,6 +80,8 @@ protected:
 	string mProtocol;
 	std::shared_ptr<Reliability> mReliability;
 
+	mutable std::shared_mutex mMutex;
+
 	std::atomic<bool> mIsOpen = false;
 	std::atomic<bool> mIsClosed = false;
 
@@ -88,13 +91,13 @@ private:
 	friend class PeerConnection;
 };
 
-class RTC_CPP_EXPORT NegociatedDataChannel final : public DataChannel {
+class RTC_CPP_EXPORT NegotiatedDataChannel final : public DataChannel {
 public:
-	NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, uint16_t stream, string label,
+	NegotiatedDataChannel(std::weak_ptr<PeerConnection> pc, uint16_t stream, string label,
 	                      string protocol, Reliability reliability);
-	NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, std::weak_ptr<SctpTransport> transport,
+	NegotiatedDataChannel(std::weak_ptr<PeerConnection> pc, std::weak_ptr<SctpTransport> transport,
 	                      uint16_t stream);
-	~NegociatedDataChannel();
+	~NegotiatedDataChannel();
 
 private:
 	void open(std::shared_ptr<SctpTransport> transport) override;

+ 5 - 3
include/rtc/track.hpp

@@ -43,6 +43,7 @@ public:
 
 	string mid() const;
 	Description::Media description() const;
+	Description::Direction direction() const;
 
 	void setDescription(Description::Media description);
 
@@ -75,13 +76,14 @@ private:
 	bool outgoing(message_ptr message);
 
 	Description::Media mMediaDescription;
+	std::shared_ptr<MediaHandler> mRtcpHandler;
+
+	mutable std::shared_mutex mMutex;
+
 	std::atomic<bool> mIsClosed = false;
 
 	Queue<message_ptr> mRecvQueue;
 
-	std::shared_mutex mRtcpHandlerMutex;
-	std::shared_ptr<MediaHandler> mRtcpHandler;
-
 	friend class PeerConnection;
 };
 

+ 49 - 25
src/datachannel.cpp

@@ -87,21 +87,34 @@ DataChannel::~DataChannel() { close(); }
 
 uint16_t DataChannel::stream() const { return mStream; }
 
-uint16_t DataChannel::id() const { return uint16_t(mStream); }
+uint16_t DataChannel::id() const { return mStream; }
 
-string DataChannel::label() const { return mLabel; }
+string DataChannel::label() const {
+	std::shared_lock lock(mMutex);
+	return mLabel;
+}
 
-string DataChannel::protocol() const { return mProtocol; }
+string DataChannel::protocol() const {
+	std::shared_lock lock(mMutex);
+	return mProtocol;
+}
 
-Reliability DataChannel::reliability() const { return *mReliability; }
+Reliability DataChannel::reliability() const {
+	std::shared_lock lock(mMutex);
+	return *mReliability;
+}
 
 void DataChannel::close() {
+	std::shared_ptr<SctpTransport> transport;
+	{
+		std::shared_lock lock(mMutex);
+		transport = mSctpTransport.lock();
+	}
+
 	mIsClosed = true;
-	if (mIsOpen.exchange(false))
-		if (auto transport = mSctpTransport.lock())
-			transport->closeStream(mStream);
+	if (mIsOpen.exchange(false) && transport)
+		transport->closeStream(mStream);
 
-	mSctpTransport.reset();
 	resetCallbacks();
 }
 
@@ -110,7 +123,6 @@ void DataChannel::remoteClose() {
 		triggerClosed();
 
 	mIsOpen = false;
-	mSctpTransport.reset();
 }
 
 bool DataChannel::send(message_variant data) { return outgoing(make_message(std::move(data))); }
@@ -167,7 +179,10 @@ size_t DataChannel::maxMessageSize() const {
 size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 
 void DataChannel::open(shared_ptr<SctpTransport> transport) {
-	mSctpTransport = transport;
+	{
+		std::unique_lock lock(mMutex);
+		mSctpTransport = transport;
+	}
 
 	if (!mIsOpen.exchange(true))
 		triggerOpen();
@@ -179,19 +194,22 @@ void DataChannel::processOpenMessage(message_ptr) {
 }
 
 bool DataChannel::outgoing(message_ptr message) {
-	if (mIsClosed)
-		throw std::runtime_error("DataChannel is closed");
+	std::shared_ptr<SctpTransport> transport;
+	{
+		std::shared_lock lock(mMutex);
+		transport = mSctpTransport.lock();
 
-	if (message->size() > maxMessageSize())
-		throw std::runtime_error("Message size exceeds limit");
+		if (!transport || mIsClosed)
+			throw std::runtime_error("DataChannel is closed");
 
-	auto transport = mSctpTransport.lock();
-	if (!transport)
-		throw std::runtime_error("DataChannel transport is not open");
+		if (message->size() > maxMessageSize())
+			throw std::runtime_error("Message size exceeds limit");
+
+		// Before the ACK has been received on a DataChannel, all messages must be sent ordered
+		message->reliability = mIsOpen ? mReliability : nullptr;
+		message->stream = mStream;
+	}
 
-	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
-	message->reliability = mIsOpen ? mReliability : nullptr;
-	message->stream = mStream;
 	return transport->send(message);
 }
 
@@ -235,20 +253,21 @@ void DataChannel::incoming(message_ptr message) {
 	}
 }
 
-NegociatedDataChannel::NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, uint16_t stream,
+NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr<PeerConnection> pc, uint16_t stream,
                                              string label, string protocol, Reliability reliability)
     : DataChannel(pc, stream, std::move(label), std::move(protocol), std::move(reliability)) {}
 
-NegociatedDataChannel::NegociatedDataChannel(std::weak_ptr<PeerConnection> pc,
+NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr<PeerConnection> pc,
                                              std::weak_ptr<SctpTransport> transport,
                                              uint16_t stream)
     : DataChannel(pc, stream, "", "", {}) {
 	mSctpTransport = transport;
 }
 
-NegociatedDataChannel::~NegociatedDataChannel() {}
+NegotiatedDataChannel::~NegotiatedDataChannel() {}
 
-void NegociatedDataChannel::open(shared_ptr<SctpTransport> transport) {
+void NegotiatedDataChannel::open(shared_ptr<SctpTransport> transport) {
+	std::unique_lock lock(mMutex);
 	mSctpTransport = transport;
 
 	uint8_t channelType;
@@ -287,10 +306,13 @@ void NegociatedDataChannel::open(shared_ptr<SctpTransport> transport) {
 	std::copy(mLabel.begin(), mLabel.end(), end);
 	std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
 
+	lock.unlock();
+
 	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 }
 
-void NegociatedDataChannel::processOpenMessage(message_ptr message) {
+void NegotiatedDataChannel::processOpenMessage(message_ptr message) {
+	std::unique_lock lock(mMutex);
 	auto transport = mSctpTransport.lock();
 	if (!transport)
 		throw std::runtime_error("DataChannel has no transport");
@@ -326,6 +348,8 @@ void NegociatedDataChannel::processOpenMessage(message_ptr message) {
 		mReliability->rexmit = int(0);
 	}
 
+	lock.unlock();
+
 	binary buffer(sizeof(AckMessage), byte(0));
 	auto &ack = *reinterpret_cast<AckMessage *>(buffer.data());
 	ack.type = MESSAGE_ACK;

+ 2 - 2
src/peerconnection.cpp

@@ -663,7 +663,7 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
 		    stream % 2 == remoteParity) {
 
-			channel = std::make_shared<NegociatedDataChannel>(shared_from_this(), sctpTransport,
+			channel = std::make_shared<NegotiatedDataChannel>(shared_from_this(), sctpTransport,
 			                                                  stream);
 			channel->onOpen(weak_bind(&PeerConnection::triggerDataChannel, this,
 			                          weak_ptr<DataChannel>{channel}));
@@ -835,7 +835,7 @@ shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role rol
 	    init.negotiated
 	        ? std::make_shared<DataChannel>(shared_from_this(), stream, std::move(label),
 	                                        std::move(init.protocol), std::move(init.reliability))
-	        : std::make_shared<NegociatedDataChannel>(shared_from_this(), stream, std::move(label),
+	        : std::make_shared<NegotiatedDataChannel>(shared_from_this(), stream, std::move(label),
 	                                                  std::move(init.protocol),
 	                                                  std::move(init.reliability));
 	mDataChannels.emplace(std::make_pair(stream, channel));

+ 49 - 30
src/track.cpp

@@ -35,11 +35,23 @@ using std::weak_ptr;
 Track::Track(Description::Media description)
     : mMediaDescription(std::move(description)), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
 
-string Track::mid() const { return mMediaDescription.mid(); }
+string Track::mid() const {
+	std::shared_lock lock(mMutex);
+	return mMediaDescription.mid();
+}
+
+Description::Media Track::description() const {
+	std::shared_lock lock(mMutex);
+	return mMediaDescription;
+}
 
-Description::Media Track::description() const { return mMediaDescription; }
+Description::Direction Track::direction() const {
+	std::shared_lock lock(mMutex);
+	return mMediaDescription.direction();
+}
 
 void Track::setDescription(Description::Media description) {
+	std::unique_lock lock(mMutex);
 	if (description.mid() != mMediaDescription.mid())
 		throw std::logic_error("Media description mid does not match track mid");
 
@@ -48,17 +60,17 @@ void Track::setDescription(Description::Media description) {
 
 void Track::close() {
 	mIsClosed = true;
-	resetCallbacks();
+
 	setRtcpHandler(nullptr);
+	resetCallbacks();
 }
 
 bool Track::send(message_variant data) {
 	if (mIsClosed)
 		throw std::runtime_error("Track is closed");
 
-	auto direction = mMediaDescription.direction();
-	if ((direction == Description::Direction::RecvOnly ||
-	     direction == Description::Direction::Inactive)) {
+	auto dir = direction();
+	if ((dir == Description::Direction::RecvOnly || dir == Description::Direction::Inactive)) {
 		COUNTER_MEDIA_BAD_DIRECTION++;
 		return false;
 	}
@@ -92,6 +104,7 @@ std::optional<message_variant> Track::peek() {
 
 bool Track::isOpen(void) const {
 #if RTC_ENABLE_MEDIA
+	std::shared_lock lock(mMutex);
 	return !mIsClosed && mDtlsSrtpTransport.lock();
 #else
 	return !mIsClosed;
@@ -108,7 +121,11 @@ size_t Track::availableAmount() const { return mRecvQueue.amount(); }
 
 #if RTC_ENABLE_MEDIA
 void Track::open(shared_ptr<DtlsSrtpTransport> transport) {
-	mDtlsSrtpTransport = transport;
+	{
+		std::lock_guard lock(mMutex);
+		mDtlsSrtpTransport = transport;
+	}
+
 	triggerOpen();
 }
 #endif
@@ -117,9 +134,8 @@ void Track::incoming(message_ptr message) {
 	if (!message)
 		return;
 
-	auto direction = mMediaDescription.direction();
-	if ((direction == Description::Direction::SendOnly ||
-	     direction == Description::Direction::Inactive) &&
+	auto dir = direction();
+	if ((dir == Description::Direction::SendOnly || dir == Description::Direction::Inactive) &&
 	    message->type != Message::Control) {
 		COUNTER_MEDIA_BAD_DIRECTION++;
 		return;
@@ -142,17 +158,21 @@ void Track::incoming(message_ptr message) {
 }
 
 bool Track::outgoing([[maybe_unused]] message_ptr message) {
-#if RTC_ENABLE_MEDIA
-	auto transport = mDtlsSrtpTransport.lock();
-	if (!transport)
-		throw std::runtime_error("Track transport is not open");
-
-	// Set recommended medium-priority DSCP value
-	// See https://tools.ietf.org/html/draft-ietf-tsvwg-rtcweb-qos-18
-	if (mMediaDescription.type() == "audio")
-		message->dscp = 46; // EF: Expedited Forwarding
-	else
-		message->dscp = 36; // AF42: Assured Forwarding class 4, medium drop probability
+#if RTC_ENABLfiE_MEDIA
+	std::shared_ptr<DtlsSrtpTransport> transport;
+	{
+		std::shared_lock lock(mMutex);
+		transport = mDtlsSrtpTransport.lock();
+		if (!transport)
+			throw std::runtime_error("Track is closed");
+
+		// Set recommended medium-priority DSCP value
+		// See https://tools.ietf.org/html/draft-ietf-tsvwg-rtcweb-qos-18
+		if (mMediaDescription.type() == "audio")
+			message->dscp = 46; // EF: Expedited Forwarding
+		else
+			message->dscp = 36; // AF42: Assured Forwarding class 4, medium drop probability
+	}
 
 	return transport->sendMedia(message);
 #else
@@ -162,24 +182,23 @@ bool Track::outgoing([[maybe_unused]] message_ptr message) {
 }
 
 void Track::setRtcpHandler(std::shared_ptr<MediaHandler> handler) {
-	std::unique_lock lock(mRtcpHandlerMutex);
-	mRtcpHandler = std::move(handler);
-	if (mRtcpHandler) {
-		auto copy = mRtcpHandler;
-		lock.unlock();
-		copy->onOutgoing(std::bind(&Track::outgoing, this, std::placeholders::_1));
+	{
+		std::unique_lock lock(mMutex);
+		mRtcpHandler = handler;
 	}
+
+	handler->onOutgoing(std::bind(&Track::outgoing, this, std::placeholders::_1));
 }
 
 bool Track::requestKeyframe() {
-	if (auto handler = getRtcpHandler()) {
+	if (auto handler = getRtcpHandler())
 		return handler->requestKeyframe();
-	}
+
 	return false;
 }
 
 std::shared_ptr<MediaHandler> Track::getRtcpHandler() {
-	std::shared_lock lock(mRtcpHandlerMutex);
+	std::shared_lock lock(mMutex);
 	return mRtcpHandler;
 }
 

+ 2 - 2
test/connectivity.cpp

@@ -221,7 +221,7 @@ void test_connectivity() {
 	auto negotiated2 = pc2->createDataChannel("negoctated", init);
 
 	if (!negotiated1->isOpen() || !negotiated2->isOpen())
-		throw runtime_error("Negociated DataChannel is not open");
+		throw runtime_error("Negotiated DataChannel is not open");
 
 	std::atomic<bool> received = false;
 	negotiated2->onMessage([&received](const variant<binary, string> &message) {
@@ -239,7 +239,7 @@ void test_connectivity() {
 		this_thread::sleep_for(1s);
 
 	if (!received)
-		throw runtime_error("Negociated DataChannel failed");
+		throw runtime_error("Negotiated DataChannel failed");
 
 	// Delay close of peer 2 to check closing works properly
 	pc1->close();

+ 2 - 2
test/turn_connectivity.cpp

@@ -232,7 +232,7 @@ void test_turn_connectivity() {
 	auto negotiated2 = pc2->createDataChannel("negoctated", init);
 
 	if (!negotiated1->isOpen() || !negotiated2->isOpen())
-		throw runtime_error("Negociated DataChannel is not open");
+		throw runtime_error("Negotiated DataChannel is not open");
 
 	std::atomic<bool> received = false;
 	negotiated2->onMessage([&received](const variant<binary, string> &message) {
@@ -250,7 +250,7 @@ void test_turn_connectivity() {
 		this_thread::sleep_for(1s);
 
 	if (!received)
-		throw runtime_error("Negociated DataChannel failed");
+		throw runtime_error("Negotiated DataChannel failed");
 
 	// Delay close of peer 2 to check closing works properly
 	pc1->close();