Просмотр исходного кода

Proper handling of SCTP EOR flag at reception

Paul-Louis Ageneau 6 лет назад
Родитель
Сommit
5213f12f1a
2 измененных файлов с 46 добавлено и 23 удалено
  1. 43 20
      src/sctptransport.cpp
  2. 3 3
      src/sctptransport.hpp

+ 43 - 20
src/sctptransport.cpp

@@ -331,19 +331,34 @@ void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
 		mBufferedAmount.erase(it);
 }
 
-int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, void *data,
+int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data,
                               size_t len, struct sctp_rcvinfo info, int flags) {
-	if (!data) {
-		recv(nullptr);
-		return 0;
-	}
-	if (flags & MSG_NOTIFICATION)
-		processNotification(reinterpret_cast<const union sctp_notification *>(data), len);
-	else
-		processData(reinterpret_cast<const byte *>(data), len, info.rcv_sid,
-		            PayloadId(htonl(info.rcv_ppid)));
+	try {
+		if (!data) {
+			recv(nullptr);
+			return 0;
+		}
+		if (flags & MSG_EOR) {
+			if (!mPartialRecv.empty()) {
+				mPartialRecv.insert(mPartialRecv.end(), data, data + len);
+				data = mPartialRecv.data();
+				len = mPartialRecv.size();
+			}
+			// Message is complete, process it
+			if (flags & MSG_NOTIFICATION)
+				processNotification(reinterpret_cast<const union sctp_notification *>(data), len);
+			else
+				processData(data, len, info.rcv_sid, PayloadId(htonl(info.rcv_ppid)));
 
-	free(data);
+			mPartialRecv.clear();
+		} else {
+			// Message is not complete
+			mPartialRecv.insert(mPartialRecv.end(), data, data + len);
+		}
+	} catch (const std::exception &e) {
+		std::cerr << "SCTP recv: " << e.what() << std::endl;
+		return -1;
+	}
 	return 0; // success
 }
 
@@ -357,14 +372,18 @@ int SctpTransport::handleSend(size_t free) {
 	return 0; // success
 }
 
-int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df) {
-	byte *b = reinterpret_cast<byte *>(data);
-	outgoing(make_message(b, b + len));
+int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) {
+	try {
+		outgoing(make_message(data, data + len));
 
-	if (!mConnectDataSent) {
-		std::unique_lock<std::mutex> lock(mConnectMutex);
-		mConnectDataSent = true;
-		mConnectCondition.notify_all();
+		if (!mConnectDataSent) {
+			std::unique_lock<std::mutex> lock(mConnectMutex);
+			mConnectDataSent = true;
+			mConnectCondition.notify_all();
+		}
+	} catch (const std::exception &e) {
+		std::cerr << "SCTP write: " << e.what() << std::endl;
+		return -1;
 	}
 	return 0; // success
 }
@@ -483,7 +502,10 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 
 int SctpTransport::RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data,
                                 size_t len, struct sctp_rcvinfo recv_info, int flags, void *ptr) {
-	return static_cast<SctpTransport *>(ptr)->handleRecv(sock, addr, data, len, recv_info, flags);
+	int ret = static_cast<SctpTransport *>(ptr)->handleRecv(
+	    sock, addr, static_cast<const byte *>(data), len, recv_info, flags);
+	free(data);
+	return ret;
 }
 
 int SctpTransport::SendCallback(struct socket *sock, uint32_t sb_free) {
@@ -498,7 +520,8 @@ int SctpTransport::SendCallback(struct socket *sock, uint32_t sb_free) {
 }
 
 int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) {
-	return static_cast<SctpTransport *>(ptr)->handleWrite(data, len, tos, set_df);
+	return static_cast<SctpTransport *>(ptr)->handleWrite(static_cast<byte *>(data), len, tos,
+	                                                      set_df);
 }
 
 } // namespace rtc

+ 3 - 3
src/sctptransport.hpp

@@ -73,10 +73,10 @@ private:
 	bool trySend(message_ptr message);
 	void updateBufferedAmount(uint16_t streamId, long delta);
 
-	int handleRecv(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
+	int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len,
 	               struct sctp_rcvinfo recv_info, int flags);
 	int handleSend(size_t free);
-	int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df);
+	int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df);
 
 	void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid);
 	void processNotification(const union sctp_notification *notify, size_t len);
@@ -99,7 +99,7 @@ private:
 	state_callback mStateChangeCallback;
 	std::atomic<State> mState;
 
-	binary mPartialStringData, mPartialBinaryData;
+	binary mPartialRecv, mPartialStringData, mPartialBinaryData;
 
 	static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
 	                        struct sctp_rcvinfo recv_info, int flags, void *user_data);