Browse Source

Merge pull request #245 from paullouisageneau/oob-datachannels

Out-of-band negotiated Data Channels
Paul-Louis Ageneau 4 years ago
parent
commit
fec3b1ad8b
8 changed files with 236 additions and 141 deletions
  1. 25 10
      include/rtc/datachannel.hpp
  2. 11 6
      include/rtc/peerconnection.hpp
  3. 28 19
      include/rtc/rtc.h
  4. 32 25
      src/capi.cpp
  5. 67 50
      src/datachannel.cpp
  6. 39 28
      src/peerconnection.cpp
  7. 33 2
      test/connectivity.cpp
  8. 1 1
      test/track.cpp

+ 25 - 10
include/rtc/datachannel.hpp

@@ -36,15 +36,14 @@ namespace rtc {
 class SctpTransport;
 class SctpTransport;
 class PeerConnection;
 class PeerConnection;
 
 
-class DataChannel final : public std::enable_shared_from_this<DataChannel>, public Channel {
+class DataChannel : public std::enable_shared_from_this<DataChannel>, public Channel {
 public:
 public:
-	DataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream, string label,
+	DataChannel(std::weak_ptr<PeerConnection> pc, uint16_t stream, string label,
 	            string protocol, Reliability reliability);
 	            string protocol, Reliability reliability);
-	DataChannel(std::weak_ptr<PeerConnection> pc, std::weak_ptr<SctpTransport> transport,
-	            unsigned int stream);
-	~DataChannel();
+	virtual ~DataChannel();
 
 
-	unsigned int stream() const;
+	uint16_t stream() const;
+	uint16_t id() const;
 	string label() const;
 	string label() const;
 	string protocol() const;
 	string protocol() const;
 	Reliability reliability() const;
 	Reliability reliability() const;
@@ -64,17 +63,17 @@ public:
 	std::optional<message_variant> receive() override;
 	std::optional<message_variant> receive() override;
 	std::optional<message_variant> peek() override;
 	std::optional<message_variant> peek() override;
 
 
-private:
+protected:
+	virtual void open(std::shared_ptr<SctpTransport> transport);
+	virtual void processOpenMessage(message_ptr message);
 	void remoteClose();
 	void remoteClose();
-	void open(std::shared_ptr<SctpTransport> transport);
 	bool outgoing(message_ptr message);
 	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
-	void processOpenMessage(message_ptr message);
 
 
 	const std::weak_ptr<PeerConnection> mPeerConnection;
 	const std::weak_ptr<PeerConnection> mPeerConnection;
 	std::weak_ptr<SctpTransport> mSctpTransport;
 	std::weak_ptr<SctpTransport> mSctpTransport;
 
 
-	unsigned int mStream;
+	uint16_t mStream;
 	string mLabel;
 	string mLabel;
 	string mProtocol;
 	string mProtocol;
 	std::shared_ptr<Reliability> mReliability;
 	std::shared_ptr<Reliability> mReliability;
@@ -82,11 +81,27 @@ private:
 	std::atomic<bool> mIsOpen = false;
 	std::atomic<bool> mIsOpen = false;
 	std::atomic<bool> mIsClosed = false;
 	std::atomic<bool> mIsClosed = false;
 
 
+private:
 	Queue<message_ptr> mRecvQueue;
 	Queue<message_ptr> mRecvQueue;
 
 
 	friend class PeerConnection;
 	friend class PeerConnection;
 };
 };
 
 
+class NegociatedDataChannel final : public DataChannel {
+public:
+	NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, uint16_t stream, string label,
+	            string protocol, Reliability reliability);
+	NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, std::weak_ptr<SctpTransport> transport,
+	            uint16_t stream);
+	~NegociatedDataChannel();
+
+private:
+	void open(std::shared_ptr<SctpTransport> transport) override;
+	void processOpenMessage(message_ptr message) override;
+
+	friend class PeerConnection;
+};
+
 template <typename Buffer> std::pair<const byte *, size_t> to_bytes(const Buffer &buf) {
 template <typename Buffer> std::pair<const byte *, size_t> to_bytes(const Buffer &buf) {
 	using T = typename std::remove_pointer<decltype(buf.data())>::type;
 	using T = typename std::remove_pointer<decltype(buf.data())>::type;
 	using E = typename std::conditional<std::is_void<T>::value, byte, T>::type;
 	using E = typename std::conditional<std::is_void<T>::value, byte, T>::type;

+ 11 - 6
include/rtc/peerconnection.hpp

@@ -50,6 +50,13 @@ class SctpTransport;
 using certificate_ptr = std::shared_ptr<Certificate>;
 using certificate_ptr = std::shared_ptr<Certificate>;
 using future_certificate_ptr = std::shared_future<certificate_ptr>;
 using future_certificate_ptr = std::shared_future<certificate_ptr>;
 
 
+struct DataChannelInit {
+	Reliability reliability = {};
+	bool negotiated = false;
+	std::optional<uint16_t> id = nullopt;
+	string protocol = "";
+};
+
 class PeerConnection final : public std::enable_shared_from_this<PeerConnection> {
 class PeerConnection final : public std::enable_shared_from_this<PeerConnection> {
 public:
 public:
 	enum class State : int {
 	enum class State : int {
@@ -98,12 +105,10 @@ public:
 	void setRemoteDescription(Description description);
 	void setRemoteDescription(Description description);
 	void addRemoteCandidate(Candidate candidate);
 	void addRemoteCandidate(Candidate candidate);
 
 
-	std::shared_ptr<DataChannel> addDataChannel(string label, string protocol = "",
-	                                            Reliability reliability = {});
+	std::shared_ptr<DataChannel> addDataChannel(string label, DataChannelInit init = {});
 
 
 	// Equivalent to calling addDataChannel() and setLocalDescription()
 	// Equivalent to calling addDataChannel() and setLocalDescription()
-	std::shared_ptr<DataChannel> createDataChannel(string label, string protocol = "",
-	                                               Reliability reliability = {});
+	std::shared_ptr<DataChannel> createDataChannel(string label, DataChannelInit init = {});
 
 
 	void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
 	void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
 	void onLocalDescription(std::function<void(Description description)> callback);
 	void onLocalDescription(std::function<void(Description description)> callback);
@@ -135,7 +140,7 @@ private:
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 
 
 	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, string label,
 	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, string label,
-	                                                string protocol, Reliability reliability);
+	                                                DataChannelInit init);
 	std::shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	std::shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void openDataChannels();
@@ -173,7 +178,7 @@ private:
 	std::shared_ptr<DtlsTransport> mDtlsTransport;
 	std::shared_ptr<DtlsTransport> mDtlsTransport;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 
 
-	std::unordered_map<unsigned int, std::weak_ptr<DataChannel>> mDataChannels; // by stream ID
+	std::unordered_map<uint16_t, std::weak_ptr<DataChannel>> mDataChannels;     // by stream ID
 	std::unordered_map<string, std::weak_ptr<Track>> mTracks;                   // by mid
 	std::unordered_map<string, std::weak_ptr<Track>> mTracks;                   // by mid
 	std::shared_mutex mDataChannelsMutex, mTracksMutex;
 	std::shared_mutex mDataChannelsMutex, mTracksMutex;
 
 

+ 28 - 19
include/rtc/rtc.h

@@ -97,20 +97,28 @@ typedef struct {
 	unsigned int maxRetransmits;    // ignored if reliable
 	unsigned int maxRetransmits;    // ignored if reliable
 } rtcReliability;
 } rtcReliability;
 
 
-typedef void (RTC_API *rtcLogCallbackFunc)(rtcLogLevel level, const char *message);
-typedef void (RTC_API *rtcDescriptionCallbackFunc)(int pc, const char *sdp, const char *type, void *ptr);
-typedef void (RTC_API *rtcCandidateCallbackFunc)(int pc, const char *cand, const char *mid, void *ptr);
-typedef void (RTC_API *rtcStateChangeCallbackFunc)(int pc, rtcState state, void *ptr);
-typedef void (RTC_API *rtcGatheringStateCallbackFunc)(int pc, rtcGatheringState state, void *ptr);
-typedef void (RTC_API *rtcSignalingStateCallbackFunc)(int pc, rtcSignalingState state, void *ptr);
-typedef void (RTC_API *rtcDataChannelCallbackFunc)(int pc, int dc, void *ptr);
-typedef void (RTC_API *rtcTrackCallbackFunc)(int pc, int tr, void *ptr);
-typedef void (RTC_API *rtcOpenCallbackFunc)(int id, void *ptr);
-typedef void (RTC_API *rtcClosedCallbackFunc)(int id, void *ptr);
-typedef void (RTC_API *rtcErrorCallbackFunc)(int id, const char *error, void *ptr);
-typedef void (RTC_API *rtcMessageCallbackFunc)(int id, const char *message, int size, void *ptr);
-typedef void (RTC_API *rtcBufferedAmountLowCallbackFunc)(int id, void *ptr);
-typedef void (RTC_API *rtcAvailableCallbackFunc)(int id, void *ptr);
+typedef struct {
+	rtcReliability reliability;
+	const char *protocol; // empty string if NULL
+	bool negotiated;
+	bool manualId;
+	uint16_t id; // ignored if manualId is false
+} rtcDataChannelInit;
+
+typedef void(RTC_API *rtcLogCallbackFunc)(rtcLogLevel level, const char *message);
+typedef void(RTC_API *rtcDescriptionCallbackFunc)(int pc, const char *sdp, const char *type, void *ptr);
+typedef void(RTC_API *rtcCandidateCallbackFunc)(int pc, const char *cand, const char *mid, void *ptr);
+typedef void(RTC_API *rtcStateChangeCallbackFunc)(int pc, rtcState state, void *ptr);
+typedef void(RTC_API *rtcGatheringStateCallbackFunc)(int pc, rtcGatheringState state, void *ptr);
+typedef void(RTC_API *rtcSignalingStateCallbackFunc)(int pc, rtcSignalingState state, void *ptr);
+typedef void(RTC_API *rtcDataChannelCallbackFunc)(int pc, int dc, void *ptr);
+typedef void(RTC_API *rtcTrackCallbackFunc)(int pc, int tr, void *ptr);
+typedef void(RTC_API *rtcOpenCallbackFunc)(int id, void *ptr);
+typedef void(RTC_API *rtcClosedCallbackFunc)(int id, void *ptr);
+typedef void(RTC_API *rtcErrorCallbackFunc)(int id, const char *error, void *ptr);
+typedef void(RTC_API *rtcMessageCallbackFunc)(int id, const char *message, int size, void *ptr);
+typedef void(RTC_API *rtcBufferedAmountLowCallbackFunc)(int id, void *ptr);
+typedef void(RTC_API *rtcAvailableCallbackFunc)(int id, void *ptr);
 
 
 // Log
 // Log
 // NULL cb on the first call will log to stdout
 // NULL cb on the first call will log to stdout
@@ -139,17 +147,18 @@ RTC_EXPORT int rtcGetRemoteDescription(int pc, char *buffer, int size);
 RTC_EXPORT int rtcGetLocalAddress(int pc, char *buffer, int size);
 RTC_EXPORT int rtcGetLocalAddress(int pc, char *buffer, int size);
 RTC_EXPORT int rtcGetRemoteAddress(int pc, char *buffer, int size);
 RTC_EXPORT int rtcGetRemoteAddress(int pc, char *buffer, int size);
 
 
-RTC_EXPORT int rtcGetSelectedCandidatePair(int pc, char *local, int localSize, char *remote, int remoteSize);
+RTC_EXPORT int rtcGetSelectedCandidatePair(int pc, char *local, int localSize, char *remote,
+                                           int remoteSize);
 
 
 // DataChannel
 // DataChannel
 RTC_EXPORT int rtcSetDataChannelCallback(int pc, rtcDataChannelCallbackFunc cb);
 RTC_EXPORT int rtcSetDataChannelCallback(int pc, rtcDataChannelCallbackFunc cb);
 RTC_EXPORT int rtcAddDataChannel(int pc, const char *label); // returns dc id
 RTC_EXPORT int rtcAddDataChannel(int pc, const char *label); // returns dc id
-RTC_EXPORT int rtcAddDataChannelExt(int pc, const char *label, const char *protocol,
-                                    const rtcReliability *reliability); // returns dc id
+RTC_EXPORT int rtcAddDataChannelExt(int pc, const char *label,
+                                    const rtcDataChannelInit *init); // returns dc id
 // Equivalent to calling rtcAddDataChannel() and rtcSetLocalDescription()
 // Equivalent to calling rtcAddDataChannel() and rtcSetLocalDescription()
 RTC_EXPORT int rtcCreateDataChannel(int pc, const char *label); // returns dc id
 RTC_EXPORT int rtcCreateDataChannel(int pc, const char *label); // returns dc id
-RTC_EXPORT int rtcCreateDataChannelExt(int pc, const char *label, const char *protocol,
-                                       const rtcReliability *reliability); // returns dc id
+RTC_EXPORT int rtcCreateDataChannelExt(int pc, const char *label,
+                                       const rtcDataChannelInit *init); // returns dc id
 RTC_EXPORT int rtcDeleteDataChannel(int dc);
 RTC_EXPORT int rtcDeleteDataChannel(int dc);
 
 
 RTC_EXPORT int rtcGetDataChannelLabel(int dc, char *buffer, int size);
 RTC_EXPORT int rtcGetDataChannelLabel(int dc, char *buffer, int size);

+ 32 - 25
src/capi.cpp

@@ -305,43 +305,49 @@ int rtcDeletePeerConnection(int pc) {
 }
 }
 
 
 int rtcAddDataChannel(int pc, const char *label) {
 int rtcAddDataChannel(int pc, const char *label) {
-	return rtcAddDataChannelExt(pc, label, nullptr, nullptr);
+	return rtcAddDataChannelExt(pc, label, nullptr);
 }
 }
 
 
-int rtcAddDataChannelExt(int pc, const char *label, const char *protocol,
-                         const rtcReliability *reliability) {
+int rtcAddDataChannelExt(int pc, const char *label, const rtcDataChannelInit *init) {
 	return WRAP({
 	return WRAP({
-		Reliability r = {};
-		if (reliability) {
-			r.unordered = reliability->unordered;
+		DataChannelInit dci = {};
+		if (init) {
+			auto *reliability = &init->reliability;
+			dci.reliability.unordered = reliability->unordered;
 			if (reliability->unreliable) {
 			if (reliability->unreliable) {
 				if (reliability->maxPacketLifeTime > 0) {
 				if (reliability->maxPacketLifeTime > 0) {
-					r.type = Reliability::Type::Timed;
-					r.rexmit = milliseconds(reliability->maxPacketLifeTime);
+					dci.reliability.type = Reliability::Type::Timed;
+					dci.reliability.rexmit = milliseconds(reliability->maxPacketLifeTime);
 				} else {
 				} else {
-					r.type = Reliability::Type::Rexmit;
-					r.rexmit = int(reliability->maxRetransmits);
+					dci.reliability.type = Reliability::Type::Rexmit;
+					dci.reliability.rexmit = int(reliability->maxRetransmits);
 				}
 				}
 			} else {
 			} else {
-				r.type = Reliability::Type::Reliable;
+				dci.reliability.type = Reliability::Type::Reliable;
 			}
 			}
+
+			dci.negotiated = init->negotiated;
+			dci.id = init->manualId ? std::make_optional(init->id) : nullopt;
+			dci.protocol = init->protocol ? init->protocol : "";
 		}
 		}
+
 		auto peerConnection = getPeerConnection(pc);
 		auto peerConnection = getPeerConnection(pc);
-		int dc = emplaceDataChannel(peerConnection->addDataChannel(
-		    string(label ? label : ""), string(protocol ? protocol : ""), r));
+		int dc = emplaceDataChannel(
+		    peerConnection->addDataChannel(string(label ? label : ""), std::move(dci)));
+
 		if (auto ptr = getUserPointer(pc))
 		if (auto ptr = getUserPointer(pc))
 			rtcSetUserPointer(dc, *ptr);
 			rtcSetUserPointer(dc, *ptr);
+
 		return dc;
 		return dc;
 	});
 	});
 }
 }
 
 
 int rtcCreateDataChannel(int pc, const char *label) {
 int rtcCreateDataChannel(int pc, const char *label) {
-	return rtcCreateDataChannelExt(pc, label, nullptr, nullptr);
+	return rtcCreateDataChannelExt(pc, label, nullptr);
 }
 }
 
 
-int rtcCreateDataChannelExt(int pc, const char *label, const char *protocol,
-                            const rtcReliability *reliability) {
-	int dc = rtcAddDataChannelExt(pc, label, protocol, reliability);
+int rtcCreateDataChannelExt(int pc, const char *label, const rtcDataChannelInit *init) {
+	int dc = rtcAddDataChannelExt(pc, label, init);
 	rtcSetLocalDescription(pc, NULL);
 	rtcSetLocalDescription(pc, NULL);
 	return dc;
 	return dc;
 }
 }
@@ -370,6 +376,7 @@ int rtcAddTrack(int pc, const char *mediaDescriptionSdp) {
 		int tr = emplaceTrack(peerConnection->addTrack(std::move(media)));
 		int tr = emplaceTrack(peerConnection->addTrack(std::move(media)));
 		if (auto ptr = getUserPointer(pc))
 		if (auto ptr = getUserPointer(pc))
 			rtcSetUserPointer(tr, *ptr);
 			rtcSetUserPointer(tr, *ptr);
+
 		return tr;
 		return tr;
 	});
 	});
 }
 }
@@ -610,11 +617,11 @@ int rtcGetSelectedCandidatePair(int pc, char *local, int localSize, char *remote
 			return RTC_ERR_NOT_AVAIL;
 			return RTC_ERR_NOT_AVAIL;
 
 
 		int localRet = copyAndReturn(string(localCand), local, localSize);
 		int localRet = copyAndReturn(string(localCand), local, localSize);
-		if(localRet < 0)
+		if (localRet < 0)
 			return localRet;
 			return localRet;
 
 
 		int remoteRet = copyAndReturn(string(remoteCand), remote, remoteSize);
 		int remoteRet = copyAndReturn(string(remoteCand), remote, remoteSize);
-		if(remoteRet < 0)
+		if (remoteRet < 0)
 			return remoteRet;
 			return remoteRet;
 
 
 		return std::max(localRet, remoteRet);
 		return std::max(localRet, remoteRet);
@@ -642,15 +649,15 @@ int rtcGetDataChannelReliability(int dc, rtcReliability *reliability) {
 		if (!reliability)
 		if (!reliability)
 			throw std::invalid_argument("Unexpected null pointer for reliability");
 			throw std::invalid_argument("Unexpected null pointer for reliability");
 
 
-		Reliability r = dataChannel->reliability();
+		Reliability dcr = dataChannel->reliability();
 		std::memset(reliability, 0, sizeof(*reliability));
 		std::memset(reliability, 0, sizeof(*reliability));
-		reliability->unordered = r.unordered;
-		if (r.type == Reliability::Type::Timed) {
+		reliability->unordered = dcr.unordered;
+		if (dcr.type == Reliability::Type::Timed) {
 			reliability->unreliable = true;
 			reliability->unreliable = true;
-			reliability->maxPacketLifeTime = unsigned(std::get<milliseconds>(r.rexmit).count());
-		} else if (r.type == Reliability::Type::Rexmit) {
+			reliability->maxPacketLifeTime = unsigned(std::get<milliseconds>(dcr.rexmit).count());
+		} else if (dcr.type == Reliability::Type::Rexmit) {
 			reliability->unreliable = true;
 			reliability->unreliable = true;
-			reliability->maxRetransmits = unsigned(std::get<int>(r.rexmit));
+			reliability->maxRetransmits = unsigned(std::get<int>(dcr.rexmit));
 		} else {
 		} else {
 			reliability->unreliable = false;
 			reliability->unreliable = false;
 		}
 		}

+ 67 - 50
src/datachannel.cpp

@@ -72,24 +72,18 @@ struct CloseMessage {
 };
 };
 #pragma pack(pop)
 #pragma pack(pop)
 
 
-DataChannel::DataChannel(weak_ptr<PeerConnection> pc, unsigned int stream, string label,
+DataChannel::DataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label,
                          string protocol, Reliability reliability)
                          string protocol, Reliability reliability)
     : mPeerConnection(pc), mStream(stream), mLabel(std::move(label)),
     : mPeerConnection(pc), mStream(stream), mLabel(std::move(label)),
       mProtocol(std::move(protocol)),
       mProtocol(std::move(protocol)),
       mReliability(std::make_shared<Reliability>(std::move(reliability))),
       mReliability(std::make_shared<Reliability>(std::move(reliability))),
       mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
       mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
 
 
-DataChannel::DataChannel(weak_ptr<PeerConnection> pc, weak_ptr<SctpTransport> transport,
-                         unsigned int stream)
-    : mPeerConnection(pc), mSctpTransport(transport), mStream(stream),
-      mReliability(std::make_shared<Reliability>()),
-      mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
+DataChannel::~DataChannel() { close(); }
 
 
-DataChannel::~DataChannel() {
-	close();
-}
+uint16_t DataChannel::stream() const { return mStream; }
 
 
-unsigned int DataChannel::stream() const { return mStream; }
+uint16_t DataChannel::id() const { return uint16_t(mStream); }
 
 
 string DataChannel::label() const { return mLabel; }
 string DataChannel::label() const { return mLabel; }
 
 
@@ -151,7 +145,6 @@ std::optional<message_variant> DataChannel::peek() {
 	return nullopt;
 	return nullopt;
 }
 }
 
 
-
 bool DataChannel::isOpen(void) const { return mIsOpen; }
 bool DataChannel::isOpen(void) const { return mIsOpen; }
 
 
 bool DataChannel::isClosed(void) const { return mIsClosed; }
 bool DataChannel::isClosed(void) const { return mIsClosed; }
@@ -172,43 +165,12 @@ size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 void DataChannel::open(shared_ptr<SctpTransport> transport) {
 void DataChannel::open(shared_ptr<SctpTransport> transport) {
 	mSctpTransport = transport;
 	mSctpTransport = transport;
 
 
-	uint8_t channelType;
-	uint32_t reliabilityParameter;
-	switch (mReliability->type) {
-	case Reliability::Type::Rexmit:
-		channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
-		reliabilityParameter = uint32_t(std::get<int>(mReliability->rexmit));
-		break;
-
-	case Reliability::Type::Timed:
-		channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
-		reliabilityParameter = uint32_t(std::get<milliseconds>(mReliability->rexmit).count());
-		break;
-
-	default:
-		channelType = CHANNEL_RELIABLE;
-		reliabilityParameter = 0;
-		break;
-	}
-
-	if (mReliability->unordered)
-		channelType |= 0x80;
-
-	const size_t len = sizeof(OpenMessage) + mLabel.size() + mProtocol.size();
-	binary buffer(len, byte(0));
-	auto &open = *reinterpret_cast<OpenMessage *>(buffer.data());
-	open.type = MESSAGE_OPEN;
-	open.channelType = channelType;
-	open.priority = htons(0);
-	open.reliabilityParameter = htonl(reliabilityParameter);
-	open.labelLength = htons(uint16_t(mLabel.size()));
-	open.protocolLength = htons(uint16_t(mProtocol.size()));
-
-	auto end = reinterpret_cast<char *>(buffer.data() + sizeof(OpenMessage));
-	std::copy(mLabel.begin(), mLabel.end(), end);
-	std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
+	if (!mIsOpen.exchange(true))
+		triggerOpen();
+}
 
 
-	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
+void DataChannel::processOpenMessage(message_ptr) {
+	PLOG_WARNING << "Received an open message for a user-negotiated DataChannel, ignoring";
 }
 }
 
 
 bool DataChannel::outgoing(message_ptr message) {
 bool DataChannel::outgoing(message_ptr message) {
@@ -268,7 +230,62 @@ void DataChannel::incoming(message_ptr message) {
 	}
 	}
 }
 }
 
 
-void DataChannel::processOpenMessage(message_ptr message) {
+NegociatedDataChannel::NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, uint16_t stream,
+                                             string label, string protocol, Reliability reliability)
+    : DataChannel(pc, stream, std::move(label), std::move(protocol), std::move(reliability)) {}
+
+NegociatedDataChannel::NegociatedDataChannel(std::weak_ptr<PeerConnection> pc,
+                                             std::weak_ptr<SctpTransport> transport,
+                                             uint16_t stream)
+    : DataChannel(pc, stream, "", "", {}) {
+	mSctpTransport = transport;
+}
+
+NegociatedDataChannel::~NegociatedDataChannel() {}
+
+void NegociatedDataChannel::open(shared_ptr<SctpTransport> transport) {
+	mSctpTransport = transport;
+
+	uint8_t channelType;
+	uint32_t reliabilityParameter;
+	switch (mReliability->type) {
+	case Reliability::Type::Rexmit:
+		channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
+		reliabilityParameter = uint32_t(std::get<int>(mReliability->rexmit));
+		break;
+
+	case Reliability::Type::Timed:
+		channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
+		reliabilityParameter = uint32_t(std::get<milliseconds>(mReliability->rexmit).count());
+		break;
+
+	default:
+		channelType = CHANNEL_RELIABLE;
+		reliabilityParameter = 0;
+		break;
+	}
+
+	if (mReliability->unordered)
+		channelType |= 0x80;
+
+	const size_t len = sizeof(OpenMessage) + mLabel.size() + mProtocol.size();
+	binary buffer(len, byte(0));
+	auto &open = *reinterpret_cast<OpenMessage *>(buffer.data());
+	open.type = MESSAGE_OPEN;
+	open.channelType = channelType;
+	open.priority = htons(0);
+	open.reliabilityParameter = htonl(reliabilityParameter);
+	open.labelLength = htons(uint16_t(mLabel.size()));
+	open.protocolLength = htons(uint16_t(mProtocol.size()));
+
+	auto end = reinterpret_cast<char *>(buffer.data() + sizeof(OpenMessage));
+	std::copy(mLabel.begin(), mLabel.end(), end);
+	std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
+
+	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
+}
+
+void NegociatedDataChannel::processOpenMessage(message_ptr message) {
 	auto transport = mSctpTransport.lock();
 	auto transport = mSctpTransport.lock();
 	if (!transport)
 	if (!transport)
 		throw std::runtime_error("DataChannel has no transport");
 		throw std::runtime_error("DataChannel has no transport");
@@ -310,8 +327,8 @@ void DataChannel::processOpenMessage(message_ptr message) {
 
 
 	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 
 
-	mIsOpen = true;
-	triggerOpen();
+	if (!mIsOpen.exchange(true))
+		triggerOpen();
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc

+ 39 - 28
src/peerconnection.cpp

@@ -38,11 +38,11 @@
 namespace {
 namespace {
 
 
 template <typename To, typename From>
 template <typename To, typename From>
-inline std::shared_ptr<To> reinterpret_pointer_cast(std::shared_ptr<From> const & ptr) noexcept {
-    return std::shared_ptr<To>(ptr, reinterpret_cast<To *>(ptr.get()));
+inline std::shared_ptr<To> reinterpret_pointer_cast(std::shared_ptr<From> const &ptr) noexcept {
+	return std::shared_ptr<To>(ptr, reinterpret_cast<To *>(ptr.get()));
 }
 }
 
 
-}
+} // namespace
 #else
 #else
 using std::reinterpret_pointer_cast;
 using std::reinterpret_pointer_cast;
 #endif
 #endif
@@ -325,8 +325,7 @@ std::optional<string> PeerConnection::remoteAddress() const {
 	return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
 	return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
 }
 }
 
 
-shared_ptr<DataChannel> PeerConnection::addDataChannel(string label, string protocol,
-                                                       Reliability reliability) {
+shared_ptr<DataChannel> PeerConnection::addDataChannel(string label, DataChannelInit init) {
 	// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
 	// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
 	// setup:passive. [...] Thus, setup:active is RECOMMENDED.
 	// setup:passive. [...] Thus, setup:active is RECOMMENDED.
 	// See https://tools.ietf.org/html/rfc5763#section-5
 	// See https://tools.ietf.org/html/rfc5763#section-5
@@ -334,8 +333,7 @@ shared_ptr<DataChannel> PeerConnection::addDataChannel(string label, string prot
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 
 
-	auto channel =
-	    emplaceDataChannel(role, std::move(label), std::move(protocol), std::move(reliability));
+	auto channel = emplaceDataChannel(role, std::move(label), std::move(init));
 
 
 	if (auto transport = std::atomic_load(&mSctpTransport))
 	if (auto transport = std::atomic_load(&mSctpTransport))
 		if (transport->state() == SctpTransport::State::Connected)
 		if (transport->state() == SctpTransport::State::Connected)
@@ -349,9 +347,8 @@ shared_ptr<DataChannel> PeerConnection::addDataChannel(string label, string prot
 	return channel;
 	return channel;
 }
 }
 
 
-shared_ptr<DataChannel> PeerConnection::createDataChannel(string label, string protocol,
-                                                          Reliability reliability) {
-	auto channel = addDataChannel(label, protocol, reliability);
+shared_ptr<DataChannel> PeerConnection::createDataChannel(string label, DataChannelInit init) {
+	auto channel = addDataChannel(std::move(label), std::move(init));
 	setLocalDescription();
 	setLocalDescription();
 	return channel;
 	return channel;
 }
 }
@@ -648,7 +645,8 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		return;
 		return;
 	}
 	}
 
 
-	auto channel = findDataChannel(uint16_t(message->stream));
+	uint16_t stream = uint16_t(message->stream);
+	auto channel = findDataChannel(stream);
 	if (!channel) {
 	if (!channel) {
 		auto iceTransport = std::atomic_load(&mIceTransport);
 		auto iceTransport = std::atomic_load(&mIceTransport);
 		auto sctpTransport = std::atomic_load(&mSctpTransport);
 		auto sctpTransport = std::atomic_load(&mSctpTransport);
@@ -656,15 +654,15 @@ void PeerConnection::forwardMessage(message_ptr message) {
 			return;
 			return;
 
 
 		const byte dataChannelOpenMessage{0x03};
 		const byte dataChannelOpenMessage{0x03};
-		unsigned int remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0;
+		uint16_t remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0;
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
-		    message->stream % 2 == remoteParity) {
+		    stream % 2 == remoteParity) {
 
 
-			channel =
-			    std::make_shared<DataChannel>(shared_from_this(), sctpTransport, message->stream);
+			channel = std::make_shared<NegociatedDataChannel>(shared_from_this(), sctpTransport,
+			                                                  message->stream);
 			channel->onOpen(weak_bind(&PeerConnection::triggerDataChannel, this,
 			channel->onOpen(weak_bind(&PeerConnection::triggerDataChannel, this,
 			                          weak_ptr<DataChannel>{channel}));
 			                          weak_ptr<DataChannel>{channel}));
-			mDataChannels.insert(std::make_pair(message->stream, channel));
+			mDataChannels.emplace(message->stream, channel);
 		} else {
 		} else {
 			// Invalid, close the DataChannel
 			// Invalid, close the DataChannel
 			sctpTransport->closeStream(message->stream);
 			sctpTransport->closeStream(message->stream);
@@ -734,20 +732,33 @@ void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 }
 }
 
 
 shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role role, string label,
 shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role role, string label,
-                                                           string protocol,
-                                                           Reliability reliability) {
-	// The active side must use streams with even identifiers, whereas the passive side must use
-	// streams with odd identifiers.
-	// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
+                                                           DataChannelInit init) {
 	std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
 	std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
-	unsigned int stream = (role == Description::Role::Active) ? 0 : 1;
-	while (mDataChannels.find(stream) != mDataChannels.end()) {
-		stream += 2;
-		if (stream >= 65535)
-			throw std::runtime_error("Too many DataChannels");
+	uint16_t stream;
+	if (init.id) {
+		stream = *init.id;
+		if (stream == 65535)
+			throw std::invalid_argument("Invalid DataChannel id");
+	} else {
+		// The active side must use streams with even identifiers, whereas the passive side must use
+		// streams with odd identifiers.
+		// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
+		stream = (role == Description::Role::Active) ? 0 : 1;
+		while (mDataChannels.find(stream) != mDataChannels.end()) {
+			if (stream >= 65535 - 2)
+				throw std::runtime_error("Too many DataChannels");
+
+			stream += 2;
+		}
 	}
 	}
-	auto channel = std::make_shared<DataChannel>(shared_from_this(), stream, std::move(label),
-	                                             std::move(protocol), std::move(reliability));
+	// If the DataChannel is user-negotiated, do not negociate it here
+	auto channel =
+	    init.negotiated
+	        ? std::make_shared<DataChannel>(shared_from_this(), stream, std::move(label),
+	                                        std::move(init.protocol), std::move(init.reliability))
+	        : std::make_shared<NegociatedDataChannel>(shared_from_this(), stream, std::move(label),
+	                                                  std::move(init.protocol),
+	                                                  std::move(init.reliability));
 	mDataChannels.emplace(std::make_pair(stream, channel));
 	mDataChannels.emplace(std::make_pair(stream, channel));
 	return channel;
 	return channel;
 }
 }

+ 33 - 2
test/connectivity.cpp

@@ -156,11 +156,11 @@ void test_connectivity() {
 		cout << "Remote address 2: " << *addr << endl;
 		cout << "Remote address 2: " << *addr << endl;
 
 
 	Candidate local, remote;
 	Candidate local, remote;
-	if(pc1->getSelectedCandidatePair(&local, &remote)) {
+	if (pc1->getSelectedCandidatePair(&local, &remote)) {
 		cout << "Local candidate 1:  " << local << endl;
 		cout << "Local candidate 1:  " << local << endl;
 		cout << "Remote candidate 1: " << remote << endl;
 		cout << "Remote candidate 1: " << remote << endl;
 	}
 	}
-	if(pc2->getSelectedCandidatePair(&local, &remote)) {
+	if (pc2->getSelectedCandidatePair(&local, &remote)) {
 		cout << "Local candidate 2:  " << local << endl;
 		cout << "Local candidate 2:  " << local << endl;
 		cout << "Remote candidate 2: " << remote << endl;
 		cout << "Remote candidate 2: " << remote << endl;
 	}
 	}
@@ -208,6 +208,37 @@ void test_connectivity() {
 	    attempts--)
 	    attempts--)
 		this_thread::sleep_for(1s);
 		this_thread::sleep_for(1s);
 
 
+	if (!asecond2 || !asecond2->isOpen() || !second1->isOpen())
+		throw runtime_error("Second DataChannel is not open");
+
+	// Try to open a negotiated channel
+	DataChannelInit init;
+	init.negotiated = true;
+	init.id = 42;
+	auto negotiated1 = pc1->createDataChannel("negotiated", init);
+	auto negotiated2 = pc2->createDataChannel("negoctated", init);
+
+	if (!negotiated1->isOpen() || !negotiated2->isOpen())
+		throw runtime_error("Negociated DataChannel is not open");
+
+	std::atomic<bool> received = false;
+	negotiated2->onMessage([&received](const variant<binary, string> &message) {
+		if (holds_alternative<string>(message)) {
+			cout << "Second Message 2: " << get<string>(message) << endl;
+			received = true;
+		}
+	});
+
+	negotiated1->send("Hello from negotiated channel");
+
+	// Wait a bit
+	attempts = 5;
+	while (!received && attempts--)
+		this_thread::sleep_for(1s);
+
+	if (!received)
+		throw runtime_error("Negociated DataChannel failed");
+
 	// Delay close of peer 2 to check closing works properly
 	// Delay close of peer 2 to check closing works properly
 	pc1->close();
 	pc1->close();
 	this_thread::sleep_for(1s);
 	this_thread::sleep_for(1s);

+ 1 - 1
test/track.cpp

@@ -133,7 +133,7 @@ void test_track() {
 		this_thread::sleep_for(1s);
 		this_thread::sleep_for(1s);
 
 
 	if (!at2 || !at2->isOpen() || !t1->isOpen())
 	if (!at2 || !at2->isOpen() || !t1->isOpen())
-		throw runtime_error("Renegociated track is not open");
+		throw runtime_error("Renegotiated track is not open");
 
 
 	// TODO: Test sending RTP packets in track
 	// TODO: Test sending RTP packets in track