Răsfoiți Sursa

Changed buffer amount low behavior to prevent deadlock situations

Paul-Louis Ageneau 5 ani în urmă
părinte
comite
1ab81731e3

+ 4 - 4
include/rtc/channel.hpp

@@ -21,8 +21,8 @@
 
 #include "include.hpp"
 
+#include <atomic>
 #include <functional>
-#include <mutex>
 #include <variant>
 
 namespace rtc {
@@ -30,7 +30,7 @@ namespace rtc {
 class Channel {
 public:
 	virtual void close() = 0;
-	virtual void send(const std::variant<binary, string> &data) = 0;
+	virtual bool send(const std::variant<binary, string> &data) = 0;
 	virtual std::optional<std::variant<binary, string>> receive() = 0;
 	virtual bool isOpen() const = 0;
 	virtual bool isClosed() const = 0;
@@ -66,8 +66,8 @@ private:
 	synchronized_callback<> mAvailableCallback;
 	synchronized_callback<> mBufferedAmountLowCallback;
 
-	size_t mBufferedAmount = 0;
-	size_t mBufferedAmountLowThreshold = 0;
+	std::atomic<size_t> mBufferedAmount = 0;
+	std::atomic<size_t> mBufferedAmountLowThreshold = 0;
 };
 
 } // namespace rtc

+ 9 - 9
include/rtc/datachannel.hpp

@@ -45,13 +45,13 @@ public:
 	~DataChannel();
 
 	void close(void);
-	void send(const std::variant<binary, string> &data);
-	void send(const byte *data, size_t size);
+	bool send(const std::variant<binary, string> &data);
+	bool send(const byte *data, size_t size);
 	std::optional<std::variant<binary, string>> receive();
 
 	// Directly send a buffer to avoid a copy
-	template <typename Buffer> void sendBuffer(const Buffer &buf);
-	template <typename Iterator> void sendBuffer(Iterator first, Iterator last);
+	template <typename Buffer> bool sendBuffer(const Buffer &buf);
+	template <typename Iterator> bool sendBuffer(Iterator first, Iterator last);
 
 	bool isOpen(void) const;
 	bool isClosed(void) const;
@@ -65,7 +65,7 @@ public:
 
 private:
 	void open(std::shared_ptr<SctpTransport> sctpTransport);
-	void outgoing(mutable_message_ptr message);
+	bool outgoing(mutable_message_ptr message);
 	void incoming(message_ptr message);
 	void processOpenMessage(message_ptr message);
 
@@ -93,14 +93,14 @@ template <typename Buffer> std::pair<const byte *, size_t> to_bytes(const Buffer
 	                      buf.size() * sizeof(E));
 }
 
-template <typename Buffer> void DataChannel::sendBuffer(const Buffer &buf) {
+template <typename Buffer> bool DataChannel::sendBuffer(const Buffer &buf) {
 	auto [bytes, size] = to_bytes(buf);
 	auto message = std::make_shared<Message>(size);
 	std::copy(bytes, bytes + size, message->data());
-	outgoing(message);
+	return outgoing(message);
 }
 
-template <typename Iterator> void DataChannel::sendBuffer(Iterator first, Iterator last) {
+template <typename Iterator> bool DataChannel::sendBuffer(Iterator first, Iterator last) {
 	size_t size = 0;
 	for (Iterator it = first; it != last; ++it)
 		size += it->size();
@@ -111,7 +111,7 @@ template <typename Iterator> void DataChannel::sendBuffer(Iterator first, Iterat
 		auto [bytes, size] = to_bytes(*it);
 		pos = std::copy(bytes, bytes + size, pos);
 	}
-	outgoing(message);
+	return outgoing(message);
 }
 
 } // namespace rtc

+ 3 - 4
src/channel.cpp

@@ -80,10 +80,9 @@ void Channel::triggerAvailable(size_t count) {
 }
 
 void Channel::triggerBufferedAmount(size_t amount) {
-	bool lowThresholdCrossed =
-	    mBufferedAmount > mBufferedAmountLowThreshold && amount <= mBufferedAmountLowThreshold;
-	mBufferedAmount = amount;
-	if (lowThresholdCrossed)
+	size_t previous = mBufferedAmount.exchange(amount);
+	size_t threshold = mBufferedAmountLowThreshold.load();
+	if (previous > threshold && amount <= threshold)
 		mBufferedAmountLowCallback();
 }
 

+ 7 - 7
src/datachannel.cpp

@@ -83,19 +83,19 @@ void DataChannel::close() {
 	}
 }
 
-void DataChannel::send(const std::variant<binary, string> &data) {
-	std::visit(
+bool DataChannel::send(const std::variant<binary, string> &data) {
+	return std::visit(
 	    [&](const auto &d) {
 		    using T = std::decay_t<decltype(d)>;
 		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
 		    auto *b = reinterpret_cast<const byte *>(d.data());
-		    outgoing(std::make_shared<Message>(b, b + d.size(), type));
+		    return outgoing(std::make_shared<Message>(b, b + d.size(), type));
 	    },
 	    data);
 }
 
-void DataChannel::send(const byte *data, size_t size) {
-	outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
+bool DataChannel::send(const byte *data, size_t size) {
+	return outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
 }
 
 std::optional<std::variant<binary, string>> DataChannel::receive() {
@@ -177,7 +177,7 @@ void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
 	mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 }
 
-void DataChannel::outgoing(mutable_message_ptr message) {
+bool DataChannel::outgoing(mutable_message_ptr message) {
 	if (mIsClosed || !mSctpTransport)
 		throw std::runtime_error("DataChannel is closed");
 
@@ -187,7 +187,7 @@ void DataChannel::outgoing(mutable_message_ptr message) {
 	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
 	message->reliability = mIsOpen ? mReliability : nullptr;
 	message->stream = mStream;
-	mSctpTransport->send(message);
+	return mSctpTransport->send(message);
 }
 
 void DataChannel::incoming(message_ptr message) {

+ 25 - 28
src/sctptransport.cpp

@@ -188,13 +188,18 @@ void SctpTransport::connect() {
 SctpTransport::State SctpTransport::state() const { return mState; }
 
 bool SctpTransport::send(message_ptr message) {
+	std::lock_guard<std::mutex> lock(mSendMutex);
+
 	if (!message)
-		return false;
+		return mSendQueue.empty();
+
+	// If nothing is pending, try to send directly
+	if (mSendQueue.empty() && trySendMessage(message))
+		return true;
 
-	updateBufferedAmount(message->stream, message->size());
 	mSendQueue.push(message);
-	trySendAll();
-	return true;
+	updateBufferedAmount(message->stream, message_size_func(message));
+	return false;
 }
 
 void SctpTransport::reset(unsigned int stream) {
@@ -231,25 +236,21 @@ void SctpTransport::changeState(State state) {
 		mStateChangeCallback(state);
 }
 
-bool SctpTransport::trySendAll() {
-	std::unique_lock<std::mutex> lock(mSendMutex, std::try_to_lock);
-	if (!lock.owns_lock())
-		return false;
-
+bool SctpTransport::trySendQueue() {
+	// Requires mSendMutex to be locked
 	while (auto next = mSendQueue.peek()) {
 		auto message = *next;
-		if (!trySend(message))
+		if (!trySendMessage(message))
 			return false;
-		updateBufferedAmount(message->stream, -message->size());
 		mSendQueue.pop();
+		updateBufferedAmount(message->stream, -message_size_func(message));
 	}
 	return true;
 }
 
-bool SctpTransport::trySend(message_ptr message) {
-	if (!message)
-		return false;
-
+bool SctpTransport::trySendMessage(message_ptr message) {
+	// Requires mSendMutex to be locked
+	//
 	// TODO: Implement SCTP ndata specification draft when supported everywhere
 	// See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08
 
@@ -316,19 +317,13 @@ bool SctpTransport::trySend(message_ptr message) {
 }
 
 void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
-	if (delta == 0)
-		return;
-	std::lock_guard<std::mutex> lock(mBufferedAmountMutex);
+	// Requires mSendMutex to be locked
 	auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first;
-	if (delta > 0)
-		it->second += size_t(delta);
-	else if (it->second > size_t(-delta))
-		it->second -= size_t(-delta);
-	else
-		it->second = 0;
-	mBufferedAmountCallback(streamId, it->second);
-	if (it->second == 0)
+	size_t amount = it->second;
+	amount = size_t(std::max(long(amount) + delta, long(0)));
+	if (amount == 0)
 		mBufferedAmount.erase(it);
+	mBufferedAmountCallback(streamId, amount);
 }
 
 int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data,
@@ -364,7 +359,8 @@ int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, co
 
 int SctpTransport::handleSend(size_t free) {
 	try {
-		trySendAll();
+		std::lock_guard<std::mutex> lock(mSendMutex);
+		trySendQueue();
 	} catch (const std::exception &e) {
 		std::cerr << "SCTP send: " << e.what() << std::endl;
 		return -1;
@@ -470,7 +466,8 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 	case SCTP_SENDER_DRY_EVENT: {
 		// It not should be necessary since the send callback should have been called already,
 		// but to be sure, let's try to send now.
-		trySendAll();
+		std::lock_guard<std::mutex> lock(mSendMutex);
+		trySendQueue();
 	}
 	case SCTP_STREAM_RESET_EVENT: {
 		const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event;

+ 3 - 4
src/sctptransport.hpp

@@ -69,8 +69,9 @@ private:
 	void connect();
 	void incoming(message_ptr message);
 	void changeState(State state);
-	bool trySendAll();
-	bool trySend(message_ptr message);
+
+	bool trySendQueue();
+	bool trySendMessage(message_ptr message);
 	void updateBufferedAmount(uint16_t streamId, long delta);
 
 	int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len,
@@ -86,8 +87,6 @@ private:
 
 	std::mutex mSendMutex;
 	Queue<message_ptr> mSendQueue;
-
-	std::mutex mBufferedAmountMutex;
 	std::map<uint16_t, size_t> mBufferedAmount;
 	amount_callback mBufferedAmountCallback;