Browse Source

Merge pull request #176 from stazio/api_updates

Changing Behavior of RtcpSession and Many Bug Fixes
Paul-Louis Ageneau 4 years ago
parent
commit
d47492a54e

+ 3 - 0
.gitmodules

@@ -10,3 +10,6 @@
 [submodule "deps/json"]
 	path = deps/json
 	url = https://github.com/nlohmann/json.git
+[submodule "deps/libsrtp"]
+	path = deps/libsrtp
+	url = https://github.com/cisco/libsrtp.git

+ 12 - 2
CMakeLists.txt

@@ -102,7 +102,7 @@ set(TESTS_SOURCES
     ${CMAKE_CURRENT_SOURCE_DIR}/test/capi_track.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/websocket.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/benchmark.cpp
-)
+        include/rtc/rtp.hpp)
 
 set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
 set(THREADS_PREFER_PTHREAD_FLAG TRUE)
@@ -173,6 +173,15 @@ if(USE_SRTP STREQUAL "AUTO")
 	else()
 		message(STATUS "LibSRTP NOT found, compiling WITHOUT media transport")
 	endif()
+elseif (USE_SRTP STREQUAL "COMPILE")
+	message(STATUS "Compiling LibSRTP from source; compiling with media transport")
+	add_subdirectory(deps/libsrtp EXCLUDE_FROM_ALL)
+	target_compile_definitions(datachannel PUBLIC RTC_ENABLE_MEDIA=1)
+	target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_MEDIA=1)
+	target_compile_definitions(datachannel PUBLIC RTC_SRTP_FROM_SOURCE=1)
+	target_compile_definitions(datachannel-static PUBLIC RTC_SRTP_FROM_SOURCE=1)
+	target_link_libraries(datachannel PRIVATE srtp2)
+	target_link_libraries(datachannel-static PRIVATE srtp2)
 elseif(USE_SRTP)
 	find_package(SRTP REQUIRED)
 endif()
@@ -189,7 +198,7 @@ if(USE_SRTP AND SRTP_FOUND)
 	target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_MEDIA=1)
 	target_link_libraries(datachannel PRIVATE SRTP::SRTP)
 	target_link_libraries(datachannel-static PRIVATE SRTP::SRTP)
-else()
+elseif (NOT USE_SRTP  STREQUAL "COMPILE")
 	target_compile_definitions(datachannel PUBLIC RTC_ENABLE_MEDIA=0)
 	target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_MEDIA=0)
 endif()
@@ -296,6 +305,7 @@ if(NOT NO_EXAMPLES)
 	add_subdirectory(deps/json)
 	add_subdirectory(examples/client)
 	add_subdirectory(examples/media)
+	add_subdirectory(examples/sfu-media)
 	add_subdirectory(examples/copy-paste)
 	add_subdirectory(examples/copy-paste-capi)
 endif()

+ 1 - 0
deps/libsrtp

@@ -0,0 +1 @@
+Subproject commit 7d351de8177b33c96669bb79dc684a8dc64c2483

+ 2 - 2
examples/media/main.cpp

@@ -1,6 +1,6 @@
 /*
  * libdatachannel client example
- * Copyright (c) 2020 Staz M
+ * Copyright (c) 2020 Staz Modrzynski
  * Copyright (c) 2020 Paul-Louis Ageneau
  *
  * This program is free software; you can redistribute it and/or
@@ -67,7 +67,7 @@ int main() {
 
 		auto track = pc->addTrack(media);
 
-		auto session = std::make_shared<rtc::RtcpSession>();
+		auto session = std::make_shared<rtc::RtcpReceivingSession>();
 		track->setRtcpHandler(session);
 
 		track->onMessage(

+ 14 - 0
examples/sfu-media/CMakeLists.txt

@@ -0,0 +1,14 @@
+cmake_minimum_required(VERSION 3.7)
+
+add_executable(datachannel-sfu-media main.cpp)
+set_target_properties(datachannel-sfu-media PROPERTIES
+        CXX_STANDARD 17
+        OUTPUT_NAME sfu-media)
+
+if(WIN32)
+    target_link_libraries(datachannel-sfu-media datachannel-static) # DLL exports only the C API
+else()
+    target_link_libraries(datachannel-sfu-media datachannel)
+endif()
+
+target_link_libraries(datachannel-sfu-media datachannel nlohmann_json)

+ 131 - 0
examples/sfu-media/main.cpp

@@ -0,0 +1,131 @@
+/*
+ * libdatachannel client example
+ * Copyright (c) 2020 Staz Modrzynski
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This program is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU General Public License
+ * as published by the Free Software Foundation; either version 2
+ * of the License, or (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#define _WINSOCK_DEPRECATED_NO_WARNINGS
+
+#include "rtc/rtc.hpp"
+
+#include <iostream>
+#include <memory>
+
+#include <nlohmann/json.hpp>
+
+using nlohmann::json;
+
+struct Receiver {
+    std::shared_ptr<rtc::PeerConnection> conn;
+    std::shared_ptr<rtc::Track> track;
+};
+int main() {
+    std::vector<std::shared_ptr<Receiver>> receivers;
+
+	try {
+		rtc::InitLogger(rtc::LogLevel::Info);
+
+		auto pc = std::make_shared<rtc::PeerConnection>();
+		pc->onStateChange(
+		    [](rtc::PeerConnection::State state) { std::cout << "State: " << state << std::endl; });
+		pc->onGatheringStateChange([pc](rtc::PeerConnection::GatheringState state) {
+			std::cout << "Gathering State: " << state << std::endl;
+			if (state == rtc::PeerConnection::GatheringState::Complete) {
+				auto description = pc->localDescription();
+				json message = {{"type", description->typeString()},
+				                {"sdp", std::string(description.value())}};
+                std::cout << "Please copy/paste this offer to the SENDER: " << message << std::endl;
+			}
+		});
+
+		rtc::Description::Video media("video", rtc::Description::Direction::RecvOnly);
+		media.addH264Codec(96);
+		media.setBitrate(3000); // Request 3Mbps (Browsers do not encode more than 2.5MBps from a webcam)
+
+		auto track = pc->addTrack(media);
+        pc->setLocalDescription();
+
+		auto session = std::make_shared<rtc::RtcpReceivingSession>();
+		track->setRtcpHandler(session);
+
+		const rtc::SSRC targetSSRC = 4;
+
+		track->onMessage(
+		    [&receivers](rtc::binary message) {
+			    // This is an RTP packet
+			    auto rtp = (rtc::RTP*) message.data();
+			    rtp->setSsrc(targetSSRC);
+			    for (auto pc : receivers) {
+			        if (pc->track != nullptr && pc->track->isOpen()) {
+                        pc->track->send(message);
+                    }
+			    }
+		    },
+		    nullptr);
+
+        // Set the SENDERS Answer
+        {
+            std::cout << "Please copy/paste the answer provided by the SENDER: " << std::endl;
+            std::string sdp;
+            std::getline(std::cin, sdp);
+            std::cout << "Got answer" << sdp << std::endl;
+            json j = json::parse(sdp);
+            rtc::Description answer(j["sdp"].get<std::string>(), j["type"].get<std::string>());
+            pc->setRemoteDescription(answer);
+        }
+
+        // For each receiver
+		while (true) {
+            auto pc = std::make_shared<Receiver>();
+            pc->conn = std::make_shared<rtc::PeerConnection>();
+            pc->conn->onStateChange(
+                    [](rtc::PeerConnection::State state) { std::cout << "State: " << state << std::endl; });
+            pc->conn->onGatheringStateChange([pc](rtc::PeerConnection::GatheringState state) {
+                std::cout << "Gathering State: " << state << std::endl;
+                if (state == rtc::PeerConnection::GatheringState::Complete) {
+                    auto description = pc->conn->localDescription();
+                    json message = {{"type", description->typeString()},
+                                    {"sdp", std::string(description.value())}};
+                    std::cout << "Please copy/paste this offer to the RECEIVER: " << message << std::endl;
+                }
+            });
+            rtc::Description::Video media("video", rtc::Description::Direction::SendOnly);
+            media.addH264Codec(96);
+            media.setBitrate(
+                    3000); // Request 3Mbps (Browsers do not encode more than 2.5MBps from a webcam)
+
+            media.addSSRC(targetSSRC, "video-send");
+
+            pc->track = pc->conn->addTrack(media);
+            pc->conn->setLocalDescription();
+
+            pc->track->onMessage([](rtc::binary var){}, nullptr);
+
+            std::cout << "Please copy/paste the answer provided by the RECEIVER: " << std::endl;
+            std::string sdp;
+            std::getline(std::cin, sdp);
+            std::cout << "Got answer" << sdp << std::endl;
+            json j = json::parse(sdp);
+            rtc::Description answer(j["sdp"].get<std::string>(), j["type"].get<std::string>());
+            pc->conn->setRemoteDescription(answer);
+
+            receivers.push_back(pc);
+		}
+
+	} catch (const std::exception &e) {
+		std::cerr << "Error: " << e.what() << std::endl;
+	}
+}

+ 87 - 0
examples/sfu-media/main.html

@@ -0,0 +1,87 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <title>libdatachannel media example</title>
+</head>
+<body>
+
+<div style="display:inline-block; width:40%;">
+    <h1>SENDER</h1>
+    <p id="send-help">Please enter the offer provided to you by the application: </p>
+    <textarea style="width:100%;" id=send-text rows="50"></textarea>
+    <button id=send-btn>Submit</button>
+</div>
+<div style="display:inline-block; width:40%;">
+    <h1>RECEIVER</h1>
+    <p id="recv-help">Please enter the offer provided to you by the application: </p>
+    <textarea id=recv-text style="width:100%;" rows="50"></textarea>
+    <button id=recv-btn>Submit</button>
+</div>
+<div id="videos">
+
+</div>
+<script>
+    document.querySelector('#send-btn').addEventListener('click',  async () => {
+        let offer = JSON.parse(document.querySelector('#send-text').value);
+        rtc = new RTCPeerConnection({
+            // Recommended for libdatachannel
+            bundlePolicy: "max-bundle",
+        });
+
+        rtc.onicegatheringstatechange = (state) => {
+            if (rtc.iceGatheringState === 'complete') {
+                // We only want to provide an answer once all of our candidates have been added to the SDP.
+                let answer = rtc.localDescription;
+                document.querySelector('#send-text').value = JSON.stringify({"type": answer.type, sdp: answer.sdp});
+                document.querySelector('#send-help').value = 'Please paste the answer in the application.';
+                alert('Please paste the answer in the application.');
+            }
+        }
+        await rtc.setRemoteDescription(offer);
+
+        let media = await navigator.mediaDevices.getUserMedia({
+            video: {
+                width: 1280,
+                height: 720
+            }
+        });
+        media.getTracks().forEach(track => rtc.addTrack(track, media));
+        let answer = await rtc.createAnswer();
+        await rtc.setLocalDescription(answer);
+    });
+
+    document.querySelector('#recv-btn').addEventListener('click',  async () => {
+        let offer = JSON.parse(document.querySelector('#recv-text').value);
+        rtc = new RTCPeerConnection({
+            // Recommended for libdatachannel
+            bundlePolicy: "max-bundle",
+        });
+
+        rtc.onicegatheringstatechange = (state) => {
+            if (rtc.iceGatheringState === 'complete') {
+                // We only want to provide an answer once all of our candidates have been added to the SDP.
+                let answer = rtc.localDescription;
+                document.querySelector('#recv-text').value = JSON.stringify({"type": answer.type, sdp: answer.sdp});
+                document.querySelector('#recv-help').value = 'Please paste the answer in the application.';
+                alert('Please paste the answer in the application.');
+            }
+        }
+        let trackCount = 0;
+        rtc.ontrack = (ev) => {
+            let thisID = trackCount++;
+
+            document.querySelector("#videos").innerHTML += "<video width=100% height=100% id='video-" + thisID + "'></video>";
+            let tracks = [];
+            rtc.getReceivers().forEach(recv => tracks.push(recv.track));
+            document.querySelector("#video-" + thisID).srcObject = new MediaStream(tracks);
+            document.querySelector("#video-" + thisID).play();
+        };
+        await rtc.setRemoteDescription(offer);
+        let answer = await rtc.createAnswer();
+        await rtc.setLocalDescription(answer);
+    });
+</script>
+
+</body>
+</html>

+ 54 - 22
include/rtc/description.hpp

@@ -1,6 +1,6 @@
 /**
  * Copyright (c) 2019-2020 Paul-Louis Ageneau
- * Copyright (c) 2020 Staz M
+ * Copyright (c) 2020 Staz Modrzynski
  *
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
@@ -77,6 +77,12 @@ public:
 
 		virtual void parseSdpLine(string_view line);
 
+
+        std::vector<string>::iterator beginAttributes();
+        std::vector<string>::iterator endAttributes();
+        std::vector<string>::iterator removeAttribute(std::vector<string>::iterator iterator);
+
+
 	protected:
 		Entry(const string &mline, string mid, Direction dir = Direction::Unknown);
 		virtual string generateSdpLines(string_view eol) const;
@@ -124,12 +130,13 @@ public:
 		string description() const override;
 		Media reciprocate() const;
 
-		void removeFormat(const string &fmt);
+        void removeFormat(const string &fmt);
 
-		void addVideoCodec(int payloadType, const string &codec);
-		void addH264Codec(int payloadType);
-		void addVP8Codec(int payloadType);
-		void addVP9Codec(int payloadType);
+        void addSSRC(uint32_t ssrc, std::string name);
+        void addSSRC(uint32_t ssrc);
+        void replaceSSRC(uint32_t oldSSRC, uint32_t ssrc, string name);
+        bool hasSSRC(uint32_t ssrc);
+        std::vector<uint32_t> getSSRCs();
 
 		void setBitrate(int bitrate);
 		int getBitrate() const;
@@ -138,40 +145,65 @@ public:
 
 		virtual void parseSdpLine(string_view line) override;
 
-	private:
-		virtual string generateSdpLines(string_view eol) const override;
+        struct RTPMap {
+            RTPMap(string_view mline);
+            RTPMap() {}
 
-		int mBas = -1;
+            void removeFB(const string &string);
+            void addFB(const string &string);
+            void addAttribute(std::string attr) {
+                fmtps.emplace_back(attr);
+            }
+
+
+            int pt;
+            string format;
+            int clockRate;
+            string encParams;
+
+            std::vector<string> rtcpFbs;
+            std::vector<string> fmtps;
 
-		struct RTPMap {
-			RTPMap(string_view mline);
+            static int parsePT(string_view view);
+            void setMLine(string_view view);
+        };
 
-			void removeFB(const string &string);
-			void addFB(const string &string);
+        std::map<int, RTPMap>::iterator beginMaps();
+        std::map<int, RTPMap>::iterator endMaps();
+        std::map<int, RTPMap>::iterator removeMap(std::map<int, RTPMap>::iterator iterator);
 
-			int pt;
-			string format;
-			int clockRate;
-			string encParams;
+	private:
+		virtual string generateSdpLines(string_view eol) const override;
 
-			std::vector<string> rtcpFbs;
-			std::vector<string> fmtps;
-		};
+		int mBas = -1;
 
 		Media::RTPMap &getFormat(int fmt);
 		Media::RTPMap &getFormat(const string &fmt);
 
 		std::map<int, RTPMap> mRtpMap;
-	};
+		std::vector<uint32_t> mSsrcs;
+
+	public:
+        void addRTPMap(const RTPMap& map);
+
+    };
 
 	class Audio : public Media {
 	public:
 		Audio(string mid = "audio", Direction dir = Direction::SendOnly);
+
+        void addAudioCodec(int payloadType, const string &codec);
+        void addOpusCodec(int payloadType);
 	};
 
 	class Video : public Media {
 	public:
 		Video(string mid = "video", Direction dir = Direction::SendOnly);
+
+        void addVideoCodec(int payloadType, const string &codec);
+        void addH264Codec(int payloadType);
+        void addVP8Codec(int payloadType);
+        void addVP9Codec(int payloadType);
 	};
 
 	bool hasApplication() const;
@@ -186,7 +218,7 @@ public:
 
 	std::variant<Media *, Application *> media(int index);
 	std::variant<const Media *, const Application *> media(int index) const;
-	int mediaCount() const;
+	size_t mediaCount() const;
 
 	Application *application();
 

+ 7 - 1
include/rtc/peerconnection.hpp

@@ -102,6 +102,7 @@ public:
 	bool getSelectedCandidatePair(Candidate *local, Candidate *remote);
 
 	void setLocalDescription(Description::Type type = Description::Type::Unspec);
+
 	void setRemoteDescription(Description description);
 	void addRemoteCandidate(Candidate candidate);
 
@@ -138,6 +139,8 @@ private:
 	void forwardMessage(message_ptr message);
 	void forwardMedia(message_ptr message);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
+    std::optional<std::string> getMidFromSSRC(SSRC ssrc);
+    std::optional<uint32_t> getMLineFromSSRC(SSRC ssrc);
 
 	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, string label,
 	                                                DataChannelInit init);
@@ -150,6 +153,7 @@ private:
 	void incomingTrack(Description::Media description);
 	void openTracks();
 
+
 	void validateRemoteDescription(const Description &description);
 	void processLocalDescription(Description description);
 	void processLocalCandidate(Candidate candidate);
@@ -180,9 +184,11 @@ private:
 
 	std::unordered_map<uint16_t, std::weak_ptr<DataChannel>> mDataChannels;     // by stream ID
 	std::unordered_map<string, std::weak_ptr<Track>> mTracks;                   // by mid
+	std::vector<std::weak_ptr<Track>> mTrackLines;                              // by SDP order
 	std::shared_mutex mDataChannelsMutex, mTracksMutex;
 
-	std::unordered_map<unsigned int, string> mMidFromPayloadType; // cache
+	std::unordered_map<uint32_t, string> mMidFromSssrc; // cache
+    std::unordered_map<uint32_t , unsigned int> mMLineFromSssrc; // cache
 
 	std::atomic<State> mState;
 	std::atomic<GatheringState> mGatheringState;

+ 45 - 11
include/rtc/rtcp.hpp

@@ -1,5 +1,5 @@
 /**
- * Copyright (c) 2020 Staz M
+ * Copyright (c) 2020 Staz Modrzynski
  * Copyright (c) 2020 Paul-Louis Ageneau
  *
  * This library is free software; you can redistribute it and/or
@@ -20,35 +20,69 @@
 #ifndef RTC_RTCP_H
 #define RTC_RTCP_H
 
+#include <utility>
+
 #include "include.hpp"
 #include "log.hpp"
 #include "message.hpp"
+#include "rtp.hpp"
 
 namespace rtc {
 
-typedef uint32_t SSRC;
-
 class RtcpHandler {
+protected:
+    /**
+     * Use this callback when trying to send custom data (such as RTCP) to the client.
+     */
+    synchronized_callback<rtc::message_ptr> outgoingCallback;
 public:
-	virtual void onOutgoing(std::function<void(rtc::message_ptr)> cb) = 0;
-	virtual std::optional<rtc::message_ptr> incoming(rtc::message_ptr ptr) = 0;
+    /**
+     * Called when there is traffic coming from the peer
+     * @param ptr
+     * @return
+     */
+    virtual rtc::message_ptr incoming(rtc::message_ptr ptr) = 0;
+
+    /**
+     * Called when there is traffic that needs to be sent to the peer
+     * @param ptr
+     * @return
+     */
+    virtual rtc::message_ptr outgoing(rtc::message_ptr ptr) = 0;
+
+
+    /**
+     * This callback is used to send traffic back to the peer.
+     * This callback skips calling the track's methods.
+     * @param cb
+     */
+    void onOutgoing(const std::function<void(rtc::message_ptr)>& cb);
+
+    virtual bool requestKeyframe() {return false;}
+
 };
 
+class Track;
+
 // An RtcpSession can be plugged into a Track to handle the whole RTCP session
-class RtcpSession : public RtcpHandler {
+class RtcpReceivingSession : public RtcpHandler {
 public:
-	void onOutgoing(std::function<void(rtc::message_ptr)> cb) override;
 
-	std::optional<rtc::message_ptr> incoming(rtc::message_ptr ptr) override;
+    rtc::message_ptr incoming(rtc::message_ptr ptr) override;
+    rtc::message_ptr outgoing(rtc::message_ptr ptr) override;
+    bool send(rtc::message_ptr ptr);
+
 	void requestBitrate(unsigned int newBitrate);
 
-private:
+    bool requestKeyframe() override;
+
+protected:
 	void pushREMB(unsigned int bitrate);
 	void pushRR(unsigned int lastSR_delay);
-	void tx(message_ptr msg);
+
+    void pushPLI();
 
 	unsigned int mRequestedBitrate = 0;
-	synchronized_callback<rtc::message_ptr> mTxCallback;
 	SSRC mSsrc = 0;
 	uint32_t mGreatestSeqNo = 0;
 	uint64_t mSyncRTPTS, mSyncNTPTS;

+ 527 - 0
include/rtc/rtp.hpp

@@ -0,0 +1,527 @@
+/**
+ * Copyright (c) 2020 Staz Modrzynski
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef WEBRTC_SERVER_RTP_HPP
+#define WEBRTC_SERVER_RTP_HPP
+
+#include <cmath>
+#ifdef _WIN32
+#include <winsock2.h>
+#else
+#include <arpa/inet.h>
+#endif
+#include <rtc/log.hpp>
+
+#ifndef htonll
+#define htonll(x)                                                                                  \
+	((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32))
+#endif
+#ifndef ntohll
+#define ntohll(x) htonll(x)
+#endif
+
+namespace rtc {
+    typedef uint32_t SSRC;
+
+#pragma pack(push, 1)
+
+struct RTP {
+private:
+    uint8_t _first;
+    uint8_t _payloadType;
+    uint16_t _seqNumber;
+    uint32_t _timestamp;
+    SSRC _ssrc;
+
+public:
+    SSRC csrc[16];
+
+    inline uint8_t version() const { return _first >> 6; }
+    inline bool padding() const { return (_first >> 5) & 0x01; }
+    inline uint8_t csrcCount() const { return _first & 0x0F; }
+    inline uint8_t marker() const { return _payloadType & 0b10000000; }
+    inline uint8_t payloadType() const { return _payloadType & 0b01111111; }
+    inline uint16_t seqNumber() const { return ntohs(_seqNumber); }
+    inline uint32_t timestamp() const { return ntohl(_timestamp); }
+    inline uint32_t ssrc() const { return ntohl(_ssrc);}
+
+    inline size_t getSize() const {
+        return ((char*)&csrc) - ((char*)this) + sizeof(SSRC)*csrcCount();
+    }
+
+    char * getBody() const {
+        return ((char*) &csrc) + sizeof(SSRC)*csrcCount();
+    }
+
+    inline void setSeqNumber(uint16_t newSeqNo) {
+        _seqNumber = htons(newSeqNo);
+    }
+    inline void setPayloadType(uint8_t newPayloadType) {
+        _payloadType = (_payloadType & 0b10000000u) | (0b01111111u & newPayloadType);
+    }
+    inline void setSsrc(uint32_t ssrc) {
+        _ssrc = htonl(ssrc);
+    }
+
+    void setTimestamp(uint32_t i) {
+        _timestamp = htonl(i);
+    }
+};
+
+struct RTCP_ReportBlock {
+    SSRC ssrc;
+
+private:
+    uint32_t _fractionLostAndPacketsLost; // fraction lost is 8-bit, packets lost is 24-bit
+    uint16_t _seqNoCycles;
+    uint16_t _highestSeqNo;
+    uint32_t _jitter;
+    uint32_t _lastReport;
+    uint32_t _delaySinceLastReport;
+
+public:
+    inline void preparePacket(SSRC ssrc, [[maybe_unused]] unsigned int packetsLost,
+                              [[maybe_unused]] unsigned int totalPackets, uint16_t highestSeqNo,
+                              uint16_t seqNoCycles, uint32_t jitter, uint64_t lastSR_NTP,
+                              uint64_t lastSR_DELAY) {
+        setSeqNo(highestSeqNo, seqNoCycles);
+        setJitter(jitter);
+        setSSRC(ssrc);
+
+        // Middle 32 bits of NTP Timestamp
+        //		  this->lastReport = lastSR_NTP >> 16u;
+        setNTPOfSR(uint64_t(lastSR_NTP));
+        setDelaySinceSR(uint32_t(lastSR_DELAY));
+
+        // The delay, expressed in units of 1/65536 seconds
+        //		  this->delaySinceLastReport = lastSR_DELAY;
+    }
+
+    inline void setSSRC(SSRC ssrc) { this->ssrc = htonl(ssrc); }
+    inline SSRC getSSRC() const { return ntohl(ssrc); }
+
+    inline void setPacketsLost([[maybe_unused]] unsigned int packetsLost,
+                               [[maybe_unused]] unsigned int totalPackets) {
+        // TODO Implement loss percentages.
+        _fractionLostAndPacketsLost = 0;
+    }
+    inline unsigned int getLossPercentage() const {
+        // TODO Implement loss percentages.
+        return 0;
+    }
+    inline unsigned int getPacketLostCount() const {
+        // TODO Implement total packets lost.
+        return 0;
+    }
+
+    inline uint16_t seqNoCycles() const { return ntohs(_seqNoCycles); }
+    inline uint16_t highestSeqNo() const { return ntohs(_highestSeqNo); }
+    inline uint32_t jitter() const { return ntohl(_jitter); }
+
+    inline void setSeqNo(uint16_t highestSeqNo, uint16_t seqNoCycles) {
+        _highestSeqNo = htons(highestSeqNo);
+        _seqNoCycles = htons(seqNoCycles);
+    }
+
+    inline void setJitter(uint32_t jitter) { _jitter = htonl(jitter); }
+
+    inline void setNTPOfSR(uint64_t ntp) { _lastReport = htonll(ntp >> 16u); }
+    inline uint32_t getNTPOfSR() const { return ntohl(_lastReport) << 16u; }
+
+    inline void setDelaySinceSR(uint32_t sr) {
+        // The delay, expressed in units of 1/65536 seconds
+        _delaySinceLastReport = htonl(sr);
+    }
+    inline uint32_t getDelaySinceSR() const { return ntohl(_delaySinceLastReport); }
+
+    inline void log() const {
+        PLOG_VERBOSE << "RTCP report block: "
+                   << "ssrc="
+                   << ntohl(ssrc)
+                   // TODO: Implement these reports
+                   //	<< ", fractionLost=" << fractionLost
+                   //	<< ", packetsLost=" << packetsLost
+                   << ", highestSeqNo=" << highestSeqNo() << ", seqNoCycles=" << seqNoCycles()
+                   << ", jitter=" << jitter() << ", lastSR=" << getNTPOfSR()
+                   << ", lastSRDelay=" << getDelaySinceSR();
+    }
+};
+
+struct RTCP_HEADER {
+private:
+    uint8_t _first;
+    uint8_t _payloadType;
+    uint16_t _length;
+
+public:
+    inline uint8_t version() const { return _first >> 6; }
+    inline bool padding() const { return (_first >> 5) & 0x01; }
+    inline uint8_t reportCount() const { return _first & 0x0F; }
+    inline uint8_t payloadType() const { return _payloadType; }
+    inline uint16_t length() const { return ntohs(_length); }
+    inline size_t lengthInBytes() const {
+        return (1+length())*4;
+    }
+
+    inline void setPayloadType(uint8_t type) { _payloadType = type; }
+    inline void setReportCount(uint8_t count) { _first = (_first & 0b11100000u) | (count & 0b00011111u); }
+    inline void setLength(uint16_t length) { _length = htons(length); }
+
+    inline void prepareHeader(uint8_t payloadType, uint8_t reportCount, uint16_t length) {
+        _first = 0b10000000; // version 2, no padding
+        setReportCount(reportCount);
+        setPayloadType(payloadType);
+        setLength(length);
+    }
+
+    inline void log() const {
+        PLOG_INFO << "RTCP header: "
+                   << "version=" << unsigned(version()) << ", padding=" << padding()
+                   << ", reportCount=" << unsigned(reportCount())
+                   << ", payloadType=" << unsigned(payloadType()) << ", length=" << length();
+    }
+};
+
+struct RTCP_FB_HEADER {
+    RTCP_HEADER header;
+    SSRC packetSender;
+    SSRC mediaSource;
+
+    [[nodiscard]] SSRC getPacketSenderSSRC() const {
+        return ntohl(packetSender);
+    }
+
+    [[nodiscard]] SSRC getMediaSourceSSRC() const {
+        return ntohl(mediaSource);
+    }
+
+    void setPacketSenderSSRC(SSRC ssrc) {
+        this->packetSender = htonl(ssrc);
+    }
+
+    void setMediaSourceSSRC(SSRC ssrc) {
+        this->mediaSource = htonl(ssrc);
+    }
+
+    void log() {
+        header.log();
+        PLOG_VERBOSE << "FB: " << " packet sender: " << getPacketSenderSSRC() << " media source: " << getMediaSourceSSRC();
+    }
+};
+
+struct RTCP_SR {
+    RTCP_HEADER header;
+    SSRC _senderSSRC;
+
+private:
+    uint64_t _ntpTimestamp;
+    uint32_t _rtpTimestamp;
+    uint32_t _packetCount;
+    uint32_t _octetCount;
+
+    RTCP_ReportBlock _reportBlocks;
+
+public:
+    inline void preparePacket(SSRC senderSSRC, uint8_t reportCount) {
+        unsigned int length =
+                ((sizeof(header) + 24 + reportCount * sizeof(RTCP_ReportBlock)) / 4) - 1;
+        header.prepareHeader(200, reportCount, uint16_t(length));
+        this->_senderSSRC = htonl(senderSSRC);
+    }
+
+    inline RTCP_ReportBlock *getReportBlock(int num) { return &_reportBlocks + num; }
+    inline const RTCP_ReportBlock *getReportBlock(int num) const { return &_reportBlocks + num; }
+
+    [[nodiscard]] inline size_t getSize() const {
+        // "length" in packet is one less than the number of 32 bit words in the packet.
+        return sizeof(uint32_t) * (1 + size_t(header.length()));
+    }
+
+    inline uint64_t ntpTimestamp() const { return ntohll(_ntpTimestamp); }
+    inline uint32_t rtpTimestamp() const { return ntohl(_rtpTimestamp); }
+    inline uint32_t packetCount() const { return ntohl(_packetCount); }
+    inline uint32_t octetCount() const { return ntohl(_octetCount); }
+    inline uint32_t senderSSRC() const {return ntohl(_senderSSRC);}
+
+    inline void setNtpTimestamp(uint32_t ts) { _ntpTimestamp = htonll(ts); }
+    inline void setRtpTimestamp(uint32_t ts) { _rtpTimestamp = htonl(ts); }
+
+    inline void log() const {
+        header.log();
+        PLOG_VERBOSE << "RTCP SR: "
+                   << " SSRC=" << senderSSRC() << ", NTP_TS=" << ntpTimestamp()
+                   << ", RTP_TS=" << rtpTimestamp() << ", packetCount=" << packetCount()
+                   << ", octetCount=" << octetCount();
+
+        for (unsigned i = 0; i < unsigned(header.reportCount()); i++) {
+            getReportBlock(i)->log();
+        }
+    }
+};
+
+struct RTCP_RR {
+    RTCP_HEADER header;
+    SSRC _senderSSRC;
+
+private:
+    RTCP_ReportBlock _reportBlocks;
+
+public:
+    inline RTCP_ReportBlock *getReportBlock(int num) { return &_reportBlocks + num; }
+    inline const RTCP_ReportBlock *getReportBlock(int num) const { return &_reportBlocks + num; }
+
+    inline SSRC senderSSRC() const { return ntohl(_senderSSRC); }
+    inline void setSenderSSRC(SSRC ssrc) { this->_senderSSRC = htonl(ssrc); }
+
+    [[nodiscard]] inline size_t getSize() const {
+        // "length" in packet is one less than the number of 32 bit words in the packet.
+        return sizeof(uint32_t) * (1 + size_t(header.length()));
+    }
+
+    inline void preparePacket(SSRC senderSSRC, uint8_t reportCount) {
+        // "length" in packet is one less than the number of 32 bit words in the packet.
+        size_t length = (sizeWithReportBlocks(reportCount) / 4) - 1;
+        header.prepareHeader(201, reportCount, uint16_t(length));
+        this->_senderSSRC = htonl(senderSSRC);
+    }
+
+    inline static size_t sizeWithReportBlocks(uint8_t reportCount) {
+        return sizeof(header) + 4 + size_t(reportCount) * sizeof(RTCP_ReportBlock);
+    }
+
+    inline bool isSenderReport() {
+        return header.payloadType() == 200;
+    }
+
+    inline bool isReceiverReport() {
+        return header.payloadType() == 201;
+    }
+
+    inline void log() const {
+        header.log();
+        PLOG_VERBOSE << "RTCP RR: "
+                   << " SSRC=" << ntohl(_senderSSRC);
+
+        for (unsigned i = 0; i < unsigned(header.reportCount()); i++) {
+            getReportBlock(i)->log();
+        }
+    }
+};
+
+
+struct RTCP_REMB {
+    RTCP_FB_HEADER header;
+
+    /*! \brief Unique identifier ('R' 'E' 'M' 'B') */
+    char id[4];
+
+    /*! \brief Num SSRC, Br Exp, Br Mantissa (bit mask) */
+    uint32_t bitrate;
+
+    SSRC ssrc[1];
+
+    [[nodiscard]] unsigned int getSize() const {
+        // "length" in packet is one less than the number of 32 bit words in the packet.
+        return sizeof(uint32_t) * (1 + header.header.length());
+    }
+
+    void preparePacket(SSRC senderSSRC, unsigned int numSSRC, unsigned int bitrate) {
+
+        // Report Count becomes the format here.
+        header.header.prepareHeader(206, 15, 0);
+
+        // Always zero.
+        header.setMediaSourceSSRC(0);
+
+        header.setPacketSenderSSRC(senderSSRC);
+
+        id[0] = 'R';
+        id[1] = 'E';
+        id[2] = 'M';
+        id[3] = 'B';
+
+        setBitrate(numSSRC, bitrate);
+    }
+
+    void setBitrate(unsigned int numSSRC, unsigned int bitrate) {
+        unsigned int exp = 0;
+        while (bitrate > pow(2, 18) - 1) {
+            exp++;
+            bitrate /= 2;
+        }
+
+        // "length" in packet is one less than the number of 32 bit words in the packet.
+        header.header.setLength((offsetof(RTCP_REMB, ssrc) / sizeof(uint32_t)) - 1 + numSSRC);
+
+        this->bitrate = htonl(
+                (numSSRC << (32u - 8u)) | (exp << (32u - 8u - 6u)) | bitrate
+        );
+    }
+
+    void setSsrc(int iterator, SSRC newSssrc){
+        ssrc[iterator] = htonl(newSssrc);
+    }
+    
+    size_t static inline sizeWithSSRCs(int count) {
+        return sizeof(RTCP_REMB) + (count-1)*sizeof(SSRC);
+    }
+};
+
+
+struct RTCP_PLI {
+    RTCP_FB_HEADER header;
+
+    void preparePacket(SSRC messageSSRC) {
+        header.header.prepareHeader(206, 1, 2);
+        header.setPacketSenderSSRC(messageSSRC);
+        header.setMediaSourceSSRC(messageSSRC);
+    }
+
+    void print() {
+        header.log();
+    }
+
+    [[nodiscard]] static unsigned int size() {
+        return sizeof(RTCP_FB_HEADER);
+    }
+};
+
+struct RTCP_FIR_PART {
+    uint32_t ssrc;
+#if __BYTE_ORDER == __BIG_ENDIAN
+    uint32_t seqNo: 8;
+    uint32_t: 24;
+#elif __BYTE_ORDER == __LITTLE_ENDIAN
+    uint32_t: 24;
+    uint32_t seqNo: 8;
+#endif
+};
+
+struct RTCP_FIR {
+    RTCP_FB_HEADER header;
+    RTCP_FIR_PART parts[1];
+
+    void preparePacket(SSRC messageSSRC, uint8_t seqNo) {
+        header.header.prepareHeader(206, 4, 2 + 2 * 1);
+        header.setPacketSenderSSRC(messageSSRC);
+        header.setMediaSourceSSRC(messageSSRC);
+        parts[0].ssrc = htonl(messageSSRC);
+        parts[0].seqNo = seqNo;
+    }
+
+    void print() {
+        header.log();
+    }
+
+    [[nodiscard]] static unsigned int size() {
+        return sizeof(RTCP_FB_HEADER) + sizeof(RTCP_FIR_PART);
+    }
+};
+
+struct RTCP_NACK_PART {
+    uint16_t pid;
+    uint16_t blp;
+};
+
+class RTCP_NACK {
+public:
+    RTCP_FB_HEADER header;
+    RTCP_NACK_PART parts[1];
+public:
+    void preparePacket(SSRC ssrc, unsigned int discreteSeqNoCount) {
+        header.header.prepareHeader(205, 1, 2 + discreteSeqNoCount);
+        header.setMediaSourceSSRC(ssrc);
+        header.setPacketSenderSSRC(ssrc);
+    }
+
+    /**
+     * Add a packet to the list of missing packets.
+     * @param fciCount The number of FCI fields that are present in this packet.
+     *                  Let the number start at zero and let this function grow the number.
+     * @param fciPID The seq no of the active FCI. It will be initialized automatically, and will change automatically.
+     * @param missingPacket The seq no of the missing packet. This will be added to the queue.
+     * @return true if the packet has grown, false otherwise.
+     */
+    bool addMissingPacket(unsigned int *fciCount, uint16_t *fciPID, const uint16_t &missingPacket) {
+        if (*fciCount == 0 || missingPacket < *fciPID || missingPacket > (*fciPID + 16)) {
+            parts[*fciCount].pid = htons(missingPacket);
+            parts[*fciCount].blp = 0;
+            *fciPID = missingPacket;
+            (*fciCount)++;
+            return true;
+        } else {
+            // TODO SPEEED!
+            parts[(*fciCount) - 1].blp = htons(
+                    ntohs(parts[(*fciCount) - 1].blp) | (1u << (unsigned int) (missingPacket - *fciPID)));
+            return false;
+        }
+    }
+
+    [[nodiscard]] static unsigned int getSize(unsigned int discreteSeqNoCount) {
+        return offsetof(RTCP_NACK, parts) + sizeof(RTCP_NACK_PART) * discreteSeqNoCount;
+    }
+
+    [[nodiscard]] unsigned int getSeqNoCount() {
+        return header.header.length() - 2;
+    }
+};
+
+class RTP_RTX {
+private:
+    RTP header;
+public:
+
+    size_t copyTo(RTP *dest, size_t totalSize, uint8_t originalPayloadType) {
+        memmove((char*)dest, (char*)this, header.getSize());
+        dest->setSeqNumber(getOriginalSeqNo());
+        dest->setPayloadType(originalPayloadType);
+        memmove(dest->getBody(), getBody(), getBodySize(totalSize));
+        return totalSize;
+    }
+
+    [[nodiscard]] uint16_t getOriginalSeqNo() const {
+        return ntohs(*(uint16_t *) (header.getBody()));
+    }
+
+    char *getBody() {
+        return header.getBody() + sizeof(uint16_t);
+    }
+
+    size_t getBodySize(size_t totalSize) {
+        return totalSize - ((char *) getBody() - (char *) this);
+    }
+
+    RTP &getHeader() {
+        return header;
+    }
+
+    size_t normalizePacket(size_t totalSize, SSRC originalSSRC, uint8_t originalPayloadType) {
+        header.setSeqNumber(getOriginalSeqNo());
+        header.setSsrc(originalSSRC); // TODO Endianess
+        header.setPayloadType(originalPayloadType);
+        // TODO, the -12 is the size of the header (which is variable!)
+        memmove(header.getBody(), header.getBody() + sizeof(uint16_t), totalSize - 12 - sizeof(uint16_t));
+        return totalSize - sizeof(uint16_t);
+    }
+};
+
+
+#pragma pack(pop)
+};
+#endif //WEBRTC_SERVER_RTP_HPP

+ 4 - 0
include/rtc/track.hpp

@@ -58,8 +58,11 @@ public:
 	std::optional<message_variant> receive() override;
 	std::optional<message_variant> peek() override;
 
+	bool requestKeyframe();
+
 	// RTCP handler
 	void setRtcpHandler(std::shared_ptr<RtcpHandler> handler);
+	std::shared_ptr<RtcpHandler> getRtcpHandler();
 
 private:
 #if RTC_ENABLE_MEDIA
@@ -77,6 +80,7 @@ private:
 	std::shared_ptr<RtcpHandler> mRtcpHandler;
 
 	friend class PeerConnection;
+
 };
 
 } // namespace rtc

+ 161 - 46
src/description.cpp

@@ -1,7 +1,6 @@
 /**
  * Copyright (c) 2019-2020 Paul-Louis Ageneau
- * Copyright (c) 2020 Staz M
- * Copyright (c) 2020 Filip Klembara (in2core)
+ * Copyright (c) 2020 Staz Modrzynski
  *
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
@@ -26,8 +25,8 @@
 #include <chrono>
 #include <iostream>
 #include <random>
-#include <sstream>
-#include <unordered_map>
+//#include <sstream>
+//#include <unordered_map>
 
 using std::shared_ptr;
 using std::size_t;
@@ -423,7 +422,7 @@ Description::media(int index) const {
 	}
 }
 
-int Description::mediaCount() const { return int(mEntries.size()); }
+size_t Description::mediaCount() const { return mEntries.size(); }
 
 Description::Entry::Entry(const string &mline, string mid, Direction dir)
     : mMid(std::move(mid)), mDirection(dir) {
@@ -433,6 +432,7 @@ Description::Entry::Entry(const string &mline, string mid, Direction dir)
 	ss >> mType;
 	ss >> port; // ignored
 	ss >> mDescription;
+
 }
 
 void Description::Entry::setDirection(Direction dir) { mDirection = dir; }
@@ -471,8 +471,10 @@ string Description::Entry::generateSdpLines(string_view eol) const {
 		break;
 	}
 
-	for (const auto &attr : mAttributes)
-		sdp << "a=" << attr << eol;
+	for (const auto &attr : mAttributes) {
+	    if (attr.find("extmap") == std::string::npos && attr.find("rtcp-rsize") == std::string::npos)
+            sdp << "a=" << attr << eol;
+    }
 
 	return sdp.str();
 }
@@ -498,6 +500,39 @@ void Description::Entry::parseSdpLine(string_view line) {
 			mAttributes.emplace_back(line.substr(2));
 	}
 }
+std::vector< string>::iterator Description::Entry::beginAttributes() {
+    return mAttributes.begin();
+}
+std::vector< string>::iterator Description::Entry::endAttributes() {
+    return mAttributes.end();
+}
+std::vector< string>::iterator Description::Entry::removeAttribute(std::vector<string>::iterator it) {
+    return mAttributes.erase(it);
+}
+
+void Description::Media::addSSRC(uint32_t ssrc, std::string name) {
+    mAttributes.emplace_back("ssrc:" + std::to_string(ssrc) + " cname:" + name);
+    mSsrcs.emplace_back(ssrc);
+}
+
+void Description::Media::replaceSSRC(uint32_t oldSSRC, uint32_t ssrc, std::string name) {
+    auto it = mAttributes.begin();
+    while (it != mAttributes.end()) {
+        if (it->find("ssrc:" + std::to_string(oldSSRC)) == 0) {
+            it = mAttributes.erase(it);
+        }else
+            it++;
+    }
+    mAttributes.emplace_back("ssrc:" + std::to_string(ssrc) + " cname:" + name);
+}
+
+void Description::Media::addSSRC(uint32_t ssrc) {
+    mAttributes.emplace_back("ssrc:" + std::to_string(ssrc));
+}
+
+bool Description::Media::hasSSRC(uint32_t ssrc) {
+    return std::find(mSsrcs.begin(), mSsrcs.end(), ssrc) != mSsrcs.end();
+}
 
 Description::Application::Application(string mid)
     : Entry("application 9 UDP/DTLS/SCTP", std::move(mid), Direction::SendRecv) {}
@@ -645,24 +680,57 @@ void Description::Media::removeFormat(const string &fmt) {
 	}
 }
 
-void Description::Media::addVideoCodec(int payloadType, const string &codec) {
+void Description::Video::addVideoCodec(int payloadType, const string &codec) {
 	RTPMap map(std::to_string(payloadType) + ' ' + codec + "/90000");
-	map.addFB("nack");
+    map.addFB("nack");
+    map.addFB("nack pli");
+//    map.addFB("nack fir");
 	map.addFB("goog-remb");
 	if (codec == "H264") {
 		// Use Constrained Baseline profile Level 4.2 (necessary for Firefox)
 		// https://developer.mozilla.org/en-US/docs/Web/Media/Formats/WebRTC_codecs#Supported_video_codecs
 		// TODO: Should be 42E0 but 42C0 appears to be more compatible. Investigate this.
-		map.fmtps.emplace_back("profile-level-id=42E02A;level-asymmetry-allowed=1");
+		map.fmtps.emplace_back("profile-level-id=4de01f;packetization-mode=1;level-asymmetry-allowed=1");
+		
+		// Because certain Android devices don't like me, let us just negotiate some random
+		{
+			RTPMap map(std::to_string(payloadType+1) + ' ' + codec + "/90000");
+			map.addFB("nack");
+            map.addFB("nack pli");
+//            map.addFB("nack fir");
+			map.addFB("goog-remb");
+			addRTPMap(map);
+			}
 	}
-	mRtpMap.emplace(map.pt, map);
+	addRTPMap(map);
+
+
+//	// RTX Packets
+/* TODO
+ *  TIL that Firefox does not properly support the negotiation of RTX! It works, but doesn't negotiate the SSRC so
+ *  we have no idea what SSRC is RTX going to be. Three solutions:
+ *  One) we don't negotitate it and (maybe) break RTX support with Edge.
+ *  Two) we do negotiate it and rebuild the original packet before we send it distribute it to each track.
+ *  Three) we complain to mozilla. This one probably won't do much.
+*/
+//    RTPMap rtx(std::to_string(payloadType+1) + " rtx/90000");
+//    // TODO rtx-time is how long can a request be stashed for before needing to resend it. Needs to be parameterized
+//    rtx.addAttribute("apt=" + std::to_string(payloadType) + ";rtx-time=3000");
+//    addRTPMap(rtx);
+}
+
+void Description::Audio::addAudioCodec(int payloadType, const string &codec) {
+    // TODO This 48000/2 should be parameterized
+    RTPMap map(std::to_string(payloadType) + ' ' + codec + "/48000/2");
+    map.fmtps.emplace_back("maxaveragebitrate=96000; stereo=1; sprop-stereo=1; useinbandfec=1");
+    addRTPMap(map);
 }
 
-void Description::Media::addH264Codec(int pt) { addVideoCodec(pt, "H264"); }
+void Description::Video::addH264Codec(int pt) { addVideoCodec(pt, "H264"); }
 
-void Description::Media::addVP8Codec(int payloadType) { addVideoCodec(payloadType, "VP8"); }
+void Description::Video::addVP8Codec(int payloadType) { addVideoCodec(payloadType, "VP8"); }
 
-void Description::Media::addVP9Codec(int payloadType) { addVideoCodec(payloadType, "VP9"); }
+void Description::Video::addVP9Codec(int payloadType) { addVideoCodec(payloadType, "VP9"); }
 
 void Description::Media::setBitrate(int bitrate) { mBas = bitrate; }
 
@@ -689,8 +757,10 @@ string Description::Media::generateSdpLines(string_view eol) const {
 			sdp << '/' << map.encParams;
 		sdp << eol;
 
-		for (const auto &val : map.rtcpFbs)
-			sdp << "a=rtcp-fb:" << map.pt << ' ' << val << eol;
+		for (const auto &val : map.rtcpFbs) {
+		    if (val != "transport-cc" )
+                sdp << "a=rtcp-fb:" << map.pt << ' ' << val << eol;
+        }
 		for (const auto &val : map.fmtps)
 			sdp << "a=fmtp:" << map.pt << ' ' << val << eol;
 	}
@@ -704,29 +774,32 @@ void Description::Media::parseSdpLine(string_view line) {
 		auto [key, value] = parse_pair(attr);
 
 		if (key == "rtpmap") {
-			Description::Media::RTPMap map(value);
-			int pt = map.pt;
-			mRtpMap.emplace(pt, std::move(map));
+		    auto pt = Description::Media::RTPMap::parsePT(value);
+            auto it = mRtpMap.find(pt);
+            if (it == mRtpMap.end()) {
+                it = mRtpMap.insert(std::make_pair(pt, Description::Media::RTPMap(value))).first;
+            }else {
+                it->second.setMLine(value);
+            }
 		} else if (key == "rtcp-fb") {
 			size_t p = value.find(' ');
 			int pt = to_integer<int>(value.substr(0, p));
 			auto it = mRtpMap.find(pt);
 			if (it == mRtpMap.end()) {
-				PLOG_WARNING << "rtcp-fb applied before the corresponding rtpmap, ignoring";
-			} else {
-				it->second.rtcpFbs.emplace_back(value.substr(p + 1));
-			}
+			    it = mRtpMap.insert(std::make_pair(pt, Description::Media::RTPMap())).first;
+            }
+            it->second.rtcpFbs.emplace_back(value.substr(p + 1));
 		} else if (key == "fmtp") {
 			size_t p = value.find(' ');
 			int pt = to_integer<int>(value.substr(0, p));
 			auto it = mRtpMap.find(pt);
-			if (it == mRtpMap.end()) {
-				PLOG_WARNING << "fmtp applied before the corresponding rtpmap, ignoring";
-			} else {
-				it->second.fmtps.emplace_back(value.substr(p + 1));
-			}
+			if (it == mRtpMap.end())
+                it = mRtpMap.insert(std::make_pair(pt, Description::Media::RTPMap())).first;
+            it->second.fmtps.emplace_back(value.substr(p + 1));
 		} else if (key == "rtcp-mux") {
-			// always added
+            // always added
+        }else if (key == "ssrc") {
+		    mSsrcs.emplace_back(std::stoul((std::string)value));
 		} else {
 			Entry::parseSdpLine(line);
 		}
@@ -737,26 +810,36 @@ void Description::Media::parseSdpLine(string_view line) {
 	}
 }
 
-Description::Media::RTPMap::RTPMap(string_view mline) {
-	size_t p = mline.find(' ');
+void Description::Media::addRTPMap(const Description::Media::RTPMap& map) {
+    mRtpMap.emplace(map.pt, map);
+}
 
-	this->pt = to_integer<int>(mline.substr(0, p));
+std::vector<uint32_t> Description::Media::getSSRCs() {
+    std::vector<uint32_t> vec;
+    for (auto &val : mAttributes) {
+        PLOG_DEBUG << val;
+        if (val.find("ssrc:") == 0) {
+            vec.emplace_back(std::stoul((std::string)val.substr(5, val.find(" "))));
+        }
+    }
+    return vec;
+}
 
-	string_view line = mline.substr(p + 1);
-	size_t spl = line.find('/');
-	this->format = line.substr(0, spl);
+std::map<int, Description::Media::RTPMap>::iterator Description::Media::beginMaps() {
+    return mRtpMap.begin();
+}
 
-	line = line.substr(spl + 1);
-	spl = line.find('/');
-	if (spl == string::npos) {
-		spl = line.find(' ');
-	}
-	if (spl == string::npos)
-		this->clockRate = to_integer<int>(line);
-	else {
-		this->clockRate = to_integer<int>(line.substr(0, spl));
-		this->encParams = line.substr(spl + 1);
-	}
+std::map<int, Description::Media::RTPMap>::iterator Description::Media::endMaps() {
+    return mRtpMap.end();
+}
+
+std::map<int, Description::Media::RTPMap>::iterator
+Description::Media::removeMap(std::map<int, Description::Media::RTPMap>::iterator iterator) {
+    return mRtpMap.erase(iterator);
+}
+
+Description::Media::RTPMap::RTPMap(string_view mline) {
+    setMLine(mline);
 }
 
 void Description::Media::RTPMap::removeFB(const string &str) {
@@ -771,9 +854,41 @@ void Description::Media::RTPMap::removeFB(const string &str) {
 
 void Description::Media::RTPMap::addFB(const string &str) { rtcpFbs.emplace_back(str); }
 
+int Description::Media::RTPMap::parsePT(string_view view) {
+    size_t p = view.find(' ');
+
+    return to_integer<int>(view.substr(0, p));
+}
+
+void Description::Media::RTPMap::setMLine(string_view mline) {
+    size_t p = mline.find(' ');
+
+    this->pt = to_integer<int>(mline.substr(0, p));
+
+    string_view line = mline.substr(p + 1);
+    size_t spl = line.find('/');
+    this->format = line.substr(0, spl);
+
+    line = line.substr(spl + 1);
+    spl = line.find('/');
+    if (spl == string::npos) {
+        spl = line.find(' ');
+    }
+    if (spl == string::npos)
+        this->clockRate = to_integer<int>(line);
+    else {
+        this->clockRate = to_integer<int>(line.substr(0, spl));
+        this->encParams = line.substr(spl+1);
+    }
+}
+
 Description::Audio::Audio(string mid, Direction dir)
     : Media("audio 9 UDP/TLS/RTP/SAVPF", std::move(mid), dir) {}
 
+void Description::Audio::addOpusCodec(int payloadType) {
+    addAudioCodec(payloadType, "OPUS");
+}
+
 Description::Video::Video(string mid, Direction dir)
     : Media("video 9 UDP/TLS/RTP/SAVPF", std::move(mid), dir) {}
 

+ 99 - 42
src/dtlssrtptransport.cpp

@@ -109,15 +109,31 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 		if (srtp_err_status_t err = srtp_protect_rtcp(mSrtpOut, message->data(), &size)) {
 			if (err == srtp_err_status_replay_fail)
 				throw std::runtime_error("SRTCP packet is a replay");
-			else
-				throw std::runtime_error("SRTCP protect error, status=" +
-				                         to_string(static_cast<int>(err)));
+			else if (err == srtp_err_status_no_ctx) {
+			    auto ssrc = ((RTCP_SR*) message->data())->senderSSRC();
+			    PLOG_INFO << "Adding SSRC to SRTCP: " << ssrc;
+			    addSSRC(ssrc);
+                if ((err = srtp_protect_rtcp(mSrtpOut, message->data(), &size)))
+                    throw std::runtime_error("SRTCP protect error, status=" +
+                                             to_string(static_cast<int>(err)));
+            }else {
+		throw std::runtime_error("SRTCP protect error, status=" +
+					 to_string(static_cast<int>(err)));
+	    }
 		}
 		PLOG_VERBOSE << "Protected SRTCP packet, size=" << size;
 	} else {
 		if (srtp_err_status_t err = srtp_protect(mSrtpOut, message->data(), &size)) {
 			if (err == srtp_err_status_replay_fail)
-				throw std::runtime_error("SRTP packet is a replay");
+				throw std::runtime_error("Outgoing SRTP packet is a replay");
+            else if (err == srtp_err_status_no_ctx) {
+                auto ssrc = ((RTP*) message->data())->ssrc();
+                PLOG_INFO << "Adding SSRC to RTP: " << ssrc;
+                addSSRC(ssrc);
+                if ((err = srtp_protect_rtcp(mSrtpOut, message->data(), &size)))
+                    throw std::runtime_error("SRTCP protect error, status=" +
+                                             to_string(static_cast<int>(err)));
+            }
 			else
 				throw std::runtime_error("SRTP protect error, status=" +
 				                         to_string(static_cast<int>(err)));
@@ -127,7 +143,6 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 
 	message->resize(size);
 	return outgoing(message);
-//	return DtlsTransport::send(message);
 }
 
 void DtlsSrtpTransport::incoming(message_ptr message) {
@@ -174,13 +189,24 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 					PLOG_WARNING << "Incoming SRTCP packet is a replay";
 				else if (err == srtp_err_status_auth_fail)
 					PLOG_WARNING << "Incoming SRTCP packet failed authentication check";
-				else
-					PLOG_WARNING << "SRTCP unprotect error, status=" << err;
-				return;
+                else if (err == srtp_err_status_no_ctx) {
+                    auto ssrc = ((RTCP_SR*) message->data())->senderSSRC();
+                    PLOG_INFO << "Adding SSRC to RTCP: " << ssrc;
+                    addSSRC(ssrc);
+                    if ((err = srtp_unprotect_rtcp(mSrtpIn, message->data(), &size)))
+                        throw std::runtime_error("SRTCP unprotect error, status=" +
+                                                 to_string(static_cast<int>(err)));
+                }
+				else {
+                    PLOG_WARNING << "SRTCP unprotect error, status=" << err << " SSRC="
+                                 << ((RTCP_SR *) message->data())->senderSSRC();
+                }
+                return;
 			}
 			PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size;
 			message->type = Message::Type::Control;
-			message->stream = to_integer<uint8_t>(*(message->begin() + 1)); // Payload Type
+            auto rtp = (RTCP_SR*) message->data();
+			message->stream = rtp->senderSSRC();
 		} else {
 			PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
 			if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
@@ -188,13 +214,22 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 					PLOG_WARNING << "Incoming SRTP packet is a replay";
 				else if (err == srtp_err_status_auth_fail)
 					PLOG_WARNING << "Incoming SRTP packet failed authentication check";
+                else if (err == srtp_err_status_no_ctx) {
+                    auto ssrc = ((RTP*) message->data())->ssrc();
+                    PLOG_INFO << "Adding SSRC to RTP: " << ssrc;
+                    addSSRC(ssrc);
+                    if ((err = srtp_unprotect(mSrtpIn, message->data(), &size)))
+                        throw std::runtime_error("SRTCP unprotect error, status=" +
+                                                 to_string(static_cast<int>(err)));
+                }
 				else
-					PLOG_WARNING << "SRTP unprotect error, status=" << err;
+					PLOG_WARNING << "SRTP unprotect error, status=" << err << " SSRC=" << ((RTP*)message->data())->ssrc();
 				return;
 			}
 			PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
 			message->type = Message::Type::Binary;
-			message->stream = value2; // Payload Type
+			auto rtp = (RTP*) message->data();
+			message->stream = rtp->ssrc();
 		}
 
 		message->resize(size);
@@ -209,6 +244,8 @@ void DtlsSrtpTransport::postHandshake() {
 	if (mInitDone)
 		return;
 
+	static_assert(SRTP_AES_ICM_128_KEY_LEN_WSALT == SRTP_AES_128_KEY_LEN + SRTP_SALT_LEN);
+
 	const size_t materialLen = SRTP_AES_ICM_128_KEY_LEN_WSALT * 2;
 	unsigned char material[materialLen];
 	const unsigned char *clientKey, *clientSalt, *serverKey, *serverSalt;
@@ -257,41 +294,61 @@ void DtlsSrtpTransport::postHandshake() {
 	serverSalt = clientSalt + SRTP_SALT_LEN;
 #endif
 
-	unsigned char clientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
-	std::memcpy(clientSessionKey, clientKey, SRTP_AES_128_KEY_LEN);
-	std::memcpy(clientSessionKey + SRTP_AES_128_KEY_LEN, clientSalt, SRTP_SALT_LEN);
-
-	unsigned char serverSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
-	std::memcpy(serverSessionKey, serverKey, SRTP_AES_128_KEY_LEN);
-	std::memcpy(serverSessionKey + SRTP_AES_128_KEY_LEN, serverSalt, SRTP_SALT_LEN);
-
-	srtp_policy_t inbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
-	inbound.ssrc.type = ssrc_any_inbound;
-	inbound.ssrc.value = 0;
-	inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
-	inbound.next = nullptr;
-
-	if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
-		throw std::runtime_error("SRTP add inbound stream failed, status=" +
-		                         to_string(static_cast<int>(err)));
-
-	srtp_policy_t outbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
-	outbound.ssrc.type = ssrc_any_outbound;
-	outbound.ssrc.value = 0;
-	outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
-	outbound.next = nullptr;
-
-	if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
-		throw std::runtime_error("SRTP add outbound stream failed, status=" +
-		                         to_string(static_cast<int>(err)));
+	std::memcpy(mClientSessionKey, clientKey, SRTP_AES_128_KEY_LEN);
+	std::memcpy(mClientSessionKey + SRTP_AES_128_KEY_LEN, clientSalt, SRTP_SALT_LEN);
+
+	std::memcpy(mServerSessionKey, serverKey, SRTP_AES_128_KEY_LEN);
+	std::memcpy(mServerSessionKey + SRTP_AES_128_KEY_LEN, serverSalt, SRTP_SALT_LEN);
+
+	// Add SSRC=1 as an inbound because that is what Chrome does.
+    srtp_policy_t inbound = {};
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+    inbound.ssrc.type = ssrc_specific;
+    inbound.ssrc.value = 1;
+    inbound.key = mIsClient ? mServerSessionKey : mClientSessionKey;
+    inbound.next = nullptr;
+
+    if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound)) {
+        throw std::runtime_error("SRTP add inbound stream failed, status=" +
+                                 to_string(static_cast<int>(err)));
+    }
 
 	mInitDone = true;
 }
 
+void DtlsSrtpTransport::addSSRC(uint32_t ssrc) {
+	if (!mInitDone)
+		throw std::logic_error("Attempted to add SSRC before SRTP keying material is derived");
+
+    srtp_policy_t inbound = {};
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+    inbound.ssrc.type = ssrc_specific;
+    inbound.ssrc.value = ssrc;
+    inbound.key = mIsClient ? mServerSessionKey : mClientSessionKey;
+    inbound.next = nullptr;
+    inbound.allow_repeat_tx = true;
+
+    if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
+        throw std::runtime_error("SRTP add inbound stream failed, status=" +
+                                 to_string(static_cast<int>(err)));
+
+    srtp_policy_t outbound = {};
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
+    outbound.ssrc.type = ssrc_specific;
+    outbound.ssrc.value = ssrc;
+    outbound.key = mIsClient ? mClientSessionKey : mServerSessionKey;
+    outbound.next = nullptr;
+    outbound.allow_repeat_tx = true;
+
+    if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
+        throw std::runtime_error("SRTP add outbound stream failed, status=" +
+                                 to_string(static_cast<int>(err)));
+}
+
+
 } // namespace rtc
 
 #endif

+ 10 - 1
src/dtlssrtptransport.hpp

@@ -24,7 +24,12 @@
 
 #if RTC_ENABLE_MEDIA
 
+#ifdef RTC_SRTP_FROM_SOURCE
+#include "srtp.h"
+#else
 #include <srtp2/srtp.h>
+#endif
+#include <atomic>
 
 namespace rtc {
 
@@ -39,6 +44,7 @@ public:
 	~DtlsSrtpTransport();
 
 	bool sendMedia(message_ptr message);
+	void addSSRC(uint32_t ssrc);
 
 private:
 	void incoming(message_ptr message) override;
@@ -47,7 +53,10 @@ private:
 	message_callback mSrtpRecvCallback;
 
 	srtp_t mSrtpIn, mSrtpOut;
-	bool mInitDone = false;
+
+	std::atomic<bool> mInitDone = false;
+	unsigned char mClientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
+	unsigned char mServerSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
 };
 
 } // namespace rtc

+ 2 - 2
src/dtlstransport.cpp

@@ -177,8 +177,8 @@ void DtlsTransport::runRecvLoop() {
 	// Receive loop
 	try {
 		PLOG_INFO << "DTLS handshake finished";
+        postHandshake();
 		changeState(State::Connected);
-		postHandshake();
 
 		const size_t bufferSize = maxMtu;
 		char buffer[bufferSize];
@@ -453,8 +453,8 @@ void DtlsTransport::runRecvLoop() {
 						SSL_set_mtu(mSsl, maxMtu + 1);
 
 						PLOG_INFO << "DTLS handshake finished";
-						changeState(State::Connected);
 						postHandshake();
+						changeState(State::Connected);
 					}
 				} else {
 					ret = SSL_read(mSsl, buffer, bufferSize);

+ 3 - 3
src/icetransport.cpp

@@ -141,12 +141,12 @@ Description IceTransport::getLocalDescription(Description::Type type) const {
 	if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
 		throw std::runtime_error("Failed to generate local SDP");
 
-	return Description(string(sdp), type, mRole);
+    return Description(string(sdp), type, type == Description::Type::Offer ? Description::Role::ActPass : mRole);
 }
 
 void IceTransport::setRemoteDescription(const Description &description) {
-	mRole = description.role() == Description::Role::Active ? Description::Role::Passive
-	                                                        : Description::Role::Active;
+    mRole = description.role() == Description::Role::Active ? Description::Role::Passive  : Description::Role::Active;
+
 	mMid = description.bundleMid();
 	if (juice_set_remote_description(mAgent.get(),
 	                                 description.generateApplicationSdp("\r\n").c_str()) < 0)

+ 225 - 68
src/peerconnection.cpp

@@ -119,6 +119,7 @@ bool PeerConnection::hasMedia() const {
 void PeerConnection::setLocalDescription(Description::Type type) {
 	PLOG_VERBOSE << "Setting local description, type=" << Description::typeToString(type);
 
+
 	SignalingState signalingState = mSignalingState.load();
 	if (type == Description::Type::Rollback) {
 		if (signalingState == SignalingState::HaveLocalOffer ||
@@ -393,6 +394,7 @@ std::shared_ptr<Track> PeerConnection::addTrack(Description::Media description)
 	if (!track) {
 		track = std::make_shared<Track>(std::move(description));
 		mTracks.emplace(std::make_pair(track->mid(), track));
+		mTrackLines.emplace_back(track);
 	}
 
 	// Renegotiation is needed for the new or updated track
@@ -677,53 +679,205 @@ void PeerConnection::forwardMedia(message_ptr message) {
 	if (!message)
 		return;
 
-	if (message->type == Message::Type::Control) {
-		std::shared_lock lock(mTracksMutex); // read-only
-		for (auto it = mTracks.begin(); it != mTracks.end(); ++it)
-			if (auto track = it->second.lock())
-				return track->incoming(message);
 
-		PLOG_WARNING << "No track available to receive control, dropping";
+	// Browsers like to compound their packets with a random SSRC.
+	// we have to do this monstrosity to distribute the report blocks
+    std::optional<unsigned int> mediaLine;
+    if (message->type == Message::Control) {
+        unsigned int offset = 0;
+        std::vector<SSRC> ssrcsFound;
+        bool hasFound = false;
+
+        while ((sizeof(rtc::RTCP_HEADER) + offset) <= message->size()) {
+            auto header = (rtc::RTCP_HEADER *) (message->data() + offset);
+            if (header->lengthInBytes() > message->size() - offset) {
+                PLOG_WARNING << "Packet was truncated";
+                break;
+            }
+            offset += header->lengthInBytes();
+            if (header->payloadType() == 205 || header->payloadType() == 206) {
+                auto rtcpfb = (RTCP_FB_HEADER *) header;
+                auto ssrc = rtcpfb->getPacketSenderSSRC();
+                if (std::find(ssrcsFound.begin(), ssrcsFound.end(), ssrc) == ssrcsFound.end()) {
+                    mediaLine = getMLineFromSSRC(ssrc);
+                    if (mediaLine.has_value()) {
+                        hasFound = true;
+                        std::shared_lock lock(mTracksMutex); // read-only
+                        if (auto track = mTrackLines[*mediaLine].lock()) {
+                            track->incoming(message);
+                        }
+                        ssrcsFound.emplace_back(ssrc);
+                    }
+                }
+
+                ssrc = rtcpfb->getMediaSourceSSRC();
+                if (std::find(ssrcsFound.begin(), ssrcsFound.end(), ssrc) == ssrcsFound.end()) {
+                    mediaLine = getMLineFromSSRC(ssrc);
+                    if (mediaLine.has_value()) {
+                        hasFound = true;
+                        std::shared_lock lock(mTracksMutex); // read-only
+                        if (auto track = mTrackLines[*mediaLine].lock()) {
+                            track->incoming(message);
+                        }
+                        ssrcsFound.emplace_back(ssrc);
+                    }
+                }
+            }else if (header->payloadType() == 200 || header->payloadType() == 201) {
+                auto rtcpsr = (RTCP_SR *) header;
+                auto ssrc = rtcpsr->senderSSRC();
+                if (std::find(ssrcsFound.begin(), ssrcsFound.end(), ssrc) == ssrcsFound.end()) {
+                    mediaLine = getMLineFromSSRC(ssrc);
+                    if (mediaLine.has_value()) {
+                        hasFound = true;
+                        std::shared_lock lock(mTracksMutex); // read-only
+                        if (auto track = mTrackLines[*mediaLine].lock()) {
+                            track->incoming(message);
+                        }
+                        ssrcsFound.emplace_back(ssrc);
+                    }
+                }
+                for (int i = 0; i < rtcpsr->header.reportCount(); i++) {
+                    auto block = rtcpsr->getReportBlock(i);
+                    ssrc = block->getSSRC();
+                    if (std::find(ssrcsFound.begin(), ssrcsFound.end(), ssrc) == ssrcsFound.end()) {
+                        mediaLine = getMLineFromSSRC(ssrc);
+                        if (mediaLine.has_value()) {
+                            hasFound = true;
+                            std::shared_lock lock(mTracksMutex); // read-only
+                            if (auto track = mTrackLines[*mediaLine].lock()) {
+                                track->incoming(message);
+                            }
+                            ssrcsFound.emplace_back(ssrc);
+                        }
+                    }
+                }
+            } else {
+                //PT=202 == SDES
+                //PT=207 == Extended Report
+                if (header->payloadType() != 202 && header->payloadType() != 207) {
+                    PLOG_WARNING << "Unknown packet type: " << (int) header->version() << " " << header->payloadType() << "";
+                }
+            }
+        }
+
+        if (hasFound)
+            return;
+    }
+
+    unsigned int ssrc = message->stream;
+    mediaLine = getMLineFromSSRC(ssrc);
+
+	if (!mediaLine) {
+	    /* TODO
+	     *   So the problem is that when stop sending streams, we stop getting report blocks for those streams
+	     *   Therefore when we get compound RTCP packets, they are empty, and we can't forward them.
+	     *   Therefore, it is expected that we don't know where to forward packets.
+	     *   Is this ideal? No! Do I know how to fix it? No!
+	     */
+	//	PLOG_WARNING << "Track not found for SSRC " << ssrc << ", dropping";
 		return;
 	}
 
-	unsigned int payloadType = message->stream;
-	std::optional<string> mid;
-	if (auto it = mMidFromPayloadType.find(payloadType); it != mMidFromPayloadType.end()) {
-		mid = it->second;
-	} else {
-		std::lock_guard lock(mLocalDescriptionMutex);
-		if (!mLocalDescription)
-			return;
+	std::shared_lock lock(mTracksMutex); // read-only
+    if (auto track = mTrackLines[*mediaLine].lock()) {
+        track->incoming(message);
+    }
+}
 
-		for (int i = 0; i < mLocalDescription->mediaCount(); ++i) {
-			if (auto found = std::visit(
-			        rtc::overloaded{[&](Description::Application *) -> std::optional<string> {
-				                        return std::nullopt;
-			                        },
-			                        [&](Description::Media *media) -> std::optional<string> {
-				                        return media->hasPayloadType(payloadType)
-				                                   ? std::make_optional(media->mid())
-				                                   : nullopt;
-			                        }},
-			        mLocalDescription->media(i))) {
-
-				mMidFromPayloadType.emplace(payloadType, *found);
-				mid = *found;
-				break;
-			}
-		}
-	}
+std::optional<unsigned int> PeerConnection::getMLineFromSSRC(SSRC ssrc) {
+    if (auto it = mMLineFromSssrc.find(ssrc); it != mMLineFromSssrc.end()) {
+        return it->second;
+    }else {
+        {
+            std::lock_guard lock(mRemoteDescriptionMutex);
+            if (!mRemoteDescription)
+                return nullopt;
+            for (unsigned int i = 0; i < mRemoteDescription->mediaCount(); ++i) {
+                if (std::visit(
+                        rtc::overloaded{[&](Description::Application *) -> bool {
+                            return false;
+                        },
+                                        [&](Description::Media *media) -> bool {
+                                            return media->hasSSRC(ssrc);
+                                        }},
+                        mRemoteDescription->media(i))) {
+
+                    mMLineFromSssrc.emplace(ssrc, i);
+                    return i;
+                }
+            }
+        }
+        {
+            std::lock_guard lock(mLocalDescriptionMutex);
+            if (!mLocalDescription)
+                return nullopt;
+            for (unsigned int i = 0; i < mLocalDescription->mediaCount(); ++i) {
+                if (std::visit(
+                        rtc::overloaded{[&](Description::Application *) -> bool {
+                            return false;
+                        },
+                                        [&](Description::Media *media) -> bool {
+                                            return media->hasSSRC(ssrc);
+                                        }},
+                        mLocalDescription->media(i))) {
+
+                    mMLineFromSssrc.emplace(ssrc, i);
+                    return i;
+                }
+            }
+        }
+    }
+    return std::nullopt;
+}
 
-	if (!mid) {
-		PLOG_WARNING << "Track not found for payload type " << payloadType << ", dropping";
-		return;
+std::optional<std::string> PeerConnection::getMidFromSSRC(SSRC ssrc) {
+    if (auto it = mMidFromSssrc.find(ssrc); it != mMidFromSssrc.end()) {
+        return it->second;
+    } else {
+        {
+            std::lock_guard lock(mRemoteDescriptionMutex);
+            if (!mRemoteDescription)
+                return nullopt;
+            for (unsigned int i = 0; i < mRemoteDescription->mediaCount(); ++i) {
+                if (auto found = std::visit(
+                        rtc::overloaded{[&](Description::Application *) -> std::optional<string> {
+                            return std::nullopt;
+                        },
+                                        [&](Description::Media *media) -> std::optional<string> {
+                                            return media->hasSSRC(ssrc)
+                                                   ? std::make_optional(media->mid())
+                                                   : nullopt;
+                                        }},
+                        mRemoteDescription->media(i))) {
+
+                    mMidFromSssrc.emplace(ssrc, *found);
+                    return *found;
+                }
+            }
+        }
+        {
+            std::lock_guard lock(mLocalDescriptionMutex);
+            if (!mLocalDescription)
+                return nullopt;
+            for (unsigned int i = 0; i < mLocalDescription->mediaCount(); ++i) {
+                if (auto found = std::visit(
+                        rtc::overloaded{[&](Description::Application *) -> std::optional<string> {
+                            return std::nullopt;
+                        },
+                                        [&](Description::Media *media) -> std::optional<string> {
+                                            return media->hasSSRC(ssrc)
+                                                   ? std::make_optional(media->mid())
+                                                   : nullopt;
+                                        }},
+                        mLocalDescription->media(i))) {
+
+                    mMidFromSssrc.emplace(ssrc, *found);
+                    return *found;
+                }
+            }
+        }
 	}
-
-	std::shared_lock lock(mTracksMutex); // read-only
-	if (auto it = mTracks.find(*mid); it != mTracks.end())
-		if (auto track = it->second.lock())
-			track->incoming(message);
+    return nullopt;
 }
 
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
@@ -825,7 +979,8 @@ void PeerConnection::incomingTrack(Description::Media description) {
 	if (mTracks.find(description.mid()) == mTracks.end()) {
 		auto track = std::make_shared<Track>(std::move(description));
 		mTracks.emplace(std::make_pair(track->mid(), track));
-		triggerTrack(std::move(track));
+        mTrackLines.emplace_back(track);
+		triggerTrack(track);
 	}
 }
 
@@ -856,7 +1011,7 @@ void PeerConnection::validateRemoteDescription(const Description &description) {
 		throw std::invalid_argument("Remote description has no media line");
 
 	int activeMediaCount = 0;
-	for (int i = 0; i < description.mediaCount(); ++i)
+	for (size_t i = 0; i < description.mediaCount(); ++i)
 		std::visit(rtc::overloaded{[&](const Description::Application *) { ++activeMediaCount; },
 		                           [&](const Description::Media *media) {
 			                           if (media->direction() != Description::Direction::Inactive)
@@ -876,9 +1031,10 @@ void PeerConnection::validateRemoteDescription(const Description &description) {
 }
 
 void PeerConnection::processLocalDescription(Description description) {
+
 	if (auto remote = remoteDescription()) {
 		// Reciprocate remote description
-		for (int i = 0; i < remote->mediaCount(); ++i)
+		for (size_t i = 0; i < remote->mediaCount(); ++i)
 			std::visit( // reciprocate each media
 			    rtc::overloaded{
 			        [&](Description::Application *remoteApp) {
@@ -936,21 +1092,21 @@ void PeerConnection::processLocalDescription(Description description) {
 
 				        auto reciprocated = remoteMedia->reciprocate();
 #if !RTC_ENABLE_MEDIA
-				        // No media support, mark as inactive
-				        reciprocated.setDirection(Description::Direction::Inactive);
+                                    // No media support, mark as inactive
+                                    reciprocated.setDirection(Description::Direction::Inactive);
 #endif
-				        incomingTrack(reciprocated);
+                                    incomingTrack(reciprocated);
 
-				        PLOG_DEBUG
-				            << "Reciprocating media in local description, mid=\""
-				            << reciprocated.mid() << "\", active=" << std::boolalpha
-				            << (reciprocated.direction() != Description::Direction::Inactive);
+                                    PLOG_DEBUG
+                                                << "Reciprocating media in local description, mid=\""
+                                                << reciprocated.mid() << "\", active=" << std::boolalpha
+                                                << (reciprocated.direction() != Description::Direction::Inactive);
 
-				        description.addMedia(std::move(reciprocated));
-			        },
-			    },
-			    remote->media(i));
-	}
+                                    description.addMedia(std::move(reciprocated));
+                                },
+                        },
+                        remote->media(i));
+        }
 
 	if (description.type() == Description::Type::Offer) {
 		// This is an offer, add locally created data channels and tracks
@@ -970,24 +1126,25 @@ void PeerConnection::processLocalDescription(Description description) {
 		}
 
 		// Add media for local tracks
+
 		std::shared_lock lock(mTracksMutex);
-		for (auto it = mTracks.begin(); it != mTracks.end(); ++it) {
-			if (description.hasMid(it->first))
-				continue;
+        for (auto it = mTrackLines.begin(); it != mTrackLines.end(); ++it) {
+            if (auto track = it->lock()) {
+                if (description.hasMid(track->mid()))
+                    continue;
 
-			if (auto track = it->second.lock()) {
-				auto media = track->description();
+                auto media = track->description();
 #if !RTC_ENABLE_MEDIA
-				// No media support, mark as inactive
-				media.setDirection(Description::Direction::Inactive);
+                // No media support, mark as inactive
+                media.setDirection(Description::Direction::Inactive);
 #endif
-				PLOG_DEBUG << "Adding media to local description, mid=\"" << media.mid()
-				           << "\", active=" << std::boolalpha
-				           << (media.direction() != Description::Direction::Inactive);
+                PLOG_DEBUG << "Adding media to local description, mid=\"" << media.mid()
+                           << "\", active=" << std::boolalpha
+                           << (media.direction() != Description::Direction::Inactive);
 
-				description.addMedia(std::move(media));
-			}
-		}
+                description.addMedia(std::move(media));
+            }
+        }
 	}
 
 	// Set local fingerprint (wait for certificate if necessary)

+ 37 - 322
src/rtcp.cpp

@@ -1,5 +1,5 @@
 /**
- * Copyright (c) 2020 Staz M
+ * Copyright (c) 2020 Staz Modrzynski
  * Copyright (c) 2020 Paul-Louis Ageneau
  *
  * This library is free software; you can redistribute it and/or
@@ -21,6 +21,7 @@
 
 #include <cmath>
 #include <utility>
+#include "track.hpp"
 
 #ifdef _WIN32
 #include <winsock2.h>
@@ -28,310 +29,14 @@
 #include <arpa/inet.h>
 #endif
 
-#ifndef htonll
-#define htonll(x)                                                                                  \
-	((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32))
-#endif
-#ifndef ntohll
-#define ntohll(x) htonll(x)
-#endif
 
 namespace rtc {
 
-#pragma pack(push, 1)
-
-struct RTP {
-private:
-	uint8_t _first;
-	uint8_t _payloadType;
-	uint16_t _seqNumber;
-	uint32_t _timestamp;
-
-public:
-	SSRC ssrc;
-	SSRC csrc[16];
-
-	inline uint8_t version() const { return _first >> 6; }
-	inline bool padding() const { return (_first >> 5) & 0x01; }
-	inline uint8_t csrcCount() const { return _first & 0x0F; }
-	inline uint8_t payloadType() const { return _payloadType; }
-	inline uint16_t seqNumber() const { return ntohs(_seqNumber); }
-	inline uint32_t timestamp() const { return ntohl(_timestamp); }
-};
-
-struct RTCP_ReportBlock {
-	SSRC ssrc;
-
-private:
-	uint32_t _fractionLostAndPacketsLost; // fraction lost is 8-bit, packets lost is 24-bit
-	uint16_t _seqNoCycles;
-	uint16_t _highestSeqNo;
-	uint32_t _jitter;
-	uint32_t _lastReport;
-	uint32_t _delaySinceLastReport;
-
-public:
-	inline void preparePacket(SSRC ssrc_, [[maybe_unused]] unsigned int packetsLost,
-	                          [[maybe_unused]] unsigned int totalPackets, uint16_t highestSeqNo,
-	                          uint16_t seqNoCycles, uint32_t jitter, uint64_t lastSR_NTP,
-	                          uint64_t lastSR_DELAY) {
-		setSeqNo(highestSeqNo, seqNoCycles);
-		setJitter(jitter);
-		setSSRC(ssrc_);
-
-		// Middle 32 bits of NTP Timestamp
-		// _lastReport = lastSR_NTP >> 16u;
-		setNTPOfSR(uint32_t(lastSR_NTP));
-		setDelaySinceSR(uint32_t(lastSR_DELAY));
-
-		// The delay, expressed in units of 1/65536 seconds
-		// _delaySinceLastReport = lastSR_DELAY;
-	}
-
-	inline void setSSRC(SSRC ssrc_) { ssrc = htonl(ssrc_); }
-	inline SSRC getSSRC() const { return ntohl(ssrc); }
-
-	inline void setPacketsLost([[maybe_unused]] unsigned int packetsLost,
-	                           [[maybe_unused]] unsigned int totalPackets) {
-		// TODO Implement loss percentages.
-		_fractionLostAndPacketsLost = 0;
-	}
-	inline unsigned int getLossPercentage() const {
-		// TODO Implement loss percentages.
-		return 0;
-	}
-	inline unsigned int getPacketLostCount() const {
-		// TODO Implement total packets lost.
-		return 0;
-	}
-
-	inline uint16_t seqNoCycles() const { return ntohs(_seqNoCycles); }
-	inline uint16_t highestSeqNo() const { return ntohs(_highestSeqNo); }
-	inline uint32_t jitter() const { return ntohl(_jitter); }
-
-	inline void setSeqNo(uint16_t highestSeqNo, uint16_t seqNoCycles) {
-		_highestSeqNo = htons(highestSeqNo);
-		_seqNoCycles = htons(seqNoCycles);
-	}
-
-	inline void setJitter(uint32_t jitter) { _jitter = htonl(jitter); }
-
-	inline void setNTPOfSR(uint32_t ntp) { _lastReport = htonl(ntp >> 16u); }
-	inline uint32_t getNTPOfSR() const { return ntohl(_lastReport) << 16u; }
-
-	inline void setDelaySinceSR(uint32_t sr) {
-		// The delay, expressed in units of 1/65536 seconds
-		_delaySinceLastReport = htonl(sr);
-	}
-	inline uint32_t getDelaySinceSR() const { return ntohl(_delaySinceLastReport); }
-
-	inline void log() const {
-		PLOG_DEBUG << "RTCP report block: "
-		           << "ssrc="
-		           << ntohl(ssrc)
-		           // TODO: Implement these reports
-		           //	<< ", fractionLost=" << fractionLost
-		           //	<< ", packetsLost=" << packetsLost
-		           << ", highestSeqNo=" << highestSeqNo() << ", seqNoCycles=" << seqNoCycles()
-		           << ", jitter=" << jitter() << ", lastSR=" << getNTPOfSR()
-		           << ", lastSRDelay=" << getDelaySinceSR();
-	}
-};
-
-struct RTCP_HEADER {
-private:
-	uint8_t _first;
-	uint8_t _payloadType;
-	uint16_t _length;
-
-public:
-	inline uint8_t version() const { return _first >> 6; }
-	inline bool padding() const { return (_first >> 5) & 0x01; }
-	inline uint8_t reportCount() const { return _first & 0x0F; }
-	inline uint8_t payloadType() const { return _payloadType; }
-	inline uint16_t length() const { return ntohs(_length); }
-
-	inline void setPayloadType(uint8_t type) { _payloadType = type; }
-	inline void setReportCount(uint8_t count) { _first = (_first & 0xF0) | (count & 0x0F); }
-	inline void setLength(uint16_t length) { _length = htons(length); }
-
-	inline void prepareHeader(uint8_t payloadType, uint8_t reportCount, uint16_t length) {
-		_first = 0x02 << 6; // version 2, no padding
-		setReportCount(reportCount);
-		setPayloadType(payloadType);
-		setLength(length);
-	}
-
-	inline void log() const {
-		PLOG_DEBUG << "RTCP header: "
-		           << "version=" << unsigned(version()) << ", padding=" << padding()
-		           << ", reportCount=" << unsigned(reportCount())
-		           << ", payloadType=" << unsigned(payloadType()) << ", length=" << length();
-	}
-};
-
-struct RTCP_SR {
-	RTCP_HEADER header;
-	SSRC senderSsrc;
-
-private:
-	uint64_t _ntpTimestamp;
-	uint32_t _rtpTimestamp;
-	uint32_t _packetCount;
-	uint32_t _octetCount;
-
-	RTCP_ReportBlock _reportBlocks;
-
-public:
-	inline void preparePacket(SSRC senderSsrc_, uint8_t reportCount) {
-		unsigned int length =
-		    ((sizeof(header) + 24 + reportCount * sizeof(RTCP_ReportBlock)) / 4) - 1;
-		header.prepareHeader(200, reportCount, uint16_t(length));
-		senderSsrc = htonl(senderSsrc_);
-	}
-
-	inline RTCP_ReportBlock *getReportBlock(int num) { return &_reportBlocks + num; }
-	inline const RTCP_ReportBlock *getReportBlock(int num) const { return &_reportBlocks + num; }
-
-	[[nodiscard]] inline size_t getSize() const {
-		// "length" in packet is one less than the number of 32 bit words in the packet.
-		return sizeof(uint32_t) * (1 + size_t(header.length()));
-	}
-
-	inline uint32_t ntpTimestamp() const { return ntohll(_ntpTimestamp); }
-	inline uint32_t rtpTimestamp() const { return ntohl(_rtpTimestamp); }
-	inline uint32_t packetCount() const { return ntohl(_packetCount); }
-	inline uint32_t octetCount() const { return ntohl(_octetCount); }
-
-	inline void setNtpTimestamp(uint32_t ts) { _ntpTimestamp = htonll(ts); }
-	inline void setRtpTimestamp(uint32_t ts) { _rtpTimestamp = htonl(ts); }
-
-	inline void log() const {
-		header.log();
-		PLOG_DEBUG << "RTCP SR: "
-		           << " SSRC=" << ntohl(senderSsrc) << ", NTP_TS=" << ntpTimestamp()
-		           << ", RTP_TS=" << rtpTimestamp() << ", packetCount=" << packetCount()
-		           << ", octetCount=" << octetCount();
-
-		for (unsigned i = 0; i < unsigned(header.reportCount()); i++) {
-			getReportBlock(i)->log();
-		}
-	}
-};
-
-struct RTCP_RR {
-	RTCP_HEADER header;
-	SSRC senderSsrc;
-
-private:
-	RTCP_ReportBlock _reportBlocks;
-
-public:
-	inline RTCP_ReportBlock *getReportBlock(int num) { return &_reportBlocks + num; }
-	inline const RTCP_ReportBlock *getReportBlock(int num) const { return &_reportBlocks + num; }
-
-	inline SSRC getSenderSSRC() const { return ntohl(senderSsrc); }
-	inline void setSenderSSRC(SSRC ssrc) { senderSsrc = htonl(ssrc); }
-
-	[[nodiscard]] inline size_t getSize() const {
-		// "length" in packet is one less than the number of 32 bit words in the packet.
-		return sizeof(uint32_t) * (1 + size_t(header.length()));
-	}
-
-	inline void preparePacket(SSRC ssrc, uint8_t reportCount) {
-		// "length" in packet is one less than the number of 32 bit words in the packet.
-		size_t length = (sizeWithReportBlocks(reportCount) / 4) - 1;
-		header.prepareHeader(201, reportCount, uint16_t(length));
-		senderSsrc = htonl(ssrc);
-	}
-
-	inline static size_t sizeWithReportBlocks(uint8_t reportCount) {
-		return sizeof(header) + 4 + size_t(reportCount) * sizeof(RTCP_ReportBlock);
-	}
-
-	inline void log() const {
-		header.log();
-		PLOG_DEBUG << "RTCP RR: "
-		           << " SSRC=" << ntohl(senderSsrc);
-
-		for (unsigned i = 0; i < unsigned(header.reportCount()); i++) {
-			getReportBlock(i)->log();
-		}
-	}
-};
-
-struct RTCP_REMB {
-	RTCP_HEADER header;
-	SSRC senderSsrc;
-	SSRC mediaSourceSSRC;
-
-	// Unique identifier
-	const char id[4] = {'R', 'E', 'M', 'B'};
-
-	// Num SSRC, Br Exp, Br Mantissa (bit mask)
-	uint32_t bitrate;
-
-	SSRC ssrc[1];
-
-	[[nodiscard]] inline size_t getSize() const {
-		// "length" in packet is one less than the number of 32 bit words in the packet.
-		return sizeof(uint32_t) * (1 + size_t(header.length()));
-	}
-
-	inline void preparePacket(SSRC senderSsrc_, unsigned int numSSRC, unsigned int br) {
-		// Report Count becomes the format here.
-		header.prepareHeader(206, 15, 0);
-
-		// Always zero.
-		mediaSourceSSRC = 0;
-
-		senderSsrc = htonl(senderSsrc_);
-		setBitrate(numSSRC, br);
-	}
-
-	inline void setBitrate(unsigned int numSSRC, unsigned int br) {
-		unsigned int exp = 0;
-		while (br > pow(2, 18) - 1) {
-			exp++;
-			br /= 2;
-		}
-
-		// "length" in packet is one less than the number of 32 bit words in the packet.
-		header.setLength(uint16_t(((sizeof(header) + 4 * 2 + 4 + 4) / 4) - 1 + numSSRC));
-
-		bitrate = htonl((numSSRC << (32u - 8u)) | (exp << (32u - 8u - 6u)) | br);
-	}
-
-	// TODO Make this work
-	//	  uint64_t getBitrate() const{
-	//		  uint32_t ntohed = ntohl(bitrate);
-	//		  uint64_t bitrate = ntohed & (unsigned int)(pow(2, 18)-1);
-	//		  unsigned int exp = ntohed & ((unsigned int)( (pow(2, 6)-1)) << (32u-8u-6u));
-	//		  return bitrate * pow(2,exp);
-	//	  }
-	//
-	//	  uint8_t getNumSSRCS() const {
-	//		  return ntohl(bitrate) & (((unsigned int) pow(2,8)-1) << (32u-8u));
-	//	  }
-
-	inline void setSSRC(uint8_t iterator, SSRC ssrc_) { ssrc[iterator] = htonl(ssrc_); }
-
-	inline void log() const {
-		header.log();
-		PLOG_DEBUG << "RTCP REMB: "
-		           << " SSRC=" << ntohl(senderSsrc);
-	}
-
-	static unsigned int sizeWithSSRCs(int numSSRC) {
-		return (sizeof(header) + 4 * 2 + 4 + 4) + sizeof(SSRC) * numSSRC;
-	}
-};
-
-#pragma pack(pop)
-
-void RtcpSession::onOutgoing(std::function<void(rtc::message_ptr)> cb) { mTxCallback = cb; }
+rtc::message_ptr RtcpReceivingSession::outgoing(rtc::message_ptr ptr) {
+    return ptr;
+}
 
-std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
+rtc::message_ptr RtcpReceivingSession::incoming(rtc::message_ptr ptr) {
 	if (ptr->type == rtc::Message::Type::Binary) {
 		auto rtp = reinterpret_cast<const RTP *>(ptr->data());
 
@@ -339,12 +44,12 @@ std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
 		if (rtp->version() != 2) {
 			PLOG_WARNING << "RTP packet is not version 2";
 
-			return std::nullopt;
+			return nullptr;
 		}
 		if (rtp->payloadType() == 201 || rtp->payloadType() == 200) {
 			PLOG_WARNING << "RTP packet has a payload type indicating RR/SR";
 
-			return std::nullopt;
+			return nullptr;
 		}
 
 		// TODO Implement the padding bit
@@ -352,13 +57,7 @@ std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
 			PLOG_WARNING << "Padding processing not implemented";
 		}
 
-		mSsrc = ntohl(rtp->ssrc);
-
-		uint32_t seqNo = rtp->seqNumber();
-		// uint32_t rtpTS = rtp->getTS();
-
-		if (mGreatestSeqNo < seqNo)
-			mGreatestSeqNo = seqNo;
+		mSsrc = rtp->ssrc();
 
 		return ptr;
 	}
@@ -367,11 +66,11 @@ std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
 	auto rr = reinterpret_cast<const RTCP_RR *>(ptr->data());
 	if (rr->header.payloadType() == 201) {
 		// RR
-		mSsrc = rr->getSenderSSRC();
+		mSsrc = rr->senderSSRC();
 		rr->log();
 	} else if (rr->header.payloadType() == 200) {
 		// SR
-		mSsrc = rr->getSenderSSRC();
+		mSsrc = rr->senderSSRC();
 		auto sr = reinterpret_cast<const RTCP_SR *>(ptr->data());
 		mSyncRTPTS = sr->rtpTimestamp();
 		mSyncNTPTS = sr->ntpTimestamp();
@@ -382,28 +81,27 @@ std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
 		if (mRequestedBitrate > 0)
 			pushREMB(mRequestedBitrate);
 	}
-	return std::nullopt;
+	return nullptr;
 }
 
-void RtcpSession::requestBitrate(unsigned int newBitrate) {
+void RtcpReceivingSession::requestBitrate(unsigned int newBitrate) {
 	mRequestedBitrate = newBitrate;
 
 	PLOG_DEBUG << "[GOOG-REMB] Requesting bitrate: " << newBitrate << std::endl;
 	pushREMB(newBitrate);
 }
 
-void RtcpSession::pushREMB(unsigned int bitrate) {
+void RtcpReceivingSession::pushREMB(unsigned int bitrate) {
 	rtc::message_ptr msg =
 	    rtc::make_message(RTCP_REMB::sizeWithSSRCs(1), rtc::Message::Type::Control);
 	auto remb = reinterpret_cast<RTCP_REMB *>(msg->data());
 	remb->preparePacket(mSsrc, 1, bitrate);
-	remb->setSSRC(0, mSsrc);
-	remb->log();
+	remb->setSsrc(0, mSsrc);
 
-	tx(msg);
+    send(msg);
 }
 
-void RtcpSession::pushRR(unsigned int lastSR_delay) {
+void RtcpReceivingSession::pushRR(unsigned int lastSR_delay) {
 	auto msg = rtc::make_message(RTCP_RR::sizeWithReportBlocks(1), rtc::Message::Type::Control);
 	auto rr = reinterpret_cast<RTCP_RR *>(msg->data());
 	rr->preparePacket(mSsrc, 1);
@@ -411,16 +109,33 @@ void RtcpSession::pushRR(unsigned int lastSR_delay) {
 	                                     lastSR_delay);
 	rr->log();
 
-	tx(msg);
+    send(msg);
 }
 
-void RtcpSession::tx(message_ptr msg) {
+bool RtcpReceivingSession::send(message_ptr msg) {
 	try {
-		mTxCallback(msg);
+	    outgoingCallback(std::move(msg));
+	    return true;
 	} catch (const std::exception &e) {
 		LOG_DEBUG << "RTCP tx failed: " << e.what();
 	}
+	return false;
 }
 
+bool RtcpReceivingSession::requestKeyframe() {
+    pushPLI();
+    return true; // TODO Make this false when it is impossible (i.e. Opus).
+}
+
+void RtcpReceivingSession::pushPLI() {
+    auto msg = rtc::make_message(rtc::RTCP_PLI::size(), rtc::Message::Type::Control);
+    auto *pli = (rtc::RTCP_PLI *) msg->data();
+    pli->preparePacket(mSsrc);
+    send(msg);
+}
+
+void RtcpHandler::onOutgoing(const std::function<void(rtc::message_ptr)>& cb) {
+    this->outgoingCallback = synchronized_callback<rtc::message_ptr>(cb);
+}
 } // namespace rtc
 

+ 50 - 29
src/track.cpp

@@ -32,6 +32,7 @@ string Track::mid() const { return mMediaDescription.mid(); }
 
 Description::Media Track::description() const { return mMediaDescription; }
 
+
 void Track::setDescription(Description::Media description) {
 	if(description.mid() != mMediaDescription.mid())
 		throw std::logic_error("Media description mid does not match track mid");
@@ -91,25 +92,35 @@ void Track::open(shared_ptr<DtlsSrtpTransport> transport) {
 #endif
 
 bool Track::outgoing(message_ptr message) {
-	auto direction = mMediaDescription.direction();
-	if (direction == Description::Direction::RecvOnly ||
-	    direction == Description::Direction::Inactive)
-		throw std::runtime_error("Track media direction does not allow sending");
 
-	if (mIsClosed)
-		throw std::runtime_error("Track is closed");
+    if (mRtcpHandler) {
+        message = mRtcpHandler->outgoing(message);
+        if (!message)
+            return false;
+    }
+
+    auto direction = mMediaDescription.direction();
+    if ((direction == Description::Direction::RecvOnly ||
+         direction == Description::Direction::Inactive) &&
+        message->type != Message::Control) {
+        PLOG_WARNING << "Track media direction does not allow transmission, dropping";
+        return false;
+    }
+
+    if (mIsClosed)
+        throw std::runtime_error("Track is closed");
 
-	if (message->size() > maxMessageSize())
-		throw std::runtime_error("Message size exceeds limit");
+    if (message->size() > maxMessageSize())
+        throw std::runtime_error("Message size exceeds limit");
 
 #if RTC_ENABLE_MEDIA
-	auto transport = mDtlsSrtpTransport.lock();
-	if (!transport)
-		throw std::runtime_error("Track transport is not open");
+    auto transport = mDtlsSrtpTransport.lock();
+    if (!transport)
+        throw std::runtime_error("Track transport is not open");
 
-	return transport->sendMedia(message);
+    return transport->sendMedia(message);
 #else
-	PLOG_WARNING << "Ignoring track send (not compiled with SRTP support)";
+    PLOG_WARNING << "Ignoring track send (not compiled with SRTP support)";
 	return false;
 #endif
 }
@@ -119,18 +130,17 @@ void Track::incoming(message_ptr message) {
 		return;
 
 	if (mRtcpHandler) {
-		auto opt = mRtcpHandler->incoming(message);
-		if (!opt)
-			return;
-
-		message = *opt;
-	}
+        message = mRtcpHandler->incoming(message);
+        if (!message)
+            return ;
+    }
 
 	auto direction = mMediaDescription.direction();
 	if ((direction == Description::Direction::SendOnly ||
 	     direction == Description::Direction::Inactive) &&
 	    message->type != Message::Control) {
 		PLOG_WARNING << "Track media direction does not allow reception, dropping";
+		return;
 	}
 
 	// Tail drop if queue is full
@@ -142,21 +152,32 @@ void Track::incoming(message_ptr message) {
 }
 
 void Track::setRtcpHandler(std::shared_ptr<RtcpHandler> handler) {
-	if (mRtcpHandler)
-		mRtcpHandler->onOutgoing(nullptr);
-
 	mRtcpHandler = std::move(handler);
 	if (mRtcpHandler) {
-		mRtcpHandler->onOutgoing([&]([[maybe_unused]] message_ptr message) {
-#if RTC_ENABLE_MEDIA
-			if (auto transport = mDtlsSrtpTransport.lock())
-				transport->sendMedia(message);
-#else
-			PLOG_WARNING << "Ignoring RTCP send (not compiled with SRTP support)";
-#endif
+		mRtcpHandler->onOutgoing([&]([[maybe_unused]] const rtc::message_ptr& message) {
+		#if RTC_ENABLE_MEDIA
+			auto transport = mDtlsSrtpTransport.lock();
+			if (!transport)
+			    throw std::runtime_error("Track transport is not open");
+
+			return transport->sendMedia(message);
+		#else
+			PLOG_WARNING << "Ignoring track send (not compiled with SRTP support)";
+		    return false;
+		#endif
 		});
 	}
 }
 
+bool Track::requestKeyframe() {
+    if (mRtcpHandler)
+        return mRtcpHandler->requestKeyframe();
+    return false;
+}
+
+std::shared_ptr<RtcpHandler> Track::getRtcpHandler() {
+    return mRtcpHandler;
+}
+
 } // namespace rtc