Ver código fonte

Merge pull request #1127 from melpon/dependency-descriptor

Support dependency descriptor and two-byte header for RTP header extension
Paul-Louis Ageneau 2 meses atrás
pai
commit
2040b439ea

+ 2 - 0
CMakeLists.txt

@@ -65,6 +65,7 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/channel.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/configuration.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/datachannel.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/dependencydescriptor.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/description.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/iceudpmuxlistener.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/mediahandler.cpp
@@ -99,6 +100,7 @@ set(LIBDATACHANNEL_HEADERS
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/channel.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/configuration.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/datachannel.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/dependencydescriptor.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/description.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/iceudpmuxlistener.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/mediahandler.hpp

+ 109 - 0
include/rtc/dependencydescriptor.hpp

@@ -0,0 +1,109 @@
+/**
+ * Copyright (c) 2024 Shigemasa Watanabe (Wandbox)
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/.
+ */
+
+#ifndef RTC_DEPENDENCY_DESCRIPTOR_H
+#define RTC_DEPENDENCY_DESCRIPTOR_H
+
+#include <bitset>
+#include <cassert>
+#include <optional>
+#include <stdint.h>
+#include <vector>
+
+namespace rtc {
+
+struct BitWriter {
+	static BitWriter fromSizeBits(std::byte *buf, size_t offsetBits, size_t sizeBits);
+	static BitWriter fromNull();
+
+	size_t getWrittenBits() const;
+
+	bool write(uint64_t v, size_t bits);
+	// Write non-symmetric unsigned encoded integer
+	// ref: https://aomediacodec.github.io/av1-rtp-spec/#a82-syntax
+	bool writeNonSymmetric(uint64_t v, uint64_t n);
+
+private:
+	size_t writePartialByte(uint8_t *p, size_t offset, uint64_t v, size_t bits);
+
+private:
+	std::byte *mBuf = nullptr;
+	size_t mInitialOffset = 0;
+	size_t mOffset = 0;
+	size_t mSize = 0;
+};
+
+enum class DecodeTargetIndication {
+	NotPresent = 0,
+	Discardable = 1,
+	Switch = 2,
+	Required = 3,
+};
+
+struct RenderResolution {
+	int width = 0;
+	int height = 0;
+};
+
+struct FrameDependencyTemplate {
+	int spatialId = 0;
+	int temporalId = 0;
+	std::vector<DecodeTargetIndication> decodeTargetIndications;
+	std::vector<int> frameDiffs;
+	std::vector<int> chainDiffs;
+};
+
+struct FrameDependencyStructure {
+	int templateIdOffset = 0;
+	int decodeTargetCount = 0;
+	int chainCount = 0;
+	std::vector<int> decodeTargetProtectedBy;
+	std::vector<RenderResolution> resolutions;
+	std::vector<FrameDependencyTemplate> templates;
+};
+
+struct DependencyDescriptor {
+	bool startOfFrame = true;
+	bool endOfFrame = true;
+	int frameNumber = 0;
+	FrameDependencyTemplate dependencyTemplate;
+	std::optional<RenderResolution> resolution;
+	std::optional<uint32_t> activeDecodeTargetsBitmask;
+	bool structureAttached;
+};
+
+struct DependencyDescriptorContext {
+	DependencyDescriptor descriptor;
+	std::bitset<32> activeChains;
+	FrameDependencyStructure structure;
+};
+
+// Write dependency descriptor to RTP Header Extension
+// Dependency descriptor specification is here:
+// https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension
+class DependencyDescriptorWriter {
+public:
+	explicit DependencyDescriptorWriter(const DependencyDescriptorContext& context);
+	size_t getSizeBits() const;
+	size_t getSize() const;
+	void writeTo(std::byte *buf, size_t sizeBytes) const;
+
+private:
+	void doWriteTo(BitWriter &writer) const;
+	void writeBits(BitWriter &writer, uint64_t v, size_t bits) const;
+	void writeNonSymmetric(BitWriter &writer, uint64_t v, uint64_t n) const;
+
+private:
+	const FrameDependencyStructure &mStructure;
+	std::bitset<32> mActiveChains;
+	const DependencyDescriptor &mDescriptor;
+};
+
+} // namespace rtc
+
+#endif

+ 1 - 0
include/rtc/rtc.hpp

@@ -30,6 +30,7 @@
 
 // Media
 #include "av1rtppacketizer.hpp"
+#include "dependencydescriptor.hpp"
 #include "h264rtppacketizer.hpp"
 #include "h264rtpdepacketizer.hpp"
 #include "h265rtppacketizer.hpp"

+ 6 - 2
include/rtc/rtp.hpp

@@ -38,8 +38,12 @@ struct RTC_CPP_EXPORT RtpExtensionHeader {
 	void setHeaderLength(uint16_t headerLength);
 
 	void clearBody();
-	void writeCurrentVideoOrientation(size_t offset, uint8_t id, uint8_t value);
-	void writeOneByteHeader(size_t offset, uint8_t id, const byte *value, size_t size);
+	size_t writeCurrentVideoOrientation(bool twoByteHeader, size_t offset, uint8_t id,
+	                                    uint8_t value);
+	size_t writeOneByteHeader(size_t offset, uint8_t id, const byte *value, size_t size);
+	size_t writeTwoByteHeader(size_t offset, uint8_t id, const byte *value, size_t size);
+	size_t writeHeader(bool twoByteHeader, size_t offset, uint8_t id, const byte *value,
+	                   size_t size);
 };
 
 struct RTC_CPP_EXPORT RtpHeader {

+ 5 - 0
include/rtc/rtppacketizationconfig.hpp

@@ -11,6 +11,7 @@
 
 #if RTC_ENABLE_MEDIA
 
+#include "dependencydescriptor.hpp"
 #include "rtp.hpp"
 
 namespace rtc {
@@ -61,6 +62,10 @@ public:
 	uint8_t ridId = 0;
 	optional<std::string> rid;
 
+	// Dependency Descriptor Extension Header
+	uint8_t dependencyDescriptorId = 0;
+
+	optional<DependencyDescriptorContext> dependencyDescriptorContext;
 	// the negotiated ID of the playout delay header extension
 	// https://webrtc.googlesource.com/src/+/main/docs/native-code/rtp-hdrext/playout-delay/README.md
 	uint8_t playoutDelayId = 0;

+ 409 - 0
src/dependencydescriptor.cpp

@@ -0,0 +1,409 @@
+/**
+ * Copyright (c) 2024 Shigemasa Watanabe (Wandbox)
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/.
+ */
+
+#include "dependencydescriptor.hpp"
+
+#include <algorithm>
+#include <functional>
+#include <limits>
+#include <stdexcept>
+
+namespace rtc {
+
+BitWriter BitWriter::fromSizeBits(std::byte *buf, size_t offsetBits, size_t sizeBits) {
+	BitWriter writer;
+	writer.mBuf = buf;
+	writer.mInitialOffset = offsetBits;
+	writer.mOffset = offsetBits;
+	writer.mSize = sizeBits;
+	return writer;
+}
+BitWriter BitWriter::fromNull() {
+	BitWriter writer;
+	writer.mSize = std::numeric_limits<size_t>::max();
+	return writer;
+}
+
+size_t BitWriter::getWrittenBits() const { return mOffset - mInitialOffset; }
+
+bool BitWriter::write(uint64_t v, size_t bits) {
+	if (mOffset + bits > mSize) {
+		return false;
+	}
+	uint8_t *p = mBuf == nullptr ? nullptr : reinterpret_cast<uint8_t *>(mBuf + mOffset / 8);
+	// First, write up to the 8-bit boundary
+	size_t written_bits = writePartialByte(p, mOffset % 8, v, bits);
+
+	if (p != nullptr) {
+		p++;
+	}
+	bits -= written_bits;
+	mOffset += written_bits;
+
+	if (bits == 0) {
+		return true;
+	}
+
+	// Write 8 bits at a time
+	while (bits >= 8) {
+		if (p != nullptr) {
+			*p = (v >> (bits - 8)) & 0xff;
+			p++;
+		}
+		bits -= 8;
+		mOffset += 8;
+	}
+
+	// Write the remaining bits
+	written_bits = writePartialByte(p, 0, v, bits);
+	bits -= written_bits;
+	mOffset += written_bits;
+
+	assert(bits == 0);
+
+	return true;
+}
+
+bool BitWriter::writeNonSymmetric(uint64_t v, uint64_t n) {
+	if (n == 1) {
+		return true;
+	}
+	size_t w = 0;
+	uint64_t x = n;
+	while (x != 0) {
+		x = x >> 1;
+		w++;
+	}
+	uint64_t m = (1ULL << w) - n;
+	if (v < m) {
+		return write(v, w - 1);
+	} else {
+		return write(v + m, w);
+	}
+}
+
+size_t BitWriter::writePartialByte(uint8_t *p, size_t offset, uint64_t v, size_t bits) {
+	// How many bits are remaining
+	size_t remaining_bits = 8 - offset;
+	// Number of bits to write
+	size_t need_write_bits = std::min(remaining_bits, bits);
+	// Number of remaining bits
+	size_t shift = remaining_bits - need_write_bits;
+	// The relationship between each values are as follows
+	// 0bxxxxxxxx
+	//   ^        - offset == 1
+	//    ^-----^ - remaining_bits == 7
+	//    ^---^   - need_write_bits == 5
+	//         ^^ - shift == 2
+	assert(offset + remaining_bits == 8);
+	assert(remaining_bits == need_write_bits + shift);
+
+	// For writing 4 bits from the 3rd bit of 0bxxxxxxxx with 0byyyy, it becomes
+	// (0bxxxxxxxx & 0b11100001) | ((0byyyy >> (4 - 4)) << 1)
+	// For writing 2 bits from the 6th bit of 0bxxxxxxxx with 0byyyyy, it becomes
+	// (0bxxxxxxxx & 0b11111100) | (((0byyyyy >> (5 - 2)) << 0)
+
+	// Creating a mask
+	// For need_write_bits == 4, shift == 1
+	// 1 << 4 == 0b00010000
+	// 0b00010000 - 1 == 0b00001111
+	// 0b00001111 << 1 == 0b00011110
+	// ~0b00011110 == 0b11100001
+	uint8_t mask = ~(((1 << need_write_bits) - 1) << shift);
+
+	uint8_t vv = static_cast<uint8_t>(v >> (bits - need_write_bits));
+
+	if (p != nullptr) {
+		*p = (*p & mask) | (vv << shift);
+	}
+
+	return need_write_bits;
+}
+
+using TemplateIterator = std::vector<FrameDependencyTemplate>::const_iterator;
+
+struct TemplateMatch {
+	size_t templatePosition;
+	bool needCustomDtis;
+	bool needCustomFdiffs;
+	bool needCustomChains;
+	// Size in bits to store frame-specific details, i.e.
+	// excluding mandatory fields and template dependency structure.
+	size_t extraSizeBits;
+};
+
+static TemplateMatch calculate_match(TemplateIterator frameTemplate,
+                                     const FrameDependencyStructure &structure,
+                                     std::bitset<32> activeChains,
+                                     const DependencyDescriptor &descriptor) {
+	TemplateMatch result;
+	result.templatePosition = frameTemplate - structure.templates.begin();
+	result.needCustomFdiffs = descriptor.dependencyTemplate.frameDiffs != frameTemplate->frameDiffs;
+	result.needCustomDtis = descriptor.dependencyTemplate.decodeTargetIndications !=
+	                        frameTemplate->decodeTargetIndications;
+	result.needCustomChains = false;
+	for (int i = 0; i < structure.chainCount; ++i) {
+		if (activeChains[i] &&
+		    descriptor.dependencyTemplate.chainDiffs[i] != frameTemplate->chainDiffs[i]) {
+			result.needCustomChains = true;
+			break;
+		}
+	}
+
+	result.extraSizeBits = 0;
+	if (result.needCustomFdiffs) {
+		result.extraSizeBits += 2 * (1 + descriptor.dependencyTemplate.frameDiffs.size());
+		for (int fdiff : descriptor.dependencyTemplate.frameDiffs) {
+			if (fdiff <= (1 << 4)) {
+				result.extraSizeBits += 4;
+			} else if (fdiff <= (1 << 8)) {
+				result.extraSizeBits += 8;
+			} else {
+				result.extraSizeBits += 12;
+			}
+		}
+	}
+	if (result.needCustomDtis) {
+		result.extraSizeBits += 2 * descriptor.dependencyTemplate.decodeTargetIndications.size();
+	}
+	if (result.needCustomChains) {
+		result.extraSizeBits += 8 * structure.chainCount;
+	}
+	return result;
+}
+
+static bool find_best_template(const FrameDependencyStructure &structure,
+                               std::bitset<32> activeChains, const DependencyDescriptor &descriptor,
+                               TemplateMatch *best) {
+	auto &templates = structure.templates;
+	// Find range of templates with matching spatial/temporal id.
+	auto sameLayer = [&](const FrameDependencyTemplate &frameTemplate) {
+		return descriptor.dependencyTemplate.spatialId == frameTemplate.spatialId &&
+		       descriptor.dependencyTemplate.temporalId == frameTemplate.temporalId;
+	};
+	auto first = std::find_if(templates.begin(), templates.end(), sameLayer);
+	if (first == templates.end()) {
+		return false;
+	}
+	auto last = std::find_if_not(first, templates.end(), sameLayer);
+
+	*best = calculate_match(first, structure, activeChains, descriptor);
+	// Search if there any better template than the first one.
+	for (auto next = std::next(first); next != last; ++next) {
+		auto match = calculate_match(next, structure, activeChains, descriptor);
+		if (match.extraSizeBits < best->extraSizeBits) {
+			*best = match;
+		}
+	}
+	return true;
+}
+
+static const uint32_t MaxTemplates = 64;
+
+DependencyDescriptorWriter::DependencyDescriptorWriter(const DependencyDescriptorContext& context)
+	: mStructure(context.structure), mActiveChains(context.activeChains), mDescriptor(context.descriptor) {}
+
+size_t DependencyDescriptorWriter::getSizeBits() const {
+	auto writer = rtc::BitWriter::fromNull();
+	doWriteTo(writer);
+	return writer.getWrittenBits();
+}
+size_t DependencyDescriptorWriter::getSize() const {
+	return (getSizeBits() + 7) / 8;
+}
+
+void DependencyDescriptorWriter::writeTo(std::byte *buf, size_t sizeBytes) const {
+	auto writer = BitWriter::fromSizeBits(buf, 0, sizeBytes * 8);
+	doWriteTo(writer);
+	// Pad up to the byte boundary
+	if (auto bits = (writer.getWrittenBits() % 8); bits != 0) {
+		writer.write(0, 8 - bits);
+	}
+}
+
+void DependencyDescriptorWriter::doWriteTo(BitWriter &w) const {
+	TemplateMatch bestTemplate;
+	if (!find_best_template(mStructure, mActiveChains, mDescriptor, &bestTemplate)) {
+		throw std::logic_error("No matching template found");
+	}
+
+	// mandatory_descriptor_fields()
+	writeBits(w, mDescriptor.startOfFrame ? 1 : 0, 1);
+	writeBits(w, mDescriptor.endOfFrame ? 1 : 0, 1);
+	uint32_t templateId =
+	    (bestTemplate.templatePosition + mStructure.templateIdOffset) % MaxTemplates;
+	writeBits(w, templateId, 6);
+	writeBits(w, mDescriptor.frameNumber, 16);
+
+	bool hasExtendedFields = bestTemplate.extraSizeBits > 0 ||
+	                         (mDescriptor.startOfFrame && mDescriptor.structureAttached) ||
+	                         mDescriptor.activeDecodeTargetsBitmask != std::nullopt;
+	if (hasExtendedFields) {
+		// extended_descriptor_fields()
+		bool templateDependencyStructurePresentFlag = mDescriptor.structureAttached;
+		writeBits(w, templateDependencyStructurePresentFlag ? 1 : 0, 1);
+		bool activeDecodeTargetsPresentFlag = std::invoke([&]() {
+			if (!mDescriptor.activeDecodeTargetsBitmask)
+				return false;
+			const uint64_t allDecodeTargetsBitmask = (1ULL << mStructure.decodeTargetCount) - 1;
+			if (mDescriptor.structureAttached &&
+			    mDescriptor.activeDecodeTargetsBitmask == allDecodeTargetsBitmask)
+				return false;
+			return true;
+		});
+		writeBits(w, activeDecodeTargetsPresentFlag ? 1 : 0, 1);
+		writeBits(w, bestTemplate.needCustomDtis ? 1 : 0, 1);
+		writeBits(w, bestTemplate.needCustomFdiffs ? 1 : 0, 1);
+		writeBits(w, bestTemplate.needCustomChains ? 1 : 0, 1);
+		if (templateDependencyStructurePresentFlag) {
+			// template_dependency_structure()
+			writeBits(w, mStructure.templateIdOffset, 6);
+			writeBits(w, mStructure.decodeTargetCount - 1, 5);
+
+			// template_layers()
+			const auto &templates = mStructure.templates;
+			assert(!templates.empty());
+			assert(templates.size() < MaxTemplates);
+			assert(templates[0].spatialId == 0);
+			assert(templates[0].temporalId == 0);
+			for (size_t i = 1; i < templates.size(); ++i) {
+				auto &prev = templates[i - 1];
+				auto &next = templates[i];
+
+				uint32_t nextLayerIdc;
+				if (next.spatialId == prev.spatialId && next.temporalId == prev.temporalId) {
+					// same layer
+					nextLayerIdc = 0;
+				} else if (next.spatialId == prev.spatialId &&
+				           next.temporalId == prev.temporalId + 1) {
+					// next temporal
+					nextLayerIdc = 1;
+				} else if (next.spatialId == prev.spatialId + 1 && next.temporalId == 0) {
+					// new spatial
+					nextLayerIdc = 2;
+				} else {
+					throw std::logic_error("Invalid layer");
+				}
+				writeBits(w, nextLayerIdc, 2);
+			}
+			// no more layers
+			writeBits(w, 3, 2);
+
+			// template_dtis()
+			for (const FrameDependencyTemplate &frameTemplate : mStructure.templates) {
+				assert(frameTemplate.decodeTargetIndications.size() ==
+				       static_cast<size_t>(mStructure.decodeTargetCount));
+				for (DecodeTargetIndication dti : frameTemplate.decodeTargetIndications) {
+					writeBits(w, static_cast<uint64_t>(dti), 2);
+				}
+			}
+
+			// template_fdiffs()
+			for (const FrameDependencyTemplate &frameTemplate : mStructure.templates) {
+				for (int fdiff : frameTemplate.frameDiffs) {
+					assert(fdiff - 1 >= 0);
+					assert(fdiff - 1 < (1 << 4));
+					writeBits(w, (1u << 4) | (fdiff - 1), 1 + 4);
+				}
+				// No more diffs for current template.
+				writeBits(w, 0, 1);
+			}
+
+			// template_chains()
+			assert(mStructure.chainCount >= 0);
+			assert(mStructure.chainCount <= mStructure.decodeTargetCount);
+			writeNonSymmetric(w, mStructure.chainCount, mStructure.decodeTargetCount + 1);
+			if (mStructure.chainCount != 0) {
+				assert(mStructure.decodeTargetProtectedBy.size() ==
+				       static_cast<size_t>(mStructure.decodeTargetCount));
+				for (int protectedBy : mStructure.decodeTargetProtectedBy) {
+					assert(protectedBy >= 0);
+					assert(protectedBy < mStructure.chainCount);
+					writeNonSymmetric(w, protectedBy, mStructure.chainCount);
+				}
+				for (const auto &frameTemplate : mStructure.templates) {
+					assert(frameTemplate.chainDiffs.size() ==
+					       static_cast<size_t>(mStructure.chainCount));
+					for (int chain_diff : frameTemplate.chainDiffs) {
+						assert(chain_diff >= 0);
+						assert(chain_diff < (1 << 4));
+						writeBits(w, chain_diff, 4);
+					}
+				}
+			}
+
+			bool hasResolutions = !mStructure.resolutions.empty();
+			writeBits(w, hasResolutions ? 1 : 0, 1);
+			if (hasResolutions) {
+				// render_resolutions()
+				assert(mStructure.resolutions.size() ==
+				       static_cast<size_t>(mStructure.templates.back().spatialId) + 1);
+				for (const RenderResolution &resolution : mStructure.resolutions) {
+					assert(resolution.width > 0);
+					assert(resolution.width <= (1 << 16));
+					assert(resolution.height > 0);
+					assert(resolution.height <= (1 << 16));
+
+					writeBits(w, resolution.width - 1, 16);
+					writeBits(w, resolution.height - 1, 16);
+				}
+			}
+		}
+		if (activeDecodeTargetsPresentFlag) {
+			writeBits(w, *mDescriptor.activeDecodeTargetsBitmask, mStructure.decodeTargetCount);
+		}
+	}
+
+	// frame_dependency_definition()
+	if (bestTemplate.needCustomDtis) {
+		// frame_dtis()
+		assert(mDescriptor.dependencyTemplate.decodeTargetIndications.size() ==
+		       static_cast<size_t>(mStructure.decodeTargetCount));
+		for (DecodeTargetIndication dti : mDescriptor.dependencyTemplate.decodeTargetIndications) {
+			writeBits(w, static_cast<uint32_t>(dti), 2);
+		}
+	}
+	if (bestTemplate.needCustomFdiffs) {
+		// frame_fdiffs()
+		for (int fdiff : mDescriptor.dependencyTemplate.frameDiffs) {
+			assert(fdiff > 0);
+			assert(fdiff <= (1 << 12));
+			if (fdiff <= (1 << 4)) {
+				writeBits(w, (1u << 4) | (fdiff - 1), 2 + 4);
+			} else if (fdiff <= (1 << 8)) {
+				writeBits(w, (2u << 8) | (fdiff - 1), 2 + 8);
+			} else { // fdiff <= (1 << 12)
+				writeBits(w, (3u << 12) | (fdiff - 1), 2 + 12);
+			}
+		}
+		// No more diffs.
+		writeBits(w, 0, 2);
+	}
+	if (bestTemplate.needCustomChains) {
+		// frame_chains()
+		for (int i = 0; i < mStructure.chainCount; ++i) {
+			int chainDiff = mActiveChains[i] ? mDescriptor.dependencyTemplate.chainDiffs[i] : 0;
+			assert(chainDiff >= 0);
+			assert(chainDiff < (1 << 8));
+			writeBits(w, chainDiff, 8);
+		}
+	}
+}
+void DependencyDescriptorWriter::writeBits(BitWriter &writer, uint64_t v, size_t bits) const {
+	if (!writer.write(v, bits)) {
+		throw std::logic_error("Failed to write bits");
+	}
+}
+void DependencyDescriptorWriter::writeNonSymmetric(BitWriter &writer, uint64_t v,
+                                                   uint64_t n) const {
+	if (!writer.writeNonSymmetric(v, n)) {
+		throw std::logic_error("Failed to write non-symmetric value");
+	}
+}
+
+} // namespace rtc

+ 30 - 6
src/rtp.cpp

@@ -130,22 +130,46 @@ void RtpExtensionHeader::setHeaderLength(uint16_t headerLength) {
 
 void RtpExtensionHeader::clearBody() { std::memset(getBody(), 0, getSize()); }
 
-void RtpExtensionHeader::writeOneByteHeader(size_t offset, uint8_t id, const byte *value,
-                                            size_t size) {
+size_t RtpExtensionHeader::writeOneByteHeader(size_t offset, uint8_t id, const byte *value,
+                                              size_t size) {
 	if ((id == 0) || (id > 14) || (size == 0) || (size > 16) || ((offset + 1 + size) > getSize()))
-		return;
+		return 0;
 	auto buf = getBody() + offset;
 	buf[0] = id << 4;
 	if (size != 1) {
 		buf[0] |= (uint8_t(size) - 1);
 	}
 	std::memcpy(buf + 1, value, size);
+	return 1 + size;
+}
+
+size_t RtpExtensionHeader::writeTwoByteHeader(size_t offset, uint8_t id, const byte *value,
+                                              size_t size) {
+	if ((id == 0) || (size > 255) || ((offset + 2 + size) > getSize()))
+		return 0;
+	auto buf = getBody() + offset;
+	buf[0] = id;
+	buf[1] = uint8_t(size);
+	std::memcpy(buf + 2, value, size);
+	return 2 + size;
 }
 
-void RtpExtensionHeader::writeCurrentVideoOrientation(size_t offset, const uint8_t id,
-                                                      uint8_t value) {
+size_t RtpExtensionHeader::writeCurrentVideoOrientation(bool twoByteHeader, size_t offset,
+                                                        const uint8_t id, uint8_t value) {
 	auto v = std::byte{value};
-	writeOneByteHeader(offset, id, &v, 1);
+	if (twoByteHeader) {
+		return writeTwoByteHeader(offset, id, &v, 1);
+	} else {
+		return writeOneByteHeader(offset, id, &v, 1);
+	}
+}
+size_t RtpExtensionHeader::writeHeader(bool twoByteHeader, size_t offset, uint8_t id,
+                                       const byte *value, size_t size) {
+	if (twoByteHeader) {
+		return writeTwoByteHeader(offset, id, value, size);
+	} else {
+		return writeOneByteHeader(offset, id, value, size);
+	}
 }
 
 SSRC RtcpReportBlock::getSSRC() const { return ntohl(_ssrc); }

+ 63 - 30
src/rtppacketizer.cpp

@@ -26,26 +26,50 @@ std::vector<binary> RtpPacketizer::fragment(binary data) {
 
 message_ptr RtpPacketizer::packetize(const binary &payload, bool mark) {
 	size_t rtpExtHeaderSize = 0;
+	bool twoByteHeader = false;
 
-	const bool setVideoRotation = (rtpConfig->videoOrientationId != 0) &&
-	                              (rtpConfig->videoOrientationId <
-	                               15) && // needs fixing if longer extension headers are supported
-	                              mark &&
-	                              (rtpConfig->videoOrientation != 0);
+	const bool setVideoRotation =
+	    (rtpConfig->videoOrientationId != 0) && mark && (rtpConfig->videoOrientation != 0);
+
+	std::optional<DependencyDescriptorWriter> ddWriter;
+	if (rtpConfig->dependencyDescriptorContext.has_value()) {
+		ddWriter.emplace(*rtpConfig->dependencyDescriptorContext);
+	}
+
+	// Determine if a two-byte header is necessary
+	// Check for dependency descriptor extension
+	if (ddWriter.has_value()) {
+		auto sizeBytes = ddWriter->getSize();
+		if (sizeBytes > 16 || rtpConfig->dependencyDescriptorId > 14) {
+			twoByteHeader = true;
+		}
+	}
+	// Check for other extensions
+	if ((setVideoRotation && rtpConfig->videoOrientationId > 14) ||
+	    (rtpConfig->mid.has_value() && rtpConfig->midId > 14) ||
+	    (rtpConfig->rid.has_value() && rtpConfig->ridId > 14) ||
+	    rtpConfig->playoutDelayId > 14) {
+		twoByteHeader = true;
+	}
+	size_t headerSize = twoByteHeader ? 2 : 1;
 
 	if (setVideoRotation)
-		rtpExtHeaderSize += 2;
+		rtpExtHeaderSize += headerSize + 1;
 
-	const bool setPlayoutDelay = (rtpConfig->playoutDelayId > 0 && rtpConfig->playoutDelayId < 15);
+	const bool setPlayoutDelay = rtpConfig->playoutDelayId > 0;
 
 	if (setPlayoutDelay)
-		rtpExtHeaderSize += 4;
+		rtpExtHeaderSize += headerSize + 3;
 
 	if (rtpConfig->mid.has_value())
-		rtpExtHeaderSize += (1 + rtpConfig->mid->length());
+		rtpExtHeaderSize += headerSize + rtpConfig->mid->length();
 
 	if (rtpConfig->rid.has_value())
-		rtpExtHeaderSize += (1 + rtpConfig->rid->length());
+		rtpExtHeaderSize += headerSize + rtpConfig->rid->length();
+
+	if (ddWriter.has_value()) {
+		rtpExtHeaderSize += headerSize + ddWriter->getSize();
+	}
 
 	if (rtpExtHeaderSize != 0)
 		rtpExtHeaderSize += 4;
@@ -67,7 +91,7 @@ message_ptr RtpPacketizer::packetize(const binary &payload, bool mark) {
 		rtp->setExtension(true);
 
 		auto extHeader = rtp->getExtensionHeader();
-		extHeader->setProfileSpecificId(0xbede);
+		extHeader->setProfileSpecificId(twoByteHeader ? 0x1000 : 0xbede);
 
 		auto headerLength = static_cast<uint16_t>(rtpExtHeaderSize / 4) - 1;
 
@@ -76,24 +100,30 @@ message_ptr RtpPacketizer::packetize(const binary &payload, bool mark) {
 
 		size_t offset = 0;
 		if (setVideoRotation) {
-			extHeader->writeCurrentVideoOrientation(offset, rtpConfig->videoOrientationId,
-			                                        rtpConfig->videoOrientation);
-			offset += 2;
+			offset += extHeader->writeCurrentVideoOrientation(
+			    twoByteHeader, offset, rtpConfig->videoOrientationId, rtpConfig->videoOrientation);
 		}
 
 		if (rtpConfig->mid.has_value()) {
-			extHeader->writeOneByteHeader(
-			    offset, rtpConfig->midId,
-			    reinterpret_cast<const std::byte *>(rtpConfig->mid->c_str()),
-			    rtpConfig->mid->length());
-			offset += (1 + rtpConfig->mid->length());
+			offset +=
+			    extHeader->writeHeader(twoByteHeader, offset, rtpConfig->midId,
+			                           reinterpret_cast<const std::byte *>(rtpConfig->mid->c_str()),
+			                           rtpConfig->mid->length());
 		}
 
 		if (rtpConfig->rid.has_value()) {
-			extHeader->writeOneByteHeader(
-			    offset, rtpConfig->ridId,
-			    reinterpret_cast<const std::byte *>(rtpConfig->rid->c_str()),
-			    rtpConfig->rid->length());
+			offset +=
+			    extHeader->writeHeader(twoByteHeader, offset, rtpConfig->ridId,
+			                           reinterpret_cast<const std::byte *>(rtpConfig->rid->c_str()),
+			                           rtpConfig->rid->length());
+		}
+
+		if (ddWriter.has_value()) {
+			auto sizeBytes = ddWriter->getSize();
+			std::vector<std::byte> buf(sizeBytes);
+			ddWriter->writeTo(buf.data(), sizeBytes);
+			offset += extHeader->writeHeader(
+			    twoByteHeader, offset, rtpConfig->dependencyDescriptorId, buf.data(), sizeBytes);
 		}
 
 		if (setPlayoutDelay) {
@@ -104,8 +134,8 @@ message_ptr RtpPacketizer::packetize(const binary &payload, bool mark) {
 			byte data[] = {byte((min >> 4) & 0xFF), byte(((min & 0xF) << 4) | ((max >> 8) & 0xF)),
 			               byte(max & 0xFF)};
 
-			extHeader->writeOneByteHeader(offset, rtpConfig->playoutDelayId, data, 3);
-			offset += 4;
+			offset += extHeader->writeHeader(
+			    twoByteHeader, offset, rtpConfig->playoutDelayId, data, 3);
 		}
 	}
 
@@ -140,11 +170,14 @@ void RtpPacketizer::outgoing(message_vector &messages,
 		}
 
 		auto payloads = fragment(std::move(*message));
-		if (payloads.size() > 0) {
-			for (size_t i = 0; i < payloads.size() - 1; i++)
-				result.push_back(packetize(payloads[i], false));
-
-			result.push_back(packetize(payloads[payloads.size() - 1], true));
+		for (size_t i = 0; i < payloads.size(); i++) {
+			if (rtpConfig->dependencyDescriptorContext.has_value()) {
+				auto &ctx = *rtpConfig->dependencyDescriptorContext;
+				ctx.descriptor.startOfFrame = i == 0;
+				ctx.descriptor.endOfFrame = i == payloads.size() - 1;
+			}
+			bool mark = i == payloads.size() - 1;
+			result.push_back(packetize(payloads[i], mark));
 		}
 	}