Explorar el Código

Changed SCTP to non-blocking to spare a thread and fix blocking on close

Paul-Louis Ageneau hace 5 años
padre
commit
cafc674689
Se han modificado 4 ficheros con 140 adiciones y 111 borrados
  1. 8 11
      include/rtc/queue.hpp
  2. 2 2
      src/datachannel.cpp
  3. 118 89
      src/sctptransport.cpp
  4. 12 9
      src/sctptransport.hpp

+ 8 - 11
include/rtc/queue.hpp

@@ -34,8 +34,7 @@ template <typename T> class Queue {
 public:
 	using amount_function = std::function<size_t(const T &element)>;
 
-	Queue(
-	    size_t limit = 0, amount_function func = [](const T &element) -> size_t { return 1; });
+	Queue(size_t limit = 0, amount_function func = nullptr);
 	~Queue();
 
 	void stop();
@@ -45,7 +44,7 @@ public:
 	void push(const T &element);
 	void push(T &&element);
 	std::optional<T> pop();
-	std::optional<T> tryPop();
+	std::optional<T> peek();
 	void wait();
 	void wait(const std::chrono::milliseconds &duration);
 
@@ -61,8 +60,9 @@ private:
 };
 
 template <typename T>
-Queue<T>::Queue(size_t limit, amount_function func)
-    : mLimit(limit), mAmount(0), mAmountFunction(func) {}
+Queue<T>::Queue(size_t limit, amount_function func) : mLimit(limit), mAmount(0) {
+	mAmountFunction = func ? func : [](const T &element) -> size_t { return 1; };
+}
 
 template <typename T> Queue<T>::~Queue() { stop(); }
 
@@ -105,7 +105,7 @@ template <typename T> std::optional<T> Queue<T>::pop() {
 	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 	if (!mQueue.empty()) {
 		mAmount -= mAmountFunction(mQueue.front());
-		std::optional<T> element(std::move(mQueue.front()));
+		std::optional<T> element{std::move(mQueue.front())};
 		mQueue.pop();
 		return element;
 	} else {
@@ -113,13 +113,10 @@ template <typename T> std::optional<T> Queue<T>::pop() {
 	}
 }
 
-template <typename T> std::optional<T> Queue<T>::tryPop() {
+template <typename T> std::optional<T> Queue<T>::peek() {
 	std::unique_lock<std::mutex> lock(mMutex);
 	if (!mQueue.empty()) {
-		mAmount -= mAmountFunction(mQueue.front());
-		std::optional<T> element(std::move(mQueue.front()));
-		mQueue.pop();
-		return element;
+		return std::optional<T>{mQueue.front()};
 	} else {
 		return nullopt;
 	}

+ 2 - 2
src/datachannel.cpp

@@ -98,8 +98,8 @@ void DataChannel::send(const byte *data, size_t size) {
 }
 
 std::optional<std::variant<binary, string>> DataChannel::receive() {
-	while (auto opt = mRecvQueue.tryPop()) {
-		auto message = *opt;
+	while (!mRecvQueue.empty()) {
+		auto message = *mRecvQueue.pop();
 		switch (message->type) {
 		case Message::Control: {
 			auto raw = reinterpret_cast<const uint8_t *>(message->data());

+ 118 - 89
src/sctptransport.cpp

@@ -58,11 +58,14 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 	GlobalInit();
 
 	usrsctp_register_address(this);
-	mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::ReadCallback,
-	                       nullptr, 0, this);
+	mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::RecvCallback,
+	                       &SctpTransport::SendCallback, 0, this);
 	if (!mSock)
 		throw std::runtime_error("Could not create SCTP socket, errno=" + std::to_string(errno));
 
+	if (usrsctp_set_non_blocking(mSock, 1))
+		throw std::runtime_error("Unable to set non-blocking mode, errno=" + std::to_string(errno));
+
 	// SCTP must stop sending after the lower layer is shut down, so disable linger
 	struct linger sol = {};
 	sol.l_onoff = 1;
@@ -81,9 +84,17 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 	struct sctp_event se = {};
 	se.se_assoc_id = SCTP_ALL_ASSOC;
 	se.se_on = 1;
+	se.se_type = SCTP_ASSOC_CHANGE;
+	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
+		throw std::runtime_error("Could not subscribe to event SCTP_ASSOC_CHANGE, errno=" +
+		                         std::to_string(errno));
+	se.se_type = SCTP_SENDER_DRY_EVENT;
+	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
+		throw std::runtime_error("Could not subscribe to event SCTP_SENDER_DRY_EVENT, errno=" +
+		                         std::to_string(errno));
 	se.se_type = SCTP_STREAM_RESET_EVENT;
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
-		throw std::runtime_error("Could not set socket option SCTP_EVENT, errno=" +
+		throw std::runtime_error("Could not subscribe to event SCTP_STREAM_RESET_EVENT, errno=" +
 		                         std::to_string(errno));
 
 	// Disable Nagle-like algorithm to reduce delay
@@ -127,18 +138,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 		throw std::runtime_error("Could not set SCTP send buffer size, errno=" +
 		                         std::to_string(errno));
 
-	struct sockaddr_conn sconn = {};
-	sconn.sconn_family = AF_CONN;
-	sconn.sconn_port = htons(mPort);
-	sconn.sconn_addr = this;
-#ifdef HAVE_SCONN_LEN
-	sconn.sconn_len = sizeof(sconn);
-#endif
-
-	if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)))
-		throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
-
-	mSendThread = std::thread(&SctpTransport::runConnectAndSendLoop, this);
+	connect();
 }
 
 SctpTransport::~SctpTransport() {
@@ -152,13 +152,32 @@ SctpTransport::~SctpTransport() {
 		usrsctp_close(mSock);
 	}
 
-	if (mSendThread.joinable())
-		mSendThread.join();
-
 	usrsctp_deregister_address(this);
 	GlobalCleanup();
 }
 
+void SctpTransport::connect() {
+	changeState(State::Connecting);
+
+	struct sockaddr_conn sconn = {};
+	sconn.sconn_family = AF_CONN;
+	sconn.sconn_port = htons(mPort);
+	sconn.sconn_addr = this;
+#ifdef HAVE_SCONN_LEN
+	sconn.sconn_len = sizeof(sconn);
+#endif
+
+	if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)))
+		throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
+
+	// According to the IETF draft, both endpoints must initiate the SCTP association, in a
+	// simultaneous-open manner, irrelevent to the SDP setup role.
+	// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
+	int ret = usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn));
+	if (ret && errno != EINPROGRESS)
+		throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno));
+}
+
 SctpTransport::State SctpTransport::state() const { return mState; }
 
 bool SctpTransport::send(message_ptr message) {
@@ -167,6 +186,7 @@ bool SctpTransport::send(message_ptr message) {
 
 	updateBufferedAmount(message->stream, message->size());
 	mSendQueue.push(message);
+	trySendAll();
 	return true;
 }
 
@@ -188,10 +208,9 @@ void SctpTransport::incoming(message_ptr message) {
 		return;
 	}
 
-	// There could be a race condition here where we receive the remote INIT before the thread in
-	// usrsctp_connect sends the local one, which would result in the connection being aborted.
-	// Therefore, we need to wait for data to be sent on our side (i.e. the local INIT) before
-	// proceeding.
+	// There could be a race condition here where we receive the remote INIT before the local one is
+	// sent, which would result in the connection being aborted. Therefore, we need to wait for data
+	// to be sent on our side (i.e. the local INIT) before proceeding.
 	if (!mConnectDataSent) {
 		std::unique_lock<std::mutex> lock(mConnectMutex);
 		mConnectCondition.wait(lock, [this] { return mConnectDataSent || mStopping; });
@@ -206,53 +225,22 @@ void SctpTransport::changeState(State state) {
 		mStateChangeCallback(state);
 }
 
-void SctpTransport::runConnectAndSendLoop() {
-	try {
-		changeState(State::Connecting);
-
-		struct sockaddr_conn sconn = {};
-		sconn.sconn_family = AF_CONN;
-		sconn.sconn_port = htons(mPort);
-		sconn.sconn_addr = this;
-#ifdef HAVE_SCONN_LEN
-		sconn.sconn_len = sizeof(sconn);
-#endif
-
-		// According to the IETF draft, both endpoints must initiate the SCTP association, in a
-		// simultaneous-open manner, irrelevent to the SDP setup role.
-		// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
-		if (usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)) != 0)
-			throw std::runtime_error("Connection failed, errno=" + std::to_string(errno));
-
-		if (!mStopping)
-			changeState(State::Connected);
-
-	} catch (const std::exception &e) {
-		std::cerr << "SCTP connect: " << e.what() << std::endl;
-		changeState(State::Failed);
-		mStopping = true;
-		mConnectCondition.notify_all();
-		return;
-	}
+bool SctpTransport::trySendAll() {
+	std::unique_lock<std::mutex> lock(mSendMutex, std::try_to_lock);
+	if (!lock.owns_lock())
+		return false;
 
-	try {
-		while (auto next = mSendQueue.pop()) {
-			auto message = *next;
-			bool success = doSend(message);
-			updateBufferedAmount(message->stream, -message->size());
-			if (!success)
-				throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
-		}
-	} catch (const std::exception &e) {
-		std::cerr << "SCTP send: " << e.what() << std::endl;
+	while (auto next = mSendQueue.peek()) {
+		auto message = *next;
+		if (!trySend(message))
+			return false;
+		updateBufferedAmount(message->stream, -message->size());
+		mSendQueue.pop();
 	}
-
-	changeState(State::Disconnected);
-	mStopping = true;
-	mConnectCondition.notify_all();
+	return true;
 }
 
-bool SctpTransport::doSend(message_ptr message) {
+bool SctpTransport::trySend(message_ptr message) {
 	if (!message)
 		return false;
 
@@ -309,7 +297,13 @@ bool SctpTransport::doSend(message_ptr message) {
 		const char zero = 0;
 		ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
 	}
-	return ret > 0;
+
+	if (ret >= 0)
+		return true;
+	else if (errno == EWOULDBLOCK && errno == EAGAIN)
+		return false;
+	else
+		throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
 }
 
 void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
@@ -328,6 +322,32 @@ void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
 		mBufferedAmount.erase(it);
 }
 
+int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, void *data,
+                              size_t len, struct sctp_rcvinfo info, int flags) {
+	if (!data) {
+		recv(nullptr);
+		return 0;
+	}
+	if (flags & MSG_NOTIFICATION)
+		processNotification(reinterpret_cast<const union sctp_notification *>(data), len);
+	else
+		processData(reinterpret_cast<const byte *>(data), len, info.rcv_sid,
+		            PayloadId(htonl(info.rcv_ppid)));
+
+	free(data);
+	return 0; // success
+}
+
+int SctpTransport::handleSend(size_t free) {
+	try {
+		trySendAll();
+	} catch (const std::exception &e) {
+		std::cerr << "SCTP send: " << e.what() << std::endl;
+		return -1;
+	}
+	return 0; // success
+}
+
 int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df) {
 	byte *b = reinterpret_cast<byte *>(data);
 	outgoing(make_message(b, b + len));
@@ -340,21 +360,6 @@ int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_
 	return 0; // success
 }
 
-int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
-                           struct sctp_rcvinfo info, int flags) {
-	if (!data) {
-		recv(nullptr);
-		return 0;
-	}
-	if (flags & MSG_NOTIFICATION) {
-		processNotification((union sctp_notification *)data, len);
-	} else {
-		processData((const byte *)data, len, info.rcv_sid, PayloadId(htonl(info.rcv_ppid)));
-	}
-	free(data);
-	return 0;
-}
-
 void SctpTransport::processData(const byte *data, size_t len, uint16_t sid, PayloadId ppid) {
 	Message::Type type;
 	switch (ppid) {
@@ -388,6 +393,21 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 		return;
 
 	switch (notify->sn_header.sn_type) {
+	case SCTP_ASSOC_CHANGE: {
+		const struct sctp_assoc_change *assoc_change = &notify->sn_assoc_change;
+		std::unique_lock<std::mutex> lock(mConnectMutex);
+		if (assoc_change->sac_state == SCTP_COMM_UP) {
+			changeState(State::Connected);
+		} else {
+			std::cerr << "SCTP connection failed" << std::endl;
+			changeState(State::Failed);
+		}
+	}
+	case SCTP_SENDER_DRY_EVENT: {
+		// It not should be necessary since the send callback should have been called already,
+		// but to be sure, let's try to send now.
+		trySendAll();
+	}
 	case SCTP_STREAM_RESET_EVENT: {
 		const struct sctp_stream_reset_event *reset_event = &notify->sn_strreset_event;
 		const int count = (reset_event->strreset_length - sizeof(*reset_event)) / sizeof(uint16_t);
@@ -415,16 +435,25 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 		break;
 	}
 }
-int SctpTransport::WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos,
-                                 uint8_t set_df) {
-	return static_cast<SctpTransport *>(sctp_ptr)->handleWrite(data, len, tos, set_df);
+
+int SctpTransport::RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data,
+                                size_t len, struct sctp_rcvinfo recv_info, int flags, void *ptr) {
+	return static_cast<SctpTransport *>(ptr)->handleRecv(sock, addr, data, len, recv_info, flags);
+}
+
+int SctpTransport::SendCallback(struct socket *sock, uint32_t sb_free) {
+	struct sctp_paddrinfo paddrinfo = {};
+	socklen_t len = sizeof(paddrinfo);
+	if (usrsctp_getsockopt(sock, IPPROTO_SCTP, SCTP_GET_PEER_ADDR_INFO, &paddrinfo, &len))
+		return -1;
+
+	auto sconn = reinterpret_cast<struct sockaddr_conn *>(&paddrinfo.spinfo_address);
+	void *ptr = sconn->sconn_addr;
+	return static_cast<SctpTransport *>(ptr)->handleSend(size_t(sb_free));
 }
 
-int SctpTransport::ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data,
-                                size_t len, struct sctp_rcvinfo recv_info, int flags,
-                                void *user_data) {
-	return static_cast<SctpTransport *>(user_data)->process(sock, addr, data, len, recv_info,
-	                                                        flags);
+int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) {
+	return static_cast<SctpTransport *>(ptr)->handleWrite(data, len, tos, set_df);
 }
 
 } // namespace rtc

+ 12 - 9
src/sctptransport.hpp

@@ -62,27 +62,29 @@ private:
 		PPID_BINARY_EMPTY = 57
 	};
 
+	void connect();
 	void incoming(message_ptr message);
 	void changeState(State state);
-	void runConnectAndSendLoop();
-	bool doSend(message_ptr message);
+	bool trySendAll();
+	bool trySend(message_ptr message);
 	void updateBufferedAmount(uint16_t streamId, long delta);
 
+	int handleRecv(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
+	               struct sctp_rcvinfo recv_info, int flags);
+	int handleSend(size_t free);
 	int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df);
 
-	int process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
-	            struct sctp_rcvinfo recv_info, int flags);
-
 	void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid);
 	void processNotification(const union sctp_notification *notify, size_t len);
 
 	const uint16_t mPort;
 	struct socket *mSock;
 
+	std::mutex mSendMutex;
 	Queue<message_ptr> mSendQueue;
-	std::thread mSendThread;
-	std::map<uint16_t, size_t> mBufferedAmount;
+
 	std::mutex mBufferedAmountMutex;
+	std::map<uint16_t, size_t> mBufferedAmount;
 	amount_callback mBufferedAmountCallback;
 
 	std::mutex mConnectMutex;
@@ -93,9 +95,10 @@ private:
 	state_callback mStateChangeCallback;
 	std::atomic<State> mState;
 
-	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
-	static int ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
+	static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
 	                        struct sctp_rcvinfo recv_info, int flags, void *user_data);
+	static int SendCallback(struct socket *sock, uint32_t sb_free);
+	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
 
 	void GlobalInit();
 	void GlobalCleanup();