Browse Source

Fixed uninitialized realiability and refactored deserialization

Paul-Louis Ageneau 6 years ago
parent
commit
2ef8690ed4
2 changed files with 17 additions and 17 deletions
  1. 17 16
      src/datachannel.cpp
  2. 0 1
      src/datachannel.hpp

+ 17 - 16
src/datachannel.cpp

@@ -32,6 +32,7 @@ enum MessageType : uint8_t {
 	MESSAGE_CLOSE = 0x04
 };
 
+#pragma pack(push, 1)
 struct OpenMessage {
 	uint8_t type = MESSAGE_OPEN;
 	uint8_t channelType;
@@ -39,8 +40,9 @@ struct OpenMessage {
 	uint32_t reliabilityParameter;
 	uint16_t labelLength;
 	uint16_t protocolLength;
-	// label
-	// protocol
+	// The following fields are:
+	// uint8_t[labelLength] label
+	// uint8_t[protocolLength] protocol
 };
 
 struct AckMessage {
@@ -50,6 +52,7 @@ struct AckMessage {
 struct CloseMessage {
 	uint8_t type = MESSAGE_CLOSE;
 };
+#pragma pack(pop)
 
 DataChannel::DataChannel(unsigned int stream, string label, string protocol,
                          Reliability reliability)
@@ -57,7 +60,8 @@ DataChannel::DataChannel(unsigned int stream, string label, string protocol,
       mReliability(std::make_shared<Reliability>(std::move(reliability))) {}
 
 DataChannel::DataChannel(unsigned int stream, shared_ptr<SctpTransport> sctpTransport)
-    : mStream(stream), mSctpTransport(sctpTransport) {}
+    : mStream(stream), mSctpTransport(sctpTransport),
+      mReliability(std::make_shared<Reliability>()) {}
 
 DataChannel::~DataChannel() { close(); }
 
@@ -172,24 +176,21 @@ void DataChannel::incoming(message_ptr message) {
 }
 
 void DataChannel::processOpenMessage(message_ptr message) {
-	auto *raw = reinterpret_cast<const uint8_t *>(message->data());
-
-	if (message->size() < 12)
+	if (message->size() < sizeof(OpenMessage))
 		throw std::invalid_argument("DataChannel open message too small");
 
-	OpenMessage open;
-	open.channelType = raw[1];
-	open.priority = (raw[2] << 8) + raw[3];
-	open.reliabilityParameter = (raw[4] << 24) + (raw[5] << 16) + (raw[6] << 8) + raw[7];
-	open.labelLength = (raw[8] << 8) + raw[9];
-	open.protocolLength = (raw[10] << 8) + raw[11];
+	OpenMessage open = *reinterpret_cast<const OpenMessage *>(message->data());
+	open.priority = ntohs(open.priority);
+	open.reliabilityParameter = ntohl(open.reliabilityParameter);
+	open.labelLength = ntohs(open.labelLength);
+	open.protocolLength = ntohs(open.protocolLength);
 
-	if (message->size() < 12 + open.labelLength + open.protocolLength)
+	if (message->size() < sizeof(OpenMessage) + size_t(open.labelLength + open.protocolLength))
 		throw std::invalid_argument("DataChannel open message truncated");
 
-	mLabel.assign(reinterpret_cast<const char *>(raw + 12), open.labelLength);
-	mProtocol.assign(reinterpret_cast<const char *>(raw + 12 + open.labelLength),
-	                 open.protocolLength);
+	auto next = message->data() + sizeof(OpenMessage);
+	mLabel.assign(reinterpret_cast<const char *>(next), open.labelLength);
+	mProtocol.assign(reinterpret_cast<const char *>(next + open.labelLength), open.protocolLength);
 
 	using std::chrono::milliseconds;
 	mReliability->unordered = (open.reliabilityParameter & 0x80) != 0;

+ 0 - 1
src/datachannel.hpp

@@ -71,4 +71,3 @@ private:
 } // namespace rtc
 
 #endif
-