Browse Source

Merge pull request #10 from paullouisageneau/big-messages

Large messages support
Paul-Louis Ageneau 5 years ago
parent
commit
5416b66116

+ 8 - 7
include/rtc/channel.hpp

@@ -21,8 +21,8 @@
 
 
 #include "include.hpp"
 #include "include.hpp"
 
 
+#include <atomic>
 #include <functional>
 #include <functional>
-#include <mutex>
 #include <variant>
 #include <variant>
 
 
 namespace rtc {
 namespace rtc {
@@ -30,13 +30,14 @@ namespace rtc {
 class Channel {
 class Channel {
 public:
 public:
 	virtual void close() = 0;
 	virtual void close() = 0;
-	virtual void send(const std::variant<binary, string> &data) = 0;
-	virtual std::optional<std::variant<binary, string>> receive() = 0;
+	virtual bool send(const std::variant<binary, string> &data) = 0; // returns false if buffered
+	virtual std::optional<std::variant<binary, string>> receive() = 0; // only if onMessage unset
+
 	virtual bool isOpen() const = 0;
 	virtual bool isOpen() const = 0;
 	virtual bool isClosed() const = 0;
 	virtual bool isClosed() const = 0;
-	virtual size_t availableAmount() const { return 0; }
 
 
-	size_t bufferedAmount() const;
+	virtual size_t availableAmount() const; // total size available to receive
+	virtual size_t bufferedAmount() const; // total size buffered to send
 
 
 	void onOpen(std::function<void()> callback);
 	void onOpen(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
@@ -66,8 +67,8 @@ private:
 	synchronized_callback<> mAvailableCallback;
 	synchronized_callback<> mAvailableCallback;
 	synchronized_callback<> mBufferedAmountLowCallback;
 	synchronized_callback<> mBufferedAmountLowCallback;
 
 
-	size_t mBufferedAmount = 0;
-	size_t mBufferedAmountLowThreshold = 0;
+	std::atomic<size_t> mBufferedAmount = 0;
+	std::atomic<size_t> mBufferedAmountLowThreshold = 0;
 };
 };
 
 
 } // namespace rtc
 } // namespace rtc

+ 19 - 16
include/rtc/datachannel.hpp

@@ -44,18 +44,21 @@ public:
 	            unsigned int stream);
 	            unsigned int stream);
 	~DataChannel();
 	~DataChannel();
 
 
-	void close(void);
-	void send(const std::variant<binary, string> &data);
-	void send(const byte *data, size_t size);
-	std::optional<std::variant<binary, string>> receive();
+	void close(void) override;
 
 
-	// Directly send a buffer to avoid a copy
-	template <typename Buffer> void sendBuffer(const Buffer &buf);
-	template <typename Iterator> void sendBuffer(Iterator first, Iterator last);
+	bool send(const std::variant<binary, string> &data) override;
+	bool send(const byte *data, size_t size);
 
 
-	bool isOpen(void) const;
-	bool isClosed(void) const;
-	size_t availableAmount() const;
+	template <typename Buffer> bool sendBuffer(const Buffer &buf);
+	template <typename Iterator> bool sendBuffer(Iterator first, Iterator last);
+
+	std::optional<std::variant<binary, string>> receive() override;
+
+	bool isOpen(void) const override;
+	bool isClosed(void) const override;
+	size_t availableAmount() const override;
+
+	size_t maxMessageSize() const;  // maximum message size in a call to send or sendBuffer
 
 
 	unsigned int stream() const;
 	unsigned int stream() const;
 	string label() const;
 	string label() const;
@@ -64,11 +67,11 @@ public:
 
 
 private:
 private:
 	void open(std::shared_ptr<SctpTransport> sctpTransport);
 	void open(std::shared_ptr<SctpTransport> sctpTransport);
-	void outgoing(mutable_message_ptr message);
+	bool outgoing(mutable_message_ptr message);
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
 	void processOpenMessage(message_ptr message);
 	void processOpenMessage(message_ptr message);
 
 
-	const std::shared_ptr<PeerConnection> mPeerConnection; // keeps the PeerConnection alive
+	const std::shared_ptr<PeerConnection> mPeerConnection;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 
 
 	unsigned int mStream;
 	unsigned int mStream;
@@ -92,14 +95,14 @@ template <typename Buffer> std::pair<const byte *, size_t> to_bytes(const Buffer
 	                      buf.size() * sizeof(E));
 	                      buf.size() * sizeof(E));
 }
 }
 
 
-template <typename Buffer> void DataChannel::sendBuffer(const Buffer &buf) {
+template <typename Buffer> bool DataChannel::sendBuffer(const Buffer &buf) {
 	auto [bytes, size] = to_bytes(buf);
 	auto [bytes, size] = to_bytes(buf);
 	auto message = std::make_shared<Message>(size);
 	auto message = std::make_shared<Message>(size);
 	std::copy(bytes, bytes + size, message->data());
 	std::copy(bytes, bytes + size, message->data());
-	outgoing(message);
+	return outgoing(message);
 }
 }
 
 
-template <typename Iterator> void DataChannel::sendBuffer(Iterator first, Iterator last) {
+template <typename Iterator> bool DataChannel::sendBuffer(Iterator first, Iterator last) {
 	size_t size = 0;
 	size_t size = 0;
 	for (Iterator it = first; it != last; ++it)
 	for (Iterator it = first; it != last; ++it)
 		size += it->size();
 		size += it->size();
@@ -110,7 +113,7 @@ template <typename Iterator> void DataChannel::sendBuffer(Iterator first, Iterat
 		auto [bytes, size] = to_bytes(*it);
 		auto [bytes, size] = to_bytes(*it);
 		pos = std::copy(bytes, bytes + size, pos);
 		pos = std::copy(bytes, bytes + size, pos);
 	}
 	}
-	outgoing(message);
+	return outgoing(message);
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc

+ 3 - 0
include/rtc/description.hpp

@@ -44,9 +44,11 @@ public:
 	string mid() const;
 	string mid() const;
 	std::optional<string> fingerprint() const;
 	std::optional<string> fingerprint() const;
 	std::optional<uint16_t> sctpPort() const;
 	std::optional<uint16_t> sctpPort() const;
+	std::optional<size_t> maxMessageSize() const;
 
 
 	void setFingerprint(string fingerprint);
 	void setFingerprint(string fingerprint);
 	void setSctpPort(uint16_t port);
 	void setSctpPort(uint16_t port);
+	void setMaxMessageSize(size_t size);
 
 
 	void addCandidate(Candidate candidate);
 	void addCandidate(Candidate candidate);
 	void endCandidates();
 	void endCandidates();
@@ -62,6 +64,7 @@ private:
 	string mIceUfrag, mIcePwd;
 	string mIceUfrag, mIcePwd;
 	std::optional<string> mFingerprint;
 	std::optional<string> mFingerprint;
 	std::optional<uint16_t> mSctpPort;
 	std::optional<uint16_t> mSctpPort;
+	std::optional<size_t> mMaxMessageSize;
 	std::vector<Candidate> mCandidates;
 	std::vector<Candidate> mCandidates;
 	bool mTrickle;
 	bool mTrickle;
 
 

+ 3 - 0
include/rtc/include.hpp

@@ -43,7 +43,10 @@ using std::uint8_t;
 
 
 const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
 const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
 const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
 const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
+
 const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
 const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
+const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536;    // Remote max message size if not specified in SDP
+const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
 
 
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
 template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
 template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;

+ 1 - 0
include/rtc/message.hpp

@@ -44,6 +44,7 @@ struct Message : binary {
 using message_ptr = std::shared_ptr<const Message>;
 using message_ptr = std::shared_ptr<const Message>;
 using mutable_message_ptr = std::shared_ptr<Message>;
 using mutable_message_ptr = std::shared_ptr<Message>;
 using message_callback = std::function<void(message_ptr message)>;
 using message_callback = std::function<void(message_ptr message)>;
+
 constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
 constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
 	return m->type != Message::Control ? m->size() : 0;
 	return m->type != Message::Control ? m->size() : 0;
 };
 };

+ 8 - 11
include/rtc/queue.hpp

@@ -34,8 +34,7 @@ template <typename T> class Queue {
 public:
 public:
 	using amount_function = std::function<size_t(const T &element)>;
 	using amount_function = std::function<size_t(const T &element)>;
 
 
-	Queue(
-	    size_t limit = 0, amount_function func = [](const T &element) -> size_t { return 1; });
+	Queue(size_t limit = 0, amount_function func = nullptr);
 	~Queue();
 	~Queue();
 
 
 	void stop();
 	void stop();
@@ -45,7 +44,7 @@ public:
 	void push(const T &element);
 	void push(const T &element);
 	void push(T &&element);
 	void push(T &&element);
 	std::optional<T> pop();
 	std::optional<T> pop();
-	std::optional<T> tryPop();
+	std::optional<T> peek();
 	void wait();
 	void wait();
 	void wait(const std::chrono::milliseconds &duration);
 	void wait(const std::chrono::milliseconds &duration);
 
 
@@ -61,8 +60,9 @@ private:
 };
 };
 
 
 template <typename T>
 template <typename T>
-Queue<T>::Queue(size_t limit, amount_function func)
-    : mLimit(limit), mAmount(0), mAmountFunction(func) {}
+Queue<T>::Queue(size_t limit, amount_function func) : mLimit(limit), mAmount(0) {
+	mAmountFunction = func ? func : [](const T &element) -> size_t { return 1; };
+}
 
 
 template <typename T> Queue<T>::~Queue() { stop(); }
 template <typename T> Queue<T>::~Queue() { stop(); }
 
 
@@ -105,7 +105,7 @@ template <typename T> std::optional<T> Queue<T>::pop() {
 	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 	if (!mQueue.empty()) {
 	if (!mQueue.empty()) {
 		mAmount -= mAmountFunction(mQueue.front());
 		mAmount -= mAmountFunction(mQueue.front());
-		std::optional<T> element(std::move(mQueue.front()));
+		std::optional<T> element{std::move(mQueue.front())};
 		mQueue.pop();
 		mQueue.pop();
 		return element;
 		return element;
 	} else {
 	} else {
@@ -113,13 +113,10 @@ template <typename T> std::optional<T> Queue<T>::pop() {
 	}
 	}
 }
 }
 
 
-template <typename T> std::optional<T> Queue<T>::tryPop() {
+template <typename T> std::optional<T> Queue<T>::peek() {
 	std::unique_lock<std::mutex> lock(mMutex);
 	std::unique_lock<std::mutex> lock(mMutex);
 	if (!mQueue.empty()) {
 	if (!mQueue.empty()) {
-		mAmount -= mAmountFunction(mQueue.front());
-		std::optional<T> element(std::move(mQueue.front()));
-		mQueue.pop();
-		return element;
+		return std::optional<T>{mQueue.front()};
 	} else {
 	} else {
 		return nullopt;
 		return nullopt;
 	}
 	}

+ 5 - 4
src/channel.cpp

@@ -57,6 +57,8 @@ void Channel::onBufferedAmountLow(std::function<void()> callback) {
 	mBufferedAmountLowCallback = callback;
 	mBufferedAmountLowCallback = callback;
 }
 }
 
 
+size_t Channel::availableAmount() const { return 0; }
+
 size_t Channel::bufferedAmount() const { return mBufferedAmount; }
 size_t Channel::bufferedAmount() const { return mBufferedAmount; }
 
 
 void Channel::setBufferedAmountLowThreshold(size_t amount) { mBufferedAmountLowThreshold = amount; }
 void Channel::setBufferedAmountLowThreshold(size_t amount) { mBufferedAmountLowThreshold = amount; }
@@ -80,10 +82,9 @@ void Channel::triggerAvailable(size_t count) {
 }
 }
 
 
 void Channel::triggerBufferedAmount(size_t amount) {
 void Channel::triggerBufferedAmount(size_t amount) {
-	bool lowThresholdCrossed =
-	    mBufferedAmount > mBufferedAmountLowThreshold && amount <= mBufferedAmountLowThreshold;
-	mBufferedAmount = amount;
-	if (lowThresholdCrossed)
+	size_t previous = mBufferedAmount.exchange(amount);
+	size_t threshold = mBufferedAmountLowThreshold.load();
+	if (previous > threshold && amount <= threshold)
 		mBufferedAmountLowCallback();
 		mBufferedAmountLowCallback();
 }
 }
 
 

+ 24 - 10
src/datachannel.cpp

@@ -17,6 +17,7 @@
  */
  */
 
 
 #include "datachannel.hpp"
 #include "datachannel.hpp"
+#include "include.hpp"
 #include "peerconnection.hpp"
 #include "peerconnection.hpp"
 #include "sctptransport.hpp"
 #include "sctptransport.hpp"
 
 
@@ -82,24 +83,24 @@ void DataChannel::close() {
 	}
 	}
 }
 }
 
 
-void DataChannel::send(const std::variant<binary, string> &data) {
-	std::visit(
+bool DataChannel::send(const std::variant<binary, string> &data) {
+	return std::visit(
 	    [&](const auto &d) {
 	    [&](const auto &d) {
 		    using T = std::decay_t<decltype(d)>;
 		    using T = std::decay_t<decltype(d)>;
 		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
 		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
 		    auto *b = reinterpret_cast<const byte *>(d.data());
 		    auto *b = reinterpret_cast<const byte *>(d.data());
-		    outgoing(std::make_shared<Message>(b, b + d.size(), type));
+		    return outgoing(std::make_shared<Message>(b, b + d.size(), type));
 	    },
 	    },
 	    data);
 	    data);
 }
 }
 
 
-void DataChannel::send(const byte *data, size_t size) {
-	outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
+bool DataChannel::send(const byte *data, size_t size) {
+	return outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
 }
 }
 
 
 std::optional<std::variant<binary, string>> DataChannel::receive() {
 std::optional<std::variant<binary, string>> DataChannel::receive() {
-	while (auto opt = mRecvQueue.tryPop()) {
-		auto message = *opt;
+	while (!mRecvQueue.empty()) {
+		auto message = *mRecvQueue.pop();
 		switch (message->type) {
 		switch (message->type) {
 		case Message::Control: {
 		case Message::Control: {
 			auto raw = reinterpret_cast<const uint8_t *>(message->data());
 			auto raw = reinterpret_cast<const uint8_t *>(message->data());
@@ -128,6 +129,15 @@ bool DataChannel::isClosed(void) const { return mIsClosed; }
 
 
 size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 
 
+size_t DataChannel::maxMessageSize() const {
+	size_t max = DEFAULT_MAX_MESSAGE_SIZE;
+	if (auto description = mPeerConnection->remoteDescription())
+		if (auto maxMessageSize = description->maxMessageSize())
+			return *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE;
+
+	return std::min(max, LOCAL_MAX_MESSAGE_SIZE);
+}
+
 unsigned int DataChannel::stream() const { return mStream; }
 unsigned int DataChannel::stream() const { return mStream; }
 
 
 string DataChannel::label() const { return mLabel; }
 string DataChannel::label() const { return mLabel; }
@@ -167,13 +177,17 @@ void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
 	mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 	mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 }
 }
 
 
-void DataChannel::outgoing(mutable_message_ptr message) {
+bool DataChannel::outgoing(mutable_message_ptr message) {
 	if (mIsClosed || !mSctpTransport)
 	if (mIsClosed || !mSctpTransport)
-		return;
+		throw std::runtime_error("DataChannel is closed");
+
+	if (message->size() > maxMessageSize())
+		throw std::runtime_error("Message size exceeds limit");
+
 	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
 	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
 	message->reliability = mIsOpen ? mReliability : nullptr;
 	message->reliability = mIsOpen ? mReliability : nullptr;
 	message->stream = mStream;
 	message->stream = mStream;
-	mSctpTransport->send(message);
+	return mSctpTransport->send(message);
 }
 }
 
 
 void DataChannel::incoming(message_ptr message) {
 void DataChannel::incoming(message_ptr message) {

+ 8 - 1
src/description.cpp

@@ -81,6 +81,8 @@ Description::Description(const string &sdp, Type type, Role role)
 			mIcePwd = line.substr(line.find(':') + 1);
 			mIcePwd = line.substr(line.find(':') + 1);
 		} else if (hasprefix(line, "a=sctp-port")) {
 		} else if (hasprefix(line, "a=sctp-port")) {
 			mSctpPort = uint16_t(std::stoul(line.substr(line.find(':') + 1)));
 			mSctpPort = uint16_t(std::stoul(line.substr(line.find(':') + 1)));
+		} else if (hasprefix(line, "a=max-message-size")) {
+			mMaxMessageSize = size_t(std::stoul(line.substr(line.find(':') + 1)));
 		} else if (hasprefix(line, "a=candidate")) {
 		} else if (hasprefix(line, "a=candidate")) {
 			addCandidate(Candidate(line.substr(2), mMid));
 			addCandidate(Candidate(line.substr(2), mMid));
 		} else if (hasprefix(line, "a=end-of-candidates")) {
 		} else if (hasprefix(line, "a=end-of-candidates")) {
@@ -103,12 +105,16 @@ std::optional<string> Description::fingerprint() const { return mFingerprint; }
 
 
 std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; }
 std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; }
 
 
+std::optional<size_t> Description::maxMessageSize() const { return mMaxMessageSize; }
+
 void Description::setFingerprint(string fingerprint) {
 void Description::setFingerprint(string fingerprint) {
 	mFingerprint.emplace(std::move(fingerprint));
 	mFingerprint.emplace(std::move(fingerprint));
 }
 }
 
 
 void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); }
 void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); }
 
 
+void Description::setMaxMessageSize(size_t size) { mMaxMessageSize.emplace(size); }
+
 void Description::addCandidate(Candidate candidate) {
 void Description::addCandidate(Candidate candidate) {
 	mCandidates.emplace_back(std::move(candidate));
 	mCandidates.emplace_back(std::move(candidate));
 }
 }
@@ -145,7 +151,8 @@ Description::operator string() const {
 		sdp << "a=fingerprint:sha-256 " << *mFingerprint << "\n";
 		sdp << "a=fingerprint:sha-256 " << *mFingerprint << "\n";
 	if (mSctpPort)
 	if (mSctpPort)
 		sdp << "a=sctp-port:" << *mSctpPort << "\n";
 		sdp << "a=sctp-port:" << *mSctpPort << "\n";
-
+	if (mMaxMessageSize)
+		sdp << "a=max-message-size:" << *mMaxMessageSize << "\n";
 	for (const auto &candidate : mCandidates) {
 	for (const auto &candidate : mCandidates) {
 		sdp << string(candidate) << "\n";
 		sdp << string(candidate) << "\n";
 	}
 	}

+ 3 - 5
src/dtlstransport.cpp

@@ -85,10 +85,10 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 }
 }
 
 
 DtlsTransport::~DtlsTransport() {
 DtlsTransport::~DtlsTransport() {
-	onRecv(nullptr); // unset recv callback
-
 	mIncomingQueue.stop();
 	mIncomingQueue.stop();
-	mRecvThread.join();
+
+	if (mRecvThread.joinable())
+		mRecvThread.join();
 
 
 	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 	gnutls_deinit(mSession);
 	gnutls_deinit(mSession);
@@ -356,8 +356,6 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 }
 }
 
 
 DtlsTransport::~DtlsTransport() {
 DtlsTransport::~DtlsTransport() {
-	onRecv(nullptr); // unset recv callback
-
 	mIncomingQueue.stop();
 	mIncomingQueue.stop();
 
 
 	if (mRecvThread.joinable())
 	if (mRecvThread.joinable())

+ 1 - 1
src/dtlstransport.hpp

@@ -54,7 +54,7 @@ public:
 
 
 	State state() const;
 	State state() const;
 
 
-	bool send(message_ptr message);
+	bool send(message_ptr message); // false if dropped
 
 
 private:
 private:
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);

+ 1 - 1
src/icetransport.hpp

@@ -67,7 +67,7 @@ public:
 	std::optional<string> getLocalAddress() const;
 	std::optional<string> getLocalAddress() const;
 	std::optional<string> getRemoteAddress() const;
 	std::optional<string> getRemoteAddress() const;
 
 
-	bool send(message_ptr message);
+	bool send(message_ptr message); // false if dropped
 
 
 private:
 private:
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);

+ 1 - 0
src/peerconnection.cpp

@@ -359,6 +359,7 @@ void PeerConnection::processLocalDescription(Description description) {
 	mLocalDescription.emplace(std::move(description));
 	mLocalDescription.emplace(std::move(description));
 	mLocalDescription->setFingerprint(mCertificate->fingerprint());
 	mLocalDescription->setFingerprint(mCertificate->fingerprint());
 	mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
 	mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
+	mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);
 
 
 	mLocalDescriptionCallback(*mLocalDescription);
 	mLocalDescriptionCallback(*mLocalDescription);
 }
 }

+ 231 - 137
src/sctptransport.cpp

@@ -58,11 +58,14 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 	GlobalInit();
 	GlobalInit();
 
 
 	usrsctp_register_address(this);
 	usrsctp_register_address(this);
-	mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::ReadCallback,
-	                       nullptr, 0, this);
+	mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::RecvCallback,
+	                       &SctpTransport::SendCallback, 0, this);
 	if (!mSock)
 	if (!mSock)
 		throw std::runtime_error("Could not create SCTP socket, errno=" + std::to_string(errno));
 		throw std::runtime_error("Could not create SCTP socket, errno=" + std::to_string(errno));
 
 
+	if (usrsctp_set_non_blocking(mSock, 1))
+		throw std::runtime_error("Unable to set non-blocking mode, errno=" + std::to_string(errno));
+
 	// SCTP must stop sending after the lower layer is shut down, so disable linger
 	// SCTP must stop sending after the lower layer is shut down, so disable linger
 	struct linger sol = {};
 	struct linger sol = {};
 	sol.l_onoff = 1;
 	sol.l_onoff = 1;
@@ -81,12 +84,21 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 	struct sctp_event se = {};
 	struct sctp_event se = {};
 	se.se_assoc_id = SCTP_ALL_ASSOC;
 	se.se_assoc_id = SCTP_ALL_ASSOC;
 	se.se_on = 1;
 	se.se_on = 1;
+	se.se_type = SCTP_ASSOC_CHANGE;
+	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
+		throw std::runtime_error("Could not subscribe to event SCTP_ASSOC_CHANGE, errno=" +
+		                         std::to_string(errno));
+	se.se_type = SCTP_SENDER_DRY_EVENT;
+	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
+		throw std::runtime_error("Could not subscribe to event SCTP_SENDER_DRY_EVENT, errno=" +
+		                         std::to_string(errno));
 	se.se_type = SCTP_STREAM_RESET_EVENT;
 	se.se_type = SCTP_STREAM_RESET_EVENT;
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
-		throw std::runtime_error("Could not set socket option SCTP_EVENT, errno=" +
+		throw std::runtime_error("Could not subscribe to event SCTP_STREAM_RESET_EVENT, errno=" +
 		                         std::to_string(errno));
 		                         std::to_string(errno));
 
 
-	// Disable Nagle-like algorithm to reduce delay
+	// The sender SHOULD disable the Nagle algorithm (see RFC1122) to minimize the latency.
+	// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.6
 	int nodelay = 1;
 	int nodelay = 1;
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_NODELAY, &nodelay, sizeof(nodelay)))
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_NODELAY, &nodelay, sizeof(nodelay)))
 		throw std::runtime_error("Could not set socket option SCTP_NODELAY, errno=" +
 		throw std::runtime_error("Could not set socket option SCTP_NODELAY, errno=" +
@@ -127,47 +139,67 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 		throw std::runtime_error("Could not set SCTP send buffer size, errno=" +
 		throw std::runtime_error("Could not set SCTP send buffer size, errno=" +
 		                         std::to_string(errno));
 		                         std::to_string(errno));
 
 
-	struct sockaddr_conn sconn = {};
-	sconn.sconn_family = AF_CONN;
-	sconn.sconn_port = htons(mPort);
-	sconn.sconn_addr = this;
-#ifdef HAVE_SCONN_LEN
-	sconn.sconn_len = sizeof(sconn);
-#endif
-
-	if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)))
-		throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
-
-	mSendThread = std::thread(&SctpTransport::runConnectAndSendLoop, this);
+	connect();
 }
 }
 
 
 SctpTransport::~SctpTransport() {
 SctpTransport::~SctpTransport() {
 	onRecv(nullptr); // unset recv callback
 	onRecv(nullptr); // unset recv callback
-	mStopping = true;
-	mConnectCondition.notify_all();
+
 	mSendQueue.stop();
 	mSendQueue.stop();
 
 
+	// Unblock incoming
+	if (!mConnectDataSent) {
+		std::unique_lock<std::mutex> lock(mConnectMutex);
+		mConnectDataSent = true;
+		mConnectCondition.notify_all();
+	}
+
 	if (mSock) {
 	if (mSock) {
 		usrsctp_shutdown(mSock, SHUT_RDWR);
 		usrsctp_shutdown(mSock, SHUT_RDWR);
 		usrsctp_close(mSock);
 		usrsctp_close(mSock);
 	}
 	}
 
 
-	if (mSendThread.joinable())
-		mSendThread.join();
-
 	usrsctp_deregister_address(this);
 	usrsctp_deregister_address(this);
 	GlobalCleanup();
 	GlobalCleanup();
 }
 }
 
 
+void SctpTransport::connect() {
+	changeState(State::Connecting);
+
+	struct sockaddr_conn sconn = {};
+	sconn.sconn_family = AF_CONN;
+	sconn.sconn_port = htons(mPort);
+	sconn.sconn_addr = this;
+#ifdef HAVE_SCONN_LEN
+	sconn.sconn_len = sizeof(sconn);
+#endif
+
+	if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)))
+		throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
+
+	// According to the IETF draft, both endpoints must initiate the SCTP association, in a
+	// simultaneous-open manner, irrelevent to the SDP setup role.
+	// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
+	int ret = usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn));
+	if (ret && errno != EINPROGRESS)
+		throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno));
+}
+
 SctpTransport::State SctpTransport::state() const { return mState; }
 SctpTransport::State SctpTransport::state() const { return mState; }
 
 
 bool SctpTransport::send(message_ptr message) {
 bool SctpTransport::send(message_ptr message) {
-	if (!message || mStopping)
-		return false;
+	std::lock_guard<std::mutex> lock(mSendMutex);
+
+	if (!message)
+		return mSendQueue.empty();
+
+	// If nothing is pending, try to send directly
+	if (mSendQueue.empty() && trySendMessage(message))
+		return true;
 
 
-	updateBufferedAmount(message->stream, message->size());
 	mSendQueue.push(message);
 	mSendQueue.push(message);
-	return true;
+	updateBufferedAmount(message->stream, message_size_func(message));
+	return false;
 }
 }
 
 
 void SctpTransport::reset(unsigned int stream) {
 void SctpTransport::reset(unsigned int stream) {
@@ -188,17 +220,15 @@ void SctpTransport::incoming(message_ptr message) {
 		return;
 		return;
 	}
 	}
 
 
-	// There could be a race condition here where we receive the remote INIT before the thread in
-	// usrsctp_connect sends the local one, which would result in the connection being aborted.
-	// Therefore, we need to wait for data to be sent on our side (i.e. the local INIT) before
-	// proceeding.
+	// There could be a race condition here where we receive the remote INIT before the local one is
+	// sent, which would result in the connection being aborted. Therefore, we need to wait for data
+	// to be sent on our side (i.e. the local INIT) before proceeding.
 	if (!mConnectDataSent) {
 	if (!mConnectDataSent) {
 		std::unique_lock<std::mutex> lock(mConnectMutex);
 		std::unique_lock<std::mutex> lock(mConnectMutex);
-		mConnectCondition.wait(lock, [this] { return mConnectDataSent || mStopping; });
+		mConnectCondition.wait(lock, [this]() -> bool { return mConnectDataSent; });
 	}
 	}
 
 
-	if (!mStopping)
-		usrsctp_conninput(this, message->data(), message->size(), 0);
+	usrsctp_conninput(this, message->data(), message->size(), 0);
 }
 }
 
 
 void SctpTransport::changeState(State state) {
 void SctpTransport::changeState(State state) {
@@ -206,60 +236,26 @@ void SctpTransport::changeState(State state) {
 		mStateChangeCallback(state);
 		mStateChangeCallback(state);
 }
 }
 
 
-void SctpTransport::runConnectAndSendLoop() {
-	try {
-		changeState(State::Connecting);
-
-		struct sockaddr_conn sconn = {};
-		sconn.sconn_family = AF_CONN;
-		sconn.sconn_port = htons(mPort);
-		sconn.sconn_addr = this;
-#ifdef HAVE_SCONN_LEN
-		sconn.sconn_len = sizeof(sconn);
-#endif
-
-		// According to the IETF draft, both endpoints must initiate the SCTP association, in a
-		// simultaneous-open manner, irrelevent to the SDP setup role.
-		// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
-		if (usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)) != 0)
-			throw std::runtime_error("Connection failed, errno=" + std::to_string(errno));
-
-		if (!mStopping)
-			changeState(State::Connected);
-
-	} catch (const std::exception &e) {
-		std::cerr << "SCTP connect: " << e.what() << std::endl;
-		changeState(State::Failed);
-		mStopping = true;
-		mConnectCondition.notify_all();
-		return;
+bool SctpTransport::trySendQueue() {
+	// Requires mSendMutex to be locked
+	while (auto next = mSendQueue.peek()) {
+		auto message = *next;
+		if (!trySendMessage(message))
+			return false;
+		mSendQueue.pop();
+		updateBufferedAmount(message->stream, -message_size_func(message));
 	}
 	}
-
-	try {
-		while (auto next = mSendQueue.pop()) {
-			auto message = *next;
-			bool success = doSend(message);
-			updateBufferedAmount(message->stream, -message->size());
-			if (!success)
-				throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
-		}
-	} catch (const std::exception &e) {
-		std::cerr << "SCTP send: " << e.what() << std::endl;
-	}
-
-	changeState(State::Disconnected);
-	mStopping = true;
-	mConnectCondition.notify_all();
+	return true;
 }
 }
 
 
-bool SctpTransport::doSend(message_ptr message) {
-	if (!message)
-		return false;
+bool SctpTransport::trySendMessage(message_ptr message) {
+	// Requires mSendMutex to be locked
+	//
+	// TODO: Implement SCTP ndata specification draft when supported everywhere
+	// See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08
 
 
 	const Reliability reliability = message->reliability ? *message->reliability : Reliability();
 	const Reliability reliability = message->reliability ? *message->reliability : Reliability();
 
 
-	struct sctp_sendv_spa spa = {};
-
 	uint32_t ppid;
 	uint32_t ppid;
 	switch (message->type) {
 	switch (message->type) {
 	case Message::String:
 	case Message::String:
@@ -273,11 +269,13 @@ bool SctpTransport::doSend(message_ptr message) {
 		break;
 		break;
 	}
 	}
 
 
+	struct sctp_sendv_spa spa = {};
+
 	// set sndinfo
 	// set sndinfo
 	spa.sendv_flags |= SCTP_SEND_SNDINFO_VALID;
 	spa.sendv_flags |= SCTP_SEND_SNDINFO_VALID;
 	spa.sendv_sndinfo.snd_sid = uint16_t(message->stream);
 	spa.sendv_sndinfo.snd_sid = uint16_t(message->stream);
 	spa.sendv_sndinfo.snd_ppid = htonl(ppid);
 	spa.sendv_sndinfo.snd_ppid = htonl(ppid);
-	spa.sendv_sndinfo.snd_flags |= SCTP_EOR;
+	spa.sendv_sndinfo.snd_flags |= SCTP_EOR; // implicit here
 
 
 	// set prinfo
 	// set prinfo
 	spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
 	spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
@@ -309,78 +307,141 @@ bool SctpTransport::doSend(message_ptr message) {
 		const char zero = 0;
 		const char zero = 0;
 		ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
 		ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
 	}
 	}
-	return ret > 0;
+
+	if (ret >= 0)
+		return true;
+	else if (errno == EWOULDBLOCK && errno == EAGAIN)
+		return false;
+	else
+		throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
 }
 }
 
 
 void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
 void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
-	if (delta == 0)
-		return;
-	std::lock_guard<std::mutex> lock(mBufferedAmountMutex);
+	// Requires mSendMutex to be locked
 	auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first;
 	auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first;
-	if (delta > 0)
-		it->second += size_t(delta);
-	else if (it->second > size_t(-delta))
-		it->second -= size_t(-delta);
-	else
-		it->second = 0;
-	mBufferedAmountCallback(streamId, it->second);
-	if (it->second == 0)
+	size_t amount = it->second;
+	amount = size_t(std::max(long(amount) + delta, long(0)));
+	if (amount == 0)
 		mBufferedAmount.erase(it);
 		mBufferedAmount.erase(it);
+	mBufferedAmountCallback(streamId, amount);
 }
 }
 
 
-int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df) {
-	byte *b = reinterpret_cast<byte *>(data);
-	outgoing(make_message(b, b + len));
-
-	if (!mConnectDataSent) {
-		std::unique_lock<std::mutex> lock(mConnectMutex);
-		mConnectDataSent = true;
-		mConnectCondition.notify_all();
+int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data,
+                              size_t len, struct sctp_rcvinfo info, int flags) {
+	try {
+		if (!data) {
+			recv(nullptr);
+			return 0;
+		}
+		if (flags & MSG_EOR) {
+			if (!mPartialRecv.empty()) {
+				mPartialRecv.insert(mPartialRecv.end(), data, data + len);
+				data = mPartialRecv.data();
+				len = mPartialRecv.size();
+			}
+			// Message is complete, process it
+			if (flags & MSG_NOTIFICATION)
+				processNotification(reinterpret_cast<const union sctp_notification *>(data), len);
+			else
+				processData(data, len, info.rcv_sid, PayloadId(htonl(info.rcv_ppid)));
+
+			mPartialRecv.clear();
+		} else {
+			// Message is not complete
+			mPartialRecv.insert(mPartialRecv.end(), data, data + len);
+		}
+	} catch (const std::exception &e) {
+		std::cerr << "SCTP recv: " << e.what() << std::endl;
+		return -1;
 	}
 	}
 	return 0; // success
 	return 0; // success
 }
 }
 
 
-int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
-                           struct sctp_rcvinfo info, int flags) {
-	if (!data) {
-		recv(nullptr);
-		return 0;
+int SctpTransport::handleSend(size_t free) {
+	try {
+		std::lock_guard<std::mutex> lock(mSendMutex);
+		trySendQueue();
+	} catch (const std::exception &e) {
+		std::cerr << "SCTP send: " << e.what() << std::endl;
+		return -1;
 	}
 	}
-	if (flags & MSG_NOTIFICATION) {
-		processNotification((union sctp_notification *)data, len);
-	} else {
-		processData((const byte *)data, len, info.rcv_sid, PayloadId(htonl(info.rcv_ppid)));
+	return 0; // success
+}
+
+int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) {
+	try {
+		outgoing(make_message(data, data + len));
+
+		if (!mConnectDataSent) {
+			std::unique_lock<std::mutex> lock(mConnectMutex);
+			mConnectDataSent = true;
+			mConnectCondition.notify_all();
+		}
+	} catch (const std::exception &e) {
+		std::cerr << "SCTP write: " << e.what() << std::endl;
+		return -1;
 	}
 	}
-	free(data);
-	return 0;
+	return 0; // success
 }
 }
 
 
 void SctpTransport::processData(const byte *data, size_t len, uint16_t sid, PayloadId ppid) {
 void SctpTransport::processData(const byte *data, size_t len, uint16_t sid, PayloadId ppid) {
-	Message::Type type;
+	// The usage of the PPIDs "WebRTC String Partial" and "WebRTC Binary Partial" is deprecated.
+	// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.6
+	// We handle them at reception for compatibility reasons but should never send them.
 	switch (ppid) {
 	switch (ppid) {
+	case PPID_CONTROL:
+		recv(make_message(data, data + len, Message::Control, sid));
+		break;
+
+	case PPID_STRING_PARTIAL: // deprecated
+		mPartialStringData.insert(mPartialStringData.end(), data, data + len);
+		break;
+
 	case PPID_STRING:
 	case PPID_STRING:
-		type = Message::String;
+		if (mPartialStringData.empty()) {
+			recv(make_message(data, data + len, Message::String, sid));
+		} else {
+			mPartialStringData.insert(mPartialStringData.end(), data, data + len);
+			recv(make_message(mPartialStringData.begin(), mPartialStringData.end(), Message::String,
+			                  sid));
+			mPartialStringData.clear();
+		}
 		break;
 		break;
+
 	case PPID_STRING_EMPTY:
 	case PPID_STRING_EMPTY:
-		type = Message::String;
-		len = 0;
+		// This only accounts for when the partial data is empty
+		recv(make_message(mPartialStringData.begin(), mPartialStringData.end(), Message::String,
+		                  sid));
+		mPartialStringData.clear();
 		break;
 		break;
+
+	case PPID_BINARY_PARTIAL: // deprecated
+		mPartialBinaryData.insert(mPartialBinaryData.end(), data, data + len);
+		break;
+
 	case PPID_BINARY:
 	case PPID_BINARY:
-		type = Message::Binary;
+		if (mPartialBinaryData.empty()) {
+			recv(make_message(data, data + len, Message::Binary, sid));
+		} else {
+			mPartialBinaryData.insert(mPartialBinaryData.end(), data, data + len);
+			recv(make_message(mPartialBinaryData.begin(), mPartialBinaryData.end(), Message::Binary,
+			                  sid));
+			mPartialBinaryData.clear();
+		}
 		break;
 		break;
+
 	case PPID_BINARY_EMPTY:
 	case PPID_BINARY_EMPTY:
-		type = Message::Binary;
-		len = 0;
-		break;
-	case PPID_CONTROL:
-		type = Message::Control;
+		// This only accounts for when the partial data is empty
+		recv(make_message(mPartialBinaryData.begin(), mPartialBinaryData.end(), Message::Binary,
+		                  sid));
+		mPartialBinaryData.clear();
 		break;
 		break;
+
 	default:
 	default:
 		// Unknown
 		// Unknown
 		std::cerr << "Unknown PPID: " << uint32_t(ppid) << std::endl;
 		std::cerr << "Unknown PPID: " << uint32_t(ppid) << std::endl;
 		return;
 		return;
 	}
 	}
-	recv(make_message(data, data + len, type, sid));
 }
 }
 
 
 void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) {
 void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) {
@@ -388,21 +449,41 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 		return;
 		return;
 
 
 	switch (notify->sn_header.sn_type) {
 	switch (notify->sn_header.sn_type) {
+	case SCTP_ASSOC_CHANGE: {
+		const struct sctp_assoc_change &assoc_change = notify->sn_assoc_change;
+		std::unique_lock<std::mutex> lock(mConnectMutex);
+		if (assoc_change.sac_state == SCTP_COMM_UP) {
+			changeState(State::Connected);
+		} else {
+			if (mState == State::Connecting) {
+				std::cerr << "SCTP connection failed" << std::endl;
+				changeState(State::Failed);
+			} else {
+				changeState(State::Disconnected);
+			}
+		}
+	}
+	case SCTP_SENDER_DRY_EVENT: {
+		// It not should be necessary since the send callback should have been called already,
+		// but to be sure, let's try to send now.
+		std::lock_guard<std::mutex> lock(mSendMutex);
+		trySendQueue();
+	}
 	case SCTP_STREAM_RESET_EVENT: {
 	case SCTP_STREAM_RESET_EVENT: {
-		const struct sctp_stream_reset_event *reset_event = &notify->sn_strreset_event;
-		const int count = (reset_event->strreset_length - sizeof(*reset_event)) / sizeof(uint16_t);
+		const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event;
+		const int count = (reset_event.strreset_length - sizeof(reset_event)) / sizeof(uint16_t);
 
 
-		if (reset_event->strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) {
+		if (reset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) {
 			for (int i = 0; i < count; ++i) {
 			for (int i = 0; i < count; ++i) {
-				uint16_t streamId = reset_event->strreset_stream_list[i];
+				uint16_t streamId = reset_event.strreset_stream_list[i];
 				reset(streamId);
 				reset(streamId);
 			}
 			}
 		}
 		}
 
 
-		if (reset_event->strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) {
+		if (reset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) {
 			const byte dataChannelCloseMessage{0x04};
 			const byte dataChannelCloseMessage{0x04};
 			for (int i = 0; i < count; ++i) {
 			for (int i = 0; i < count; ++i) {
-				uint16_t streamId = reset_event->strreset_stream_list[i];
+				uint16_t streamId = reset_event.strreset_stream_list[i];
 				recv(make_message(&dataChannelCloseMessage, &dataChannelCloseMessage + 1,
 				recv(make_message(&dataChannelCloseMessage, &dataChannelCloseMessage + 1,
 				                  Message::Control, streamId));
 				                  Message::Control, streamId));
 			}
 			}
@@ -415,16 +496,29 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 		break;
 		break;
 	}
 	}
 }
 }
-int SctpTransport::WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos,
-                                 uint8_t set_df) {
-	return static_cast<SctpTransport *>(sctp_ptr)->handleWrite(data, len, tos, set_df);
+
+int SctpTransport::RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data,
+                                size_t len, struct sctp_rcvinfo recv_info, int flags, void *ptr) {
+	int ret = static_cast<SctpTransport *>(ptr)->handleRecv(
+	    sock, addr, static_cast<const byte *>(data), len, recv_info, flags);
+	free(data);
+	return ret;
+}
+
+int SctpTransport::SendCallback(struct socket *sock, uint32_t sb_free) {
+	struct sctp_paddrinfo paddrinfo = {};
+	socklen_t len = sizeof(paddrinfo);
+	if (usrsctp_getsockopt(sock, IPPROTO_SCTP, SCTP_GET_PEER_ADDR_INFO, &paddrinfo, &len))
+		return -1;
+
+	auto sconn = reinterpret_cast<struct sockaddr_conn *>(&paddrinfo.spinfo_address);
+	void *ptr = sconn->sconn_addr;
+	return static_cast<SctpTransport *>(ptr)->handleSend(size_t(sb_free));
 }
 }
 
 
-int SctpTransport::ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data,
-                                size_t len, struct sctp_rcvinfo recv_info, int flags,
-                                void *user_data) {
-	return static_cast<SctpTransport *>(user_data)->process(sock, addr, data, len, recv_info,
-	                                                        flags);
+int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) {
+	return static_cast<SctpTransport *>(ptr)->handleWrite(static_cast<byte *>(data), len, tos,
+	                                                      set_df);
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc

+ 19 - 11
src/sctptransport.hpp

@@ -50,28 +50,34 @@ public:
 
 
 	State state() const;
 	State state() const;
 
 
-	bool send(message_ptr message);
+	bool send(message_ptr message); // false if buffered
 	void reset(unsigned int stream);
 	void reset(unsigned int stream);
 
 
 private:
 private:
+	// Order seems wrong but these are the actual values
+	// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-8
 	enum PayloadId : uint32_t {
 	enum PayloadId : uint32_t {
 		PPID_CONTROL = 50,
 		PPID_CONTROL = 50,
 		PPID_STRING = 51,
 		PPID_STRING = 51,
+		PPID_BINARY_PARTIAL = 52,
 		PPID_BINARY = 53,
 		PPID_BINARY = 53,
+		PPID_STRING_PARTIAL = 54,
 		PPID_STRING_EMPTY = 56,
 		PPID_STRING_EMPTY = 56,
 		PPID_BINARY_EMPTY = 57
 		PPID_BINARY_EMPTY = 57
 	};
 	};
 
 
+	void connect();
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
 	void changeState(State state);
 	void changeState(State state);
-	void runConnectAndSendLoop();
-	bool doSend(message_ptr message);
-	void updateBufferedAmount(uint16_t streamId, long delta);
 
 
-	int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df);
+	bool trySendQueue();
+	bool trySendMessage(message_ptr message);
+	void updateBufferedAmount(uint16_t streamId, long delta);
 
 
-	int process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
-	            struct sctp_rcvinfo recv_info, int flags);
+	int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len,
+	               struct sctp_rcvinfo recv_info, int flags);
+	int handleSend(size_t free);
+	int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df);
 
 
 	void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid);
 	void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid);
 	void processNotification(const union sctp_notification *notify, size_t len);
 	void processNotification(const union sctp_notification *notify, size_t len);
@@ -79,10 +85,9 @@ private:
 	const uint16_t mPort;
 	const uint16_t mPort;
 	struct socket *mSock;
 	struct socket *mSock;
 
 
+	std::mutex mSendMutex;
 	Queue<message_ptr> mSendQueue;
 	Queue<message_ptr> mSendQueue;
-	std::thread mSendThread;
 	std::map<uint16_t, size_t> mBufferedAmount;
 	std::map<uint16_t, size_t> mBufferedAmount;
-	std::mutex mBufferedAmountMutex;
 	amount_callback mBufferedAmountCallback;
 	amount_callback mBufferedAmountCallback;
 
 
 	std::mutex mConnectMutex;
 	std::mutex mConnectMutex;
@@ -93,9 +98,12 @@ private:
 	state_callback mStateChangeCallback;
 	state_callback mStateChangeCallback;
 	std::atomic<State> mState;
 	std::atomic<State> mState;
 
 
-	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
-	static int ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
+	binary mPartialRecv, mPartialStringData, mPartialBinaryData;
+
+	static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
 	                        struct sctp_rcvinfo recv_info, int flags, void *user_data);
 	                        struct sctp_rcvinfo recv_info, int flags, void *user_data);
+	static int SendCallback(struct socket *sock, uint32_t sb_free);
+	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
 
 
 	void GlobalInit();
 	void GlobalInit();
 	void GlobalCleanup();
 	void GlobalCleanup();

+ 10 - 12
src/transport.hpp

@@ -31,27 +31,25 @@ using namespace std::placeholders;
 
 
 class Transport {
 class Transport {
 public:
 public:
-	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(lower) { init(); }
-	virtual ~Transport() {}
+	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {
+		if (mLower)
+			mLower->onRecv(std::bind(&Transport::incoming, this, _1));
+	}
+	virtual ~Transport() {
+		if (mLower)
+			mLower->onRecv(nullptr);
+	}
 
 
 	virtual bool send(message_ptr message) = 0;
 	virtual bool send(message_ptr message) = 0;
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
 
 
 protected:
 protected:
-	void recv(message_ptr message) {
-		if (mRecvCallback)
-			mRecvCallback(message);
-	}
+	void recv(message_ptr message) { mRecvCallback(message); }
 
 
 	virtual void incoming(message_ptr message) = 0;
 	virtual void incoming(message_ptr message) = 0;
 	virtual void outgoing(message_ptr message) { getLower()->send(message); }
 	virtual void outgoing(message_ptr message) { getLower()->send(message); }
 
 
 private:
 private:
-	void init() {
-		if (mLower)
-			mLower->onRecv(std::bind(&Transport::incoming, this, _1));
-	}
-
 	std::shared_ptr<Transport> getLower() {
 	std::shared_ptr<Transport> getLower() {
 		if (mLower)
 		if (mLower)
 			return mLower;
 			return mLower;
@@ -60,7 +58,7 @@ private:
 	}
 	}
 
 
 	std::shared_ptr<Transport> mLower;
 	std::shared_ptr<Transport> mLower;
-	message_callback mRecvCallback;
+	synchronized_callback<message_ptr> mRecvCallback;
 };
 };
 
 
 } // namespace rtc
 } // namespace rtc