Browse Source

Changed sent callback to more generic bufferAmountLow

Paul-Louis Ageneau 5 years ago
parent
commit
89ff113688

+ 12 - 4
include/rtc/channel.hpp

@@ -34,6 +34,9 @@ public:
 	virtual std::optional<std::variant<binary, string>> receive() = 0;
 	virtual bool isOpen(void) const = 0;
 	virtual bool isClosed(void) const = 0;
+	virtual size_t availableAmount() const { return 0; }
+
+	size_t bufferedAmount() const;
 
 	void onOpen(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
@@ -44,14 +47,16 @@ public:
 	               std::function<void(const string &data)> stringCallback);
 
 	void onAvailable(std::function<void()> callback);
-	void onSent(std::function<void()> callback);
+	void onBufferedAmountLow(std::function<void()> callback);
+
+	void setBufferedAmountLowThreshold(size_t amount);
 
 protected:
 	virtual void triggerOpen(void);
 	virtual void triggerClosed(void);
 	virtual void triggerError(const string &error);
-	virtual void triggerAvailable(size_t available);
-	virtual void triggerSent();
+	virtual void triggerAvailable(size_t count);
+	virtual void triggerBufferedAmount(size_t amount);
 
 private:
 	synchronized_callback<> mOpenCallback;
@@ -59,7 +64,10 @@ private:
 	synchronized_callback<const string &> mErrorCallback;
 	synchronized_callback<const std::variant<binary, string> &> mMessageCallback;
 	synchronized_callback<> mAvailableCallback;
-	synchronized_callback<> mSentCallback;
+	synchronized_callback<> mBufferedAmountLowCallback;
+
+	size_t mBufferedAmount = 0;
+	size_t mBufferedAmountLowThreshold = 0;
 };
 
 } // namespace rtc

+ 5 - 6
include/rtc/datachannel.hpp

@@ -49,20 +49,19 @@ public:
 	void send(const byte *data, size_t size);
 	std::optional<std::variant<binary, string>> receive();
 
+	// Directly send a buffer to avoid a copy
 	template <typename Buffer> void sendBuffer(const Buffer &buf);
 	template <typename Iterator> void sendBuffer(Iterator first, Iterator last);
 
-	size_t available() const;
-	size_t availableSize() const;
+	bool isOpen(void) const;
+	bool isClosed(void) const;
+	size_t availableAmount() const;
 
 	unsigned int stream() const;
 	string label() const;
 	string protocol() const;
 	Reliability reliability() const;
 
-	bool isOpen(void) const;
-	bool isClosed(void) const;
-
 private:
 	void open(std::shared_ptr<SctpTransport> sctpTransport);
 	void outgoing(mutable_message_ptr message);
@@ -81,7 +80,7 @@ private:
 	std::atomic<bool> mIsClosed = false;
 
 	Queue<message_ptr> mRecvQueue;
-	std::atomic<size_t> mRecvSize = 0;
+	std::atomic<size_t> mRecvAmount = 0;
 
 	friend class PeerConnection;
 };

+ 3 - 0
include/rtc/message.hpp

@@ -44,6 +44,9 @@ struct Message : binary {
 using message_ptr = std::shared_ptr<const Message>;
 using mutable_message_ptr = std::shared_ptr<Message>;
 using message_callback = std::function<void(message_ptr message)>;
+constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
+	return m->type != Message::Control ? m->size() : 0;
+};
 
 template <typename Iterator>
 message_ptr make_message(Iterator begin, Iterator end, Message::Type type = Message::Binary,

+ 2 - 1
include/rtc/peerconnection.hpp

@@ -89,7 +89,8 @@ private:
 
 	bool checkFingerprint(std::weak_ptr<PeerConnection> weak_this, const std::string &fingerprint) const;
 	void forwardMessage(std::weak_ptr<PeerConnection> weak_this, message_ptr message);
-	void forwardSent(std::weak_ptr<PeerConnection> weak_this, uint16_t stream);
+	void forwardBufferedAmount(std::weak_ptr<PeerConnection> weak_this, uint16_t stream,
+	                           size_t amount);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void closeDataChannels();

+ 19 - 3
include/rtc/queue.hpp

@@ -32,12 +32,16 @@ namespace rtc {
 
 template <typename T> class Queue {
 public:
-	Queue(std::size_t limit = 0);
+	using amount_function = std::function<size_t(const T &element)>;
+
+	Queue(
+	    size_t limit = 0, amount_function func = [](const T &element) -> size_t { return 1; });
 	~Queue();
 
 	void stop();
 	bool empty() const;
-	size_t size() const;
+	size_t size() const;   // elements
+	size_t amount() const; // amount
 	void push(const T &element);
 	void push(T &&element);
 	std::optional<T> pop();
@@ -47,14 +51,18 @@ public:
 
 private:
 	const size_t mLimit;
+	size_t mAmount;
 	std::queue<T> mQueue;
 	std::condition_variable mPopCondition, mPushCondition;
+	amount_function mAmountFunction;
 	bool mStopping = false;
 
 	mutable std::mutex mMutex;
 };
 
-template <typename T> Queue<T>::Queue(size_t limit) : mLimit(limit) {}
+template <typename T>
+Queue<T>::Queue(size_t limit, amount_function func)
+    : mLimit(limit), mAmount(0), mAmountFunction(func) {}
 
 template <typename T> Queue<T>::~Queue() { stop(); }
 
@@ -75,12 +83,18 @@ template <typename T> size_t Queue<T>::size() const {
 	return mQueue.size();
 }
 
+template <typename T> size_t Queue<T>::amount() const {
+	std::lock_guard<std::mutex> lock(mMutex);
+	return mAmount;
+}
+
 template <typename T> void Queue<T>::push(const T &element) { push(T{element}); }
 
 template <typename T> void Queue<T>::push(T &&element) {
 	std::unique_lock<std::mutex> lock(mMutex);
 	mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
 	if (!mStopping) {
+		mAmount += mAmountFunction(element);
 		mQueue.emplace(std::move(element));
 		mPopCondition.notify_one();
 	}
@@ -90,6 +104,7 @@ template <typename T> std::optional<T> Queue<T>::pop() {
 	std::unique_lock<std::mutex> lock(mMutex);
 	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 	if (!mQueue.empty()) {
+		mAmount -= mAmountFunction(mQueue.front());
 		std::optional<T> element(std::move(mQueue.front()));
 		mQueue.pop();
 		return element;
@@ -101,6 +116,7 @@ template <typename T> std::optional<T> Queue<T>::pop() {
 template <typename T> std::optional<T> Queue<T>::tryPop() {
 	std::unique_lock<std::mutex> lock(mMutex);
 	if (!mQueue.empty()) {
+		mAmount -= mAmountFunction(mQueue.front());
 		std::optional<T> element(std::move(mQueue.front()));
 		mQueue.pop();
 		return element;

+ 16 - 6
src/channel.cpp

@@ -53,21 +53,25 @@ void Channel::onAvailable(std::function<void()> callback) {
 	mAvailableCallback = callback;
 }
 
-void Channel::onSent(std::function<void()> callback) {
-	mSentCallback = callback;
+void Channel::onBufferedAmountLow(std::function<void()> callback) {
+	mBufferedAmountLowCallback = callback;
 }
 
+size_t Channel::bufferedAmount() const { return mBufferedAmount; }
+
+void Channel::setBufferedAmountLowThreshold(size_t amount) { mBufferedAmountLowThreshold = amount; }
+
 void Channel::triggerOpen() { mOpenCallback(); }
 
 void Channel::triggerClosed() { mClosedCallback(); }
 
 void Channel::triggerError(const string &error) { mErrorCallback(error); }
 
-void Channel::triggerAvailable(size_t available) {
-	if (available == 1)
+void Channel::triggerAvailable(size_t count) {
+	if (count == 1)
 		mAvailableCallback();
 
-	while (mMessageCallback && available--) {
+	while (mMessageCallback && count--) {
 		auto message = receive();
 		if (!message)
 			break;
@@ -75,7 +79,13 @@ void Channel::triggerAvailable(size_t available) {
 	}
 }
 
-void Channel::triggerSent() { mSentCallback(); }
+void Channel::triggerBufferedAmount(size_t amount) {
+	bool lowThresholdCrossed =
+	    mBufferedAmount > mBufferedAmountLowThreshold && amount <= mBufferedAmountLowThreshold;
+	mBufferedAmount = amount;
+	if (lowThresholdCrossed)
+		mBufferedAmountLowCallback();
+}
 
 } // namespace rtc
 

+ 7 - 11
src/datachannel.cpp

@@ -62,12 +62,13 @@ DataChannel::DataChannel(shared_ptr<PeerConnection> pc, unsigned int stream, str
     : mPeerConnection(std::move(pc)), mStream(stream), mLabel(std::move(label)),
       mProtocol(std::move(protocol)),
       mReliability(std::make_shared<Reliability>(std::move(reliability))),
-      mRecvQueue(RECV_QUEUE_SIZE) {}
+      mRecvQueue(RECV_QUEUE_SIZE, message_size_func) {}
 
 DataChannel::DataChannel(shared_ptr<PeerConnection> pc, shared_ptr<SctpTransport> transport,
                          unsigned int stream)
     : mPeerConnection(std::move(pc)), mSctpTransport(transport), mStream(stream),
-      mReliability(std::make_shared<Reliability>()) {}
+      mReliability(std::make_shared<Reliability>()),
+      mRecvQueue(RECV_QUEUE_SIZE, message_size_func) {}
 
 DataChannel::~DataChannel() { close(); }
 
@@ -109,11 +110,9 @@ std::optional<std::variant<binary, string>> DataChannel::receive() {
 			break;
 		}
 		case Message::String:
-			mRecvSize -= message->size();
 			return std::make_optional(
 			    string(reinterpret_cast<const char *>(message->data()), message->size()));
 		case Message::Binary:
-			mRecvSize -= message->size();
 			return std::make_optional(std::move(*message));
 		}
 	}
@@ -121,9 +120,11 @@ std::optional<std::variant<binary, string>> DataChannel::receive() {
 	return nullopt;
 }
 
-size_t DataChannel::available() const { return mRecvQueue.size(); }
+bool DataChannel::isOpen(void) const { return mIsOpen; }
+
+bool DataChannel::isClosed(void) const { return mIsClosed; }
 
-size_t DataChannel::availableSize() const { return mRecvSize; }
+size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 
 unsigned int DataChannel::stream() const { return mStream; }
 
@@ -133,10 +134,6 @@ string DataChannel::protocol() const { return mProtocol; }
 
 Reliability DataChannel::reliability() const { return *mReliability; }
 
-bool DataChannel::isOpen(void) const { return mIsOpen; }
-
-bool DataChannel::isClosed(void) const { return mIsClosed; }
-
 void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
 	mSctpTransport = sctpTransport;
 
@@ -203,7 +200,6 @@ void DataChannel::incoming(message_ptr message) {
 	}
 	case Message::String:
 	case Message::Binary:
-		mRecvSize += message->size();
 		mRecvQueue.push(message);
 		triggerAvailable(mRecvQueue.size());
 		break;

+ 5 - 4
src/peerconnection.cpp

@@ -235,8 +235,8 @@ void PeerConnection::initSctpTransport() {
 	    mDtlsTransport, sctpPort,
 	    std::bind(&PeerConnection::forwardMessage, this,
 	              weak_ptr<PeerConnection>{shared_from_this()}, _1),
-	    std::bind(&PeerConnection::forwardSent, this, weak_ptr<PeerConnection>{shared_from_this()},
-	              _1),
+	    std::bind(&PeerConnection::forwardBufferedAmount, this,
+	              weak_ptr<PeerConnection>{shared_from_this()}, _1, _2),
 	    [this,
 	     weak_this = weak_ptr<PeerConnection>{shared_from_this()}](SctpTransport::State state) {
 		    auto strong_this = weak_this.lock();
@@ -312,7 +312,8 @@ void PeerConnection::forwardMessage(weak_ptr<PeerConnection> weak_this, message_
 	channel->incoming(message);
 }
 
-void PeerConnection::forwardSent(weak_ptr<PeerConnection> weak_this, uint16_t stream) {
+void PeerConnection::forwardBufferedAmount(weak_ptr<PeerConnection> weak_this, uint16_t stream,
+                                           size_t amount) {
 	auto strong_this = weak_this.lock();
 	if (!strong_this)
 		return;
@@ -327,7 +328,7 @@ void PeerConnection::forwardSent(weak_ptr<PeerConnection> weak_this, uint16_t st
 	}
 
 	if (channel)
-		channel->triggerSent();
+		channel->triggerBufferedAmount(amount);
 }
 
 void PeerConnection::iterateDataChannels(

+ 26 - 18
src/sctptransport.cpp

@@ -47,12 +47,13 @@ void SctpTransport::GlobalCleanup() {
 	}
 }
 
-SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
-                             sent_callback sentCallback, state_callback stateChangeCallback)
-    : Transport(lower), mPort(port), mSentCallback(std::move(sentCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)) {
-	mState = State::Disconnected;
-	onRecv(recv);
+SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
+                             message_callback recvCallback, amount_callback bufferedAmountCallback,
+                             state_callback stateChangeCallback)
+    : Transport(lower), mPort(port), mSendQueue(0, message_size_func),
+      mBufferedAmountCallback(std::move(bufferedAmountCallback)),
+      mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
+	onRecv(recvCallback);
 
 	GlobalInit();
 
@@ -138,7 +139,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
 		throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
 
 	mSendThread = std::thread(&SctpTransport::runConnectAndSendLoop, this);
-		}
+}
 
 SctpTransport::~SctpTransport() {
 	onRecv(nullptr); // unset recv callback
@@ -164,7 +165,7 @@ bool SctpTransport::send(message_ptr message) {
 	if (!message || mStopping)
 		return false;
 
-	updateSendCount(message->stream, 1);
+	updateBufferedAmount(message->stream, message->size());
 	mSendQueue.push(message);
 	return true;
 }
@@ -237,8 +238,9 @@ void SctpTransport::runConnectAndSendLoop() {
 	try {
 		while (auto next = mSendQueue.pop()) {
 			auto message = *next;
-			updateSendCount(message->stream, -1);
-			if (!doSend(message))
+			bool success = doSend(message);
+			updateBufferedAmount(message->stream, -message->size());
+			if (!success)
 				throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
 		}
 	} catch (const std::exception &e) {
@@ -310,14 +312,20 @@ bool SctpTransport::doSend(message_ptr message) {
 	return ret > 0;
 }
 
-void SctpTransport::updateSendCount(uint16_t streamId, int delta) {
-	std::lock_guard<std::mutex> lock(mSendCountMutex);
-	auto it = mSendCount.insert(std::make_pair(streamId, 0)).first;
-	it->second = std::max(it->second + delta, 0);
-	if (it->second == 0) {
-		mSendCount.erase(it);
-		mSentCallback(streamId);
-	}
+void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
+	if (delta == 0)
+		return;
+	std::lock_guard<std::mutex> lock(mBufferedAmountMutex);
+	auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first;
+	if (delta > 0)
+		it->second += size_t(delta);
+	else if (it->second > size_t(-delta))
+		it->second -= size_t(-delta);
+	else
+		it->second = 0;
+	mBufferedAmountCallback(streamId, it->second);
+	if (it->second == 0)
+		mBufferedAmount.erase(it);
 }
 
 int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df) {

+ 8 - 8
src/sctptransport.hpp

@@ -41,11 +41,11 @@ class SctpTransport : public Transport {
 public:
 	enum class State { Disconnected, Connecting, Connected, Failed };
 
-	using sent_callback = std::function<void(uint16_t streamId)>;
+	using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
 	using state_callback = std::function<void(State state)>;
 
-	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
-	              sent_callback sent, state_callback stateChangeCallback);
+	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
+	              amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
 	~SctpTransport();
 
 	State state() const;
@@ -66,7 +66,7 @@ private:
 	void changeState(State state);
 	void runConnectAndSendLoop();
 	bool doSend(message_ptr message);
-	void updateSendCount(uint16_t streamId, int delta);
+	void updateBufferedAmount(uint16_t streamId, long delta);
 
 	int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df);
 
@@ -81,17 +81,17 @@ private:
 
 	Queue<message_ptr> mSendQueue;
 	std::thread mSendThread;
-	std::map<uint16_t, int> mSendCount;
-	std::mutex mSendCountMutex;
-	sent_callback mSentCallback;
+	std::map<uint16_t, size_t> mBufferedAmount;
+	std::mutex mBufferedAmountMutex;
+	amount_callback mBufferedAmountCallback;
 
 	std::mutex mConnectMutex;
 	std::condition_variable mConnectCondition;
 	std::atomic<bool> mConnectDataSent = false;
 	std::atomic<bool> mStopping = false;
 
-	std::atomic<State> mState;
 	state_callback mStateChangeCallback;
+	std::atomic<State> mState;
 
 	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
 	static int ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,