Przeglądaj źródła

Added sent callback on DataChannel

Paul-Louis Ageneau 5 lat temu
rodzic
commit
d9bfcbd6be

+ 9 - 1
include/rtc/channel.hpp

@@ -44,12 +44,14 @@ public:
 	               std::function<void(const string &data)> stringCallback);
 
 	void onAvailable(std::function<void()> callback);
+	void onSent(std::function<void()> callback);
 
 protected:
 	virtual void triggerOpen(void);
 	virtual void triggerClosed(void);
 	virtual void triggerError(const string &error);
 	virtual void triggerAvailable(size_t available);
+	virtual void triggerSent();
 
 private:
 	std::function<void()> mOpenCallback;
@@ -57,7 +59,13 @@ private:
 	std::function<void(const string &)> mErrorCallback;
 	std::function<void(const std::variant<binary, string> &)> mMessageCallback;
 	std::function<void()> mAvailableCallback;
-	std::recursive_mutex mCallbackMutex;
+	std::function<void()> mSentCallback;
+	std::mutex mCallbackMutex;
+
+	template <typename T> T getCallback(const T &callback) {
+		std::lock_guard<std::mutex> lock(mCallbackMutex);
+		return callback;
+	}
 };
 
 } // namespace rtc

+ 1 - 0
include/rtc/peerconnection.hpp

@@ -89,6 +89,7 @@ private:
 
 	bool checkFingerprint(std::weak_ptr<PeerConnection> weak_this, const std::string &fingerprint) const;
 	void forwardMessage(std::weak_ptr<PeerConnection> weak_this, message_ptr message);
+	void forwardSent(std::weak_ptr<PeerConnection> weak_this, uint16_t stream);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void closeDataChannels();

+ 35 - 23
src/channel.cpp

@@ -18,30 +18,34 @@
 
 #include "channel.hpp"
 
+namespace {}
+
 namespace rtc {
 
 void Channel::onOpen(std::function<void()> callback) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mOpenCallback = callback;
 }
 
 void Channel::onClosed(std::function<void()> callback) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mClosedCallback = callback;
 }
 
 void Channel::onError(std::function<void(const string &error)> callback) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mErrorCallback = callback;
 }
 
 void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mMessageCallback = callback;
 
 	// Pass pending messages
 	while (auto message = receive()) {
-		mMessageCallback(*message);
+		// The callback might be changed from itself
+		if (auto callback = getCallback(mMessageCallback))
+			callback(*message);
 	}
 }
 
@@ -53,41 +57,49 @@ void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
 }
 
 void Channel::onAvailable(std::function<void()> callback) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
+	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mAvailableCallback = callback;
 }
 
-void Channel::triggerOpen(void) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
-	if (mOpenCallback)
-		mOpenCallback();
+void Channel::onSent(std::function<void()> callback) {
+	std::lock_guard<std::mutex> lock(mCallbackMutex);
+	mSentCallback = callback;
+}
+
+void Channel::triggerOpen() {
+	if (auto callback = getCallback(mOpenCallback))
+		callback();
 }
 
-void Channel::triggerClosed(void) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
-	if (mClosedCallback)
-		mClosedCallback();
+void Channel::triggerClosed() {
+	if (auto callback = getCallback(mClosedCallback))
+		callback();
 }
 
 void Channel::triggerError(const string &error) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
-	if (mErrorCallback)
-		mErrorCallback(error);
+	if (auto callback = getCallback(mErrorCallback))
+		callback(error);
 }
 
 void Channel::triggerAvailable(size_t available) {
-	std::lock_guard<std::recursive_mutex> lock(mCallbackMutex);
-	if (mAvailableCallback && available == 1) {
-		mAvailableCallback();
+	if (available == 1) {
+		if (auto callback = getCallback(mAvailableCallback))
+			callback();
 	}
-	// The callback might be changed from itself
-	while (mMessageCallback && available--) {
+	while (available--) {
 		auto message = receive();
 		if (!message)
 			break;
-		mMessageCallback(*message);
+		// The callback might be changed from itself
+		if (auto callback = getCallback(mMessageCallback))
+			callback(*message);
 	}
 }
 
+void Channel::triggerSent() {
+	if (auto callback = getCallback(mSentCallback))
+		callback();
+}
+
 } // namespace rtc
 

+ 28 - 4
src/peerconnection.cpp

@@ -232,10 +232,16 @@ void PeerConnection::initDtlsTransport() {
 void PeerConnection::initSctpTransport() {
 	uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
 	mSctpTransport = std::make_shared<SctpTransport>(
-	    mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, weak_ptr<PeerConnection>{shared_from_this()}, _1),
-	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](SctpTransport::State state) {
-        auto strong_this = weak_this.lock();
-        if (!strong_this) return;
+	    mDtlsTransport, sctpPort,
+	    std::bind(&PeerConnection::forwardMessage, this,
+	              weak_ptr<PeerConnection>{shared_from_this()}, _1),
+	    std::bind(&PeerConnection::forwardSent, this, weak_ptr<PeerConnection>{shared_from_this()},
+	              _1),
+	    [this,
+	     weak_this = weak_ptr<PeerConnection>{shared_from_this()}](SctpTransport::State state) {
+		    auto strong_this = weak_this.lock();
+		    if (!strong_this)
+			    return;
 
 		    switch (state) {
 		    case SctpTransport::State::Connected:
@@ -305,6 +311,24 @@ void PeerConnection::forwardMessage(weak_ptr<PeerConnection> weak_this, message_
 	channel->incoming(message);
 }
 
+void PeerConnection::forwardSent(weak_ptr<PeerConnection> weak_this, uint16_t stream) {
+	auto strong_this = weak_this.lock();
+	if (!strong_this)
+		return;
+
+	shared_ptr<DataChannel> channel;
+	if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) {
+		channel = it->second.lock();
+		if (!channel || channel->isClosed()) {
+			mDataChannels.erase(it);
+			channel = nullptr;
+		}
+	}
+
+	if (channel)
+		channel->triggerSent();
+}
+
 void PeerConnection::iterateDataChannels(
     std::function<void(shared_ptr<DataChannel> channel)> func) {
 	auto it = mDataChannels.begin();

+ 17 - 4
src/sctptransport.cpp

@@ -48,9 +48,9 @@ void SctpTransport::GlobalCleanup() {
 }
 
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
-                             state_callback stateChangeCallback)
+                             sent_callback sentCallback, state_callback stateChangeCallback)
     : Transport(lower), mPort(port), mState(State::Disconnected),
-      mStateChangeCallback(std::move(stateChangeCallback)) {
+      mSentCallback(std::move(sentCallback)), mStateChangeCallback(std::move(stateChangeCallback)) {
 
 	onRecv(recv);
 
@@ -142,6 +142,7 @@ bool SctpTransport::send(message_ptr message) {
 	if (!message || mStopping)
 		return false;
 
+	updateSendCount(message->stream, 1);
 	mSendQueue.push(message);
 	return true;
 }
@@ -212,8 +213,10 @@ void SctpTransport::runConnectAndSendLoop() {
 	}
 
 	try {
-		while (auto message = mSendQueue.pop()) {
-			if (!doSend(*message))
+		while (auto next = mSendQueue.pop()) {
+			auto message = *next;
+			updateSendCount(message->stream, -1);
+			if (!doSend(message))
 				throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
 		}
 	} catch (const std::exception &e) {
@@ -285,6 +288,16 @@ bool SctpTransport::doSend(message_ptr message) {
 	return ret > 0;
 }
 
+void SctpTransport::updateSendCount(uint16_t streamId, int delta) {
+	std::lock_guard<std::mutex> lock(mSendCountMutex);
+	auto it = mSendCount.insert(std::make_pair(streamId, 0)).first;
+	it->second = std::max(it->second + delta, 0);
+	if (it->second == 0) {
+		mSendCount.erase(it);
+		mSentCallback(streamId);
+	}
+}
+
 int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df) {
 	byte *b = reinterpret_cast<byte *>(data);
 	outgoing(make_message(b, b + len));

+ 8 - 1
src/sctptransport.hpp

@@ -26,6 +26,7 @@
 
 #include <condition_variable>
 #include <functional>
+#include <map>
 #include <mutex>
 #include <thread>
 
@@ -40,10 +41,11 @@ class SctpTransport : public Transport {
 public:
 	enum class State { Disconnected, Connecting, Connected, Failed };
 
+	using sent_callback = std::function<void(uint16_t streamId)>;
 	using state_callback = std::function<void(State state)>;
 
 	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
-	              state_callback stateChangeCallback);
+	              sent_callback sent, state_callback stateChangeCallback);
 	~SctpTransport();
 
 	State state() const;
@@ -64,6 +66,7 @@ private:
 	void changeState(State state);
 	void runConnectAndSendLoop();
 	bool doSend(message_ptr message);
+	void updateSendCount(uint16_t streamId, int delta);
 
 	int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df);
 
@@ -78,6 +81,10 @@ private:
 
 	Queue<message_ptr> mSendQueue;
 	std::thread mSendThread;
+	std::map<uint16_t, int> mSendCount;
+	std::mutex mSendCountMutex;
+	sent_callback mSentCallback;
+
 	std::mutex mConnectMutex;
 	std::condition_variable mConnectCondition;
 	std::atomic<bool> mConnectDataSent = false;