Browse Source

Merge pull request #254 from paullouisageneau/usrsctp-no-callback

Change usrsctp callbacks to upcall
Paul-Louis Ageneau 4 years ago
parent
commit
f2dd46e589
3 changed files with 87 additions and 82 deletions
  1. 79 70
      src/sctptransport.cpp
  2. 6 10
      src/sctptransport.hpp
  3. 2 2
      src/transport.hpp

+ 79 - 70
src/sctptransport.cpp

@@ -88,7 +88,7 @@ void SctpTransport::Cleanup() {
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
                              message_callback recvCallback, amount_callback bufferedAmountCallback,
                              state_callback stateChangeCallback)
-    : Transport(lower, std::move(stateChangeCallback)), mPort(port), mProcessor(16),
+    : Transport(lower, std::move(stateChangeCallback)), mPort(port), mReceiving(false),
       mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
 	onRecv(recvCallback);
 
@@ -100,11 +100,12 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 		Instances.insert(this);
 	}
 
-	mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::RecvCallback,
-	                       &SctpTransport::SendCallback, 0, this);
+	mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, nullptr, nullptr, 0, nullptr);
 	if (!mSock)
 		throw std::runtime_error("Could not create SCTP socket, errno=" + std::to_string(errno));
 
+	usrsctp_set_upcall(mSock, &SctpTransport::UpcallCallback, this);
+
 	if (usrsctp_set_non_blocking(mSock, 1))
 		throw std::runtime_error("Unable to set non-blocking mode, errno=" + std::to_string(errno));
 
@@ -122,6 +123,10 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av)))
 		throw std::runtime_error("Could not set socket option SCTP_ENABLE_STREAM_RESET, errno=" +
 		                         std::to_string(errno));
+	int on = 1;
+	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RECVRCVINFO, &on, sizeof(on)))
+		throw std::runtime_error("Could set socket option SCTP_RECVRCVINFO, errno=" +
+		                         std::to_string(errno));
 
 	struct sctp_event se = {};
 	se.se_assoc_id = SCTP_ALL_ASSOC;
@@ -225,21 +230,12 @@ bool SctpTransport::stop() {
 
 void SctpTransport::close() {
 	if (mSock) {
+		mProcessor.join();
 		usrsctp_close(mSock);
 		mSock = nullptr;
 	}
 }
 
-void SctpTransport::recv(message_ptr message) {
-	// Delegate to processor to release SCTP thread
-	mProcessor.enqueue([this, message = std::move(message)]() { Transport::recv(message); });
-}
-
-void SctpTransport::changeState(State state) {
-	// Delegate to processor to release SCTP thread
-	mProcessor.enqueue([this, state]() { Transport::changeState(state); });
-}
-
 void SctpTransport::connect() {
 	if (!mSock)
 		throw std::logic_error("Attempted SCTP connect with closed socket");
@@ -329,6 +325,61 @@ void SctpTransport::incoming(message_ptr message) {
 	usrsctp_conninput(this, message->data(), message->size(), 0);
 }
 
+void SctpTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	try {
+		scope_guard scope([this]() { mReceiving = false; });
+		mReceiving = true;
+		while (true) {
+			const size_t bufferSize = 65536;
+			byte buffer[bufferSize];
+			socklen_t fromlen = 0;
+			struct sctp_rcvinfo info = {};
+			socklen_t infolen = sizeof(info);
+			unsigned int infotype = 0;
+			int flags = 0;
+			ssize_t len = usrsctp_recvv(mSock, buffer, bufferSize, nullptr, &fromlen, &info,
+			                            &infolen, &infotype, &flags);
+			if (len < 0) {
+				if (errno == EWOULDBLOCK || errno == EAGAIN || errno == ECONNRESET)
+					break;
+				else
+					throw std::runtime_error("SCTP recv failed, errno=" + std::to_string(errno));
+			}
+
+			PLOG_VERBOSE << "SCTP recv, len=" << len;
+
+			// SCTP_FRAGMENT_INTERLEAVE does not seem to work as expected for messages > 64KB,
+			// therefore partial notifications and messages need to be handled separately.
+			if (flags & MSG_NOTIFICATION) {
+				// SCTP event notification
+				mPartialNotification.insert(mPartialNotification.end(), buffer, buffer + len);
+				if (flags & MSG_EOR) {
+					// Notification is complete, process it
+					auto notification =
+					    reinterpret_cast<union sctp_notification *>(mPartialNotification.data());
+					processNotification(notification, mPartialNotification.size());
+					mPartialNotification.clear();
+				}
+			} else {
+				// SCTP message
+				mPartialMessage.insert(mPartialMessage.end(), buffer, buffer + len);
+				if (flags & MSG_EOR) {
+					// Message is complete, process it
+					if (infotype != SCTP_RECVV_RCVINFO)
+						throw std::runtime_error("Missing SCTP recv info");
+
+					processData(std::move(mPartialMessage), info.rcv_sid,
+					            PayloadId(ntohl(info.rcv_ppid)));
+					mPartialMessage.clear();
+				}
+			}
+		}
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+	}
+}
+
 bool SctpTransport::trySendQueue() {
 	// Requires mSendMutex to be locked
 	while (auto next = mSendQueue.peek()) {
@@ -482,44 +533,19 @@ bool SctpTransport::safeFlush() {
 	}
 }
 
-int SctpTransport::handleRecv(struct socket * /*sock*/, union sctp_sockstore /*addr*/,
-                              const byte *data, size_t len, struct sctp_rcvinfo info, int flags) {
-	try {
-		PLOG_VERBOSE << "Handle recv, len=" << len;
-
-		// SCTP_FRAGMENT_INTERLEAVE does not seem to work as expected for messages > 64KB,
-		// therefore partial notifications and messages need to be handled separately.
-		if (flags & MSG_NOTIFICATION) {
-			// SCTP event notification
-			mPartialNotification.insert(mPartialNotification.end(), data, data + len);
-			if (flags & MSG_EOR) {
-				// Notification is complete, process it
-				processNotification(
-				    reinterpret_cast<const union sctp_notification *>(mPartialNotification.data()),
-				    mPartialNotification.size());
-				mPartialNotification.clear();
-			}
-		} else {
-			// SCTP message
-			mPartialMessage.insert(mPartialMessage.end(), data, data + len);
-			if (flags & MSG_EOR) {
-				// Message is complete, process it
-				processData(std::move(mPartialMessage), info.rcv_sid,
-				            PayloadId(ntohl(info.rcv_ppid)));
-				mPartialMessage.clear();
-			}
-		}
+void SctpTransport::handleUpcall() {
+	if(!mSock)
+		return;
 
-	} catch (const std::exception &e) {
-		PLOG_ERROR << "SCTP recv: " << e.what();
-		return -1;
-	}
-	return 0; // success
-}
+	PLOG_VERBOSE << "Handle upcall";
+
+	int events = usrsctp_get_events(mSock);
+
+	if ((events & SCTP_EVENT_READ) && !mReceiving.exchange(true))
+		mProcessor.enqueue(&SctpTransport::doRecv, this);
 
-int SctpTransport::handleSend(size_t free) {
-	PLOG_VERBOSE << "Handle send, free=" << free;
-	return safeFlush() ? 0 : -1;
+	if (events & SCTP_EVENT_WRITE)
+		mProcessor.enqueue(&SctpTransport::safeFlush, this);
 }
 
 int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/, uint8_t /*set_df*/) {
@@ -709,31 +735,14 @@ std::optional<milliseconds> SctpTransport::rtt() {
 	return milliseconds(status.sstat_primary.spinfo_srtt);
 }
 
-int SctpTransport::RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data,
-                                size_t len, struct sctp_rcvinfo recv_info, int flags,
-                                void *ulp_info) {
-	auto *transport = static_cast<SctpTransport *>(ulp_info);
-
-	std::shared_lock lock(InstancesMutex);
-	if (Instances.find(transport) == Instances.end()) {
-		free(data);
-		return -1;
-	}
-
-	int ret =
-	    transport->handleRecv(sock, addr, static_cast<const byte *>(data), len, recv_info, flags);
-	free(data);
-	return ret;
-}
-
-int SctpTransport::SendCallback(struct socket *, uint32_t sb_free, void *ulp_info) {
-	auto *transport = static_cast<SctpTransport *>(ulp_info);
+void SctpTransport::UpcallCallback(struct socket *, void *arg, int /* flags */) {
+	auto *transport = static_cast<SctpTransport *>(arg);
 
 	std::shared_lock lock(InstancesMutex);
 	if (Instances.find(transport) == Instances.end())
-		return -1;
+		return;
 
-	return transport->handleSend(size_t(sb_free));
+	transport->handleUpcall();
 }
 
 int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) {

+ 6 - 10
src/sctptransport.hpp

@@ -24,6 +24,7 @@
 #include "processor.hpp"
 #include "queue.hpp"
 #include "transport.hpp"
+#include "processor.hpp"
 
 #include <condition_variable>
 #include <functional>
@@ -72,23 +73,19 @@ private:
 		PPID_BINARY_EMPTY = 57
 	};
 
-	void recv(message_ptr message) override;
-	void changeState(State state) override;
-
 	void connect();
 	void shutdown();
 	void close();
 	void incoming(message_ptr message) override;
 
+	void doRecv();
 	bool trySendQueue();
 	bool trySendMessage(message_ptr message);
 	void updateBufferedAmount(uint16_t streamId, long delta);
 	void sendReset(uint16_t streamId);
 	bool safeFlush();
 
-	int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len,
-	               struct sctp_rcvinfo recv_info, int flags);
-	int handleSend(size_t free);
+	void handleUpcall();
 	int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df);
 
 	void processData(binary &&data, uint16_t streamId, PayloadId ppid);
@@ -98,7 +95,8 @@ private:
 	struct socket *mSock;
 
 	Processor mProcessor;
-	std::mutex mSendMutex;
+	std::mutex mRecvMutex, mSendMutex;
+	std::atomic<bool> mReceiving;
 	Queue<message_ptr> mSendQueue;
 	std::map<uint16_t, size_t> mBufferedAmount;
 	amount_callback mBufferedAmountCallback;
@@ -114,9 +112,7 @@ private:
 	// Stats
 	std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
 
-	static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
-	                        struct sctp_rcvinfo recv_info, int flags, void *ulp_info);
-	static int SendCallback(struct socket *sock, uint32_t sb_free, void *ulp_info);
+	static void UpcallCallback(struct socket *sock, void *arg, int flags);
 	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
 
 	static std::unordered_set<SctpTransport *> Instances;

+ 2 - 2
src/transport.hpp

@@ -67,14 +67,14 @@ public:
 	virtual bool send(message_ptr message) { return outgoing(message); }
 
 protected:
-	virtual void recv(message_ptr message) {
+	void recv(message_ptr message) {
 		try {
 			mRecvCallback(message);
 		} catch (const std::exception &e) {
 			PLOG_WARNING << e.what();
 		}
 	}
-	virtual void changeState(State state) {
+	void changeState(State state) {
 		try {
 			if (mState.exchange(state) != state)
 				mStateChangeCallback(state);