Browse Source

Merge pull request #1031 from paullouisageneau/refactor-reliability

Refactor Data Channel reliability API
Paul-Louis Ageneau 1 year ago
parent
commit
d8c63d6bf8

+ 1 - 0
CMakeLists.txt

@@ -191,6 +191,7 @@ set(TESTS_SOURCES
     ${CMAKE_CURRENT_SOURCE_DIR}/test/main.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/connectivity.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/negotiated.cpp
+    ${CMAKE_CURRENT_SOURCE_DIR}/test/reliability.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/turn_connectivity.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/track.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/capi_connectivity.cpp

+ 2 - 2
DOC.md

@@ -644,8 +644,8 @@ Arguments:
   - `reliability`: a structure of reliability settings containing:
     - `unordered`: if `true`, the Data Channel will not enforce message ordering, else it will be ordered
     - `unreliable`: if `true`, the Data Channel will not enforce strict reliability, else it will be reliable
-    - `maxPacketLifeTime`: if unreliable, maximum packet life time in milliseconds
-    - `maxRetransmits`: if unreliable and maxPacketLifeTime is 0, maximum number of retransmissions (0 means no retransmission)
+    - `maxPacketLifeTime`: if unreliable, time window in milliseconds during which transmissions and retransmissions may occur
+    - `maxRetransmits`: if unreliable and maxPacketLifeTime is 0, maximum number of attempted retransmissions (0 means no retransmission)
   - `protocol` (optional): a user-defined UTF-8 string representing the Data Channel protocol, empty if NULL
   - `negotiated`: if `true`, the Data Channel is assumed to be negotiated by the user and won't be negotiated by the WebRTC layer
   - `manualStream`: if `true`, the Data Channel will use `stream` as stream ID, else an available id is automatically selected

+ 18 - 3
include/rtc/reliability.hpp

@@ -16,10 +16,25 @@
 namespace rtc {
 
 struct Reliability {
-	enum class Type { Reliable = 0, Rexmit, Timed };
-
-	Type type = Type::Reliable;
+	// It true, the channel does not enforce message ordering and out-of-order delivery is allowed
 	bool unordered = false;
+
+	// If both maxPacketLifeTime or maxRetransmits are unset, the channel is reliable.
+	// If either maxPacketLifeTime or maxRetransmits is set, the channel is unreliable.
+	// (The settings are exclusive so both maxPacketLifetime and maxRetransmits must not be set.)
+
+	// Time window during which transmissions and retransmissions may occur
+	optional<std::chrono::milliseconds> maxPacketLifeTime;
+
+	// Maximum number of retransmissions that are attempted
+	optional<unsigned int> maxRetransmits;
+
+	// For backward compatibility, do not use
+	enum class Type { Reliable = 0, Rexmit, Timed };
+	union {
+		Type typeDeprecated = Type::Reliable;
+		[[deprecated("Use maxPacketLifeTime or maxRetransmits")]] Type type;
+	};
 	variant<int, std::chrono::milliseconds> rexmit = 0;
 };
 

+ 2 - 2
include/rtc/rtc.h

@@ -245,8 +245,8 @@ RTC_C_EXPORT int rtcReceiveMessage(int id, char *buffer, int *size);
 typedef struct {
 	bool unordered;
 	bool unreliable;
-	int maxPacketLifeTime; // ignored if reliable
-	int maxRetransmits;    // ignored if reliable
+	unsigned int maxPacketLifeTime; // ignored if reliable
+	unsigned int maxRetransmits;    // ignored if reliable
 } rtcReliability;
 
 typedef struct {

+ 2 - 2
pages/content/pages/reference.md

@@ -647,8 +647,8 @@ Arguments:
   - `reliability`: a structure of reliability settings containing:
     - `unordered`: if `true`, the Data Channel will not enforce message ordering, else it will be ordered
     - `unreliable`: if `true`, the Data Channel will not enforce strict reliability, else it will be reliable
-    - `maxPacketLifeTime`: if unreliable, maximum packet life time in milliseconds
-    - `maxRetransmits`: if unreliable and maxPacketLifeTime is 0, maximum number of retransmissions (0 means no retransmission)
+    - `maxPacketLifeTime`: if unreliable, time window in milliseconds during which transmissions and retransmissions may occur
+    - `maxRetransmits`: if unreliable and maxPacketLifeTime is 0, maximum number of attempted retransmissions (0 means no retransmission)
   - `protocol` (optional): a user-defined UTF-8 string representing the Data Channel protocol, empty if NULL
   - `negotiated`: if `true`, the Data Channel is assumed to be negotiated by the user and won't be negotiated by the WebRTC layer
   - `manualStream`: if `true`, the Data Channel will use `stream` as stream ID, else an available id is automatically selected

+ 8 - 13
src/capi.cpp

@@ -901,15 +901,10 @@ int rtcCreateDataChannelEx(int pc, const char *label, const rtcDataChannelInit *
 			auto *reliability = &init->reliability;
 			dci.reliability.unordered = reliability->unordered;
 			if (reliability->unreliable) {
-				if (reliability->maxPacketLifeTime > 0) {
-					dci.reliability.type = Reliability::Type::Timed;
-					dci.reliability.rexmit = milliseconds(reliability->maxPacketLifeTime);
-				} else {
-					dci.reliability.type = Reliability::Type::Rexmit;
-					dci.reliability.rexmit = reliability->maxRetransmits;
-				}
-			} else {
-				dci.reliability.type = Reliability::Type::Reliable;
+				if (reliability->maxPacketLifeTime > 0)
+					dci.reliability.maxPacketLifeTime.emplace(milliseconds(reliability->maxPacketLifeTime));
+				else
+					dci.reliability.maxRetransmits.emplace(reliability->maxRetransmits);
 			}
 
 			dci.negotiated = init->negotiated;
@@ -971,12 +966,12 @@ int rtcGetDataChannelReliability(int dc, rtcReliability *reliability) {
 		Reliability dcr = dataChannel->reliability();
 		std::memset(reliability, 0, sizeof(*reliability));
 		reliability->unordered = dcr.unordered;
-		if (dcr.type == Reliability::Type::Timed) {
+		if(dcr.maxPacketLifeTime) {
 			reliability->unreliable = true;
-			reliability->maxPacketLifeTime = int(std::get<milliseconds>(dcr.rexmit).count());
-		} else if (dcr.type == Reliability::Type::Rexmit) {
+			reliability->maxPacketLifeTime = static_cast<unsigned int>(dcr.maxPacketLifeTime->count());
+		} else if (dcr.maxRetransmits) {
 			reliability->unreliable = true;
-			reliability->maxRetransmits = std::get<int>(dcr.rexmit);
+			reliability->maxRetransmits = *dcr.maxRetransmits;
 		} else {
 			reliability->unreliable = false;
 		}

+ 56 - 21
src/impl/datachannel.cpp

@@ -12,7 +12,7 @@
 #include "logcounter.hpp"
 #include "peerconnection.hpp"
 #include "sctptransport.hpp"
-
+#include "utils.hpp"
 #include "rtc/datachannel.hpp"
 #include "rtc/track.hpp"
 
@@ -28,6 +28,9 @@ using std::chrono::milliseconds;
 
 namespace rtc::impl {
 
+using utils::to_uint16;
+using utils::to_uint32;
+
 // Messages for the DataChannel establishment protocol (RFC 8832)
 // See https://www.rfc-editor.org/rfc/rfc8832.html
 
@@ -74,8 +77,13 @@ bool DataChannel::IsOpenMessage(message_ptr message) {
 DataChannel::DataChannel(weak_ptr<PeerConnection> pc, string label, string protocol,
                          Reliability reliability)
     : mPeerConnection(pc), mLabel(std::move(label)), mProtocol(std::move(protocol)),
-      mReliability(std::make_shared<Reliability>(std::move(reliability))),
-      mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
+      mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
+
+	if(reliability.maxPacketLifeTime && reliability.maxRetransmits)
+		throw std::invalid_argument("Both maxPacketLifeTime and maxRetransmits are set");
+
+    mReliability = std::make_shared<Reliability>(std::move(reliability));
+}
 
 DataChannel::~DataChannel() {
 	PLOG_VERBOSE << "Destroying DataChannel";
@@ -247,22 +255,35 @@ void OutgoingDataChannel::open(shared_ptr<SctpTransport> transport) {
 
 	uint8_t channelType;
 	uint32_t reliabilityParameter;
-	switch (mReliability->type) {
-	case Reliability::Type::Rexmit:
+	if (mReliability->maxPacketLifeTime) {
+		channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
+		reliabilityParameter = to_uint32(mReliability->maxPacketLifeTime->count());
+	} else if (mReliability->maxRetransmits) {
 		channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
-		reliabilityParameter = uint32_t(std::max(std::get<int>(mReliability->rexmit), 0));
-		break;
+		reliabilityParameter = to_uint32(*mReliability->maxRetransmits);
+	}
+	// else {
+	//	channelType = CHANNEL_RELIABLE;
+	//	reliabilityParameter = 0;
+	// }
+	// Deprecated
+	else
+		switch (mReliability->typeDeprecated) {
+		case Reliability::Type::Rexmit:
+			channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
+			reliabilityParameter = to_uint32(std::max(std::get<int>(mReliability->rexmit), 0));
+			break;
 
-	case Reliability::Type::Timed:
-		channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
-		reliabilityParameter = uint32_t(std::get<milliseconds>(mReliability->rexmit).count());
-		break;
+		case Reliability::Type::Timed:
+			channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
+			reliabilityParameter = to_uint32(std::get<milliseconds>(mReliability->rexmit).count());
+			break;
 
-	default:
-		channelType = CHANNEL_RELIABLE;
-		reliabilityParameter = 0;
-		break;
-	}
+		default:
+			channelType = CHANNEL_RELIABLE;
+			reliabilityParameter = 0;
+			break;
+		}
 
 	if (mReliability->unordered)
 		channelType |= 0x80;
@@ -274,8 +295,8 @@ void OutgoingDataChannel::open(shared_ptr<SctpTransport> transport) {
 	open.channelType = channelType;
 	open.priority = htons(0);
 	open.reliabilityParameter = htonl(reliabilityParameter);
-	open.labelLength = htons(uint16_t(mLabel.size()));
-	open.protocolLength = htons(uint16_t(mProtocol.size()));
+	open.labelLength = htons(to_uint16(mLabel.size()));
+	open.protocolLength = htons(to_uint16(mProtocol.size()));
 
 	auto end = reinterpret_cast<char *>(buffer.data() + sizeof(OpenMessage));
 	std::copy(mLabel.begin(), mLabel.end(), end);
@@ -329,17 +350,31 @@ void IncomingDataChannel::processOpenMessage(message_ptr message) {
 	mProtocol.assign(end + open.labelLength, open.protocolLength);
 
 	mReliability->unordered = (open.channelType & 0x80) != 0;
+	mReliability->maxPacketLifeTime.reset();
+	mReliability->maxRetransmits.reset();
+	switch (open.channelType & 0x7F) {
+	case CHANNEL_PARTIAL_RELIABLE_REXMIT:
+		mReliability->maxRetransmits.emplace(open.reliabilityParameter);
+		break;
+	case CHANNEL_PARTIAL_RELIABLE_TIMED:
+		mReliability->maxPacketLifeTime.emplace(milliseconds(open.reliabilityParameter));
+		break;
+	default:
+		break;
+	}
+
+	// Deprecated
 	switch (open.channelType & 0x7F) {
 	case CHANNEL_PARTIAL_RELIABLE_REXMIT:
-		mReliability->type = Reliability::Type::Rexmit;
+		mReliability->typeDeprecated = Reliability::Type::Rexmit;
 		mReliability->rexmit = int(open.reliabilityParameter);
 		break;
 	case CHANNEL_PARTIAL_RELIABLE_TIMED:
-		mReliability->type = Reliability::Type::Timed;
+		mReliability->typeDeprecated = Reliability::Type::Timed;
 		mReliability->rexmit = milliseconds(open.reliabilityParameter);
 		break;
 	default:
-		mReliability->type = Reliability::Type::Reliable;
+		mReliability->typeDeprecated = Reliability::Type::Reliable;
 		mReliability->rexmit = int(0);
 	}
 

+ 20 - 23
src/impl/sctptransport.cpp

@@ -10,6 +10,7 @@
 #include "dtlstransport.hpp"
 #include "internals.hpp"
 #include "logcounter.hpp"
+#include "utils.hpp"
 
 #include <algorithm>
 #include <chrono>
@@ -50,28 +51,11 @@
 using namespace std::chrono_literals;
 using namespace std::chrono;
 
-namespace {
-
-template <typename T> uint16_t to_uint16(T i) {
-	if (i >= 0 && static_cast<typename std::make_unsigned<T>::type>(i) <=
-	                  std::numeric_limits<uint16_t>::max())
-		return static_cast<uint16_t>(i);
-	else
-		throw std::invalid_argument("Integer out of range");
-}
-
-template <typename T> uint32_t to_uint32(T i) {
-	if (i >= 0 && static_cast<typename std::make_unsigned<T>::type>(i) <=
-	                  std::numeric_limits<uint32_t>::max())
-		return static_cast<uint32_t>(i);
-	else
-		throw std::invalid_argument("Integer out of range");
-}
-
-} // namespace
-
 namespace rtc::impl {
 
+using utils::to_uint16;
+using utils::to_uint32;
+
 static LogCounter COUNTER_UNKNOWN_PPID(plog::warning,
                                        "Number of SCTP packets received with an unknown PPID");
 
@@ -387,7 +371,7 @@ bool SctpTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
-	if(message->size() > mMaxMessageSize)
+	if (message->size() > mMaxMessageSize)
 		throw std::invalid_argument("Message is too large");
 
 	// Flush the queue, and if nothing is pending, try to send directly
@@ -522,7 +506,7 @@ void SctpTransport::doRecv() {
 			} else {
 				// SCTP message
 				mPartialMessage.insert(mPartialMessage.end(), buffer, buffer + len);
-				if(mPartialMessage.size() > mMaxMessageSize) {
+				if (mPartialMessage.size() > mMaxMessageSize) {
 					PLOG_WARNING << "SCTP message is too large, truncating it";
 					mPartialMessage.resize(mMaxMessageSize);
 				}
@@ -646,7 +630,20 @@ bool SctpTransport::trySendMessage(message_ptr message) {
 	if (reliability.unordered)
 		spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED;
 
-	switch (reliability.type) {
+	if (reliability.maxPacketLifeTime) {
+		spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
+		spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL;
+		spa.sendv_prinfo.pr_value = to_uint32(reliability.maxPacketLifeTime->count());
+	} else if (reliability.maxRetransmits) {
+		spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
+		spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX;
+		spa.sendv_prinfo.pr_value = to_uint32(*reliability.maxRetransmits);
+	}
+	// else {
+	// 	spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE;
+	// }
+	// Deprecated
+	else switch (reliability.typeDeprecated) {
 	case Reliability::Type::Rexmit:
 		spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
 		spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX;

+ 17 - 0
src/impl/utils.hpp

@@ -15,6 +15,7 @@
 #include <limits>
 #include <map>
 #include <random>
+#include <stdexcept>
 #include <vector>
 
 namespace rtc::impl::utils {
@@ -60,6 +61,22 @@ template <typename Generator = std::mt19937> auto random_bytes_engine() {
 	return random_engine<char_independent_bits_engine, uint8_t>();
 }
 
+template <typename T> uint16_t to_uint16(T i) {
+	if (i >= 0 && static_cast<typename std::make_unsigned<T>::type>(i) <=
+	                  std::numeric_limits<uint16_t>::max())
+		return static_cast<uint16_t>(i);
+	else
+		throw std::invalid_argument("Integer out of range");
+}
+
+template <typename T> uint32_t to_uint32(T i) {
+	if (i >= 0 && static_cast<typename std::make_unsigned<T>::type>(i) <=
+	                  std::numeric_limits<uint32_t>::max())
+		return static_cast<uint32_t>(i);
+	else
+		throw std::invalid_argument("Integer out of range");
+}
+
 namespace this_thread {
 
 void set_name(const string &name);

+ 10 - 1
test/main.cpp

@@ -15,8 +15,9 @@
 using namespace std;
 using namespace chrono_literals;
 
-void test_negotiated();
 void test_connectivity(bool signal_wrong_fingerprint);
+void test_negotiated();
+void test_reliability();
 void test_turn_connectivity();
 void test_track();
 void test_capi_connectivity();
@@ -74,6 +75,14 @@ int main(int argc, char **argv) {
 		cerr << "WebRTC negotiated DataChannel test failed: " << e.what() << endl;
 		return -1;
 	}
+	try {
+		cout << endl << "*** Running WebRTC reliability mode test..." << endl;
+		test_reliability();
+		cout << "*** Finished WebRTC reliaility mode test" << endl;
+	} catch (const exception &e) {
+		cerr << "WebRTC reliability test failed: " << e.what() << endl;
+		return -1;
+	}
 #if RTC_ENABLE_MEDIA
 	try {
 		cout << endl << "*** Running WebRTC Track test..." << endl;

+ 128 - 0
test/reliability.cpp

@@ -0,0 +1,128 @@
+/**
+ * Copyright (c) 2019 Paul-Louis Ageneau
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/.
+ */
+
+#include "rtc/rtc.hpp"
+
+#include <atomic>
+#include <chrono>
+#include <iostream>
+#include <memory>
+#include <thread>
+
+using namespace rtc;
+using namespace std;
+
+void test_reliability() {
+	InitLogger(LogLevel::Debug);
+
+	Configuration config1;
+	// STUN server example (not necessary to connect locally)
+	config1.iceServers.emplace_back("stun:stun.l.google.com:19302");
+
+	PeerConnection pc1(config1);
+
+	Configuration config2;
+	// STUN server example (not necessary to connect locally)
+	config2.iceServers.emplace_back("stun:stun.l.google.com:19302");
+
+	PeerConnection pc2(config2);
+
+	pc1.onLocalDescription([&pc2](Description sdp) {
+		cout << "Description 1: " << sdp << endl;
+		pc2.setRemoteDescription(string(sdp));
+	});
+
+	pc1.onLocalCandidate([&pc2](Candidate candidate) {
+		cout << "Candidate 1: " << candidate << endl;
+		pc2.addRemoteCandidate(string(candidate));
+	});
+
+	pc2.onLocalDescription([&pc1](Description sdp) {
+		cout << "Description 2: " << sdp << endl;
+		pc1.setRemoteDescription(string(sdp));
+	});
+
+	pc2.onLocalCandidate([&pc1](Candidate candidate) {
+		cout << "Candidate 2: " << candidate << endl;
+		pc1.addRemoteCandidate(string(candidate));
+	});
+
+	Reliability reliableOrdered;
+	auto dcReliableOrdered = pc1.createDataChannel("reliable_ordered", {reliableOrdered});
+
+	Reliability reliableUnordered;
+	reliableUnordered.unordered = true;
+	auto dcReliableUnordered = pc1.createDataChannel("reliable_unordered", {reliableUnordered});
+
+	Reliability unreliableMaxPacketLifeTime;
+	unreliableMaxPacketLifeTime.unordered = true;
+	unreliableMaxPacketLifeTime.maxPacketLifeTime = 222ms;
+	auto dcUnreliableMaxPacketLifeTime =
+	    pc1.createDataChannel("unreliable_maxpacketlifetime", {unreliableMaxPacketLifeTime});
+
+	Reliability unreliableMaxRetransmits;
+	unreliableMaxRetransmits.unordered = true;
+	unreliableMaxRetransmits.maxRetransmits = 2;
+	auto dcUnreliableMaxRetransmits =
+	    pc1.createDataChannel("unreliable_maxretransmits", {unreliableMaxRetransmits});
+
+	std::atomic<int> count = 0;
+	std::atomic<bool> failed = false;
+	pc2.onDataChannel([&count, &failed](shared_ptr<DataChannel> dc) {
+		cout << "DataChannel 2: Received with label \"" << dc->label() << "\"" << endl;
+
+		auto label = dc->label();
+		auto reliability = dc->reliability();
+
+		try {
+			if (label == "reliable_ordered") {
+				if (reliability.unordered != false || reliability.maxPacketLifeTime ||
+				    reliability.maxRetransmits)
+					throw std::runtime_error("Expected reliable ordered");
+			} else if (label == "reliable_unordered") {
+				if (reliability.unordered != true || reliability.maxPacketLifeTime ||
+				    reliability.maxRetransmits)
+					throw std::runtime_error("Expected reliable unordered");
+			} else if (label == "unreliable_maxpacketlifetime") {
+				if (!reliability.maxPacketLifeTime || *reliability.maxPacketLifeTime != 222ms ||
+				    reliability.maxRetransmits)
+					throw std::runtime_error("Expected maxPacketLifeTime to be set");
+			} else if (label == "unreliable_maxretransmits") {
+				if (reliability.maxPacketLifeTime || !reliability.maxRetransmits ||
+				    *reliability.maxRetransmits != 2)
+					throw std::runtime_error("Expected maxRetransmits to be set");
+			} else
+				throw std::runtime_error("Unexpected label: " + label);
+		} catch (const std::exception &e) {
+			cerr << "Error: " << e.what();
+			failed = true;
+			return;
+		}
+		++count;
+	});
+
+	// Wait a bit
+	int attempts = 10;
+	shared_ptr<DataChannel> adc2;
+	while (count != 4 && !failed && attempts--)
+		this_thread::sleep_for(1s);
+
+	if (pc1.state() != PeerConnection::State::Connected ||
+	    pc2.state() != PeerConnection::State::Connected)
+		throw runtime_error("PeerConnection is not connected");
+
+	if (failed)
+		throw runtime_error("Incorrect reliability settings");
+
+	if (count != 4)
+		throw runtime_error("Some DataChannels are not open");
+
+	pc1.close();
+
+	cout << "Success" << endl;
+}