Browse Source

Implemented reading back-pressure and callbacks synchronization

Paul-Louis Ageneau 5 years ago
parent
commit
900c482146
7 changed files with 144 additions and 51 deletions
  1. 9 2
      include/rtc/channel.hpp
  2. 7 2
      include/rtc/datachannel.hpp
  3. 4 0
      include/rtc/include.hpp
  4. 39 21
      include/rtc/queue.hpp
  5. 35 5
      src/channel.cpp
  6. 42 17
      src/datachannel.cpp
  7. 8 4
      src/sctptransport.cpp

+ 9 - 2
include/rtc/channel.hpp

@@ -22,6 +22,7 @@
 #include "include.hpp"
 
 #include <functional>
+#include <mutex>
 #include <variant>
 
 namespace rtc {
@@ -30,30 +31,36 @@ class Channel {
 public:
 	virtual void close(void) = 0;
 	virtual void send(const std::variant<binary, string> &data) = 0;
-
+	virtual std::optional<std::variant<binary, string>> receive() = 0;
 	virtual bool isOpen(void) const = 0;
 	virtual bool isClosed(void) const = 0;
 
 	void onOpen(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
 	void onError(std::function<void(const string &error)> callback);
+
 	void onMessage(std::function<void(const std::variant<binary, string> &data)> callback);
 	void onMessage(std::function<void(const binary &data)> binaryCallback,
 	               std::function<void(const string &data)> stringCallback);
 
+	void onAvailable(std::function<void()> callback);
+
 protected:
 	virtual void triggerOpen(void);
 	virtual void triggerClosed(void);
 	virtual void triggerError(const string &error);
-	virtual void triggerMessage(const std::variant<binary, string> &data);
+	virtual void triggerAvailable(size_t available);
 
 private:
 	std::function<void()> mOpenCallback;
 	std::function<void()> mClosedCallback;
 	std::function<void(const string &)> mErrorCallback;
 	std::function<void(const std::variant<binary, string> &)> mMessageCallback;
+	std::function<void()> mAvailableCallback;
+	std::recursive_mutex mCallbackMutex;
 };
 
 } // namespace rtc
 
 #endif // RTC_CHANNEL_H
+

+ 7 - 2
include/rtc/datachannel.hpp

@@ -22,8 +22,10 @@
 #include "channel.hpp"
 #include "include.hpp"
 #include "message.hpp"
+#include "queue.hpp"
 #include "reliability.hpp"
 
+#include <atomic>
 #include <chrono>
 #include <functional>
 #include <variant>
@@ -42,6 +44,7 @@ public:
 	void close(void);
 	void send(const std::variant<binary, string> &data);
 	void send(const byte *data, size_t size);
+	std::optional<std::variant<binary, string>> receive();
 
 	unsigned int stream() const;
 	string label() const;
@@ -62,8 +65,10 @@ private:
 	string mProtocol;
 	std::shared_ptr<Reliability> mReliability;
 
-	bool mIsOpen = false;
-	bool mIsClosed = false;
+	std::atomic<bool> mIsOpen = false;
+	std::atomic<bool> mIsClosed = false;
+
+	Queue<message_ptr> mRecvQueue;
 
 	friend class PeerConnection;
 };

+ 4 - 0
include/rtc/include.hpp

@@ -33,6 +33,7 @@ using binary = std::vector<byte>;
 
 using std::nullopt;
 
+using std::size_t;
 using std::uint16_t;
 using std::uint32_t;
 using std::uint64_t;
@@ -41,6 +42,9 @@ using std::uint8_t;
 const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
 const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
 
+const size_t RECV_QUEUE_SIZE = 256; // DataChannel receive queue size in messages
+                                    // (0 means unlimited)
+
 const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
 
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };

+ 39 - 21
src/queue.hpp → include/rtc/queue.hpp

@@ -32,32 +32,36 @@ namespace rtc {
 
 template <typename T> class Queue {
 public:
-	Queue();
+	Queue(std::size_t limit = 0);
 	~Queue();
 
 	void stop();
 	bool empty() const;
+	size_t size() const;
 	void push(const T &element);
 	std::optional<T> pop();
+	std::optional<T> tryPop();
 	void wait();
 	void wait(const std::chrono::milliseconds &duration);
 
 private:
+	const size_t mLimit;
 	std::queue<T> mQueue;
-	std::condition_variable mCondition;
-	std::atomic<bool> mStopping;
+	std::condition_variable mPopCondition, mPushCondition;
+	bool mStopping = false;
 
 	mutable std::mutex mMutex;
 };
 
-template <typename T> Queue<T>::Queue() : mStopping(false) {}
+template <typename T> Queue<T>::Queue(size_t limit) : mLimit(limit) {}
 
 template <typename T> Queue<T>::~Queue() { stop(); }
 
 template <typename T> void Queue<T>::stop() {
 	std::lock_guard<std::mutex> lock(mMutex);
 	mStopping = true;
-	mCondition.notify_all();
+	mPopCondition.notify_all();
+	mPushCondition.notify_all();
 }
 
 template <typename T> bool Queue<T>::empty() const {
@@ -65,37 +69,51 @@ template <typename T> bool Queue<T>::empty() const {
 	return mQueue.empty();
 }
 
-template <typename T> void Queue<T>::push(const T &element) {
+template <typename T> size_t Queue<T>::size() const {
 	std::lock_guard<std::mutex> lock(mMutex);
-	if (mStopping)
-		return;
-	mQueue.push(element);
-	mCondition.notify_one();
+	return mQueue.size();
+}
+
+template <typename T> void Queue<T>::push(const T &element) {
+	std::unique_lock<std::mutex> lock(mMutex);
+	mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
+	if (!mStopping) {
+		mQueue.push(element);
+		mPopCondition.notify_one();
+	}
 }
 
 template <typename T> std::optional<T> Queue<T>::pop() {
 	std::unique_lock<std::mutex> lock(mMutex);
-	while (mQueue.empty()) {
-		if (mStopping)
-			return nullopt;
-		mCondition.wait(lock);
+	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
+	if (!mQueue.empty()) {
+		std::optional<T> element(std::move(mQueue.front()));
+		mQueue.pop();
+		return element;
+	} else {
+		return nullopt;
 	}
+}
 
-	std::optional<T> element = mQueue.front();
-	mQueue.pop();
-	return element;
+template <typename T> std::optional<T> Queue<T>::tryPop() {
+	std::unique_lock<std::mutex> lock(mMutex);
+	if (!mQueue.empty()) {
+		std::optional<T> element(std::move(mQueue.front()));
+		mQueue.pop();
+		return element;
+	} else {
+		return nullopt;
+	}
 }
 
 template <typename T> void Queue<T>::wait() {
 	std::unique_lock<std::mutex> lock(mMutex);
-	if (mQueue.empty() && !mStopping)
-		mCondition.wait(lock);
+	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 }
 
 template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) {
 	std::unique_lock<std::mutex> lock(mMutex);
-	if (mQueue.empty() && !mStopping)
-		mCondition.wait_for(lock, duration);
+	mPopCondition.wait_for(lock, duration, [this]() { return !mQueue.empty() || mStopping; });
 }
 
 } // namespace rtc

+ 35 - 5
src/channel.cpp

@@ -20,16 +20,29 @@
 
 namespace rtc {
 
-void Channel::onOpen(std::function<void()> callback) { mOpenCallback = callback; }
+void Channel::onOpen(std::function<void()> callback) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	mOpenCallback = callback;
+}
 
-void Channel::onClosed(std::function<void()> callback) { mClosedCallback = callback; }
+void Channel::onClosed(std::function<void()> callback) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	mClosedCallback = callback;
+}
 
 void Channel::onError(std::function<void(const string &error)> callback) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
 	mErrorCallback = callback;
 }
 
 void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
 	mMessageCallback = callback;
+
+	// Pass pending messages
+	while (auto message = receive()) {
+		mMessageCallback(*message);
+	}
 }
 
 void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
@@ -39,24 +52,41 @@ void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
 	});
 }
 
+void Channel::onAvailable(std::function<void()> callback) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	mAvailableCallback = callback;
+}
+
 void Channel::triggerOpen(void) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
 	if (mOpenCallback)
 		mOpenCallback();
 }
 
 void Channel::triggerClosed(void) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
 	if (mClosedCallback)
 		mClosedCallback();
 }
 
 void Channel::triggerError(const string &error) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
 	if (mErrorCallback)
 		mErrorCallback(error);
 }
 
-void Channel::triggerMessage(const std::variant<binary, string> &data) {
-	if (mMessageCallback)
-		mMessageCallback(data);
+void Channel::triggerAvailable(size_t available) {
+	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	if (mAvailableCallback && available == 1) {
+		mAvailableCallback();
+	}
+	// The callback might be changed from itself
+	while (mMessageCallback && available--) {
+		auto message = receive();
+		if (!message)
+			break;
+		mMessageCallback(*message);
+	}
 }
 
 } // namespace rtc

+ 42 - 17
src/datachannel.cpp

@@ -60,7 +60,8 @@ struct CloseMessage {
 DataChannel::DataChannel(unsigned int stream, string label, string protocol,
                          Reliability reliability)
     : mStream(stream), mLabel(std::move(label)), mProtocol(std::move(protocol)),
-      mReliability(std::make_shared<Reliability>(std::move(reliability))) {}
+      mReliability(std::make_shared<Reliability>(std::move(reliability))),
+      mRecvQueue(RECV_QUEUE_SIZE) {}
 
 DataChannel::DataChannel(unsigned int stream, shared_ptr<SctpTransport> sctpTransport)
     : mStream(stream), mSctpTransport(sctpTransport),
@@ -70,8 +71,7 @@ DataChannel::~DataChannel() { close(); }
 
 void DataChannel::close() {
 	mIsOpen = false;
-	if (!mIsClosed) {
-		mIsClosed = true;
+	if (!mIsClosed.exchange(true)) {
 		if (mSctpTransport)
 			mSctpTransport->reset(mStream);
 	}
@@ -88,7 +88,8 @@ void DataChannel::send(const std::variant<binary, string> &data) {
 		    auto *b = reinterpret_cast<const byte *>(d.data());
 		    // Before the ACK has been received on a DataChannel, all messages must be sent ordered
 		    auto reliability = mIsOpen ? mReliability : nullptr;
-		    mSctpTransport->send(make_message(b, b + d.size(), type, mStream, reliability));
+		    auto message = make_message(b, b + d.size(), type, mStream, reliability);
+		    mSctpTransport->send(message);
 	    },
 	    data);
 }
@@ -98,7 +99,33 @@ void DataChannel::send(const byte *data, size_t size) {
 		return;
 
 	auto reliability = mIsOpen ? mReliability : nullptr;
-	mSctpTransport->send(make_message(data, data + size, Message::Binary, mStream, reliability));
+	auto message = make_message(data, data + size, Message::Binary, mStream, reliability);
+	mSctpTransport->send(message);
+}
+
+std::optional<std::variant<binary, string>> DataChannel::receive() {
+	while (auto opt = mRecvQueue.tryPop()) {
+		auto message = *opt;
+		switch (message->type) {
+		case Message::Control: {
+			auto raw = reinterpret_cast<const uint8_t *>(message->data());
+			if (raw[0] == MESSAGE_CLOSE) {
+				if (mIsOpen) {
+					close();
+					triggerClosed();
+				}
+			}
+			break;
+		}
+		case Message::String:
+			return std::make_optional(
+			    string(reinterpret_cast<const char *>(message->data()), message->size()));
+		case Message::Binary:
+			return std::make_optional(std::move(*message));
+		}
+	}
+
+	return nullopt;
 }
 
 unsigned int DataChannel::stream() const { return mStream; }
@@ -153,16 +180,14 @@ void DataChannel::incoming(message_ptr message) {
 			processOpenMessage(message);
 			break;
 		case MESSAGE_ACK:
-			if (!mIsOpen) {
-				mIsOpen = true;
+			if (!mIsOpen.exchange(true)) {
 				triggerOpen();
 			}
 			break;
 		case MESSAGE_CLOSE:
-			if (mIsOpen) {
-				close();
-				triggerClosed();
-			}
+			// The close message will be processed in-order in receive()
+			mRecvQueue.push(message);
+			triggerAvailable(mRecvQueue.size());
 			break;
 		default:
 			// Ignore
@@ -170,15 +195,15 @@ void DataChannel::incoming(message_ptr message) {
 		}
 		break;
 	}
-	case Message::String: {
-		triggerMessage(string(reinterpret_cast<const char *>(message->data()), message->size()));
+	case Message::String:
+	case Message::Binary:
+		mRecvQueue.push(message);
+		triggerAvailable(mRecvQueue.size());
 		break;
-	}
-	case Message::Binary: {
-		triggerMessage(*message);
+	default:
+		// Ignore
 		break;
 	}
-	}
 }
 
 void DataChannel::processOpenMessage(message_ptr message) {

+ 8 - 4
src/sctptransport.cpp

@@ -52,7 +52,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
     : Transport(lower), mPort(port), mState(State::Disconnected),
       mStateChangeCallback(std::move(stateChangeCallback)) {
 
-  onRecv(recv);
+	onRecv(recv);
 
 	GlobalInit();
 	usrsctp_register_address(this);
@@ -273,13 +273,15 @@ bool SctpTransport::doSend(message_ptr message) {
 		break;
 	}
 
+	ssize_t ret;
 	if (!message->empty()) {
-		return usrsctp_sendv(mSock, message->data(), message->size(), nullptr, 0, &spa, sizeof(spa),
-		                     SCTP_SENDV_SPA, 0) > 0;
+		ret = usrsctp_sendv(mSock, message->data(), message->size(), nullptr, 0, &spa, sizeof(spa),
+		                    SCTP_SENDV_SPA, 0);
 	} else {
 		const char zero = 0;
-		return usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0) > 0;
+		ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
 	}
+	return ret > 0;
 }
 
 int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df) {
@@ -296,6 +298,8 @@ int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_
 
 int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
                            struct sctp_rcvinfo info, int flags) {
+	if (!data)
+		recv(nullptr);
 	if (flags & MSG_NOTIFICATION) {
 		processNotification((union sctp_notification *)data, len);
 	} else {