Browse Source

Implemented user-negotiated DataChannels

Paul-Louis Ageneau 5 years ago
parent
commit
37ebe8cc58
5 changed files with 128 additions and 88 deletions
  1. 21 7
      include/rtc/datachannel.hpp
  2. 10 6
      include/rtc/peerconnection.hpp
  3. 6 3
      src/capi.cpp
  4. 64 49
      src/datachannel.cpp
  5. 27 23
      src/peerconnection.cpp

+ 21 - 7
include/rtc/datachannel.hpp

@@ -36,13 +36,11 @@ namespace rtc {
 class SctpTransport;
 class SctpTransport;
 class PeerConnection;
 class PeerConnection;
 
 
-class DataChannel final : public std::enable_shared_from_this<DataChannel>, public Channel {
+class DataChannel : public std::enable_shared_from_this<DataChannel>, public Channel {
 public:
 public:
 	DataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream, string label,
 	DataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream, string label,
 	            string protocol, Reliability reliability);
 	            string protocol, Reliability reliability);
-	DataChannel(std::weak_ptr<PeerConnection> pc, std::weak_ptr<SctpTransport> transport,
-	            unsigned int stream);
-	~DataChannel();
+	virtual ~DataChannel();
 
 
 	unsigned int stream() const;
 	unsigned int stream() const;
 	string label() const;
 	string label() const;
@@ -64,12 +62,12 @@ public:
 	std::optional<message_variant> receive() override;
 	std::optional<message_variant> receive() override;
 	std::optional<message_variant> peek() override;
 	std::optional<message_variant> peek() override;
 
 
-private:
+protected:
+	virtual void open(std::shared_ptr<SctpTransport> transport);
+	virtual void processOpenMessage(message_ptr message);
 	void remoteClose();
 	void remoteClose();
-	void open(std::shared_ptr<SctpTransport> transport);
 	bool outgoing(message_ptr message);
 	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
-	void processOpenMessage(message_ptr message);
 
 
 	const std::weak_ptr<PeerConnection> mPeerConnection;
 	const std::weak_ptr<PeerConnection> mPeerConnection;
 	std::weak_ptr<SctpTransport> mSctpTransport;
 	std::weak_ptr<SctpTransport> mSctpTransport;
@@ -82,11 +80,27 @@ private:
 	std::atomic<bool> mIsOpen = false;
 	std::atomic<bool> mIsOpen = false;
 	std::atomic<bool> mIsClosed = false;
 	std::atomic<bool> mIsClosed = false;
 
 
+private:
 	Queue<message_ptr> mRecvQueue;
 	Queue<message_ptr> mRecvQueue;
 
 
 	friend class PeerConnection;
 	friend class PeerConnection;
 };
 };
 
 
+class NegociatedDataChannel final : public DataChannel {
+public:
+	NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream, string label,
+	            string protocol, Reliability reliability);
+	NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, std::weak_ptr<SctpTransport> transport,
+	            unsigned int stream);
+	~NegociatedDataChannel();
+
+private:
+	void open(std::shared_ptr<SctpTransport> transport) override;
+	void processOpenMessage(message_ptr message) override;
+
+	friend class PeerConnection;
+};
+
 template <typename Buffer> std::pair<const byte *, size_t> to_bytes(const Buffer &buf) {
 template <typename Buffer> std::pair<const byte *, size_t> to_bytes(const Buffer &buf) {
 	using T = typename std::remove_pointer<decltype(buf.data())>::type;
 	using T = typename std::remove_pointer<decltype(buf.data())>::type;
 	using E = typename std::conditional<std::is_void<T>::value, byte, T>::type;
 	using E = typename std::conditional<std::is_void<T>::value, byte, T>::type;

+ 10 - 6
include/rtc/peerconnection.hpp

@@ -50,6 +50,13 @@ class SctpTransport;
 using certificate_ptr = std::shared_ptr<Certificate>;
 using certificate_ptr = std::shared_ptr<Certificate>;
 using future_certificate_ptr = std::shared_future<certificate_ptr>;
 using future_certificate_ptr = std::shared_future<certificate_ptr>;
 
 
+struct DataChannelInit {
+	Reliability reliability = {};
+	string protocol = "";
+	bool negociated = false;
+	std::optional<uint16_t> id = nullopt;
+};
+
 class PeerConnection final : public std::enable_shared_from_this<PeerConnection> {
 class PeerConnection final : public std::enable_shared_from_this<PeerConnection> {
 public:
 public:
 	enum class State : int {
 	enum class State : int {
@@ -98,12 +105,10 @@ public:
 	void setRemoteDescription(Description description);
 	void setRemoteDescription(Description description);
 	void addRemoteCandidate(Candidate candidate);
 	void addRemoteCandidate(Candidate candidate);
 
 
-	std::shared_ptr<DataChannel> addDataChannel(string label, string protocol = "",
-	                                            Reliability reliability = {});
+	std::shared_ptr<DataChannel> addDataChannel(string label, DataChannelInit init = {});
 
 
 	// Equivalent to calling addDataChannel() and setLocalDescription()
 	// Equivalent to calling addDataChannel() and setLocalDescription()
-	std::shared_ptr<DataChannel> createDataChannel(string label, string protocol = "",
-	                                               Reliability reliability = {});
+	std::shared_ptr<DataChannel> createDataChannel(string label, DataChannelInit init = {});
 
 
 	void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
 	void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
 	void onLocalDescription(std::function<void(Description description)> callback);
 	void onLocalDescription(std::function<void(Description description)> callback);
@@ -135,8 +140,7 @@ private:
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 
 
 	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, string label,
 	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, string label,
-	                                                string protocol, Reliability reliability,
-	                                                std::optional<unsigned int> stream = nullopt);
+	                                                DataChannelInit init);
 	std::shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	std::shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void openDataChannels();

+ 6 - 3
src/capi.cpp

@@ -328,7 +328,10 @@ int rtcAddDataChannelExt(int pc, const char *label, const char *protocol,
 		}
 		}
 		auto peerConnection = getPeerConnection(pc);
 		auto peerConnection = getPeerConnection(pc);
 		int dc = emplaceDataChannel(peerConnection->addDataChannel(
 		int dc = emplaceDataChannel(peerConnection->addDataChannel(
-		    string(label ? label : ""), string(protocol ? protocol : ""), r));
+		    string(label ? label : ""), {.reliability = std::move(r),
+		                                 .protocol = protocol ? protocol : "",
+		                                 .negociated = false,
+		                                 .id = nullopt}));
 		if (auto ptr = getUserPointer(pc))
 		if (auto ptr = getUserPointer(pc))
 			rtcSetUserPointer(dc, *ptr);
 			rtcSetUserPointer(dc, *ptr);
 		return dc;
 		return dc;
@@ -610,11 +613,11 @@ int rtcGetSelectedCandidatePair(int pc, char *local, int localSize, char *remote
 			return RTC_ERR_NOT_AVAIL;
 			return RTC_ERR_NOT_AVAIL;
 
 
 		int localRet = copyAndReturn(string(localCand), local, localSize);
 		int localRet = copyAndReturn(string(localCand), local, localSize);
-		if(localRet < 0)
+		if (localRet < 0)
 			return localRet;
 			return localRet;
 
 
 		int remoteRet = copyAndReturn(string(remoteCand), remote, remoteSize);
 		int remoteRet = copyAndReturn(string(remoteCand), remote, remoteSize);
-		if(remoteRet < 0)
+		if (remoteRet < 0)
 			return remoteRet;
 			return remoteRet;
 
 
 		return std::max(localRet, remoteRet);
 		return std::max(localRet, remoteRet);

+ 64 - 49
src/datachannel.cpp

@@ -79,15 +79,7 @@ DataChannel::DataChannel(weak_ptr<PeerConnection> pc, unsigned int stream, strin
       mReliability(std::make_shared<Reliability>(std::move(reliability))),
       mReliability(std::make_shared<Reliability>(std::move(reliability))),
       mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
       mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
 
 
-DataChannel::DataChannel(weak_ptr<PeerConnection> pc, weak_ptr<SctpTransport> transport,
-                         unsigned int stream)
-    : mPeerConnection(pc), mSctpTransport(transport), mStream(stream),
-      mReliability(std::make_shared<Reliability>()),
-      mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
-
-DataChannel::~DataChannel() {
-	close();
-}
+DataChannel::~DataChannel() { close(); }
 
 
 unsigned int DataChannel::stream() const { return mStream; }
 unsigned int DataChannel::stream() const { return mStream; }
 
 
@@ -151,7 +143,6 @@ std::optional<message_variant> DataChannel::peek() {
 	return nullopt;
 	return nullopt;
 }
 }
 
 
-
 bool DataChannel::isOpen(void) const { return mIsOpen; }
 bool DataChannel::isOpen(void) const { return mIsOpen; }
 
 
 bool DataChannel::isClosed(void) const { return mIsClosed; }
 bool DataChannel::isClosed(void) const { return mIsClosed; }
@@ -172,43 +163,12 @@ size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 void DataChannel::open(shared_ptr<SctpTransport> transport) {
 void DataChannel::open(shared_ptr<SctpTransport> transport) {
 	mSctpTransport = transport;
 	mSctpTransport = transport;
 
 
-	uint8_t channelType;
-	uint32_t reliabilityParameter;
-	switch (mReliability->type) {
-	case Reliability::Type::Rexmit:
-		channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
-		reliabilityParameter = uint32_t(std::get<int>(mReliability->rexmit));
-		break;
-
-	case Reliability::Type::Timed:
-		channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
-		reliabilityParameter = uint32_t(std::get<milliseconds>(mReliability->rexmit).count());
-		break;
-
-	default:
-		channelType = CHANNEL_RELIABLE;
-		reliabilityParameter = 0;
-		break;
-	}
-
-	if (mReliability->unordered)
-		channelType |= 0x80;
-
-	const size_t len = sizeof(OpenMessage) + mLabel.size() + mProtocol.size();
-	binary buffer(len, byte(0));
-	auto &open = *reinterpret_cast<OpenMessage *>(buffer.data());
-	open.type = MESSAGE_OPEN;
-	open.channelType = channelType;
-	open.priority = htons(0);
-	open.reliabilityParameter = htonl(reliabilityParameter);
-	open.labelLength = htons(uint16_t(mLabel.size()));
-	open.protocolLength = htons(uint16_t(mProtocol.size()));
-
-	auto end = reinterpret_cast<char *>(buffer.data() + sizeof(OpenMessage));
-	std::copy(mLabel.begin(), mLabel.end(), end);
-	std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
+	if (!mIsOpen.exchange(true))
+		triggerOpen();
+}
 
 
-	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
+void DataChannel::processOpenMessage(message_ptr) {
+	PLOG_WARNING << "Received an open message for a user-negociated DataChannel, ignoring";
 }
 }
 
 
 bool DataChannel::outgoing(message_ptr message) {
 bool DataChannel::outgoing(message_ptr message) {
@@ -268,7 +228,62 @@ void DataChannel::incoming(message_ptr message) {
 	}
 	}
 }
 }
 
 
-void DataChannel::processOpenMessage(message_ptr message) {
+NegociatedDataChannel::NegociatedDataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream,
+                                             string label, string protocol, Reliability reliability)
+    : DataChannel(pc, stream, std::move(label), std::move(protocol), std::move(reliability)) {}
+
+NegociatedDataChannel::NegociatedDataChannel(std::weak_ptr<PeerConnection> pc,
+                                             std::weak_ptr<SctpTransport> transport,
+                                             unsigned int stream)
+    : DataChannel(pc, stream, "", "", {}) {
+	mSctpTransport = transport;
+}
+
+NegociatedDataChannel::~NegociatedDataChannel() {}
+
+void NegociatedDataChannel::open(shared_ptr<SctpTransport> transport) {
+	mSctpTransport = transport;
+
+	uint8_t channelType;
+	uint32_t reliabilityParameter;
+	switch (mReliability->type) {
+	case Reliability::Type::Rexmit:
+		channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
+		reliabilityParameter = uint32_t(std::get<int>(mReliability->rexmit));
+		break;
+
+	case Reliability::Type::Timed:
+		channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
+		reliabilityParameter = uint32_t(std::get<milliseconds>(mReliability->rexmit).count());
+		break;
+
+	default:
+		channelType = CHANNEL_RELIABLE;
+		reliabilityParameter = 0;
+		break;
+	}
+
+	if (mReliability->unordered)
+		channelType |= 0x80;
+
+	const size_t len = sizeof(OpenMessage) + mLabel.size() + mProtocol.size();
+	binary buffer(len, byte(0));
+	auto &open = *reinterpret_cast<OpenMessage *>(buffer.data());
+	open.type = MESSAGE_OPEN;
+	open.channelType = channelType;
+	open.priority = htons(0);
+	open.reliabilityParameter = htonl(reliabilityParameter);
+	open.labelLength = htons(uint16_t(mLabel.size()));
+	open.protocolLength = htons(uint16_t(mProtocol.size()));
+
+	auto end = reinterpret_cast<char *>(buffer.data() + sizeof(OpenMessage));
+	std::copy(mLabel.begin(), mLabel.end(), end);
+	std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
+
+	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
+}
+
+void NegociatedDataChannel::processOpenMessage(message_ptr message) {
 	auto transport = mSctpTransport.lock();
 	auto transport = mSctpTransport.lock();
 	if (!transport)
 	if (!transport)
 		throw std::runtime_error("DataChannel has no transport");
 		throw std::runtime_error("DataChannel has no transport");
@@ -310,8 +325,8 @@ void DataChannel::processOpenMessage(message_ptr message) {
 
 
 	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 
 
-	mIsOpen = true;
-	triggerOpen();
+	if (!mIsOpen.exchange(true))
+		triggerOpen();
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc

+ 27 - 23
src/peerconnection.cpp

@@ -38,11 +38,11 @@
 namespace {
 namespace {
 
 
 template <typename To, typename From>
 template <typename To, typename From>
-inline std::shared_ptr<To> reinterpret_pointer_cast(std::shared_ptr<From> const & ptr) noexcept {
-    return std::shared_ptr<To>(ptr, reinterpret_cast<To *>(ptr.get()));
+inline std::shared_ptr<To> reinterpret_pointer_cast(std::shared_ptr<From> const &ptr) noexcept {
+	return std::shared_ptr<To>(ptr, reinterpret_cast<To *>(ptr.get()));
 }
 }
 
 
-}
+} // namespace
 #else
 #else
 using std::reinterpret_pointer_cast;
 using std::reinterpret_pointer_cast;
 #endif
 #endif
@@ -325,8 +325,7 @@ std::optional<string> PeerConnection::remoteAddress() const {
 	return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
 	return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
 }
 }
 
 
-shared_ptr<DataChannel> PeerConnection::addDataChannel(string label, string protocol,
-                                                       Reliability reliability) {
+shared_ptr<DataChannel> PeerConnection::addDataChannel(string label, DataChannelInit init) {
 	// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
 	// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
 	// setup:passive. [...] Thus, setup:active is RECOMMENDED.
 	// setup:passive. [...] Thus, setup:active is RECOMMENDED.
 	// See https://tools.ietf.org/html/rfc5763#section-5
 	// See https://tools.ietf.org/html/rfc5763#section-5
@@ -334,8 +333,7 @@ shared_ptr<DataChannel> PeerConnection::addDataChannel(string label, string prot
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 
 
-	auto channel =
-	    emplaceDataChannel(role, std::move(label), std::move(protocol), std::move(reliability));
+	auto channel = emplaceDataChannel(role, std::move(label), std::move(init));
 
 
 	if (auto transport = std::atomic_load(&mSctpTransport))
 	if (auto transport = std::atomic_load(&mSctpTransport))
 		if (transport->state() == SctpTransport::State::Connected)
 		if (transport->state() == SctpTransport::State::Connected)
@@ -349,9 +347,8 @@ shared_ptr<DataChannel> PeerConnection::addDataChannel(string label, string prot
 	return channel;
 	return channel;
 }
 }
 
 
-shared_ptr<DataChannel> PeerConnection::createDataChannel(string label, string protocol,
-                                                          Reliability reliability) {
-	auto channel = addDataChannel(label, protocol, reliability);
+shared_ptr<DataChannel> PeerConnection::createDataChannel(string label, DataChannelInit init) {
+	auto channel = addDataChannel(std::move(label), std::move(init));
 	setLocalDescription();
 	setLocalDescription();
 	return channel;
 	return channel;
 }
 }
@@ -660,11 +657,11 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
 		    message->stream % 2 == remoteParity) {
 		    message->stream % 2 == remoteParity) {
 
 
-			channel =
-			    std::make_shared<DataChannel>(shared_from_this(), sctpTransport, message->stream);
+			channel = std::make_shared<NegociatedDataChannel>(shared_from_this(), sctpTransport,
+			                                                  message->stream);
 			channel->onOpen(weak_bind(&PeerConnection::triggerDataChannel, this,
 			channel->onOpen(weak_bind(&PeerConnection::triggerDataChannel, this,
 			                          weak_ptr<DataChannel>{channel}));
 			                          weak_ptr<DataChannel>{channel}));
-			mDataChannels.insert(std::make_pair(message->stream, channel));
+			mDataChannels.emplace(message->stream, channel);
 		} else {
 		} else {
 			// Invalid, close the DataChannel
 			// Invalid, close the DataChannel
 			sctpTransport->closeStream(message->stream);
 			sctpTransport->closeStream(message->stream);
@@ -734,23 +731,30 @@ void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 }
 }
 
 
 shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role role, string label,
 shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role role, string label,
-                                                           string protocol,
-                                                           Reliability reliability,
-                                                           std::optional<unsigned int> stream) {
+                                                           DataChannelInit init) {
 	std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
 	std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
-	if(!stream) {
+	unsigned int stream;
+	if (init.id) {
+		stream = *init.id;
+	} else {
 		// The active side must use streams with even identifiers, whereas the passive side must use
 		// The active side must use streams with even identifiers, whereas the passive side must use
 		// streams with odd identifiers.
 		// streams with odd identifiers.
 		// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
 		// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
-		*stream = (role == Description::Role::Active) ? 0 : 1;
-		while (mDataChannels.find(*stream) != mDataChannels.end()) {
-			*stream += 2;
-			if (*stream >= 65535)
+		stream = (role == Description::Role::Active) ? 0 : 1;
+		while (mDataChannels.find(stream) != mDataChannels.end()) {
+			stream += 2;
+			if (stream >= 65535)
 				throw std::runtime_error("Too many DataChannels");
 				throw std::runtime_error("Too many DataChannels");
 		}
 		}
 	}
 	}
-	auto channel = std::make_shared<DataChannel>(shared_from_this(), *stream, std::move(label),
-	                                             std::move(protocol), std::move(reliability));
+	// If the DataChannel is user-negociated, do not negociate it here
+	auto channel =
+	    init.negociated
+	        ? std::make_shared<DataChannel>(shared_from_this(), stream, std::move(label),
+	                                        std::move(init.protocol), std::move(init.reliability))
+	        : std::make_shared<NegociatedDataChannel>(shared_from_this(), stream, std::move(label),
+	                                                  std::move(init.protocol),
+	                                                  std::move(init.reliability));
 	mDataChannels.emplace(std::make_pair(stream, channel));
 	mDataChannels.emplace(std::make_pair(stream, channel));
 	return channel;
 	return channel;
 }
 }