فهرست منبع

Merge pull request #408 from paullouisageneau/open-cb-recv-dc

Call DataChannel::onOpen() callback on receiving side
Paul-Louis Ageneau 4 سال پیش
والد
کامیت
ebe64ed8ef

+ 6 - 4
examples/client/main.cpp

@@ -222,9 +222,14 @@ shared_ptr<PeerConnection> createPeerConnection(const Configuration &config,
 		cout << "DataChannel from " << id << " received with label \"" << dc->label() << "\""
 		cout << "DataChannel from " << id << " received with label \"" << dc->label() << "\""
 		     << endl;
 		     << endl;
 
 
+		dc->onOpen([wdc = make_weak_ptr(dc)]() {
+			if (auto dc = wdc.lock())
+				dc->send("Hello from " + localId);
+		});
+
 		dc->onClosed([id]() { cout << "DataChannel from " << id << " closed" << endl; });
 		dc->onClosed([id]() { cout << "DataChannel from " << id << " closed" << endl; });
 
 
-		dc->onMessage([id, wdc = make_weak_ptr(dc)](variant<binary, string> data) {
+		dc->onMessage([id](variant<binary, string> data) {
 			if (holds_alternative<string>(data))
 			if (holds_alternative<string>(data))
 				cout << "Message from " << id << " received: " << get<string>(data) << endl;
 				cout << "Message from " << id << " received: " << get<string>(data) << endl;
 			else
 			else
@@ -232,8 +237,6 @@ shared_ptr<PeerConnection> createPeerConnection(const Configuration &config,
 				     << " received, size=" << get<binary>(data).size() << endl;
 				     << " received, size=" << get<binary>(data).size() << endl;
 		});
 		});
 
 
-		dc->send("Hello from " + localId);
-
 		dataChannelMap.emplace(id, dc);
 		dataChannelMap.emplace(id, dc);
 	});
 	});
 
 
@@ -251,4 +254,3 @@ string randomId(size_t length) {
 	generate(id.begin(), id.end(), [&]() { return characters.at(dist(rng)); });
 	generate(id.begin(), id.end(), [&]() { return characters.at(dist(rng)); });
 	return id;
 	return id;
 }
 }
-

+ 46 - 12
include/rtc/utils.hpp

@@ -22,6 +22,8 @@
 #include <functional>
 #include <functional>
 #include <memory>
 #include <memory>
 #include <mutex>
 #include <mutex>
+#include <optional>
+#include <tuple>
 
 
 namespace rtc {
 namespace rtc {
 
 
@@ -58,40 +60,35 @@ private:
 };
 };
 
 
 // callback with built-in synchronization
 // callback with built-in synchronization
-template <typename... Args> class synchronized_callback final {
+template <typename... Args> class synchronized_callback {
 public:
 public:
 	synchronized_callback() = default;
 	synchronized_callback() = default;
 	synchronized_callback(synchronized_callback &&cb) { *this = std::move(cb); }
 	synchronized_callback(synchronized_callback &&cb) { *this = std::move(cb); }
 	synchronized_callback(const synchronized_callback &cb) { *this = cb; }
 	synchronized_callback(const synchronized_callback &cb) { *this = cb; }
 	synchronized_callback(std::function<void(Args...)> func) { *this = std::move(func); }
 	synchronized_callback(std::function<void(Args...)> func) { *this = std::move(func); }
-	~synchronized_callback() { *this = nullptr; }
+	virtual ~synchronized_callback() { *this = nullptr; }
 
 
 	synchronized_callback &operator=(synchronized_callback &&cb) {
 	synchronized_callback &operator=(synchronized_callback &&cb) {
 		std::scoped_lock lock(mutex, cb.mutex);
 		std::scoped_lock lock(mutex, cb.mutex);
-		callback = std::move(cb.callback);
-		cb.callback = nullptr;
+		set(std::exchange(cb.callback, nullptr));
 		return *this;
 		return *this;
 	}
 	}
 
 
 	synchronized_callback &operator=(const synchronized_callback &cb) {
 	synchronized_callback &operator=(const synchronized_callback &cb) {
 		std::scoped_lock lock(mutex, cb.mutex);
 		std::scoped_lock lock(mutex, cb.mutex);
-		callback = cb.callback;
+		set(cb.callback);
 		return *this;
 		return *this;
 	}
 	}
 
 
 	synchronized_callback &operator=(std::function<void(Args...)> func) {
 	synchronized_callback &operator=(std::function<void(Args...)> func) {
 		std::lock_guard lock(mutex);
 		std::lock_guard lock(mutex);
-		callback = std::move(func);
+		set(std::move(func));
 		return *this;
 		return *this;
 	}
 	}
 
 
 	bool operator()(Args... args) const {
 	bool operator()(Args... args) const {
 		std::lock_guard lock(mutex);
 		std::lock_guard lock(mutex);
-		if (!callback)
-			return false;
-
-		callback(std::move(args)...);
-		return true;
+		return call(std::move(args)...);
 	}
 	}
 
 
 	operator bool() const {
 	operator bool() const {
@@ -103,11 +100,48 @@ public:
 		return [this](Args... args) { (*this)(std::move(args)...); };
 		return [this](Args... args) { (*this)(std::move(args)...); };
 	}
 	}
 
 
-private:
+protected:
+	virtual void set(std::function<void(Args...)> func) { callback = std::move(func); }
+	virtual bool call(Args... args) const {
+		if (!callback)
+			return false;
+
+		callback(std::move(args)...);
+		return true;
+	}
+
 	std::function<void(Args...)> callback;
 	std::function<void(Args...)> callback;
 	mutable std::recursive_mutex mutex;
 	mutable std::recursive_mutex mutex;
 };
 };
 
 
+// callback with built-in synchronization and replay of the last missed call
+template <typename... Args>
+class synchronized_stored_callback final : public synchronized_callback<Args...> {
+public:
+	template <typename... CArgs>
+	synchronized_stored_callback(CArgs &&...cargs)
+	    : synchronized_callback<Args...>(std::forward<CArgs>(cargs)...) {}
+	~synchronized_stored_callback() {}
+
+private:
+	void set(std::function<void(Args...)> func) {
+		synchronized_callback<Args...>::set(func);
+		if (func && stored) {
+			std::apply(func, std::move(*stored));
+			stored.reset();
+		}
+	}
+
+	bool call(Args... args) const {
+		if (!synchronized_callback<Args...>::call(args...))
+			stored.emplace(std::move(args)...);
+
+		return true;
+	}
+
+	mutable std::optional<std::tuple<Args...>> stored;
+};
+
 // pimpl base class
 // pimpl base class
 template <typename T> using impl_ptr = std::shared_ptr<T>;
 template <typename T> using impl_ptr = std::shared_ptr<T>;
 template <typename T> class CheshireCat {
 template <typename T> class CheshireCat {

+ 1 - 4
src/channel.cpp

@@ -43,10 +43,7 @@ void Channel::onError(std::function<void(string error)> callback) {
 
 
 void Channel::onMessage(std::function<void(message_variant data)> callback) {
 void Channel::onMessage(std::function<void(message_variant data)> callback) {
 	impl()->messageCallback = callback;
 	impl()->messageCallback = callback;
-
-	// Pass pending messages
-	while (auto message = receive())
-		impl()->messageCallback(*message);
+	impl()->flushPendingMessages();
 }
 }
 
 
 void Channel::onMessage(std::function<void(binary data)> binaryCallback,
 void Channel::onMessage(std::function<void(binary data)> binaryCallback,

+ 27 - 9
src/impl/channel.cpp

@@ -20,22 +20,21 @@
 
 
 namespace rtc::impl {
 namespace rtc::impl {
 
 
-void Channel::triggerOpen() { openCallback(); }
+void Channel::triggerOpen() {
+	mOpenTriggered = true;
+	openCallback();
+	flushPendingMessages();
+}
 
 
 void Channel::triggerClosed() { closedCallback(); }
 void Channel::triggerClosed() { closedCallback(); }
 
 
-void Channel::triggerError(string error) { errorCallback(error); }
+void Channel::triggerError(string error) { errorCallback(std::move(error)); }
 
 
 void Channel::triggerAvailable(size_t count) {
 void Channel::triggerAvailable(size_t count) {
 	if (count == 1)
 	if (count == 1)
 		availableCallback();
 		availableCallback();
 
 
-	while (messageCallback && count--) {
-		auto message = receive();
-		if (!message)
-			break;
-		messageCallback(*message);
-	}
+	flushPendingMessages();
 }
 }
 
 
 void Channel::triggerBufferedAmount(size_t amount) {
 void Channel::triggerBufferedAmount(size_t amount) {
@@ -45,13 +44,32 @@ void Channel::triggerBufferedAmount(size_t amount) {
 		bufferedAmountLowCallback();
 		bufferedAmountLowCallback();
 }
 }
 
 
+void Channel::flushPendingMessages() {
+	if (!mOpenTriggered)
+		return;
+
+	while (messageCallback) {
+		auto next = receive();
+		if (!next)
+			break;
+
+		messageCallback(*next);
+	}
+}
+
+void Channel::resetOpenCallback() {
+	mOpenTriggered = false;
+	openCallback = nullptr;
+}
+
 void Channel::resetCallbacks() {
 void Channel::resetCallbacks() {
+	mOpenTriggered = false;
 	openCallback = nullptr;
 	openCallback = nullptr;
 	closedCallback = nullptr;
 	closedCallback = nullptr;
 	errorCallback = nullptr;
 	errorCallback = nullptr;
-	messageCallback = nullptr;
 	availableCallback = nullptr;
 	availableCallback = nullptr;
 	bufferedAmountLowCallback = nullptr;
 	bufferedAmountLowCallback = nullptr;
+	messageCallback = nullptr;
 }
 }
 
 
 } // namespace rtc::impl
 } // namespace rtc::impl

+ 12 - 6
src/impl/channel.hpp

@@ -38,17 +38,23 @@ struct Channel {
 	virtual void triggerAvailable(size_t count);
 	virtual void triggerAvailable(size_t count);
 	virtual void triggerBufferedAmount(size_t amount);
 	virtual void triggerBufferedAmount(size_t amount);
 
 
-	virtual void resetCallbacks();
+	void flushPendingMessages();
+	void resetOpenCallback();
+	void resetCallbacks();
+
+	synchronized_stored_callback<> openCallback;
+	synchronized_stored_callback<> closedCallback;
+	synchronized_stored_callback<string> errorCallback;
+	synchronized_stored_callback<> availableCallback;
+	synchronized_stored_callback<> bufferedAmountLowCallback;
 
 
-	synchronized_callback<> openCallback;
-	synchronized_callback<> closedCallback;
-	synchronized_callback<string> errorCallback;
 	synchronized_callback<message_variant> messageCallback;
 	synchronized_callback<message_variant> messageCallback;
-	synchronized_callback<> availableCallback;
-	synchronized_callback<> bufferedAmountLowCallback;
 
 
 	std::atomic<size_t> bufferedAmount = 0;
 	std::atomic<size_t> bufferedAmount = 0;
 	std::atomic<size_t> bufferedAmountLowThreshold = 0;
 	std::atomic<size_t> bufferedAmountLowThreshold = 0;
+
+private:
+	std::atomic<bool> mOpenTriggered = false;
 };
 };
 
 
 } // namespace rtc::impl
 } // namespace rtc::impl

+ 30 - 9
src/impl/peerconnection.cpp

@@ -973,33 +973,54 @@ string PeerConnection::localBundleMid() const {
 
 
 void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
 void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
 	auto dataChannel = weakDataChannel.lock();
 	auto dataChannel = weakDataChannel.lock();
-	if (dataChannel)
+	if (dataChannel) {
+		dataChannel->resetOpenCallback(); // might be set internally
 		mPendingDataChannels.push(std::move(dataChannel));
 		mPendingDataChannels.push(std::move(dataChannel));
+	}
+	triggerPendingDataChannels();
+}
+
+void PeerConnection::triggerTrack(weak_ptr<Track> weakTrack) {
+	auto track = weakTrack.lock();
+	if (track) {
+		track->resetOpenCallback(); // might be set internally
+		mPendingTracks.push(std::move(track));
+	}
+	triggerPendingTracks();
+}
 
 
+void PeerConnection::triggerPendingDataChannels() {
 	while (dataChannelCallback) {
 	while (dataChannelCallback) {
 		auto next = mPendingDataChannels.tryPop();
 		auto next = mPendingDataChannels.tryPop();
 		if (!next)
 		if (!next)
 			break;
 			break;
 
 
-		mProcessor->enqueue(dataChannelCallback.wrap(),
-		                    std::make_shared<rtc::DataChannel>(std::move(*next)));
+		auto impl = std::move(*next);
+		dataChannelCallback(std::make_shared<rtc::DataChannel>(impl));
+		impl->triggerOpen();
 	}
 	}
 }
 }
 
 
-void PeerConnection::triggerTrack(weak_ptr<Track> weakTrack) {
-	auto track = weakTrack.lock();
-	if (track)
-		mPendingTracks.push(std::move(track));
-
+void PeerConnection::triggerPendingTracks() {
 	while (trackCallback) {
 	while (trackCallback) {
 		auto next = mPendingTracks.tryPop();
 		auto next = mPendingTracks.tryPop();
 		if (!next)
 		if (!next)
 			break;
 			break;
 
 
-		mProcessor->enqueue(trackCallback.wrap(), std::make_shared<rtc::Track>(std::move(*next)));
+		auto impl = std::move(*next);
+		trackCallback(std::make_shared<rtc::Track>(impl));
+		impl->triggerOpen();
 	}
 	}
 }
 }
 
 
+void PeerConnection::flushPendingDataChannels() {
+	mProcessor->enqueue(std::bind(&PeerConnection::triggerPendingDataChannels, this));
+}
+
+void PeerConnection::flushPendingTracks() {
+	mProcessor->enqueue(std::bind(&PeerConnection::triggerPendingTracks, this));
+}
+
 bool PeerConnection::changeState(State newState) {
 bool PeerConnection::changeState(State newState) {
 	State current;
 	State current;
 	do {
 	do {

+ 8 - 2
src/impl/peerconnection.hpp

@@ -84,8 +84,14 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	void processRemoteCandidate(Candidate candidate);
 	void processRemoteCandidate(Candidate candidate);
 	string localBundleMid() const;
 	string localBundleMid() const;
 
 
-	void triggerDataChannel(weak_ptr<DataChannel> weakDataChannel = {});
-	void triggerTrack(weak_ptr<Track> weakTrack = {});
+	void triggerDataChannel(weak_ptr<DataChannel> weakDataChannel);
+	void triggerTrack(weak_ptr<Track> weakTrack);
+
+	void triggerPendingDataChannels();
+	void triggerPendingTracks();
+
+	void flushPendingDataChannels();
+	void flushPendingTracks();
 
 
 	bool changeState(State newState);
 	bool changeState(State newState);
 	bool changeGatheringState(GatheringState newState);
 	bool changeGatheringState(GatheringState newState);

+ 2 - 2
src/peerconnection.cpp

@@ -267,7 +267,7 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(string label, DataChan
 void PeerConnection::onDataChannel(
 void PeerConnection::onDataChannel(
     std::function<void(shared_ptr<DataChannel> dataChannel)> callback) {
     std::function<void(shared_ptr<DataChannel> dataChannel)> callback) {
 	impl()->dataChannelCallback = callback;
 	impl()->dataChannelCallback = callback;
-	impl()->triggerDataChannel(); // trigger pending DataChannels
+	impl()->flushPendingDataChannels();
 }
 }
 
 
 std::shared_ptr<Track> PeerConnection::addTrack(Description::Media description) {
 std::shared_ptr<Track> PeerConnection::addTrack(Description::Media description) {
@@ -282,7 +282,7 @@ std::shared_ptr<Track> PeerConnection::addTrack(Description::Media description)
 
 
 void PeerConnection::onTrack(std::function<void(std::shared_ptr<Track>)> callback) {
 void PeerConnection::onTrack(std::function<void(std::shared_ptr<Track>)> callback) {
 	impl()->trackCallback = callback;
 	impl()->trackCallback = callback;
-	impl()->triggerTrack(); // trigger pending tracks
+	impl()->flushPendingTracks();
 }
 }
 
 
 void PeerConnection::onLocalDescription(std::function<void(Description description)> callback) {
 void PeerConnection::onLocalDescription(std::function<void(Description description)> callback) {

+ 1 - 1
test/capi_connectivity.cpp

@@ -137,11 +137,11 @@ static void RTC_API dataChannelCallback(int pc, int dc, void *ptr) {
 		return;
 		return;
 	}
 	}
 
 
+	rtcSetOpenCallback(dc, openCallback);
 	rtcSetClosedCallback(dc, closedCallback);
 	rtcSetClosedCallback(dc, closedCallback);
 	rtcSetMessageCallback(dc, messageCallback);
 	rtcSetMessageCallback(dc, messageCallback);
 
 
 	peer->dc = dc;
 	peer->dc = dc;
-	peer->connected = true;
 
 
 	const char *message = peer == peer1 ? "Hello from 1" : "Hello from 2";
 	const char *message = peer == peer1 ? "Hello from 1" : "Hello from 2";
 	rtcSendMessage(peer->dc, message, -1); // negative size indicates a null-terminated string
 	rtcSendMessage(peer->dc, message, -1); // negative size indicates a null-terminated string

+ 1 - 1
test/capi_track.cpp

@@ -83,7 +83,7 @@ static void RTC_API closedCallback(int id, void *ptr) {
 static void RTC_API trackCallback(int pc, int tr, void *ptr) {
 static void RTC_API trackCallback(int pc, int tr, void *ptr) {
 	Peer *peer = (Peer *)ptr;
 	Peer *peer = (Peer *)ptr;
 	peer->tr = tr;
 	peer->tr = tr;
-	peer->connected = true;
+	rtcSetOpenCallback(tr, openCallback);
 	rtcSetClosedCallback(tr, closedCallback);
 	rtcSetClosedCallback(tr, closedCallback);
 
 
 	char buffer[1024];
 	char buffer[1024];

+ 20 - 16
test/connectivity.cpp

@@ -107,26 +107,29 @@ void test_connectivity() {
 			return;
 			return;
 		}
 		}
 
 
+		dc->onOpen([wdc = make_weak_ptr(dc)]() {
+			if (auto dc = wdc.lock())
+				dc->send("Hello from 2");
+		});
+
 		dc->onMessage([](variant<binary, string> message) {
 		dc->onMessage([](variant<binary, string> message) {
 			if (holds_alternative<string>(message)) {
 			if (holds_alternative<string>(message)) {
 				cout << "Message 2: " << get<string>(message) << endl;
 				cout << "Message 2: " << get<string>(message) << endl;
 			}
 			}
 		});
 		});
 
 
-		dc->send("Hello from 2");
-
 		std::atomic_store(&dc2, dc);
 		std::atomic_store(&dc2, dc);
 	});
 	});
 
 
 	auto dc1 = pc1.createDataChannel("test");
 	auto dc1 = pc1.createDataChannel("test");
-	dc1->onOpen([wdc1 = make_weak_ptr(dc1)]() {
-		auto dc1 = wdc1.lock();
-		if (!dc1)
-			return;
 
 
-		cout << "DataChannel 1: Open" << endl;
-		dc1->send("Hello from 1");
+	dc1->onOpen([wdc1 = make_weak_ptr(dc1)]() {
+		if (auto dc1 = wdc1.lock()) {
+			cout << "DataChannel 1: Open" << endl;
+			dc1->send("Hello from 1");
+		}
 	});
 	});
+
 	dc1->onMessage([](const variant<binary, string> &message) {
 	dc1->onMessage([](const variant<binary, string> &message) {
 		if (holds_alternative<string>(message)) {
 		if (holds_alternative<string>(message)) {
 			cout << "Message 1: " << get<string>(message) << endl;
 			cout << "Message 1: " << get<string>(message) << endl;
@@ -177,25 +180,26 @@ void test_connectivity() {
 			return;
 			return;
 		}
 		}
 
 
+		dc->onOpen([wdc = make_weak_ptr(dc)]() {
+			if (auto dc = wdc.lock())
+				dc->send("Second hello from 2");
+		});
+
 		dc->onMessage([](variant<binary, string> message) {
 		dc->onMessage([](variant<binary, string> message) {
 			if (holds_alternative<string>(message)) {
 			if (holds_alternative<string>(message)) {
 				cout << "Second Message 2: " << get<string>(message) << endl;
 				cout << "Second Message 2: " << get<string>(message) << endl;
 			}
 			}
 		});
 		});
 
 
-		dc->send("Send hello from 2");
-
 		std::atomic_store(&second2, dc);
 		std::atomic_store(&second2, dc);
 	});
 	});
 
 
 	auto second1 = pc1.createDataChannel("second");
 	auto second1 = pc1.createDataChannel("second");
 	second1->onOpen([wsecond1 = make_weak_ptr(dc1)]() {
 	second1->onOpen([wsecond1 = make_weak_ptr(dc1)]() {
-		auto second1 = wsecond1.lock();
-		if (!second1)
-			return;
-
-		cout << "Second DataChannel 1: Open" << endl;
-		second1->send("Second hello from 1");
+		if (auto second1 = wsecond1.lock()) {
+			cout << "Second DataChannel 1: Open" << endl;
+			second1->send("Second hello from 1");
+		}
 	});
 	});
 	dc1->onMessage([](const variant<binary, string> &message) {
 	dc1->onMessage([](const variant<binary, string> &message) {
 		if (holds_alternative<string>(message)) {
 		if (holds_alternative<string>(message)) {

+ 10 - 4
test/turn_connectivity.cpp

@@ -108,14 +108,17 @@ void test_turn_connectivity() {
 			return;
 			return;
 		}
 		}
 
 
+		dc->onOpen([wdc = make_weak_ptr(dc)]() {
+			if (auto dc = wdc.lock())
+				dc->send("Hello from 2");
+		});
+
 		dc->onMessage([](variant<binary, string> message) {
 		dc->onMessage([](variant<binary, string> message) {
 			if (holds_alternative<string>(message)) {
 			if (holds_alternative<string>(message)) {
 				cout << "Message 2: " << get<string>(message) << endl;
 				cout << "Message 2: " << get<string>(message) << endl;
 			}
 			}
 		});
 		});
 
 
-		dc->send("Hello from 2");
-
 		std::atomic_store(&dc2, dc);
 		std::atomic_store(&dc2, dc);
 	});
 	});
 
 
@@ -175,14 +178,17 @@ void test_turn_connectivity() {
 			return;
 			return;
 		}
 		}
 
 
+		dc->onOpen([wdc = make_weak_ptr(dc)]() {
+			if (auto dc = wdc.lock())
+				dc->send("Second hello from 2");
+		});
+
 		dc->onMessage([](variant<binary, string> message) {
 		dc->onMessage([](variant<binary, string> message) {
 			if (holds_alternative<string>(message)) {
 			if (holds_alternative<string>(message)) {
 				cout << "Second Message 2: " << get<string>(message) << endl;
 				cout << "Second Message 2: " << get<string>(message) << endl;
 			}
 			}
 		});
 		});
 
 
-		dc->send("Send hello from 2");
-
 		std::atomic_store(&second2, dc);
 		std::atomic_store(&second2, dc);
 	});
 	});