Преглед изворни кода

Enhanced handling of usrsctp shutdown

Paul-Louis Ageneau пре 5 година
родитељ
комит
3a737e940c
5 измењених фајлова са 92 додато и 48 уклоњено
  1. 12 3
      include/rtc/message.hpp
  2. 8 4
      src/datachannel.cpp
  3. 2 2
      src/peerconnection.cpp
  4. 68 38
      src/sctptransport.cpp
  5. 2 1
      src/sctptransport.hpp

+ 12 - 3
include/rtc/message.hpp

@@ -28,9 +28,9 @@
 namespace rtc {
 
 struct Message : binary {
-	enum Type { Binary, String, Control };
+	enum Type { Binary, String, Control, Reset };
 
-	Message(size_t size) : binary(size), type(Binary) {}
+	Message(size_t size, Type type_ = Binary) : binary(size), type(type_) {}
 
 	template <typename Iterator>
 	Message(Iterator begin_, Iterator end_, Type type_ = Binary)
@@ -46,7 +46,7 @@ using mutable_message_ptr = std::shared_ptr<Message>;
 using message_callback = std::function<void(message_ptr message)>;
 
 constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
-	return m->type != Message::Control ? m->size() : 0;
+	return m->type == Message::Binary || m->type == Message::String ? m->size() : 0;
 };
 
 template <typename Iterator>
@@ -59,6 +59,15 @@ message_ptr make_message(Iterator begin, Iterator end, Message::Type type = Mess
 	return message;
 }
 
+inline message_ptr make_message(size_t size, Message::Type type = Message::Binary,
+                                unsigned int stream = 0,
+                                std::shared_ptr<Reliability> reliability = nullptr) {
+	auto message = std::make_shared<Message>(size, type);
+	message->stream = stream;
+	message->reliability = reliability;
+	return message;
+}
+
 } // namespace rtc
 
 #endif

+ 8 - 4
src/datachannel.cpp

@@ -93,19 +93,20 @@ string DataChannel::protocol() const { return mProtocol; }
 Reliability DataChannel::reliability() const { return *mReliability; }
 
 void DataChannel::close() {
+	mIsClosed = true;
 	if (mIsOpen.exchange(false))
 		if (auto transport = mSctpTransport.lock())
-			transport->reset(mStream);
-	mIsClosed = true;
-	mSctpTransport.reset();
+			transport->close(mStream);
 
+	mSctpTransport.reset();
 	resetCallbacks();
 }
 
 void DataChannel::remoteClose() {
-	mIsOpen = false;
 	if (!mIsClosed.exchange(true))
 		triggerClosed();
+
+	mIsOpen = false;
 	mSctpTransport.reset();
 }
 
@@ -139,6 +140,9 @@ std::optional<std::variant<binary, string>> DataChannel::receive() {
 			    string(reinterpret_cast<const char *>(message->data()), message->size()));
 		case Message::Binary:
 			return std::make_optional(std::move(*message));
+		default:
+			// Ignore
+			break;
 		}
 	}
 

+ 2 - 2
src/peerconnection.cpp

@@ -422,8 +422,8 @@ void PeerConnection::forwardMessage(message_ptr message) {
 			                          weak_ptr<DataChannel>{channel}));
 			mDataChannels.insert(std::make_pair(message->stream, channel));
 		} else {
-			// Invalid, close the DataChannel by resetting the stream
-			sctpTransport->reset(message->stream);
+			// Invalid, close the DataChannel
+			sctpTransport->close(message->stream);
 			return;
 		}
 	}

+ 68 - 38
src/sctptransport.cpp

@@ -170,7 +170,9 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 SctpTransport::~SctpTransport() {
 	stop();
 
-	usrsctp_close(mSock);
+	if (mSock)
+		usrsctp_close(mSock);
+
 	usrsctp_deregister_address(this);
 }
 
@@ -188,6 +190,9 @@ bool SctpTransport::stop() {
 }
 
 void SctpTransport::connect() {
+	if (!mSock)
+		return;
+
 	PLOG_DEBUG << "SCTP connect";
 	changeState(State::Connecting);
 
@@ -211,12 +216,19 @@ void SctpTransport::connect() {
 }
 
 void SctpTransport::shutdown() {
+	if (!mSock)
+		return;
+
 	PLOG_DEBUG << "SCTP shutdown";
 
-	if (usrsctp_shutdown(mSock, SHUT_RDWR)) {
+	if (usrsctp_shutdown(mSock, SHUT_RDWR) != 0 && errno != ENOTCONN) {
 		PLOG_WARNING << "SCTP shutdown failed, errno=" << errno;
 	}
 
+	// close() abort the connection when linger is disabled, call it now
+	usrsctp_close(mSock);
+	mSock = nullptr;
+
 	PLOG_INFO << "SCTP disconnected";
 	changeState(State::Disconnected);
 	mWrittenCondition.notify_all();
@@ -238,32 +250,15 @@ bool SctpTransport::send(message_ptr message) {
 	return false;
 }
 
+void SctpTransport::close(unsigned int stream) {
+	send(make_message(0, Message::Reset, uint16_t(stream)));
+}
+
 void SctpTransport::flush() {
 	std::lock_guard lock(mSendMutex);
 	trySendQueue();
 }
 
-void SctpTransport::reset(unsigned int stream) {
-	PLOG_DEBUG << "SCTP resetting stream " << stream;
-
-	using srs_t = struct sctp_reset_streams;
-	const size_t len = sizeof(srs_t) + sizeof(uint16_t);
-	byte buffer[len] = {};
-	srs_t &srs = *reinterpret_cast<srs_t *>(buffer);
-	srs.srs_flags = SCTP_STREAM_RESET_OUTGOING;
-	srs.srs_number_streams = 1;
-	srs.srs_stream_list[0] = uint16_t(stream);
-
-	mWritten = false;
-	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
-		std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
-		mWrittenCondition.wait_for(lock, 1000ms,
-		                           [&]() { return mWritten || mState != State::Connected; });
-	} else {
-		PLOG_WARNING << "SCTP reset stream " << stream << " failed, errno=" << errno;
-	}
-}
-
 void SctpTransport::incoming(message_ptr message) {
 	// There could be a race condition here where we receive the remote INIT before the local one is
 	// sent, which would result in the connection being aborted. Therefore, we need to wait for data
@@ -303,16 +298,9 @@ bool SctpTransport::trySendQueue() {
 
 bool SctpTransport::trySendMessage(message_ptr message) {
 	// Requires mSendMutex to be locked
-	if (mState != State::Connected)
+	if (!mSock || mState != State::Connected)
 		return false;
 
-	PLOG_VERBOSE << "SCTP try send size=" << message->size();
-
-	// TODO: Implement SCTP ndata specification draft when supported everywhere
-	// See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08
-
-	const Reliability reliability = message->reliability ? *message->reliability : Reliability();
-
 	uint32_t ppid;
 	switch (message->type) {
 	case Message::String:
@@ -321,11 +309,24 @@ bool SctpTransport::trySendMessage(message_ptr message) {
 	case Message::Binary:
 		ppid = !message->empty() ? PPID_BINARY : PPID_BINARY_EMPTY;
 		break;
-	default:
+	case Message::Control:
 		ppid = PPID_CONTROL;
 		break;
+	case Message::Reset:
+		sendReset(message->stream);
+		return true;
+	default:
+		// Ignore
+		return true;
 	}
 
+	PLOG_VERBOSE << "SCTP try send size=" << message->size();
+
+	// TODO: Implement SCTP ndata specification draft when supported everywhere
+	// See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08
+
+	const Reliability reliability = message->reliability ? *message->reliability : Reliability();
+
 	struct sctp_sendv_spa spa = {};
 
 	// set sndinfo
@@ -390,6 +391,33 @@ void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
 	mBufferedAmountCallback(streamId, amount);
 }
 
+void SctpTransport::sendReset(uint16_t streamId) {
+	// Requires mSendMutex to be locked
+	if (!mSock || state() != State::Connected)
+		return;
+
+	PLOG_DEBUG << "SCTP resetting stream " << streamId;
+
+	using srs_t = struct sctp_reset_streams;
+	const size_t len = sizeof(srs_t) + sizeof(uint16_t);
+	byte buffer[len] = {};
+	srs_t &srs = *reinterpret_cast<srs_t *>(buffer);
+	srs.srs_flags = SCTP_STREAM_RESET_OUTGOING;
+	srs.srs_number_streams = 1;
+	srs.srs_stream_list[0] = streamId;
+
+	mWritten = false;
+	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
+		std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
+		mWrittenCondition.wait_for(lock, 1000ms,
+		                           [&]() { return mWritten || mState != State::Connected; });
+	} else if (errno == EINVAL) {
+		PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
+	} else {
+		PLOG_WARNING << "SCTP reset stream " << streamId << " failed, errno=" << errno;
+	}
+}
+
 bool SctpTransport::safeFlush() {
 	try {
 		flush();
@@ -566,7 +594,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 		if (flags & SCTP_STREAM_RESET_OUTGOING_SSN) {
 			for (int i = 0; i < count; ++i) {
 				uint16_t streamId = reset_event.strreset_stream_list[i];
-				reset(streamId);
+				close(streamId);
 			}
 		}
 		if (flags & SCTP_STREAM_RESET_INCOMING_SSN) {
@@ -595,15 +623,17 @@ size_t SctpTransport::bytesSent() { return mBytesSent; }
 
 size_t SctpTransport::bytesReceived() { return mBytesReceived; }
 
-std::optional<std::chrono::milliseconds> SctpTransport::rtt() {
+std::optional<milliseconds> SctpTransport::rtt() {
+	if (!mSock || state() != State::Connected)
+		return nullopt;
+
 	struct sctp_status status = {};
 	socklen_t len = sizeof(status);
-
-	if (usrsctp_getsockopt(this->mSock, IPPROTO_SCTP, SCTP_STATUS, &status, &len)) {
+	if (usrsctp_getsockopt(mSock, IPPROTO_SCTP, SCTP_STATUS, &status, &len)) {
 		PLOG_WARNING << "Could not read SCTP_STATUS";
-		return std::nullopt;
+		return nullopt;
 	}
-	return std::chrono::milliseconds(status.sstat_primary.spinfo_srtt);
+	return milliseconds(status.sstat_primary.spinfo_srtt);
 }
 
 int SctpTransport::RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data,

+ 2 - 1
src/sctptransport.hpp

@@ -51,8 +51,8 @@ public:
 
 	bool stop() override;
 	bool send(message_ptr message) override; // false if buffered
+	void close(unsigned int stream);
 	void flush();
-	void reset(unsigned int stream);
 
 	// Stats
 	void clearStats();
@@ -81,6 +81,7 @@ private:
 	bool trySendQueue();
 	bool trySendMessage(message_ptr message);
 	void updateBufferedAmount(uint16_t streamId, long delta);
+	void sendReset(uint16_t streamId);
 	bool safeFlush();
 
 	int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len,