Browse Source

Fixed DataChannel and SCTP shutdown

Paul-Louis Ageneau 5 years ago
parent
commit
23e1a75248

+ 1 - 1
include/rtc/datachannel.hpp

@@ -36,7 +36,7 @@ namespace rtc {
 class SctpTransport;
 class SctpTransport;
 class PeerConnection;
 class PeerConnection;
 
 
-class DataChannel : public Channel {
+class DataChannel : public std::enable_shared_from_this<DataChannel>, public Channel {
 public:
 public:
 	DataChannel(std::shared_ptr<PeerConnection> pc, unsigned int stream, string label,
 	DataChannel(std::shared_ptr<PeerConnection> pc, unsigned int stream, string label,
 	            string protocol, Reliability reliability);
 	            string protocol, Reliability reliability);

+ 2 - 1
src/datachannel.cpp

@@ -78,9 +78,9 @@ DataChannel::~DataChannel() {
 }
 }
 
 
 void DataChannel::close() {
 void DataChannel::close() {
-	mIsClosed = true;
 	if (mIsOpen.exchange(false) && mSctpTransport)
 	if (mIsOpen.exchange(false) && mSctpTransport)
 		mSctpTransport->reset(mStream);
 		mSctpTransport->reset(mStream);
+	mIsClosed = true;
 	mSctpTransport.reset();
 	mSctpTransport.reset();
 }
 }
 
 
@@ -88,6 +88,7 @@ void DataChannel::remoteClose() {
 	mIsOpen = false;
 	mIsOpen = false;
 	if (!mIsClosed.exchange(true))
 	if (!mIsClosed.exchange(true))
 		triggerClosed();
 		triggerClosed();
+	mSctpTransport.reset();
 }
 }
 
 
 bool DataChannel::send(const std::variant<binary, string> &data) {
 bool DataChannel::send(const std::variant<binary, string> &data) {

+ 29 - 25
src/dtlstransport.cpp

@@ -44,7 +44,7 @@ static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
 			PLOG_INFO << gnutls_strerror(ret);
 			PLOG_INFO << gnutls_strerror(ret);
 			return false;
 			return false;
 		}
 		}
-		PLOG_ERROR << gnutls_strerror(ret);
+		PLOG_ERROR << message << ": " << gnutls_strerror(ret);
 		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
 		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
 	}
 	}
 	return true;
 	return true;
@@ -105,6 +105,7 @@ void DtlsTransport::stop() {
 	Transport::stop();
 	Transport::stop();
 
 
 	if (mRecvThread.joinable()) {
 	if (mRecvThread.joinable()) {
+		PLOG_DEBUG << "Stopping DTLS recv thread";
 		mIncomingQueue.stop();
 		mIncomingQueue.stop();
 		gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 		gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 		mRecvThread.join();
 		mRecvThread.join();
@@ -115,6 +116,8 @@ bool DtlsTransport::send(message_ptr message) {
 	if (!message || mState != State::Connected)
 	if (!message || mState != State::Connected)
 		return false;
 		return false;
 
 
+	PLOG_VERBOSE << "Send size=" << message->size();
+
 	ssize_t ret;
 	ssize_t ret;
 	do {
 	do {
 		ret = gnutls_record_send(mSession, message->data(), message->size());
 		ret = gnutls_record_send(mSession, message->data(), message->size());
@@ -179,12 +182,15 @@ void DtlsTransport::runRecvLoop() {
 			} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
 			} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
 
 
 			// Consider premature termination as remote closing
 			// Consider premature termination as remote closing
-			if (ret == GNUTLS_E_PREMATURE_TERMINATION)
+			if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
+				PLOG_DEBUG << "DTLS connection terminated";
 				break;
 				break;
+			}
 
 
 			if (check_gnutls(ret)) {
 			if (check_gnutls(ret)) {
 				if (ret == 0) {
 				if (ret == 0) {
 					// Closed
 					// Closed
+					PLOG_DEBUG << "DTLS connection cleanly closed";
 					break;
 					break;
 				}
 				}
 				auto *b = reinterpret_cast<byte *>(buffer);
 				auto *b = reinterpret_cast<byte *>(buffer);
@@ -287,7 +293,7 @@ bool check_openssl(int success, const string &message = "OpenSSL error") {
 		return true;
 		return true;
 
 
 	string str = openssl_error_string(ERR_get_error());
 	string str = openssl_error_string(ERR_get_error());
-	PLOG_ERROR << str;
+	PLOG_ERROR << message << ": " << str;
 	throw std::runtime_error(message + ": " + str);
 	throw std::runtime_error(message + ": " + str);
 }
 }
 
 
@@ -300,7 +306,7 @@ bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error
 		return true;
 		return true;
 	}
 	}
 	if (err == SSL_ERROR_ZERO_RETURN) {
 	if (err == SSL_ERROR_ZERO_RETURN) {
-		PLOG_DEBUG << "The TLS connection has been cleanly closed";
+		PLOG_DEBUG << "DTLS connection cleanly closed";
 		return false;
 		return false;
 	}
 	}
 	string str = openssl_error_string(err);
 	string str = openssl_error_string(err);
@@ -384,7 +390,6 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 DtlsTransport::~DtlsTransport() {
 DtlsTransport::~DtlsTransport() {
 	stop();
 	stop();
 
 
-	SSL_shutdown(mSsl);
 	SSL_free(mSsl);
 	SSL_free(mSsl);
 	SSL_CTX_free(mCtx);
 	SSL_CTX_free(mCtx);
 }
 }
@@ -393,31 +398,29 @@ void DtlsTransport::stop() {
 	Transport::stop();
 	Transport::stop();
 
 
 	if (mRecvThread.joinable()) {
 	if (mRecvThread.joinable()) {
+		PLOG_DEBUG << "Stopping DTLS recv thread";
 		mIncomingQueue.stop();
 		mIncomingQueue.stop();
 		mRecvThread.join();
 		mRecvThread.join();
+
+		SSL_shutdown(mSsl);
+		writePending();
 	}
 	}
 }
 }
 
 
 DtlsTransport::State DtlsTransport::state() const { return mState; }
 DtlsTransport::State DtlsTransport::state() const { return mState; }
 
 
 bool DtlsTransport::send(message_ptr message) {
 bool DtlsTransport::send(message_ptr message) {
-	const size_t bufferSize = 4096;
-	byte buffer[bufferSize];
-
 	if (!message || mState != State::Connected)
 	if (!message || mState != State::Connected)
 		return false;
 		return false;
 
 
+	PLOG_VERBOSE << "Send size=" << message->size();
+
 	int ret = SSL_write(mSsl, message->data(), message->size());
 	int ret = SSL_write(mSsl, message->data(), message->size());
 	if (!check_openssl_ret(mSsl, ret)) {
 	if (!check_openssl_ret(mSsl, ret)) {
 		return false;
 		return false;
 	}
 	}
 
 
-	while (BIO_ctrl_pending(mOutBio) > 0) {
-		int ret = BIO_read(mOutBio, buffer, bufferSize);
-		if (check_openssl_ret(mSsl, ret) && ret > 0)
-			outgoing(make_message(buffer, buffer + ret));
-	}
-
+	writePending();
 	return true;
 	return true;
 }
 }
 
 
@@ -441,11 +444,7 @@ void DtlsTransport::runRecvLoop() {
 		changeState(State::Connecting);
 		changeState(State::Connecting);
 
 
 		SSL_do_handshake(mSsl);
 		SSL_do_handshake(mSsl);
-		while (BIO_ctrl_pending(mOutBio) > 0) {
-			int ret = BIO_read(mOutBio, buffer, bufferSize);
-			if (check_openssl_ret(mSsl, ret) && ret > 0)
-				outgoing(make_message(buffer, buffer + ret));
-		}
+		writePending();
 
 
 		while (auto next = mIncomingQueue.pop()) {
 		while (auto next = mIncomingQueue.pop()) {
 			auto message = *next;
 			auto message = *next;
@@ -460,12 +459,7 @@ void DtlsTransport::runRecvLoop() {
 				if (unsigned long err = ERR_get_error())
 				if (unsigned long err = ERR_get_error())
 					throw std::runtime_error("handshake failed: " + openssl_error_string(err));
 					throw std::runtime_error("handshake failed: " + openssl_error_string(err));
 
 
-				while (BIO_ctrl_pending(mOutBio) > 0) {
-					ret = BIO_read(mOutBio, buffer, bufferSize);
-					if (check_openssl_ret(mSsl, ret) && ret > 0)
-						outgoing(make_message(buffer, buffer + ret));
-				}
-
+				writePending();
 				if (SSL_is_init_finished(mSsl))
 				if (SSL_is_init_finished(mSsl))
 					changeState(State::Connected);
 					changeState(State::Connected);
 			}
 			}
@@ -487,6 +481,16 @@ void DtlsTransport::runRecvLoop() {
 	}
 	}
 }
 }
 
 
+void DtlsTransport::writePending() {
+	const size_t bufferSize = 4096;
+	byte buffer[bufferSize];
+	while (BIO_ctrl_pending(mOutBio) > 0) {
+		int ret = BIO_read(mOutBio, buffer, bufferSize);
+		if (check_openssl_ret(mSsl, ret) && ret > 0)
+			outgoing(make_message(buffer, buffer + ret));
+	}
+}
+
 int DtlsTransport::CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx) {
 int DtlsTransport::CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx) {
 	SSL *ssl =
 	SSL *ssl =
 	    static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
 	    static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));

+ 2 - 0
src/dtlstransport.hpp

@@ -79,6 +79,8 @@ private:
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 #else
 #else
+	void writePending();
+
 	SSL_CTX *mCtx;
 	SSL_CTX *mCtx;
 	SSL *mSsl;
 	SSL *mSsl;
 	BIO *mInBio, *mOutBio;
 	BIO *mInBio, *mOutBio;

+ 11 - 8
src/icetransport.cpp

@@ -185,6 +185,7 @@ void IceTransport::stop() {
 		mTimeoutId = 0;
 		mTimeoutId = 0;
 	}
 	}
 	if (mMainLoopThread.joinable()) {
 	if (mMainLoopThread.joinable()) {
+		PLOG_DEBUG << "Stopping ICE thread";
 		g_main_loop_quit(mMainLoop.get());
 		g_main_loop_quit(mMainLoop.get());
 		mMainLoopThread.join();
 		mMainLoopThread.join();
 	}
 	}
@@ -217,7 +218,6 @@ void IceTransport::setRemoteDescription(const Description &description) {
 
 
 bool IceTransport::addRemoteCandidate(const Candidate &candidate) {
 bool IceTransport::addRemoteCandidate(const Candidate &candidate) {
 	// Don't try to pass unresolved candidates to libnice for more safety
 	// Don't try to pass unresolved candidates to libnice for more safety
-
 	if (!candidate.isResolved())
 	if (!candidate.isResolved())
 		return false;
 		return false;
 
 
@@ -263,11 +263,11 @@ std::optional<string> IceTransport::getRemoteAddress() const {
 }
 }
 
 
 bool IceTransport::send(message_ptr message) {
 bool IceTransport::send(message_ptr message) {
-	if (!message || !mStreamId)
+	if (!message || (mState != State::Connected && mState != State::Completed))
 		return false;
 		return false;
 
 
-	outgoing(message);
-	return true;
+	PLOG_VERBOSE << "Send size=" << message->size();
+	return outgoing(message);
 }
 }
 
 
 void IceTransport::incoming(message_ptr message) { recv(message); }
 void IceTransport::incoming(message_ptr message) { recv(message); }
@@ -276,9 +276,9 @@ void IceTransport::incoming(const byte *data, int size) {
 	incoming(make_message(data, data + size));
 	incoming(make_message(data, data + size));
 }
 }
 
 
-void IceTransport::outgoing(message_ptr message) {
-	nice_agent_send(mNiceAgent.get(), mStreamId, 1, message->size(),
-	                reinterpret_cast<const char *>(message->data()));
+bool IceTransport::outgoing(message_ptr message) {
+	return nice_agent_send(mNiceAgent.get(), mStreamId, 1, message->size(),
+	                       reinterpret_cast<const char *>(message->data())) >= 0;
 }
 }
 
 
 void IceTransport::changeState(State state) {
 void IceTransport::changeState(State state) {
@@ -286,7 +286,10 @@ void IceTransport::changeState(State state) {
 		mStateChangeCallback(mState);
 		mStateChangeCallback(mState);
 }
 }
 
 
-void IceTransport::processTimeout() { changeState(State::Failed); }
+void IceTransport::processTimeout() {
+	PLOG_WARNING << "ICE timeout";
+	changeState(State::Failed);
+}
 
 
 void IceTransport::changeGatheringState(GatheringState state) {
 void IceTransport::changeGatheringState(GatheringState state) {
 	mGatheringState = state;
 	mGatheringState = state;

+ 1 - 1
src/icetransport.hpp

@@ -74,7 +74,7 @@ public:
 private:
 private:
 	void incoming(message_ptr message) override;
 	void incoming(message_ptr message) override;
 	void incoming(const byte *data, int size);
 	void incoming(const byte *data, int size);
-	void outgoing(message_ptr message) override;
+	bool outgoing(message_ptr message) override;
 
 
 	void changeState(State state);
 	void changeState(State state);
 	void changeGatheringState(GatheringState state);
 	void changeGatheringState(GatheringState state);

+ 8 - 9
src/peerconnection.cpp

@@ -50,16 +50,15 @@ void PeerConnection::close() {
 	closeDataChannels();
 	closeDataChannels();
 	mDataChannels.clear();
 	mDataChannels.clear();
 
 
-	changeState(State::Disconnected);
-
 	// Close Transports
 	// Close Transports
-	if (auto transport = std::atomic_load(&mIceTransport))
-		transport->stop();
-	if (auto transport = std::atomic_load(&mDtlsTransport))
-		transport->stop();
-	if (auto transport = std::atomic_load(&mSctpTransport))
-		transport->stop();
-
+	for (int i = 0; i < 2; ++i) { // Make sure a transport wasn't spawn behind our back
+		if (auto transport = std::atomic_load(&mSctpTransport))
+			transport->stop();
+		if (auto transport = std::atomic_load(&mDtlsTransport))
+			transport->stop();
+		if (auto transport = std::atomic_load(&mIceTransport))
+			transport->stop();
+	}
 	changeState(State::Closed);
 	changeState(State::Closed);
 }
 }
 
 

+ 65 - 36
src/sctptransport.cpp

@@ -25,6 +25,9 @@
 
 
 #include <arpa/inet.h>
 #include <arpa/inet.h>
 
 
+using namespace std::chrono_literals;
+using namespace std::chrono;
+
 using std::shared_ptr;
 using std::shared_ptr;
 
 
 namespace rtc {
 namespace rtc {
@@ -167,18 +170,14 @@ void SctpTransport::stop() {
 	onRecv(nullptr);
 	onRecv(nullptr);
 
 
 	if (!mShutdown.exchange(true)) {
 	if (!mShutdown.exchange(true)) {
-		flush();
 		mSendQueue.stop();
 		mSendQueue.stop();
-		usrsctp_shutdown(mSock, SHUT_RDWR);
-
-		// Unblock incoming
-		std::unique_lock<std::mutex> lock(mConnectMutex);
-		mConnectDataSent = true;
-		mConnectCondition.notify_all();
+		flush();
+		shutdown();
 	}
 	}
 }
 }
 
 
 void SctpTransport::connect() {
 void SctpTransport::connect() {
+	PLOG_DEBUG << "SCTP connect";
 	changeState(State::Connecting);
 	changeState(State::Connecting);
 
 
 	struct sockaddr_conn sconn = {};
 	struct sockaddr_conn sconn = {};
@@ -200,12 +199,24 @@ void SctpTransport::connect() {
 		throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno));
 		throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno));
 }
 }
 
 
+void SctpTransport::shutdown() {
+	PLOG_DEBUG << "SCTP shutdown";
+
+	if (usrsctp_shutdown(mSock, SHUT_WR))
+		PLOG_WARNING << "SCTP shutdown failed, errno=" << errno;
+
+	PLOG_INFO << "SCTP disconnected";
+	changeState(State::Disconnected);
+	mWrittenCondition.notify_all();
+}
+
 bool SctpTransport::send(message_ptr message) {
 bool SctpTransport::send(message_ptr message) {
 	std::lock_guard lock(mSendMutex);
 	std::lock_guard lock(mSendMutex);
-
 	if (!message)
 	if (!message)
 		return mSendQueue.empty();
 		return mSendQueue.empty();
 
 
+	PLOG_VERBOSE << "Send size=" << message->size();
+
 	// If nothing is pending, try to send directly
 	// If nothing is pending, try to send directly
 	if (mSendQueue.empty() && trySendMessage(message))
 	if (mSendQueue.empty() && trySendMessage(message))
 		return true;
 		return true;
@@ -221,6 +232,10 @@ void SctpTransport::flush() {
 }
 }
 
 
 void SctpTransport::reset(unsigned int stream) {
 void SctpTransport::reset(unsigned int stream) {
+	PLOG_DEBUG << "SCTP resetting stream " << stream;
+
+	std::unique_lock lock(mWriteMutex);
+	mWritten = false;
 	using srs_t = struct sctp_reset_streams;
 	using srs_t = struct sctp_reset_streams;
 	const size_t len = sizeof(srs_t) + sizeof(uint16_t);
 	const size_t len = sizeof(srs_t) + sizeof(uint16_t);
 	byte buffer[len] = {};
 	byte buffer[len] = {};
@@ -228,25 +243,33 @@ void SctpTransport::reset(unsigned int stream) {
 	srs.srs_flags = SCTP_STREAM_RESET_OUTGOING;
 	srs.srs_flags = SCTP_STREAM_RESET_OUTGOING;
 	srs.srs_number_streams = 1;
 	srs.srs_number_streams = 1;
 	srs.srs_stream_list[0] = uint16_t(stream);
 	srs.srs_stream_list[0] = uint16_t(stream);
-	usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len);
+	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
+		mWrittenCondition.wait_for(lock, 1000ms,
+		                           [&]() { return mWritten || mState != State::Connected; });
+	} else {
+		PLOG_WARNING << "SCTP reset stream " << stream << " failed, errno=" << errno;
+	}
 }
 }
 
 
 void SctpTransport::incoming(message_ptr message) {
 void SctpTransport::incoming(message_ptr message) {
-	if (!message) {
-		changeState(State::Disconnected);
-		recv(nullptr);
-		return;
-	}
-
 	// There could be a race condition here where we receive the remote INIT before the local one is
 	// There could be a race condition here where we receive the remote INIT before the local one is
 	// sent, which would result in the connection being aborted. Therefore, we need to wait for data
 	// sent, which would result in the connection being aborted. Therefore, we need to wait for data
 	// to be sent on our side (i.e. the local INIT) before proceeding.
 	// to be sent on our side (i.e. the local INIT) before proceeding.
 	{
 	{
-		std::unique_lock lock(mConnectMutex);
-		mConnectCondition.wait(lock, [this]() -> bool { return mConnectDataSent; });
+		std::unique_lock lock(mWriteMutex);
+		mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; });
 	}
 	}
 
 
-	usrsctp_conninput(this, message->data(), message->size(), 0);
+	if (message) {
+		usrsctp_conninput(this, message->data(), message->size(), 0);
+	} else {
+		if (usrsctp_shutdown(mSock, SHUT_RD))
+			PLOG_WARNING << "SCTP shutdown reading failed, errno=" << errno;
+
+		PLOG_INFO << "SCTP disconnected";
+		changeState(State::Disconnected);
+		recv(nullptr);
+	}
 }
 }
 
 
 void SctpTransport::changeState(State state) {
 void SctpTransport::changeState(State state) {
@@ -268,7 +291,11 @@ bool SctpTransport::trySendQueue() {
 
 
 bool SctpTransport::trySendMessage(message_ptr message) {
 bool SctpTransport::trySendMessage(message_ptr message) {
 	// Requires mSendMutex to be locked
 	// Requires mSendMutex to be locked
-	//
+	if (mState != State::Connected)
+		return false;
+
+	PLOG_VERBOSE << "SCTP try send size=" << message->size();
+
 	// TODO: Implement SCTP ndata specification draft when supported everywhere
 	// TODO: Implement SCTP ndata specification draft when supported everywhere
 	// See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08
 	// See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08
 
 
@@ -300,7 +327,6 @@ bool SctpTransport::trySendMessage(message_ptr message) {
 	if (reliability.unordered)
 	if (reliability.unordered)
 		spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED;
 		spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED;
 
 
-	using std::chrono::milliseconds;
 	switch (reliability.type) {
 	switch (reliability.type) {
 	case Reliability::TYPE_PARTIAL_RELIABLE_REXMIT:
 	case Reliability::TYPE_PARTIAL_RELIABLE_REXMIT:
 		spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
 		spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
@@ -326,12 +352,16 @@ bool SctpTransport::trySendMessage(message_ptr message) {
 		ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
 		ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
 	}
 	}
 
 
-	if (ret >= 0)
+	if (ret >= 0) {
+		PLOG_VERBOSE << "SCTP sent size=" << message->size();
 		return true;
 		return true;
-	else if (errno == EWOULDBLOCK && errno == EAGAIN)
+	} else if (errno == EWOULDBLOCK && errno == EAGAIN) {
+		PLOG_VERBOSE << "SCTP sending not possible ";
 		return false;
 		return false;
-	else
+	} else {
+		PLOG_ERROR << "SCTP sending failed, errno=" << errno;
 		throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
 		throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
+	}
 }
 }
 
 
 void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
 void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
@@ -347,11 +377,8 @@ void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
 int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data,
 int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data,
                               size_t len, struct sctp_rcvinfo info, int flags) {
                               size_t len, struct sctp_rcvinfo info, int flags) {
 	try {
 	try {
-		if (!data) {
-			PLOG_INFO << "SCTP connection closed";
-			recv(nullptr);
-			return 0;
-		}
+		if (!len)
+			return -1;
 		if (flags & MSG_EOR) {
 		if (flags & MSG_EOR) {
 			if (!mPartialRecv.empty()) {
 			if (!mPartialRecv.empty()) {
 				mPartialRecv.insert(mPartialRecv.end(), data, data + len);
 				mPartialRecv.insert(mPartialRecv.end(), data, data + len);
@@ -389,11 +416,12 @@ int SctpTransport::handleSend(size_t free) {
 
 
 int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) {
 int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) {
 	try {
 	try {
-		outgoing(make_message(data, data + len));
-
-		std::unique_lock lock(mConnectMutex);
-		mConnectDataSent = true;
-		mConnectCondition.notify_all();
+		std::unique_lock lock(mWriteMutex);
+		if (!outgoing(make_message(data, data + len)))
+			return -1;
+		mWritten = true;
+		mWrittenOnce = true;
+		mWrittenCondition.notify_all();
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "SCTP write: " << e.what();
 		PLOG_ERROR << "SCTP write: " << e.what();
 		return -1;
 		return -1;
@@ -479,6 +507,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 				PLOG_INFO << "SCTP disconnected";
 				PLOG_INFO << "SCTP disconnected";
 				changeState(State::Disconnected);
 				changeState(State::Disconnected);
 			}
 			}
+			mWrittenCondition.notify_all();
 		}
 		}
 	}
 	}
 	case SCTP_SENDER_DRY_EVENT: {
 	case SCTP_SENDER_DRY_EVENT: {
@@ -490,15 +519,15 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 	case SCTP_STREAM_RESET_EVENT: {
 	case SCTP_STREAM_RESET_EVENT: {
 		const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event;
 		const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event;
 		const int count = (reset_event.strreset_length - sizeof(reset_event)) / sizeof(uint16_t);
 		const int count = (reset_event.strreset_length - sizeof(reset_event)) / sizeof(uint16_t);
+		const uint16_t flags = reset_event.strreset_flags;
 
 
-		if (reset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) {
+		if (flags & SCTP_STREAM_RESET_OUTGOING_SSN) {
 			for (int i = 0; i < count; ++i) {
 			for (int i = 0; i < count; ++i) {
 				uint16_t streamId = reset_event.strreset_stream_list[i];
 				uint16_t streamId = reset_event.strreset_stream_list[i];
 				reset(streamId);
 				reset(streamId);
 			}
 			}
 		}
 		}
-
-		if (reset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) {
+		if (flags & SCTP_STREAM_RESET_INCOMING_SSN) {
 			const byte dataChannelCloseMessage{0x04};
 			const byte dataChannelCloseMessage{0x04};
 			for (int i = 0; i < count; ++i) {
 			for (int i = 0; i < count; ++i) {
 				uint16_t streamId = reset_event.strreset_stream_list[i];
 				uint16_t streamId = reset_event.strreset_stream_list[i];

+ 5 - 4
src/sctptransport.hpp

@@ -28,7 +28,6 @@
 #include <functional>
 #include <functional>
 #include <map>
 #include <map>
 #include <mutex>
 #include <mutex>
-#include <thread>
 
 
 #include <sys/socket.h>
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <sys/types.h>
@@ -69,6 +68,7 @@ private:
 	};
 	};
 
 
 	void connect();
 	void connect();
+	void shutdown();
 	void incoming(message_ptr message) override;
 	void incoming(message_ptr message) override;
 	void changeState(State state);
 	void changeState(State state);
 
 
@@ -92,9 +92,10 @@ private:
 	std::map<uint16_t, size_t> mBufferedAmount;
 	std::map<uint16_t, size_t> mBufferedAmount;
 	amount_callback mBufferedAmountCallback;
 	amount_callback mBufferedAmountCallback;
 
 
-	std::mutex mConnectMutex;
-	std::condition_variable mConnectCondition;
-	bool mConnectDataSent = false;
+	std::recursive_mutex mWriteMutex;
+	std::condition_variable_any mWrittenCondition;
+	bool mWritten = false;
+	bool mWrittenOnce = false;
 
 
 	std::atomic<bool> mShutdown = false;
 	std::atomic<bool> mShutdown = false;
 
 

+ 13 - 12
src/transport.hpp

@@ -33,12 +33,16 @@ using namespace std::placeholders;
 class Transport {
 class Transport {
 public:
 public:
 	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {
 	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {
-		if (auto lower = std::atomic_load(&mLower))
-			lower->onRecv(std::bind(&Transport::incoming, this, _1));
+		if (mLower)
+			mLower->onRecv(std::bind(&Transport::incoming, this, _1));
+	}
+	virtual ~Transport() { stop(); }
+
+	virtual void stop() {
+		if (mLower)
+			mLower->onRecv(nullptr);
 	}
 	}
-	virtual ~Transport() {}
 
 
-	virtual void stop() { resetLower(); }
 	virtual bool send(message_ptr message) = 0;
 	virtual bool send(message_ptr message) = 0;
 
 
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
@@ -46,15 +50,12 @@ public:
 protected:
 protected:
 	void recv(message_ptr message) { mRecvCallback(message); }
 	void recv(message_ptr message) { mRecvCallback(message); }
 
 
-	void resetLower() {
-		if (auto lower = std::atomic_exchange(&mLower, std::shared_ptr<Transport>(nullptr)))
-			lower->onRecv(nullptr);
-	}
-
 	virtual void incoming(message_ptr message) = 0;
 	virtual void incoming(message_ptr message) = 0;
-	virtual void outgoing(message_ptr message) {
-		if (auto lower = std::atomic_load(&mLower))
-			lower->send(message);
+	virtual bool outgoing(message_ptr message) {
+		if (mLower)
+			return mLower->send(message);
+		else
+			return false;
 	}
 	}
 
 
 private:
 private:

+ 1 - 1
test/main.cpp

@@ -29,7 +29,7 @@ using namespace std;
 template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 
 
 int main(int argc, char **argv) {
 int main(int argc, char **argv) {
-	InitLogger(LogLevel::Debug);
+	// InitLogger(LogLevel::Debug);
 	Configuration config;
 	Configuration config;
 
 
 	// config.iceServers.emplace_back("stun.l.google.com:19302");
 	// config.iceServers.emplace_back("stun.l.google.com:19302");

+ 6 - 2
test/p2p/answerer.cpp

@@ -28,7 +28,8 @@ using namespace std;
 template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 
 
 int main(int argc, char **argv) {
 int main(int argc, char **argv) {
-	rtc::Configuration config;
+	// InitLogger(LogLevel::Debug);
+	Configuration config;
 	// config.iceServers.emplace_back("stun.l.google.com:19302");
 	// config.iceServers.emplace_back("stun.l.google.com:19302");
 	// config.enableIceTcp = true;
 	// config.enableIceTcp = true;
 
 
@@ -56,9 +57,12 @@ int main(int argc, char **argv) {
 	});
 	});
 
 
 	shared_ptr<DataChannel> dc = nullptr;
 	shared_ptr<DataChannel> dc = nullptr;
-	pc->onDataChannel([&dc](shared_ptr<DataChannel> _dc) {
+	pc->onDataChannel([&](shared_ptr<DataChannel> _dc) {
 		cout << "[ Got a DataChannel with label: " << _dc->label() << " ]" << endl;
 		cout << "[ Got a DataChannel with label: " << _dc->label() << " ]" << endl;
 		dc = _dc;
 		dc = _dc;
+
+		dc->onClosed([&]() { cout << "[ DataChannel closed: " << dc->label() << " ]" << endl; });
+
 		dc->onMessage([](const variant<binary, string> &message) {
 		dc->onMessage([](const variant<binary, string> &message) {
 			if (holds_alternative<string>(message)) {
 			if (holds_alternative<string>(message)) {
 				cout << "[ Received: " << get<string>(message) << " ]" << endl;
 				cout << "[ Received: " << get<string>(message) << " ]" << endl;

+ 9 - 6
test/p2p/offerer.cpp

@@ -28,7 +28,8 @@ using namespace std;
 template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 
 
 int main(int argc, char **argv) {
 int main(int argc, char **argv) {
-	rtc::Configuration config;
+	// InitLogger(LogLevel::Debug);
+	Configuration config;
 	// config.iceServers.emplace_back("stun.l.google.com:19302");
 	// config.iceServers.emplace_back("stun.l.google.com:19302");
 	// config.enableIceTcp = true;
 	// config.enableIceTcp = true;
 
 
@@ -57,11 +58,9 @@ int main(int argc, char **argv) {
 	});
 	});
 
 
 	auto dc = pc->createDataChannel("test");
 	auto dc = pc->createDataChannel("test");
-	dc->onOpen([&]() {
-		if (!dc)
-			return;
-		cout << "[ DataChannel open: " << dc->label() << " ]" << endl;
-	});
+	dc->onOpen([&]() { cout << "[ DataChannel open: " << dc->label() << " ]" << endl; });
+
+	dc->onClosed([&]() { cout << "[ DataChannel closed: " << dc->label() << " ]" << endl; });
 
 
 	dc->onMessage([](const variant<binary, string> &message) {
 	dc->onMessage([](const variant<binary, string> &message) {
 		if (holds_alternative<string>(message)) {
 		if (holds_alternative<string>(message)) {
@@ -126,6 +125,10 @@ int main(int argc, char **argv) {
 			while (message.length() == 0)
 			while (message.length() == 0)
 				getline(cin, message);
 				getline(cin, message);
 			dc->send(message);
 			dc->send(message);
+			if (dc)
+				dc->close();
+			if (pc)
+				pc->close();
 			break;
 			break;
 
 
 		default:
 		default: