Browse Source

Properly handle unexpected DTLS termination

Paul-Louis Ageneau 6 years ago
parent
commit
68b0ab73b5
5 changed files with 77 additions and 38 deletions
  1. 3 1
      include/rtc/peerconnection.hpp
  2. 25 15
      src/dtlstransport.cpp
  3. 1 1
      src/icetransport.cpp
  4. 16 2
      src/peerconnection.cpp
  5. 32 19
      src/sctptransport.cpp

+ 3 - 1
include/rtc/peerconnection.hpp

@@ -66,7 +66,9 @@ private:
 
 	bool checkFingerprint(const std::string &fingerprint) const;
 	void forwardMessage(message_ptr message);
-	void openDataChannels(void);
+	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
+	void openDataChannels();
+	void closeDataChannels();
 
 	void processLocalDescription(Description description);
 	void processLocalCandidate(std::optional<Candidate> candidate);

+ 25 - 15
src/dtlstransport.cpp

@@ -22,6 +22,7 @@
 #include <cassert>
 #include <cstring>
 #include <exception>
+#include <iostream>
 
 #include <gnutls/dtls.h>
 
@@ -82,6 +83,9 @@ DtlsTransport::~DtlsTransport() {
 }
 
 bool DtlsTransport::send(message_ptr message) {
+	if (!message)
+		return false;
+
 	while (true) {
 		ssize_t ret = gnutls_record_send(mSession, message->data(), message->size());
 		if (check_gnutls(ret)) {
@@ -93,24 +97,31 @@ bool DtlsTransport::send(message_ptr message) {
 void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); }
 
 void DtlsTransport::runRecvLoop() {
-	while (!check_gnutls(gnutls_handshake(mSession), "TLS handshake failed")) {}
+	try {
+		while (!check_gnutls(gnutls_handshake(mSession), "TLS handshake failed")) {
+		}
 
-	mReadyCallback();
+		mReadyCallback();
 
-	const size_t bufferSize = 2048;
-	char buffer[bufferSize];
+		const size_t bufferSize = 2048;
+		char buffer[bufferSize];
 
-	while (true) {
-		ssize_t ret = gnutls_record_recv(mSession, buffer, bufferSize);
-		if (check_gnutls(ret)) {
-			if (ret == 0) {
-				// Closed
-				break;
+		while (true) {
+			ssize_t ret = gnutls_record_recv(mSession, buffer, bufferSize);
+			if (check_gnutls(ret)) {
+				if (ret == 0) {
+					// Closed
+					break;
+				}
+				auto *b = reinterpret_cast<byte *>(buffer);
+				recv(make_message(b, b + ret));
 			}
-			auto *b = reinterpret_cast<byte *>(buffer);
-			recv(make_message(b, b + ret));
 		}
+	} catch (const std::exception &e) {
+		std::cerr << "DTLS recv: " << e.what() << std::endl;
 	}
+
+	recv(nullptr);
 }
 
 int DtlsTransport::CertificateCallback(gnutls_session_t session) {
@@ -120,7 +131,6 @@ int DtlsTransport::CertificateCallback(gnutls_session_t session) {
 		return GNUTLS_E_CERTIFICATE_ERROR;
 	}
 
-	// Get peer's certificate
 	unsigned int count = 0;
 	const gnutls_datum_t *array = gnutls_certificate_get_peers(session, &count);
 	if (!array || count == 0) {
@@ -155,13 +165,13 @@ ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *dat
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	auto next = t->mIncomingQueue.pop();
-	if (!next) {
+	auto message = next ? *next : nullptr;
+	if (!message) {
 		// Closed
 		gnutls_transport_set_errno(t->mSession, 0);
 		return 0;
 	}
 
-	auto message = *next;
 	ssize_t len = std::min(maxlen, message->size());
 	std::memcpy(data, message->data(), len);
 	gnutls_transport_set_errno(t->mSession, 0);

+ 1 - 1
src/icetransport.cpp

@@ -167,7 +167,7 @@ bool IceTransport::addRemoteCandidate(const Candidate &candidate) {
 }
 
 bool IceTransport::send(message_ptr message) {
-	if (!mStreamId)
+	if (!message || !mStreamId)
 		return false;
 
 	outgoing(message);

+ 16 - 2
src/peerconnection.cpp

@@ -138,6 +138,11 @@ void PeerConnection::forwardMessage(message_ptr message) {
 	if (!mIceTransport || !mSctpTransport)
 		throw std::logic_error("Got a DataChannel message without transport");
 
+	if (!message) {
+		closeDataChannels();
+		return;
+	}
+
 	shared_ptr<DataChannel> channel;
 	if (auto it = mDataChannels.find(message->stream); it != mDataChannels.end()) {
 		channel = it->second.lock();
@@ -165,7 +170,8 @@ void PeerConnection::forwardMessage(message_ptr message) {
 	channel->incoming(message);
 }
 
-void PeerConnection::openDataChannels(void) {
+void PeerConnection::iterateDataChannels(
+    std::function<void(shared_ptr<DataChannel> channel)> func) {
 	auto it = mDataChannels.begin();
 	while (it != mDataChannels.end()) {
 		auto channel = it->second.lock();
@@ -173,11 +179,19 @@ void PeerConnection::openDataChannels(void) {
 			it = mDataChannels.erase(it);
 			continue;
 		}
-		channel->open(mSctpTransport);
+		func(channel);
 		++it;
 	}
 }
 
+void PeerConnection::openDataChannels() {
+	iterateDataChannels([this](shared_ptr<DataChannel> channel) { channel->open(mSctpTransport); });
+}
+
+void PeerConnection::closeDataChannels() {
+	iterateDataChannels([](shared_ptr<DataChannel> channel) { channel->close(); });
+}
+
 void PeerConnection::processLocalDescription(Description description) {
 	auto remoteSctpPort = mRemoteDescription ? mRemoteDescription->sctpPort() : nullopt;
 

+ 32 - 19
src/sctptransport.cpp

@@ -120,6 +120,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, re
 
 SctpTransport::~SctpTransport() {
 	mStopping = true;
+	mConnectCondition.notify_all();
 	if (mConnectThread.joinable())
 		mConnectThread.join();
 
@@ -135,6 +136,9 @@ SctpTransport::~SctpTransport() {
 bool SctpTransport::isReady() const { return mIsReady; }
 
 bool SctpTransport::send(message_ptr message) {
+	if (!message)
+		return false;
+
 	const Reliability reliability = message->reliability ? *message->reliability : Reliability();
 
 	struct sctp_sendv_spa spa = {};
@@ -201,6 +205,11 @@ void SctpTransport::reset(unsigned int stream) {
 }
 
 void SctpTransport::incoming(message_ptr message) {
+	if (!message) {
+		recv(nullptr);
+		return;
+	}
+
 	// There could be a race condition here where we receive the remote INIT before the thread in
 	// usrsctp_connect sends the local one, which would result in the connection being aborted.
 	// Therefore, we need to wait for data to be sent on our side (i.e. the local INIT) before
@@ -215,26 +224,31 @@ void SctpTransport::incoming(message_ptr message) {
 }
 
 void SctpTransport::runConnect() {
-	struct sockaddr_conn sconn = {};
-	sconn.sconn_family = AF_CONN;
-	sconn.sconn_port = htons(mPort);
-	sconn.sconn_addr = this;
+	try {
+		struct sockaddr_conn sconn = {};
+		sconn.sconn_family = AF_CONN;
+		sconn.sconn_port = htons(mPort);
+		sconn.sconn_addr = this;
 #ifdef HAVE_SCONN_LEN
-	sconn.sconn_len = sizeof(sconn);
+		sconn.sconn_len = sizeof(sconn);
 #endif
 
-	// According to the IETF draft, both endpoints must initiate the SCTP association, in a
-	// simultaneous-open manner, irrelevent to the SDP setup role.
-	// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
-	if (usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)) != 0) {
-		std::cerr << "SCTP connection failed, errno=" << errno << std::endl;
-		mStopping = true;
-		return;
-	}
+		// According to the IETF draft, both endpoints must initiate the SCTP association, in a
+		// simultaneous-open manner, irrelevent to the SDP setup role.
+		// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
+		if (usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)) !=
+		    0) {
+			std::cerr << "SCTP connection failed, errno=" << errno << std::endl;
+			mStopping = true;
+			return;
+		}
 
-	if (!mStopping) {
-		mIsReady = true;
-		mReadyCallback();
+		if (!mStopping) {
+			mIsReady = true;
+			mReadyCallback();
+		}
+	} catch (const std::exception &e) {
+		std::cerr << "SCTP connect: " << e.what() << std::endl;
 	}
 }
 
@@ -251,12 +265,11 @@ int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_
 }
 
 int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
-                           struct sctp_rcvinfo recv_info, int flags) {
+                           struct sctp_rcvinfo info, int flags) {
 	if (flags & MSG_NOTIFICATION) {
 		processNotification((union sctp_notification *)data, len);
 	} else {
-		processData((const byte *)data, len, recv_info.rcv_sid,
-		            PayloadId(htonl(recv_info.rcv_ppid)));
+		processData((const byte *)data, len, info.rcv_sid, PayloadId(htonl(info.rcv_ppid)));
 	}
 	free(data);
 	return 0;