Bläddra i källkod

Merge pull request #992 from paullouisageneau/enforce-sctp-message-size

Enforce SCTP max message size for safety
Paul-Louis Ageneau 1 år sedan
förälder
incheckning
e60f0cc52f
2 ändrade filer med 31 tillägg och 13 borttagningar
  1. 30 13
      src/impl/sctptransport.cpp
  2. 1 0
      src/impl/sctptransport.hpp

+ 30 - 13
src/impl/sctptransport.cpp

@@ -167,8 +167,10 @@ void SctpTransport::Cleanup() {
 SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &config, Ports ports,
                              message_callback recvCallback, amount_callback bufferedAmountCallback,
                              state_callback stateChangeCallback)
-    : Transport(lower, std::move(stateChangeCallback)), mPorts(std::move(ports)),
-      mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
+    : Transport(lower, std::move(stateChangeCallback)),
+      mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_LOCAL_MAX_MESSAGE_SIZE)),
+      mPorts(std::move(ports)), mSendQueue(0, message_size_func),
+      mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
 	onRecv(std::move(recvCallback));
 
 	PLOG_DEBUG << "Initializing SCTP transport";
@@ -294,8 +296,7 @@ SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &c
 		                         std::to_string(errno));
 
 	// Ensure the buffer is also large enough to accomodate the largest messages
-	const size_t maxMessageSize = config.maxMessageSize.value_or(DEFAULT_LOCAL_MAX_MESSAGE_SIZE);
-	const int minBuf = int(std::min(maxMessageSize, size_t(std::numeric_limits<int>::max())));
+	const int minBuf = int(std::min(mMaxMessageSize, size_t(std::numeric_limits<int>::max())));
 	rcvBuf = std::max(rcvBuf, minBuf);
 	sndBuf = std::max(sndBuf, minBuf);
 
@@ -379,6 +380,9 @@ bool SctpTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
+	if(message->size() > mMaxMessageSize)
+		throw std::invalid_argument("Message is too large");
+
 	// Flush the queue, and if nothing is pending, try to send directly
 	if (trySendQueue() && trySendMessage(message))
 		return true;
@@ -499,24 +503,31 @@ void SctpTransport::doRecv() {
 			if (flags & MSG_NOTIFICATION) {
 				// SCTP event notification
 				mPartialNotification.insert(mPartialNotification.end(), buffer, buffer + len);
+
 				if (flags & MSG_EOR) {
 					// Notification is complete, process it
-					auto notification =
-					    reinterpret_cast<union sctp_notification *>(mPartialNotification.data());
-					processNotification(notification, mPartialNotification.size());
-					mPartialNotification.clear();
+					binary notification;
+					mPartialNotification.swap(notification);
+					auto n = reinterpret_cast<union sctp_notification *>(notification.data());
+					processNotification(n, notification.size());
 				}
+
 			} else {
 				// SCTP message
 				mPartialMessage.insert(mPartialMessage.end(), buffer, buffer + len);
+				if(mPartialMessage.size() > mMaxMessageSize) {
+					PLOG_WARNING << "SCTP message is too large, truncating it";
+					mPartialMessage.resize(mMaxMessageSize);
+				}
+
 				if (flags & MSG_EOR) {
 					// Message is complete, process it
+					binary message;
+					mPartialMessage.swap(message);
 					if (infotype != SCTP_RECVV_RCVINFO)
 						throw std::runtime_error("Missing SCTP recv info");
 
-					processData(std::move(mPartialMessage), info.rcv_sid,
-					            PayloadId(ntohl(info.rcv_ppid)));
-					mPartialMessage.clear();
+					processData(std::move(message), info.rcv_sid, PayloadId(ntohl(info.rcv_ppid)));
 				}
 			}
 		}
@@ -773,6 +784,7 @@ void SctpTransport::processData(binary &&data, uint16_t sid, PayloadId ppid) {
 
 	case PPID_STRING_PARTIAL: // deprecated
 		mPartialStringData.insert(mPartialStringData.end(), data.begin(), data.end());
+		mPartialStringData.resize(mMaxMessageSize);
 		break;
 
 	case PPID_STRING:
@@ -781,9 +793,11 @@ void SctpTransport::processData(binary &&data, uint16_t sid, PayloadId ppid) {
 			recv(make_message(std::move(data), Message::String, sid));
 		} else {
 			mPartialStringData.insert(mPartialStringData.end(), data.begin(), data.end());
+			mPartialStringData.resize(mMaxMessageSize);
 			mBytesReceived += mPartialStringData.size();
-			recv(make_message(std::move(mPartialStringData), Message::String, sid));
+			auto message = make_message(std::move(mPartialStringData), Message::String, sid);
 			mPartialStringData.clear();
+			recv(std::move(message));
 		}
 		break;
 
@@ -794,6 +808,7 @@ void SctpTransport::processData(binary &&data, uint16_t sid, PayloadId ppid) {
 
 	case PPID_BINARY_PARTIAL: // deprecated
 		mPartialBinaryData.insert(mPartialBinaryData.end(), data.begin(), data.end());
+		mPartialBinaryData.resize(mMaxMessageSize);
 		break;
 
 	case PPID_BINARY:
@@ -802,9 +817,11 @@ void SctpTransport::processData(binary &&data, uint16_t sid, PayloadId ppid) {
 			recv(make_message(std::move(data), Message::Binary, sid));
 		} else {
 			mPartialBinaryData.insert(mPartialBinaryData.end(), data.begin(), data.end());
+			mPartialBinaryData.resize(mMaxMessageSize);
 			mBytesReceived += mPartialBinaryData.size();
-			recv(make_message(std::move(mPartialBinaryData), Message::Binary, sid));
+			auto message = make_message(std::move(mPartialBinaryData), Message::Binary, sid);
 			mPartialBinaryData.clear();
+			recv(std::move(message));
 		}
 		break;
 

+ 1 - 0
src/impl/sctptransport.hpp

@@ -96,6 +96,7 @@ private:
 	void processData(binary &&data, uint16_t streamId, PayloadId ppid);
 	void processNotification(const union sctp_notification *notify, size_t len);
 
+	const size_t mMaxMessageSize;
 	const Ports mPorts;
 	struct socket *mSock;
 	std::optional<uint16_t> mNegotiatedStreamsCount;