datachannel.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. /**
  2. * Copyright (c) 2019-2021 Paul-Louis Ageneau
  3. *
  4. * This library is free software; you can redistribute it and/or
  5. * modify it under the terms of the GNU Lesser General Public
  6. * License as published by the Free Software Foundation; either
  7. * version 2.1 of the License, or (at your option) any later version.
  8. *
  9. * This library is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. * Lesser General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU Lesser General Public
  15. * License along with this library; if not, write to the Free Software
  16. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  17. */
  18. #include "datachannel.hpp"
  19. #include "include.hpp"
  20. #include "logcounter.hpp"
  21. #include "peerconnection.hpp"
  22. #include "sctptransport.hpp"
  23. #include "rtc/datachannel.hpp"
  24. #include "rtc/track.hpp"
  25. #ifdef _WIN32
  26. #include <winsock2.h>
  27. #else
  28. #include <arpa/inet.h>
  29. #endif
  30. using std::chrono::milliseconds;
  31. namespace rtc::impl {
  32. // Messages for the DataChannel establishment protocol
  33. // See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09
  34. enum MessageType : uint8_t {
  35. MESSAGE_OPEN_REQUEST = 0x00,
  36. MESSAGE_OPEN_RESPONSE = 0x01,
  37. MESSAGE_ACK = 0x02,
  38. MESSAGE_OPEN = 0x03,
  39. MESSAGE_CLOSE = 0x04
  40. };
  41. enum ChannelType : uint8_t {
  42. CHANNEL_RELIABLE = 0x00,
  43. CHANNEL_PARTIAL_RELIABLE_REXMIT = 0x01,
  44. CHANNEL_PARTIAL_RELIABLE_TIMED = 0x02
  45. };
  46. #pragma pack(push, 1)
  47. struct OpenMessage {
  48. uint8_t type = MESSAGE_OPEN;
  49. uint8_t channelType;
  50. uint16_t priority;
  51. uint32_t reliabilityParameter;
  52. uint16_t labelLength;
  53. uint16_t protocolLength;
  54. // The following fields are:
  55. // uint8_t[labelLength] label
  56. // uint8_t[protocolLength] protocol
  57. };
  58. struct AckMessage {
  59. uint8_t type = MESSAGE_ACK;
  60. };
  61. struct CloseMessage {
  62. uint8_t type = MESSAGE_CLOSE;
  63. };
  64. #pragma pack(pop)
  65. LogCounter COUNTER_USERNEG_OPEN_MESSAGE(
  66. plog::warning, "Number of open messages for a user-negotiated DataChannel received");
  67. DataChannel::DataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label,
  68. string protocol, Reliability reliability)
  69. : mPeerConnection(pc), mStream(stream), mLabel(std::move(label)),
  70. mProtocol(std::move(protocol)),
  71. mReliability(std::make_shared<Reliability>(std::move(reliability))),
  72. mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
  73. DataChannel::~DataChannel() { close(); }
  74. void DataChannel::close() {
  75. std::shared_ptr<SctpTransport> transport;
  76. {
  77. std::shared_lock lock(mMutex);
  78. transport = mSctpTransport.lock();
  79. }
  80. mIsClosed = true;
  81. if (mIsOpen.exchange(false) && transport)
  82. transport->closeStream(mStream);
  83. resetCallbacks();
  84. }
  85. void DataChannel::remoteClose() {
  86. if (!mIsClosed.exchange(true))
  87. triggerClosed();
  88. mIsOpen = false;
  89. }
  90. std::optional<message_variant> DataChannel::receive() {
  91. while (auto next = mRecvQueue.tryPop()) {
  92. message_ptr message = *next;
  93. if (message->type != Message::Control)
  94. return to_variant(std::move(*message));
  95. auto raw = reinterpret_cast<const uint8_t *>(message->data());
  96. if (!message->empty() && raw[0] == MESSAGE_CLOSE)
  97. remoteClose();
  98. }
  99. return nullopt;
  100. }
  101. std::optional<message_variant> DataChannel::peek() {
  102. while (auto next = mRecvQueue.peek()) {
  103. message_ptr message = *next;
  104. if (message->type != Message::Control)
  105. return to_variant(std::move(*message));
  106. auto raw = reinterpret_cast<const uint8_t *>(message->data());
  107. if (!message->empty() && raw[0] == MESSAGE_CLOSE)
  108. remoteClose();
  109. mRecvQueue.tryPop();
  110. }
  111. return nullopt;
  112. }
  113. size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
  114. uint16_t DataChannel::stream() const {
  115. std::shared_lock lock(mMutex);
  116. return mStream;
  117. }
  118. string DataChannel::label() const {
  119. std::shared_lock lock(mMutex);
  120. return mLabel;
  121. }
  122. string DataChannel::protocol() const {
  123. std::shared_lock lock(mMutex);
  124. return mProtocol;
  125. }
  126. Reliability DataChannel::reliability() const {
  127. std::shared_lock lock(mMutex);
  128. return *mReliability;
  129. }
  130. bool DataChannel::isOpen(void) const { return mIsOpen; }
  131. bool DataChannel::isClosed(void) const { return mIsClosed; }
  132. size_t DataChannel::maxMessageSize() const {
  133. size_t remoteMax = DEFAULT_MAX_MESSAGE_SIZE;
  134. if (auto pc = mPeerConnection.lock())
  135. if (auto description = pc->remoteDescription())
  136. if (auto *application = description->application())
  137. if (auto maxMessageSize = application->maxMessageSize())
  138. remoteMax = *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE;
  139. return std::min(remoteMax, LOCAL_MAX_MESSAGE_SIZE);
  140. }
  141. void DataChannel::shiftStream() {
  142. if (mStream % 2 == 1)
  143. mStream -= 1;
  144. }
  145. void DataChannel::open(shared_ptr<SctpTransport> transport) {
  146. {
  147. std::unique_lock lock(mMutex);
  148. mSctpTransport = transport;
  149. }
  150. if (!mIsOpen.exchange(true))
  151. triggerOpen();
  152. }
  153. void DataChannel::processOpenMessage(message_ptr) {
  154. PLOG_DEBUG << "Received an open message for a user-negotiated DataChannel, ignoring";
  155. COUNTER_USERNEG_OPEN_MESSAGE++;
  156. }
  157. bool DataChannel::outgoing(message_ptr message) {
  158. std::shared_ptr<SctpTransport> transport;
  159. {
  160. std::shared_lock lock(mMutex);
  161. transport = mSctpTransport.lock();
  162. if (!transport || mIsClosed)
  163. throw std::runtime_error("DataChannel is closed");
  164. if (message->size() > maxMessageSize())
  165. throw std::runtime_error("Message size exceeds limit");
  166. // Before the ACK has been received on a DataChannel, all messages must be sent ordered
  167. message->reliability = mIsOpen ? mReliability : nullptr;
  168. message->stream = mStream;
  169. }
  170. return transport->send(message);
  171. }
  172. void DataChannel::incoming(message_ptr message) {
  173. if (!message)
  174. return;
  175. switch (message->type) {
  176. case Message::Control: {
  177. if (message->size() == 0)
  178. break; // Ignore
  179. auto raw = reinterpret_cast<const uint8_t *>(message->data());
  180. switch (raw[0]) {
  181. case MESSAGE_OPEN:
  182. processOpenMessage(message);
  183. break;
  184. case MESSAGE_ACK:
  185. if (!mIsOpen.exchange(true)) {
  186. triggerOpen();
  187. }
  188. break;
  189. case MESSAGE_CLOSE:
  190. // The close message will be processed in-order in receive()
  191. mRecvQueue.push(message);
  192. triggerAvailable(mRecvQueue.size());
  193. break;
  194. default:
  195. // Ignore
  196. break;
  197. }
  198. break;
  199. }
  200. case Message::String:
  201. case Message::Binary:
  202. mRecvQueue.push(message);
  203. triggerAvailable(mRecvQueue.size());
  204. break;
  205. default:
  206. // Ignore
  207. break;
  208. }
  209. }
  210. NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr<impl::PeerConnection> pc,
  211. uint16_t stream, string label, string protocol,
  212. Reliability reliability)
  213. : DataChannel(pc, stream, std::move(label), std::move(protocol), std::move(reliability)) {}
  214. NegotiatedDataChannel::NegotiatedDataChannel(std::weak_ptr<impl::PeerConnection> pc,
  215. std::weak_ptr<impl::SctpTransport> transport,
  216. uint16_t stream)
  217. : DataChannel(pc, stream, "", "", {}) {
  218. mSctpTransport = transport;
  219. }
  220. NegotiatedDataChannel::~NegotiatedDataChannel() {}
  221. void NegotiatedDataChannel::open(shared_ptr<impl::SctpTransport> transport) {
  222. std::unique_lock lock(mMutex);
  223. mSctpTransport = transport;
  224. uint8_t channelType;
  225. uint32_t reliabilityParameter;
  226. switch (mReliability->type) {
  227. case Reliability::Type::Rexmit:
  228. channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
  229. reliabilityParameter = uint32_t(std::get<int>(mReliability->rexmit));
  230. break;
  231. case Reliability::Type::Timed:
  232. channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
  233. reliabilityParameter = uint32_t(std::get<milliseconds>(mReliability->rexmit).count());
  234. break;
  235. default:
  236. channelType = CHANNEL_RELIABLE;
  237. reliabilityParameter = 0;
  238. break;
  239. }
  240. if (mReliability->unordered)
  241. channelType |= 0x80;
  242. const size_t len = sizeof(OpenMessage) + mLabel.size() + mProtocol.size();
  243. binary buffer(len, byte(0));
  244. auto &open = *reinterpret_cast<OpenMessage *>(buffer.data());
  245. open.type = MESSAGE_OPEN;
  246. open.channelType = channelType;
  247. open.priority = htons(0);
  248. open.reliabilityParameter = htonl(reliabilityParameter);
  249. open.labelLength = htons(uint16_t(mLabel.size()));
  250. open.protocolLength = htons(uint16_t(mProtocol.size()));
  251. auto end = reinterpret_cast<char *>(buffer.data() + sizeof(OpenMessage));
  252. std::copy(mLabel.begin(), mLabel.end(), end);
  253. std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
  254. lock.unlock();
  255. transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
  256. }
  257. void NegotiatedDataChannel::processOpenMessage(message_ptr message) {
  258. std::unique_lock lock(mMutex);
  259. auto transport = mSctpTransport.lock();
  260. if (!transport)
  261. throw std::runtime_error("DataChannel has no transport");
  262. if (message->size() < sizeof(OpenMessage))
  263. throw std::invalid_argument("DataChannel open message too small");
  264. OpenMessage open = *reinterpret_cast<const OpenMessage *>(message->data());
  265. open.priority = ntohs(open.priority);
  266. open.reliabilityParameter = ntohl(open.reliabilityParameter);
  267. open.labelLength = ntohs(open.labelLength);
  268. open.protocolLength = ntohs(open.protocolLength);
  269. if (message->size() < sizeof(OpenMessage) + size_t(open.labelLength + open.protocolLength))
  270. throw std::invalid_argument("DataChannel open message truncated");
  271. auto end = reinterpret_cast<const char *>(message->data() + sizeof(OpenMessage));
  272. mLabel.assign(end, open.labelLength);
  273. mProtocol.assign(end + open.labelLength, open.protocolLength);
  274. mReliability->unordered = (open.channelType & 0x80) != 0;
  275. switch (open.channelType & 0x7F) {
  276. case CHANNEL_PARTIAL_RELIABLE_REXMIT:
  277. mReliability->type = Reliability::Type::Rexmit;
  278. mReliability->rexmit = int(open.reliabilityParameter);
  279. break;
  280. case CHANNEL_PARTIAL_RELIABLE_TIMED:
  281. mReliability->type = Reliability::Type::Timed;
  282. mReliability->rexmit = milliseconds(open.reliabilityParameter);
  283. break;
  284. default:
  285. mReliability->type = Reliability::Type::Reliable;
  286. mReliability->rexmit = int(0);
  287. }
  288. lock.unlock();
  289. binary buffer(sizeof(AckMessage), byte(0));
  290. auto &ack = *reinterpret_cast<AckMessage *>(buffer.data());
  291. ack.type = MESSAGE_ACK;
  292. transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
  293. if (!mIsOpen.exchange(true))
  294. triggerOpen();
  295. }
  296. } // namespace rtc::impl