Browse Source

Check uint16 and uint32 conversions in DataChannel

Paul-Louis Ageneau 1 năm trước cách đây
mục cha
commit
e736c60684
3 tập tin đã thay đổi với 31 bổ sung27 xóa
  1. 10 7
      src/impl/datachannel.cpp
  2. 4 20
      src/impl/sctptransport.cpp
  3. 17 0
      src/impl/utils.hpp

+ 10 - 7
src/impl/datachannel.cpp

@@ -12,7 +12,7 @@
 #include "logcounter.hpp"
 #include "peerconnection.hpp"
 #include "sctptransport.hpp"
-
+#include "utils.hpp"
 #include "rtc/datachannel.hpp"
 #include "rtc/track.hpp"
 
@@ -28,6 +28,9 @@ using std::chrono::milliseconds;
 
 namespace rtc::impl {
 
+using utils::to_uint16;
+using utils::to_uint32;
+
 // Messages for the DataChannel establishment protocol (RFC 8832)
 // See https://www.rfc-editor.org/rfc/rfc8832.html
 
@@ -254,10 +257,10 @@ void OutgoingDataChannel::open(shared_ptr<SctpTransport> transport) {
 	uint32_t reliabilityParameter;
 	if (mReliability->maxPacketLifeTime) {
 		channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
-		reliabilityParameter = uint32_t(mReliability->maxPacketLifeTime->count());
+		reliabilityParameter = to_uint32(mReliability->maxPacketLifeTime->count());
 	} else if (mReliability->maxRetransmits) {
 		channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
-		reliabilityParameter = uint32_t(*mReliability->maxRetransmits);
+		reliabilityParameter = to_uint32(*mReliability->maxRetransmits);
 	}
 	// else {
 	//	channelType = CHANNEL_RELIABLE;
@@ -268,12 +271,12 @@ void OutgoingDataChannel::open(shared_ptr<SctpTransport> transport) {
 		switch (mReliability->typeDeprecated) {
 		case Reliability::Type::Rexmit:
 			channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT;
-			reliabilityParameter = uint32_t(std::max(std::get<int>(mReliability->rexmit), 0));
+			reliabilityParameter = to_uint32(std::max(std::get<int>(mReliability->rexmit), 0));
 			break;
 
 		case Reliability::Type::Timed:
 			channelType = CHANNEL_PARTIAL_RELIABLE_TIMED;
-			reliabilityParameter = uint32_t(std::get<milliseconds>(mReliability->rexmit).count());
+			reliabilityParameter = to_uint32(std::get<milliseconds>(mReliability->rexmit).count());
 			break;
 
 		default:
@@ -292,8 +295,8 @@ void OutgoingDataChannel::open(shared_ptr<SctpTransport> transport) {
 	open.channelType = channelType;
 	open.priority = htons(0);
 	open.reliabilityParameter = htonl(reliabilityParameter);
-	open.labelLength = htons(uint16_t(mLabel.size()));
-	open.protocolLength = htons(uint16_t(mProtocol.size()));
+	open.labelLength = htons(to_uint16(mLabel.size()));
+	open.protocolLength = htons(to_uint16(mProtocol.size()));
 
 	auto end = reinterpret_cast<char *>(buffer.data() + sizeof(OpenMessage));
 	std::copy(mLabel.begin(), mLabel.end(), end);

+ 4 - 20
src/impl/sctptransport.cpp

@@ -10,6 +10,7 @@
 #include "dtlstransport.hpp"
 #include "internals.hpp"
 #include "logcounter.hpp"
+#include "utils.hpp"
 
 #include <algorithm>
 #include <chrono>
@@ -50,28 +51,11 @@
 using namespace std::chrono_literals;
 using namespace std::chrono;
 
-namespace {
-
-template <typename T> uint16_t to_uint16(T i) {
-	if (i >= 0 && static_cast<typename std::make_unsigned<T>::type>(i) <=
-	                  std::numeric_limits<uint16_t>::max())
-		return static_cast<uint16_t>(i);
-	else
-		throw std::invalid_argument("Integer out of range");
-}
-
-template <typename T> uint32_t to_uint32(T i) {
-	if (i >= 0 && static_cast<typename std::make_unsigned<T>::type>(i) <=
-	                  std::numeric_limits<uint32_t>::max())
-		return static_cast<uint32_t>(i);
-	else
-		throw std::invalid_argument("Integer out of range");
-}
-
-} // namespace
-
 namespace rtc::impl {
 
+using utils::to_uint16;
+using utils::to_uint32;
+
 static LogCounter COUNTER_UNKNOWN_PPID(plog::warning,
                                        "Number of SCTP packets received with an unknown PPID");
 

+ 17 - 0
src/impl/utils.hpp

@@ -15,6 +15,7 @@
 #include <limits>
 #include <map>
 #include <random>
+#include <stdexcept>
 #include <vector>
 
 namespace rtc::impl::utils {
@@ -60,6 +61,22 @@ template <typename Generator = std::mt19937> auto random_bytes_engine() {
 	return random_engine<char_independent_bits_engine, uint8_t>();
 }
 
+template <typename T> uint16_t to_uint16(T i) {
+	if (i >= 0 && static_cast<typename std::make_unsigned<T>::type>(i) <=
+	                  std::numeric_limits<uint16_t>::max())
+		return static_cast<uint16_t>(i);
+	else
+		throw std::invalid_argument("Integer out of range");
+}
+
+template <typename T> uint32_t to_uint32(T i) {
+	if (i >= 0 && static_cast<typename std::make_unsigned<T>::type>(i) <=
+	                  std::numeric_limits<uint32_t>::max())
+		return static_cast<uint32_t>(i);
+	else
+		throw std::invalid_argument("Integer out of range");
+}
+
 namespace this_thread {
 
 void set_name(const string &name);