Jelajahi Sumber

Implemented max message size negociation

Paul-Louis Ageneau 5 tahun lalu
induk
melakukan
c5e25bbdbc

+ 2 - 1
include/rtc/datachannel.hpp

@@ -56,6 +56,7 @@ public:
 	bool isOpen(void) const;
 	bool isClosed(void) const;
 	size_t availableAmount() const;
+	size_t maxMessageSize() const;
 
 	unsigned int stream() const;
 	string label() const;
@@ -68,7 +69,7 @@ private:
 	void incoming(message_ptr message);
 	void processOpenMessage(message_ptr message);
 
-	const std::shared_ptr<PeerConnection> mPeerConnection; // keeps the PeerConnection alive
+	const std::shared_ptr<PeerConnection> mPeerConnection;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 
 	unsigned int mStream;

+ 3 - 0
include/rtc/description.hpp

@@ -44,9 +44,11 @@ public:
 	string mid() const;
 	std::optional<string> fingerprint() const;
 	std::optional<uint16_t> sctpPort() const;
+	std::optional<size_t> maxMessageSize() const;
 
 	void setFingerprint(string fingerprint);
 	void setSctpPort(uint16_t port);
+	void setMaxMessageSize(size_t size);
 
 	void addCandidate(Candidate candidate);
 	void endCandidates();
@@ -62,6 +64,7 @@ private:
 	string mIceUfrag, mIcePwd;
 	std::optional<string> mFingerprint;
 	std::optional<uint16_t> mSctpPort;
+	std::optional<size_t> mMaxMessageSize;
 	std::vector<Candidate> mCandidates;
 	bool mTrickle;
 

+ 3 - 0
include/rtc/include.hpp

@@ -43,7 +43,10 @@ using std::uint8_t;
 
 const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
 const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
+
 const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
+const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536;    // Remote max message size if not specified in SDP
+const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
 
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
 template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;

+ 1 - 0
include/rtc/message.hpp

@@ -44,6 +44,7 @@ struct Message : binary {
 using message_ptr = std::shared_ptr<const Message>;
 using mutable_message_ptr = std::shared_ptr<Message>;
 using message_callback = std::function<void(message_ptr message)>;
+
 constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
 	return m->type != Message::Control ? m->size() : 0;
 };

+ 15 - 1
src/datachannel.cpp

@@ -17,6 +17,7 @@
  */
 
 #include "datachannel.hpp"
+#include "include.hpp"
 #include "peerconnection.hpp"
 #include "sctptransport.hpp"
 
@@ -128,6 +129,15 @@ bool DataChannel::isClosed(void) const { return mIsClosed; }
 
 size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 
+size_t DataChannel::maxMessageSize() const {
+	size_t max = DEFAULT_MAX_MESSAGE_SIZE;
+	if (auto description = mPeerConnection->remoteDescription())
+		if (auto maxMessageSize = description->maxMessageSize())
+			return *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE;
+
+	return std::min(max, LOCAL_MAX_MESSAGE_SIZE);
+}
+
 unsigned int DataChannel::stream() const { return mStream; }
 
 string DataChannel::label() const { return mLabel; }
@@ -169,7 +179,11 @@ void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
 
 void DataChannel::outgoing(mutable_message_ptr message) {
 	if (mIsClosed || !mSctpTransport)
-		return;
+		throw std::runtime_error("DataChannel is closed");
+
+	if (message->size() > maxMessageSize())
+		throw std::runtime_error("Message size exceeds limit");
+
 	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
 	message->reliability = mIsOpen ? mReliability : nullptr;
 	message->stream = mStream;

+ 8 - 1
src/description.cpp

@@ -81,6 +81,8 @@ Description::Description(const string &sdp, Type type, Role role)
 			mIcePwd = line.substr(line.find(':') + 1);
 		} else if (hasprefix(line, "a=sctp-port")) {
 			mSctpPort = uint16_t(std::stoul(line.substr(line.find(':') + 1)));
+		} else if (hasprefix(line, "a=max-message-size")) {
+			mMaxMessageSize = size_t(std::stoul(line.substr(line.find(':') + 1)));
 		} else if (hasprefix(line, "a=candidate")) {
 			addCandidate(Candidate(line.substr(2), mMid));
 		} else if (hasprefix(line, "a=end-of-candidates")) {
@@ -103,12 +105,16 @@ std::optional<string> Description::fingerprint() const { return mFingerprint; }
 
 std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; }
 
+std::optional<size_t> Description::maxMessageSize() const { return mMaxMessageSize; }
+
 void Description::setFingerprint(string fingerprint) {
 	mFingerprint.emplace(std::move(fingerprint));
 }
 
 void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); }
 
+void Description::setMaxMessageSize(size_t size) { mMaxMessageSize.emplace(size); }
+
 void Description::addCandidate(Candidate candidate) {
 	mCandidates.emplace_back(std::move(candidate));
 }
@@ -145,7 +151,8 @@ Description::operator string() const {
 		sdp << "a=fingerprint:sha-256 " << *mFingerprint << "\n";
 	if (mSctpPort)
 		sdp << "a=sctp-port:" << *mSctpPort << "\n";
-
+	if (mMaxMessageSize)
+		sdp << "a=max-message-size:" << *mMaxMessageSize << "\n";
 	for (const auto &candidate : mCandidates) {
 		sdp << string(candidate) << "\n";
 	}

+ 1 - 0
src/peerconnection.cpp

@@ -359,6 +359,7 @@ void PeerConnection::processLocalDescription(Description description) {
 	mLocalDescription.emplace(std::move(description));
 	mLocalDescription->setFingerprint(mCertificate->fingerprint());
 	mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
+	mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);
 
 	mLocalDescriptionCallback(*mLocalDescription);
 }