Browse Source

Merge pull request #7 from paullouisageneau/back-pressure

Back-pressure
Paul-Louis Ageneau 5 years ago
parent
commit
04df12b581

+ 15 - 6
include/rtc/channel.hpp

@@ -22,6 +22,7 @@
 #include "include.hpp"
 #include "include.hpp"
 
 
 #include <functional>
 #include <functional>
+#include <mutex>
 #include <variant>
 #include <variant>
 
 
 namespace rtc {
 namespace rtc {
@@ -30,30 +31,38 @@ class Channel {
 public:
 public:
 	virtual void close(void) = 0;
 	virtual void close(void) = 0;
 	virtual void send(const std::variant<binary, string> &data) = 0;
 	virtual void send(const std::variant<binary, string> &data) = 0;
-
+	virtual std::optional<std::variant<binary, string>> receive() = 0;
 	virtual bool isOpen(void) const = 0;
 	virtual bool isOpen(void) const = 0;
 	virtual bool isClosed(void) const = 0;
 	virtual bool isClosed(void) const = 0;
 
 
 	void onOpen(std::function<void()> callback);
 	void onOpen(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
 	void onError(std::function<void(const string &error)> callback);
 	void onError(std::function<void(const string &error)> callback);
+
 	void onMessage(std::function<void(const std::variant<binary, string> &data)> callback);
 	void onMessage(std::function<void(const std::variant<binary, string> &data)> callback);
 	void onMessage(std::function<void(const binary &data)> binaryCallback,
 	void onMessage(std::function<void(const binary &data)> binaryCallback,
 	               std::function<void(const string &data)> stringCallback);
 	               std::function<void(const string &data)> stringCallback);
 
 
+	void onAvailable(std::function<void()> callback);
+	void onSent(std::function<void()> callback);
+
 protected:
 protected:
 	virtual void triggerOpen(void);
 	virtual void triggerOpen(void);
 	virtual void triggerClosed(void);
 	virtual void triggerClosed(void);
 	virtual void triggerError(const string &error);
 	virtual void triggerError(const string &error);
-	virtual void triggerMessage(const std::variant<binary, string> &data);
+	virtual void triggerAvailable(size_t available);
+	virtual void triggerSent();
 
 
 private:
 private:
-	std::function<void()> mOpenCallback;
-	std::function<void()> mClosedCallback;
-	std::function<void(const string &)> mErrorCallback;
-	std::function<void(const std::variant<binary, string> &)> mMessageCallback;
+	synchronized_callback<> mOpenCallback;
+	synchronized_callback<> mClosedCallback;
+	synchronized_callback<const string &> mErrorCallback;
+	synchronized_callback<const std::variant<binary, string> &> mMessageCallback;
+	synchronized_callback<> mAvailableCallback;
+	synchronized_callback<> mSentCallback;
 };
 };
 
 
 } // namespace rtc
 } // namespace rtc
 
 
 #endif // RTC_CHANNEL_H
 #endif // RTC_CHANNEL_H
+

+ 18 - 5
include/rtc/datachannel.hpp

@@ -22,8 +22,10 @@
 #include "channel.hpp"
 #include "channel.hpp"
 #include "include.hpp"
 #include "include.hpp"
 #include "message.hpp"
 #include "message.hpp"
+#include "queue.hpp"
 #include "reliability.hpp"
 #include "reliability.hpp"
 
 
+#include <atomic>
 #include <chrono>
 #include <chrono>
 #include <functional>
 #include <functional>
 #include <variant>
 #include <variant>
@@ -35,13 +37,19 @@ class PeerConnection;
 
 
 class DataChannel : public Channel {
 class DataChannel : public Channel {
 public:
 public:
-	DataChannel(unsigned int stream_, string label_, string protocol_, Reliability reliability_);
-	DataChannel(unsigned int stream, std::shared_ptr<SctpTransport> sctpTransport);
+	DataChannel(std::shared_ptr<PeerConnection> pc, unsigned int stream, string label,
+	            string protocol, Reliability reliability);
+	DataChannel(std::shared_ptr<PeerConnection> pc, std::shared_ptr<SctpTransport> transport,
+	            unsigned int stream);
 	~DataChannel();
 	~DataChannel();
 
 
 	void close(void);
 	void close(void);
 	void send(const std::variant<binary, string> &data);
 	void send(const std::variant<binary, string> &data);
 	void send(const byte *data, size_t size);
 	void send(const byte *data, size_t size);
+	std::optional<std::variant<binary, string>> receive();
+
+	size_t available() const;
+	size_t availableSize() const;
 
 
 	unsigned int stream() const;
 	unsigned int stream() const;
 	string label() const;
 	string label() const;
@@ -56,14 +64,19 @@ private:
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
 	void processOpenMessage(message_ptr message);
 	void processOpenMessage(message_ptr message);
 
 
-	unsigned int mStream;
+	const std::shared_ptr<PeerConnection> mPeerConnection; // keeps the PeerConnection alive
 	std::shared_ptr<SctpTransport> mSctpTransport;
 	std::shared_ptr<SctpTransport> mSctpTransport;
+
+	unsigned int mStream;
 	string mLabel;
 	string mLabel;
 	string mProtocol;
 	string mProtocol;
 	std::shared_ptr<Reliability> mReliability;
 	std::shared_ptr<Reliability> mReliability;
 
 
-	bool mIsOpen = false;
-	bool mIsClosed = false;
+	std::atomic<bool> mIsOpen = false;
+	std::atomic<bool> mIsClosed = false;
+
+	Queue<message_ptr> mRecvQueue;
+	std::atomic<size_t> mRecvSize = 0;
 
 
 	friend class PeerConnection;
 	friend class PeerConnection;
 };
 };

+ 30 - 0
include/rtc/include.hpp

@@ -20,7 +20,9 @@
 #define RTC_INCLUDE_H
 #define RTC_INCLUDE_H
 
 
 #include <cstddef>
 #include <cstddef>
+#include <functional>
 #include <memory>
 #include <memory>
+#include <mutex>
 #include <optional>
 #include <optional>
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
@@ -33,6 +35,7 @@ using binary = std::vector<byte>;
 
 
 using std::nullopt;
 using std::nullopt;
 
 
+using std::size_t;
 using std::uint16_t;
 using std::uint16_t;
 using std::uint32_t;
 using std::uint32_t;
 using std::uint64_t;
 using std::uint64_t;
@@ -41,10 +44,37 @@ using std::uint8_t;
 const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
 const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
 const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
 const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
 
 
+const size_t RECV_QUEUE_SIZE = 256; // DataChannel receive queue size in messages
+                                    // (0 means unlimited)
+
 const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
 const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
 
 
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
 template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
 template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
+
+template <typename... P> class synchronized_callback {
+public:
+	synchronized_callback() = default;
+	~synchronized_callback() { *this = nullptr; }
+
+	synchronized_callback &operator=(std::function<void(P...)> func) {
+		std::lock_guard<std::recursive_mutex> lock(mutex);
+		callback = func;
+		return *this;
+	}
+
+	void operator()(P... args) const {
+		std::lock_guard<std::recursive_mutex> lock(mutex);
+		if (callback)
+			callback(args...);
+	}
+
+	operator bool() const { return callback ? true : false; }
+
+private:
+	std::function<void(P...)> callback;
+	mutable std::recursive_mutex mutex;
+};
 }
 }
 
 
 #endif
 #endif

+ 6 - 7
include/rtc/peerconnection.hpp

@@ -89,6 +89,7 @@ private:
 
 
 	bool checkFingerprint(std::weak_ptr<PeerConnection> weak_this, const std::string &fingerprint) const;
 	bool checkFingerprint(std::weak_ptr<PeerConnection> weak_this, const std::string &fingerprint) const;
 	void forwardMessage(std::weak_ptr<PeerConnection> weak_this, message_ptr message);
 	void forwardMessage(std::weak_ptr<PeerConnection> weak_this, message_ptr message);
+	void forwardSent(std::weak_ptr<PeerConnection> weak_this, uint16_t stream);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void openDataChannels();
 	void closeDataChannels();
 	void closeDataChannels();
@@ -114,13 +115,11 @@ private:
 	std::atomic<State> mState;
 	std::atomic<State> mState;
 	std::atomic<GatheringState> mGatheringState;
 	std::atomic<GatheringState> mGatheringState;
 
 
-	std::list<std::thread> mResolveThreads;
-
-	std::function<void(std::shared_ptr<DataChannel> dataChannel)> mDataChannelCallback;
-	std::function<void(const Description &description)> mLocalDescriptionCallback;
-	std::function<void(const Candidate &candidate)> mLocalCandidateCallback;
-	std::function<void(State state)> mStateChangeCallback;
-	std::function<void(GatheringState state)> mGatheringStateChangeCallback;
+	synchronized_callback<std::shared_ptr<DataChannel>> mDataChannelCallback;
+	synchronized_callback<const Description &> mLocalDescriptionCallback;
+	synchronized_callback<const Candidate &> mLocalCandidateCallback;
+	synchronized_callback<State> mStateChangeCallback;
+	synchronized_callback<GatheringState> mGatheringStateChangeCallback;
 };
 };
 
 
 } // namespace rtc
 } // namespace rtc

+ 39 - 21
src/queue.hpp → include/rtc/queue.hpp

@@ -32,32 +32,36 @@ namespace rtc {
 
 
 template <typename T> class Queue {
 template <typename T> class Queue {
 public:
 public:
-	Queue();
+	Queue(std::size_t limit = 0);
 	~Queue();
 	~Queue();
 
 
 	void stop();
 	void stop();
 	bool empty() const;
 	bool empty() const;
+	size_t size() const;
 	void push(const T &element);
 	void push(const T &element);
 	std::optional<T> pop();
 	std::optional<T> pop();
+	std::optional<T> tryPop();
 	void wait();
 	void wait();
 	void wait(const std::chrono::milliseconds &duration);
 	void wait(const std::chrono::milliseconds &duration);
 
 
 private:
 private:
+	const size_t mLimit;
 	std::queue<T> mQueue;
 	std::queue<T> mQueue;
-	std::condition_variable mCondition;
-	std::atomic<bool> mStopping;
+	std::condition_variable mPopCondition, mPushCondition;
+	bool mStopping = false;
 
 
 	mutable std::mutex mMutex;
 	mutable std::mutex mMutex;
 };
 };
 
 
-template <typename T> Queue<T>::Queue() : mStopping(false) {}
+template <typename T> Queue<T>::Queue(size_t limit) : mLimit(limit) {}
 
 
 template <typename T> Queue<T>::~Queue() { stop(); }
 template <typename T> Queue<T>::~Queue() { stop(); }
 
 
 template <typename T> void Queue<T>::stop() {
 template <typename T> void Queue<T>::stop() {
 	std::lock_guard<std::mutex> lock(mMutex);
 	std::lock_guard<std::mutex> lock(mMutex);
 	mStopping = true;
 	mStopping = true;
-	mCondition.notify_all();
+	mPopCondition.notify_all();
+	mPushCondition.notify_all();
 }
 }
 
 
 template <typename T> bool Queue<T>::empty() const {
 template <typename T> bool Queue<T>::empty() const {
@@ -65,37 +69,51 @@ template <typename T> bool Queue<T>::empty() const {
 	return mQueue.empty();
 	return mQueue.empty();
 }
 }
 
 
-template <typename T> void Queue<T>::push(const T &element) {
+template <typename T> size_t Queue<T>::size() const {
 	std::lock_guard<std::mutex> lock(mMutex);
 	std::lock_guard<std::mutex> lock(mMutex);
-	if (mStopping)
-		return;
-	mQueue.push(element);
-	mCondition.notify_one();
+	return mQueue.size();
+}
+
+template <typename T> void Queue<T>::push(const T &element) {
+	std::unique_lock<std::mutex> lock(mMutex);
+	mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
+	if (!mStopping) {
+		mQueue.push(element);
+		mPopCondition.notify_one();
+	}
 }
 }
 
 
 template <typename T> std::optional<T> Queue<T>::pop() {
 template <typename T> std::optional<T> Queue<T>::pop() {
 	std::unique_lock<std::mutex> lock(mMutex);
 	std::unique_lock<std::mutex> lock(mMutex);
-	while (mQueue.empty()) {
-		if (mStopping)
-			return nullopt;
-		mCondition.wait(lock);
+	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
+	if (!mQueue.empty()) {
+		std::optional<T> element(std::move(mQueue.front()));
+		mQueue.pop();
+		return element;
+	} else {
+		return nullopt;
 	}
 	}
+}
 
 
-	std::optional<T> element = mQueue.front();
-	mQueue.pop();
-	return element;
+template <typename T> std::optional<T> Queue<T>::tryPop() {
+	std::unique_lock<std::mutex> lock(mMutex);
+	if (!mQueue.empty()) {
+		std::optional<T> element(std::move(mQueue.front()));
+		mQueue.pop();
+		return element;
+	} else {
+		return nullopt;
+	}
 }
 }
 
 
 template <typename T> void Queue<T>::wait() {
 template <typename T> void Queue<T>::wait() {
 	std::unique_lock<std::mutex> lock(mMutex);
 	std::unique_lock<std::mutex> lock(mMutex);
-	if (mQueue.empty() && !mStopping)
-		mCondition.wait(lock);
+	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 }
 }
 
 
 template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) {
 template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) {
 	std::unique_lock<std::mutex> lock(mMutex);
 	std::unique_lock<std::mutex> lock(mMutex);
-	if (mQueue.empty() && !mStopping)
-		mCondition.wait_for(lock, duration);
+	mPopCondition.wait_for(lock, duration, [this]() { return !mQueue.empty() || mStopping; });
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc

+ 33 - 15
src/channel.cpp

@@ -18,11 +18,17 @@
 
 
 #include "channel.hpp"
 #include "channel.hpp"
 
 
+namespace {}
+
 namespace rtc {
 namespace rtc {
 
 
-void Channel::onOpen(std::function<void()> callback) { mOpenCallback = callback; }
+void Channel::onOpen(std::function<void()> callback) {
+	mOpenCallback = callback;
+}
 
 
-void Channel::onClosed(std::function<void()> callback) { mClosedCallback = callback; }
+void Channel::onClosed(std::function<void()> callback) {
+	mClosedCallback = callback;
+}
 
 
 void Channel::onError(std::function<void(const string &error)> callback) {
 void Channel::onError(std::function<void(const string &error)> callback) {
 	mErrorCallback = callback;
 	mErrorCallback = callback;
@@ -30,6 +36,10 @@ void Channel::onError(std::function<void(const string &error)> callback) {
 
 
 void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) {
 void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) {
 	mMessageCallback = callback;
 	mMessageCallback = callback;
+
+	// Pass pending messages
+	while (auto message = receive())
+		mMessageCallback(*message);
 }
 }
 
 
 void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
 void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
@@ -39,25 +49,33 @@ void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
 	});
 	});
 }
 }
 
 
-void Channel::triggerOpen(void) {
-	if (mOpenCallback)
-		mOpenCallback();
+void Channel::onAvailable(std::function<void()> callback) {
+	mAvailableCallback = callback;
 }
 }
 
 
-void Channel::triggerClosed(void) {
-	if (mClosedCallback)
-		mClosedCallback();
+void Channel::onSent(std::function<void()> callback) {
+	mSentCallback = callback;
 }
 }
 
 
-void Channel::triggerError(const string &error) {
-	if (mErrorCallback)
-		mErrorCallback(error);
-}
+void Channel::triggerOpen() { mOpenCallback(); }
 
 
-void Channel::triggerMessage(const std::variant<binary, string> &data) {
-	if (mMessageCallback)
-		mMessageCallback(data);
+void Channel::triggerClosed() { mClosedCallback(); }
+
+void Channel::triggerError(const string &error) { mErrorCallback(error); }
+
+void Channel::triggerAvailable(size_t available) {
+	if (available == 1)
+		mAvailableCallback();
+
+	while (mMessageCallback && available--) {
+		auto message = receive();
+		if (!message)
+			break;
+		mMessageCallback(*message);
+	}
 }
 }
 
 
+void Channel::triggerSent() { mSentCallback(); }
+
 } // namespace rtc
 } // namespace rtc
 
 

+ 57 - 23
src/datachannel.cpp

@@ -57,21 +57,23 @@ struct CloseMessage {
 };
 };
 #pragma pack(pop)
 #pragma pack(pop)
 
 
-DataChannel::DataChannel(unsigned int stream, string label, string protocol,
-                         Reliability reliability)
-    : mStream(stream), mLabel(std::move(label)), mProtocol(std::move(protocol)),
-      mReliability(std::make_shared<Reliability>(std::move(reliability))) {}
-
-DataChannel::DataChannel(unsigned int stream, shared_ptr<SctpTransport> sctpTransport)
-    : mStream(stream), mSctpTransport(sctpTransport),
+DataChannel::DataChannel(shared_ptr<PeerConnection> pc, unsigned int stream, string label,
+                         string protocol, Reliability reliability)
+    : mPeerConnection(std::move(pc)), mStream(stream), mLabel(std::move(label)),
+      mProtocol(std::move(protocol)),
+      mReliability(std::make_shared<Reliability>(std::move(reliability))),
+      mRecvQueue(RECV_QUEUE_SIZE) {}
+
+DataChannel::DataChannel(shared_ptr<PeerConnection> pc, shared_ptr<SctpTransport> transport,
+                         unsigned int stream)
+    : mPeerConnection(std::move(pc)), mSctpTransport(transport), mStream(stream),
       mReliability(std::make_shared<Reliability>()) {}
       mReliability(std::make_shared<Reliability>()) {}
 
 
 DataChannel::~DataChannel() { close(); }
 DataChannel::~DataChannel() { close(); }
 
 
 void DataChannel::close() {
 void DataChannel::close() {
 	mIsOpen = false;
 	mIsOpen = false;
-	if (!mIsClosed) {
-		mIsClosed = true;
+	if (!mIsClosed.exchange(true)) {
 		if (mSctpTransport)
 		if (mSctpTransport)
 			mSctpTransport->reset(mStream);
 			mSctpTransport->reset(mStream);
 	}
 	}
@@ -88,7 +90,8 @@ void DataChannel::send(const std::variant<binary, string> &data) {
 		    auto *b = reinterpret_cast<const byte *>(d.data());
 		    auto *b = reinterpret_cast<const byte *>(d.data());
 		    // Before the ACK has been received on a DataChannel, all messages must be sent ordered
 		    // Before the ACK has been received on a DataChannel, all messages must be sent ordered
 		    auto reliability = mIsOpen ? mReliability : nullptr;
 		    auto reliability = mIsOpen ? mReliability : nullptr;
-		    mSctpTransport->send(make_message(b, b + d.size(), type, mStream, reliability));
+		    auto message = make_message(b, b + d.size(), type, mStream, reliability);
+		    mSctpTransport->send(message);
 	    },
 	    },
 	    data);
 	    data);
 }
 }
@@ -98,9 +101,41 @@ void DataChannel::send(const byte *data, size_t size) {
 		return;
 		return;
 
 
 	auto reliability = mIsOpen ? mReliability : nullptr;
 	auto reliability = mIsOpen ? mReliability : nullptr;
-	mSctpTransport->send(make_message(data, data + size, Message::Binary, mStream, reliability));
+	auto message = make_message(data, data + size, Message::Binary, mStream, reliability);
+	mSctpTransport->send(message);
 }
 }
 
 
+std::optional<std::variant<binary, string>> DataChannel::receive() {
+	while (auto opt = mRecvQueue.tryPop()) {
+		auto message = *opt;
+		switch (message->type) {
+		case Message::Control: {
+			auto raw = reinterpret_cast<const uint8_t *>(message->data());
+			if (raw[0] == MESSAGE_CLOSE) {
+				if (mIsOpen) {
+					close();
+					triggerClosed();
+				}
+			}
+			break;
+		}
+		case Message::String:
+			mRecvSize -= message->size();
+			return std::make_optional(
+			    string(reinterpret_cast<const char *>(message->data()), message->size()));
+		case Message::Binary:
+			mRecvSize -= message->size();
+			return std::make_optional(std::move(*message));
+		}
+	}
+
+	return nullopt;
+}
+
+size_t DataChannel::available() const { return mRecvQueue.size(); }
+
+size_t DataChannel::availableSize() const { return mRecvSize; }
+
 unsigned int DataChannel::stream() const { return mStream; }
 unsigned int DataChannel::stream() const { return mStream; }
 
 
 string DataChannel::label() const { return mLabel; }
 string DataChannel::label() const { return mLabel; }
@@ -153,16 +188,14 @@ void DataChannel::incoming(message_ptr message) {
 			processOpenMessage(message);
 			processOpenMessage(message);
 			break;
 			break;
 		case MESSAGE_ACK:
 		case MESSAGE_ACK:
-			if (!mIsOpen) {
-				mIsOpen = true;
+			if (!mIsOpen.exchange(true)) {
 				triggerOpen();
 				triggerOpen();
 			}
 			}
 			break;
 			break;
 		case MESSAGE_CLOSE:
 		case MESSAGE_CLOSE:
-			if (mIsOpen) {
-				close();
-				triggerClosed();
-			}
+			// The close message will be processed in-order in receive()
+			mRecvQueue.push(message);
+			triggerAvailable(mRecvQueue.size());
 			break;
 			break;
 		default:
 		default:
 			// Ignore
 			// Ignore
@@ -170,15 +203,16 @@ void DataChannel::incoming(message_ptr message) {
 		}
 		}
 		break;
 		break;
 	}
 	}
-	case Message::String: {
-		triggerMessage(string(reinterpret_cast<const char *>(message->data()), message->size()));
+	case Message::String:
+	case Message::Binary:
+		mRecvSize += message->size();
+		mRecvQueue.push(message);
+		triggerAvailable(mRecvQueue.size());
 		break;
 		break;
-	}
-	case Message::Binary: {
-		triggerMessage(*message);
+	default:
+		// Ignore
 		break;
 		break;
 	}
 	}
-	}
 }
 }
 
 
 void DataChannel::processOpenMessage(message_ptr message) {
 void DataChannel::processOpenMessage(message_ptr message) {

+ 12 - 6
src/dtlstransport.cpp

@@ -83,11 +83,13 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 }
 }
 
 
 DtlsTransport::~DtlsTransport() {
 DtlsTransport::~DtlsTransport() {
-  onRecv(nullptr);
+	onRecv(nullptr); // unset recv callback
+
 	mIncomingQueue.stop();
 	mIncomingQueue.stop();
-  mRecvThread.join();
-  gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
-  gnutls_deinit(mSession);
+	mRecvThread.join();
+
+	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
+	gnutls_deinit(mSession);
 }
 }
 
 
 DtlsTransport::State DtlsTransport::state() const { return mState; }
 DtlsTransport::State DtlsTransport::state() const { return mState; }
@@ -110,8 +112,8 @@ bool DtlsTransport::send(message_ptr message) {
 void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); }
 void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); }
 
 
 void DtlsTransport::changeState(State state) {
 void DtlsTransport::changeState(State state) {
-	mState = state;
-	mStateChangeCallback(state);
+	if (mState.exchange(state) != state)
+		mStateChangeCallback(state);
 }
 }
 
 
 void DtlsTransport::runRecvLoop() {
 void DtlsTransport::runRecvLoop() {
@@ -154,6 +156,10 @@ void DtlsTransport::runRecvLoop() {
 				ret = gnutls_record_recv(mSession, buffer, bufferSize);
 				ret = gnutls_record_recv(mSession, buffer, bufferSize);
 			} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
 			} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
 
 
+			// Consider premature termination as remote closing
+			if (ret == GNUTLS_E_PREMATURE_TERMINATION)
+				break;
+
 			if (check_gnutls(ret)) {
 			if (check_gnutls(ret)) {
 				if (ret == 0) {
 				if (ret == 0) {
 					// Closed
 					// Closed

+ 3 - 4
src/icetransport.cpp

@@ -132,8 +132,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
 
 
 IceTransport::~IceTransport() {
 IceTransport::~IceTransport() {
 	g_main_loop_quit(mMainLoop.get());
 	g_main_loop_quit(mMainLoop.get());
-	if (mMainLoopThread.joinable())
-		mMainLoopThread.join();
+	mMainLoopThread.join();
 }
 }
 
 
 Description::Role IceTransport::role() const { return mRole; }
 Description::Role IceTransport::role() const { return mRole; }
@@ -227,8 +226,8 @@ void IceTransport::outgoing(message_ptr message) {
 }
 }
 
 
 void IceTransport::changeState(State state) {
 void IceTransport::changeState(State state) {
-	mState = state;
-	mStateChangeCallback(mState);
+	if (mState.exchange(state) != state)
+		mStateChangeCallback(mState);
 }
 }
 
 
 void IceTransport::changeGatheringState(GatheringState state) {
 void IceTransport::changeGatheringState(GatheringState state) {

+ 53 - 30
src/peerconnection.cpp

@@ -28,7 +28,6 @@ namespace rtc {
 
 
 using namespace std::placeholders;
 using namespace std::placeholders;
 
 
-using std::function;
 using std::shared_ptr;
 using std::shared_ptr;
 using std::weak_ptr;
 using std::weak_ptr;
 
 
@@ -37,10 +36,7 @@ PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
 PeerConnection::PeerConnection(const Configuration &config)
 PeerConnection::PeerConnection(const Configuration &config)
     : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
     : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
 
 
-PeerConnection::~PeerConnection() {
-	for (auto &t : mResolveThreads)
-		t.join();
-}
+PeerConnection::~PeerConnection() {}
 
 
 const Configuration *PeerConnection::config() const { return &mConfig; }
 const Configuration *PeerConnection::config() const { return &mConfig; }
 
 
@@ -94,10 +90,13 @@ void PeerConnection::addRemoteCandidate(Candidate candidate) {
 		mIceTransport->addRemoteCandidate(candidate);
 		mIceTransport->addRemoteCandidate(candidate);
 	} else {
 	} else {
 		// OK, we might need a lookup, do it asynchronously
 		// OK, we might need a lookup, do it asynchronously
-		mResolveThreads.emplace_back(std::thread([this, candidate]() mutable {
+		weak_ptr<IceTransport> weakIceTransport{mIceTransport};
+		std::thread t([weakIceTransport, candidate]() mutable {
 			if (candidate.resolve(Candidate::ResolveMode::Lookup))
 			if (candidate.resolve(Candidate::ResolveMode::Lookup))
-				mIceTransport->addRemoteCandidate(candidate);
-		}));
+				if (auto iceTransport = weakIceTransport.lock())
+					iceTransport->addRemoteCandidate(candidate);
+		});
+		t.detach();
 	}
 	}
 }
 }
 
 
@@ -128,7 +127,8 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
 			throw std::runtime_error("Too many DataChannels");
 			throw std::runtime_error("Too many DataChannels");
 	}
 	}
 
 
-	auto channel = std::make_shared<DataChannel>(stream, label, protocol, reliability);
+	auto channel =
+	    std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability);
 	mDataChannels.insert(std::make_pair(stream, channel));
 	mDataChannels.insert(std::make_pair(stream, channel));
 
 
 	if (!mIceTransport) {
 	if (!mIceTransport) {
@@ -232,10 +232,16 @@ void PeerConnection::initDtlsTransport() {
 void PeerConnection::initSctpTransport() {
 void PeerConnection::initSctpTransport() {
 	uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
 	uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
 	mSctpTransport = std::make_shared<SctpTransport>(
 	mSctpTransport = std::make_shared<SctpTransport>(
-	    mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, weak_ptr<PeerConnection>{shared_from_this()}, _1),
-	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](SctpTransport::State state) {
-        auto strong_this = weak_this.lock();
-        if (!strong_this) return;
+	    mDtlsTransport, sctpPort,
+	    std::bind(&PeerConnection::forwardMessage, this,
+	              weak_ptr<PeerConnection>{shared_from_this()}, _1),
+	    std::bind(&PeerConnection::forwardSent, this, weak_ptr<PeerConnection>{shared_from_this()},
+	              _1),
+	    [this,
+	     weak_this = weak_ptr<PeerConnection>{shared_from_this()}](SctpTransport::State state) {
+		    auto strong_this = weak_this.lock();
+		    if (!strong_this)
+			    return;
 
 
 		    switch (state) {
 		    switch (state) {
 		    case SctpTransport::State::Connected:
 		    case SctpTransport::State::Connected:
@@ -292,7 +298,8 @@ void PeerConnection::forwardMessage(weak_ptr<PeerConnection> weak_this, message_
 		unsigned int remoteParity = (mIceTransport->role() == Description::Role::Active) ? 1 : 0;
 		unsigned int remoteParity = (mIceTransport->role() == Description::Role::Active) ? 1 : 0;
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
 		    message->stream % 2 == remoteParity) {
 		    message->stream % 2 == remoteParity) {
-			channel = std::make_shared<DataChannel>(message->stream, mSctpTransport);
+			channel =
+			    std::make_shared<DataChannel>(shared_from_this(), mSctpTransport, message->stream);
 			channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, weak_this, weak_ptr<DataChannel>{channel}));
 			channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, weak_this, weak_ptr<DataChannel>{channel}));
 			mDataChannels.insert(std::make_pair(message->stream, channel));
 			mDataChannels.insert(std::make_pair(message->stream, channel));
 		} else {
 		} else {
@@ -305,6 +312,24 @@ void PeerConnection::forwardMessage(weak_ptr<PeerConnection> weak_this, message_
 	channel->incoming(message);
 	channel->incoming(message);
 }
 }
 
 
+void PeerConnection::forwardSent(weak_ptr<PeerConnection> weak_this, uint16_t stream) {
+	auto strong_this = weak_this.lock();
+	if (!strong_this)
+		return;
+
+	shared_ptr<DataChannel> channel;
+	if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) {
+		channel = it->second.lock();
+		if (!channel || channel->isClosed()) {
+			mDataChannels.erase(it);
+			channel = nullptr;
+		}
+	}
+
+	if (channel)
+		channel->triggerSent();
+}
+
 void PeerConnection::iterateDataChannels(
 void PeerConnection::iterateDataChannels(
     std::function<void(shared_ptr<DataChannel> channel)> func) {
     std::function<void(shared_ptr<DataChannel> channel)> func) {
 	auto it = mDataChannels.begin();
 	auto it = mDataChannels.begin();
@@ -334,43 +359,41 @@ void PeerConnection::processLocalDescription(Description description) {
 	mLocalDescription->setFingerprint(mCertificate->fingerprint());
 	mLocalDescription->setFingerprint(mCertificate->fingerprint());
 	mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
 	mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
 
 
-	if (mLocalDescriptionCallback)
-		mLocalDescriptionCallback(*mLocalDescription);
+	mLocalDescriptionCallback(*mLocalDescription);
 }
 }
 
 
 void PeerConnection::processLocalCandidate(weak_ptr<PeerConnection> weak_this, Candidate candidate) {
 void PeerConnection::processLocalCandidate(weak_ptr<PeerConnection> weak_this, Candidate candidate) {
-  auto strong_this = weak_this.lock();
-  if (!strong_this) return;
+	auto strong_this = weak_this.lock();
+	if (!strong_this)
+		return;
 
 
 	if (!mLocalDescription)
 	if (!mLocalDescription)
 		throw std::logic_error("Got a local candidate without local description");
 		throw std::logic_error("Got a local candidate without local description");
 
 
 	mLocalDescription->addCandidate(candidate);
 	mLocalDescription->addCandidate(candidate);
 
 
-	if (mLocalCandidateCallback)
-		mLocalCandidateCallback(candidate);
+	mLocalCandidateCallback(candidate);
 }
 }
 
 
 void PeerConnection::triggerDataChannel(weak_ptr<PeerConnection> weak_this, weak_ptr<DataChannel> weakDataChannel) {
 void PeerConnection::triggerDataChannel(weak_ptr<PeerConnection> weak_this, weak_ptr<DataChannel> weakDataChannel) {
-  auto strong_this = weak_this.lock();
-  if (!strong_this) return;
+	auto strong_this = weak_this.lock();
+	if (!strong_this)
+		return;
 
 
-  auto dataChannel = weakDataChannel.lock();
-  if (!dataChannel) return;
+	auto dataChannel = weakDataChannel.lock();
+	if (!dataChannel)
+		return;
 
 
-	if (mDataChannelCallback)
-		mDataChannelCallback(dataChannel);
+	mDataChannelCallback(dataChannel);
 }
 }
 
 
 void PeerConnection::changeState(State state) {
 void PeerConnection::changeState(State state) {
-	mState = state;
-	if (mStateChangeCallback)
+	if (mState.exchange(state) != state)
 		mStateChangeCallback(state);
 		mStateChangeCallback(state);
 }
 }
 
 
 void PeerConnection::changeGatheringState(GatheringState state) {
 void PeerConnection::changeGatheringState(GatheringState state) {
-	mGatheringState = state;
-	if (mGatheringStateChangeCallback)
+	if (mGatheringState.exchange(state) != state)
 		mGatheringStateChangeCallback(state);
 		mGatheringStateChangeCallback(state);
 }
 }
 
 

+ 38 - 18
src/sctptransport.cpp

@@ -48,11 +48,11 @@ void SctpTransport::GlobalCleanup() {
 }
 }
 
 
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
-                             state_callback stateChangeCallback)
+                             sent_callback sentCallback, state_callback stateChangeCallback)
     : Transport(lower), mPort(port), mState(State::Disconnected),
     : Transport(lower), mPort(port), mState(State::Disconnected),
-      mStateChangeCallback(std::move(stateChangeCallback)) {
+      mSentCallback(std::move(sentCallback)), mStateChangeCallback(std::move(stateChangeCallback)) {
 
 
-  onRecv(recv);
+	onRecv(recv);
 
 
 	GlobalInit();
 	GlobalInit();
 	usrsctp_register_address(this);
 	usrsctp_register_address(this);
@@ -116,23 +116,22 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
 	if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)))
 	if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)))
 		throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
 		throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
 
 
-	mConnectThread = std::thread(&SctpTransport::runConnect, this);
+	mSendThread = std::thread(&SctpTransport::runConnectAndSendLoop, this);
 }
 }
 
 
 SctpTransport::~SctpTransport() {
 SctpTransport::~SctpTransport() {
-	onRecv(nullptr);
+	onRecv(nullptr); // unset recv callback
 	mStopping = true;
 	mStopping = true;
 	mConnectCondition.notify_all();
 	mConnectCondition.notify_all();
 	mSendQueue.stop();
 	mSendQueue.stop();
 
 
-	if (mConnectThread.joinable())
-		mConnectThread.join();
-
 	if (mSock) {
 	if (mSock) {
 		usrsctp_shutdown(mSock, SHUT_RDWR);
 		usrsctp_shutdown(mSock, SHUT_RDWR);
 		usrsctp_close(mSock);
 		usrsctp_close(mSock);
 	}
 	}
 
 
+	mSendThread.join();
+
 	usrsctp_deregister_address(this);
 	usrsctp_deregister_address(this);
 	GlobalCleanup();
 	GlobalCleanup();
 }
 }
@@ -143,6 +142,7 @@ bool SctpTransport::send(message_ptr message) {
 	if (!message || mStopping)
 	if (!message || mStopping)
 		return false;
 		return false;
 
 
+	updateSendCount(message->stream, 1);
 	mSendQueue.push(message);
 	mSendQueue.push(message);
 	return true;
 	return true;
 }
 }
@@ -179,11 +179,11 @@ void SctpTransport::incoming(message_ptr message) {
 }
 }
 
 
 void SctpTransport::changeState(State state) {
 void SctpTransport::changeState(State state) {
-	mState = state;
-	mStateChangeCallback(state);
+	if (mState.exchange(state) != state)
+		mStateChangeCallback(state);
 }
 }
 
 
-void SctpTransport::runConnect() {
+void SctpTransport::runConnectAndSendLoop() {
 	try {
 	try {
 		changeState(State::Connecting);
 		changeState(State::Connecting);
 
 
@@ -213,15 +213,19 @@ void SctpTransport::runConnect() {
 	}
 	}
 
 
 	try {
 	try {
-		while (auto message = mSendQueue.pop()) {
-			if (!doSend(*message))
+		while (auto next = mSendQueue.pop()) {
+			auto message = *next;
+			updateSendCount(message->stream, -1);
+			if (!doSend(message))
 				throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
 				throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
 		}
 		}
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		std::cerr << "SCTP send: " << e.what() << std::endl;
 		std::cerr << "SCTP send: " << e.what() << std::endl;
-		mStopping = true;
-		return;
 	}
 	}
+
+	changeState(State::Disconnected);
+	mStopping = true;
+	mConnectCondition.notify_all();
 }
 }
 
 
 bool SctpTransport::doSend(message_ptr message) {
 bool SctpTransport::doSend(message_ptr message) {
@@ -273,12 +277,24 @@ bool SctpTransport::doSend(message_ptr message) {
 		break;
 		break;
 	}
 	}
 
 
+	ssize_t ret;
 	if (!message->empty()) {
 	if (!message->empty()) {
-		return usrsctp_sendv(mSock, message->data(), message->size(), nullptr, 0, &spa, sizeof(spa),
-		                     SCTP_SENDV_SPA, 0) > 0;
+		ret = usrsctp_sendv(mSock, message->data(), message->size(), nullptr, 0, &spa, sizeof(spa),
+		                    SCTP_SENDV_SPA, 0);
 	} else {
 	} else {
 		const char zero = 0;
 		const char zero = 0;
-		return usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0) > 0;
+		ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
+	}
+	return ret > 0;
+}
+
+void SctpTransport::updateSendCount(uint16_t streamId, int delta) {
+	std::lock_guard<std::mutex> lock(mSendCountMutex);
+	auto it = mSendCount.insert(std::make_pair(streamId, 0)).first;
+	it->second = std::max(it->second + delta, 0);
+	if (it->second == 0) {
+		mSendCount.erase(it);
+		mSentCallback(streamId);
 	}
 	}
 }
 }
 
 
@@ -296,6 +312,10 @@ int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_
 
 
 int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
 int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
                            struct sctp_rcvinfo info, int flags) {
                            struct sctp_rcvinfo info, int flags) {
+	if (!data) {
+		recv(nullptr);
+		return 0;
+	}
 	if (flags & MSG_NOTIFICATION) {
 	if (flags & MSG_NOTIFICATION) {
 		processNotification((union sctp_notification *)data, len);
 		processNotification((union sctp_notification *)data, len);
 	} else {
 	} else {

+ 11 - 5
src/sctptransport.hpp

@@ -26,6 +26,7 @@
 
 
 #include <condition_variable>
 #include <condition_variable>
 #include <functional>
 #include <functional>
+#include <map>
 #include <mutex>
 #include <mutex>
 #include <thread>
 #include <thread>
 
 
@@ -40,10 +41,11 @@ class SctpTransport : public Transport {
 public:
 public:
 	enum class State { Disconnected, Connecting, Connected, Failed };
 	enum class State { Disconnected, Connecting, Connected, Failed };
 
 
+	using sent_callback = std::function<void(uint16_t streamId)>;
 	using state_callback = std::function<void(State state)>;
 	using state_callback = std::function<void(State state)>;
 
 
 	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
 	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
-	              state_callback stateChangeCallback);
+	              sent_callback sent, state_callback stateChangeCallback);
 	~SctpTransport();
 	~SctpTransport();
 
 
 	State state() const;
 	State state() const;
@@ -62,8 +64,9 @@ private:
 
 
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
 	void changeState(State state);
 	void changeState(State state);
-	void runConnect();
+	void runConnectAndSendLoop();
 	bool doSend(message_ptr message);
 	bool doSend(message_ptr message);
+	void updateSendCount(uint16_t streamId, int delta);
 
 
 	int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df);
 	int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df);
 
 
@@ -73,18 +76,21 @@ private:
 	void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid);
 	void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid);
 	void processNotification(const union sctp_notification *notify, size_t len);
 	void processNotification(const union sctp_notification *notify, size_t len);
 
 
+	const uint16_t mPort;
 	struct socket *mSock;
 	struct socket *mSock;
-	uint16_t mPort;
 
 
 	Queue<message_ptr> mSendQueue;
 	Queue<message_ptr> mSendQueue;
-	std::thread mConnectThread;
+	std::thread mSendThread;
+	std::map<uint16_t, int> mSendCount;
+	std::mutex mSendCountMutex;
+	sent_callback mSentCallback;
+
 	std::mutex mConnectMutex;
 	std::mutex mConnectMutex;
 	std::condition_variable mConnectCondition;
 	std::condition_variable mConnectCondition;
 	std::atomic<bool> mConnectDataSent = false;
 	std::atomic<bool> mConnectDataSent = false;
 	std::atomic<bool> mStopping = false;
 	std::atomic<bool> mStopping = false;
 
 
 	std::atomic<State> mState;
 	std::atomic<State> mState;
-
 	state_callback mStateChangeCallback;
 	state_callback mStateChangeCallback;
 
 
 	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
 	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);