浏览代码

Modified the RtcpSession class to a better API design

Staz M 4 年之前
父节点
当前提交
4930e666ac
共有 6 个文件被更改,包括 259 次插入64 次删除
  1. 1 9
      examples/sfu-media/main.cpp
  2. 26 8
      include/rtc/rtcp.hpp
  3. 182 1
      include/rtc/rtp.hpp
  4. 2 0
      include/rtc/track.hpp
  5. 17 14
      src/rtcp.cpp
  6. 31 32
      src/track.cpp

+ 1 - 9
examples/sfu-media/main.cpp

@@ -28,12 +28,7 @@
 
 using nlohmann::json;
 
-class Sender {
-    rtc::PeerConnection conn;
-};
-
 struct Receiver {
-    // TODO @paul.
     std::shared_ptr<rtc::PeerConnection> conn;
     std::shared_ptr<rtc::Track> track;
 };
@@ -63,7 +58,7 @@ int main() {
 		auto track = pc->addTrack(media);
         pc->setLocalDescription();
 
-		auto session = std::make_shared<rtc::RtcpSession>();
+		auto session = std::make_shared<rtc::RtcpReceivingSession>(track);
 		track->setRtcpHandler(session);
 
 		const rtc::SSRC targetSSRC = 4;
@@ -117,9 +112,6 @@ int main() {
             pc->track = pc->conn->addTrack(media);
             pc->conn->setLocalDescription();
 
-            auto session = std::make_shared<rtc::RtcpSession>();
-            pc->track->setRtcpHandler(session);
-
             pc->track->onMessage([](rtc::binary var){}, nullptr);
 
             std::cout << "Please copy/paste the answer provided by the RECEIVER: " << std::endl;

+ 26 - 8
include/rtc/rtcp.hpp

@@ -20,6 +20,8 @@
 #ifndef RTC_RTCP_H
 #define RTC_RTCP_H
 
+#include <utility>
+
 #include "include.hpp"
 #include "log.hpp"
 #include "message.hpp"
@@ -29,25 +31,41 @@ namespace rtc {
 
 class RtcpHandler {
 public:
-	virtual void onOutgoing(std::function<void(rtc::message_ptr)> cb) = 0;
-	virtual std::optional<rtc::message_ptr> incoming(rtc::message_ptr ptr) = 0;
+    /**
+     * If there is traffic coming from the remote side
+     * @param ptr
+     * @return
+     */
+    virtual rtc::message_ptr incoming(rtc::message_ptr ptr) = 0;
+
+    /**
+     * If there is traffic being sent to the remote side
+     * @param ptr
+     * @return
+     */
+    virtual rtc::message_ptr outgoing(rtc::message_ptr ptr) = 0;
 };
 
+class Track;
+
 // An RtcpSession can be plugged into a Track to handle the whole RTCP session
-class RtcpSession : public RtcpHandler {
+class RtcpReceivingSession : public RtcpHandler {
+protected:
+    std::shared_ptr<Track> track;
 public:
-	void onOutgoing(std::function<void(rtc::message_ptr)> cb) override;
+    RtcpReceivingSession(std::shared_ptr<Track> track): track(std::move(track)) {}
+
+    rtc::message_ptr incoming(rtc::message_ptr ptr) override;
+    rtc::message_ptr outgoing(rtc::message_ptr ptr) override;
+    bool send(rtc::message_ptr ptr);
 
-	std::optional<rtc::message_ptr> incoming(rtc::message_ptr ptr) override;
 	void requestBitrate(unsigned int newBitrate);
 
-private:
+protected:
 	void pushREMB(unsigned int bitrate);
 	void pushRR(unsigned int lastSR_delay);
-	void tx(message_ptr msg);
 
 	unsigned int mRequestedBitrate = 0;
-	synchronized_callback<rtc::message_ptr> mTxCallback;
 	SSRC mSsrc = 0;
 	uint32_t mGreatestSeqNo = 0;
 	uint64_t mSyncRTPTS, mSyncNTPTS;

+ 182 - 1
include/rtc/rtp.hpp

@@ -42,9 +42,9 @@ private:
     uint8_t _payloadType;
     uint16_t _seqNumber;
     uint32_t _timestamp;
+    SSRC _ssrc;
 
 public:
-    SSRC ssrc;
     SSRC csrc[16];
 
     inline uint8_t version() const { return _first >> 6; }
@@ -53,6 +53,25 @@ public:
     inline uint8_t payloadType() const { return _payloadType; }
     inline uint16_t seqNumber() const { return ntohs(_seqNumber); }
     inline uint32_t timestamp() const { return ntohl(_timestamp); }
+    inline uint32_t ssrc() const { return ntohl(_ssrc);}
+
+    inline size_t getSize() const {
+        return ((char*)&_ssrc) - ((char*)this) + sizeof(SSRC)*csrcCount();
+    }
+
+    char * getBody() const {
+        return ((char*) this) + getSize();
+    }
+
+    inline void setSeqNumber(uint16_t newSeqNo) {
+        _seqNumber = htons(newSeqNo);
+    }
+    inline void setPayloadType(uint16_t newPayloadType) {
+        _payloadType = newPayloadType;
+    }
+    inline void setSsrc(uint32_t ssrc) {
+        _ssrc = htonl(ssrc);
+    }
 };
 
 struct RTCP_ReportBlock {
@@ -166,6 +185,33 @@ public:
     }
 };
 
+struct RTCP_FB_HEADER {
+    RTCP_HEADER header;
+    SSRC packetSender;
+    SSRC mediaSource;
+
+    [[nodiscard]] SSRC getPacketSenderSSRC() const {
+        return ntohl(packetSender);
+    }
+
+    [[nodiscard]] SSRC getMediaSourceSSRC() const {
+        return ntohl(mediaSource);
+    }
+
+    void setPacketSenderSSRC(SSRC ssrc) {
+        this->packetSender = htonl(ssrc);
+    }
+
+    void setMediaSourceSSRC(SSRC ssrc) {
+        this->mediaSource = htonl(ssrc);
+    }
+
+    void log() {
+        header.log();
+        PLOG_DEBUG << "FB: " << " packet sender: " << getPacketSenderSSRC() << " media source: " << getMediaSourceSSRC();
+    }
+};
+
 struct RTCP_SR {
     RTCP_HEADER header;
     SSRC senderSSRC;
@@ -323,6 +369,141 @@ struct RTCP_REMB {
     }
 };
 
+
+
+struct RTCP_PLI {
+    RTCP_FB_HEADER header;
+
+    void preparePacket(SSRC messageSSRC) {
+        header.header.prepareHeader(206, 1, 2);
+        header.setPacketSenderSSRC(messageSSRC);
+        header.setMediaSourceSSRC(messageSSRC);
+    }
+
+    void print() {
+        header.log();
+    }
+
+    [[nodiscard]] static unsigned int size() {
+        return sizeof(RTCP_FB_HEADER);
+    }
+};
+
+struct RTCP_FIR_PART {
+    uint32_t ssrc;
+#if __BYTE_ORDER == __BIG_ENDIAN
+    uint32_t seqNo: 8;
+    uint32_t: 24;
+#elif __BYTE_ORDER == __LITTLE_ENDIAN
+    uint32_t: 24;
+    uint32_t seqNo: 8;
+#endif
+};
+
+struct RTCP_FIR {
+    RTCP_FB_HEADER header;
+    RTCP_FIR_PART parts[1];
+
+    void preparePacket(SSRC messageSSRC, uint8_t seqNo) {
+        header.header.prepareHeader(206, 4, 2 + 2 * 1);
+        header.setPacketSenderSSRC(messageSSRC);
+        header.setMediaSourceSSRC(messageSSRC);
+        parts[0].ssrc = htonl(messageSSRC);
+        parts[0].seqNo = seqNo;
+    }
+
+    void print() {
+        header.log();
+    }
+
+    [[nodiscard]] static unsigned int size() {
+        return sizeof(RTCP_FB_HEADER) + sizeof(RTCP_FIR_PART);
+    }
+};
+
+struct RTCP_NACK_PART {
+    uint16_t pid;
+    uint16_t blp;
+};
+
+class RTCP_NACK {
+    RTCP_FB_HEADER header;
+    RTCP_NACK_PART parts[1];
+public:
+    void preparePacket(SSRC ssrc, unsigned int discreteSeqNoCount) {
+        header.header.prepareHeader(205, 1, 2 + discreteSeqNoCount);
+        header.setMediaSourceSSRC(ssrc);
+        header.setPacketSenderSSRC(ssrc);
+    }
+
+    /**
+     * Add a packet to the list of missing packets.
+     * @param fciCount The number of FCI fields that are present in this packet.
+     *                  Let the number start at zero and let this function grow the number.
+     * @param fciPID The seq no of the active FCI. It will be initialized automatically, and will change automatically.
+     * @param missingPacket The seq no of the missing packet. This will be added to the queue.
+     * @return true if the packet has grown, false otherwise.
+     */
+    bool addMissingPacket(unsigned int *fciCount, uint16_t *fciPID, const uint16_t &missingPacket) {
+        if (*fciCount == 0 || missingPacket < *fciPID || missingPacket > (*fciPID + 16)) {
+            parts[*fciCount].pid = htons(missingPacket);
+            parts[*fciCount].blp = 0;
+            *fciPID = missingPacket;
+            (*fciCount)++;
+            return true;
+        } else {
+            // TODO SPEEED!
+            parts[(*fciCount) - 1].blp = htons(
+                    ntohs(parts[(*fciCount) - 1].blp) | (1u << (unsigned int) (missingPacket - *fciPID)));
+            return false;
+        }
+    }
+
+    [[nodiscard]] static unsigned int getSize(unsigned int discreteSeqNoCount) {
+        return offsetof(RTCP_NACK, parts) + sizeof(RTCP_NACK_PART) * discreteSeqNoCount;
+    }
+};
+
+class RTP_RTX {
+private:
+    RTP header;
+public:
+
+    size_t copyTo(RTP *dest, size_t totalSize, uint8_t originalPayloadType) {
+        memmove((char*)dest, (char*)this, header.getSize());
+        dest->setSeqNumber(getOriginalSeqNo());
+        dest->setPayloadType(originalPayloadType);
+        memmove(dest->getBody(), getBody(), getBodySize(totalSize));
+        return totalSize;
+    }
+
+    [[nodiscard]] uint16_t getOriginalSeqNo() const {
+        return ntohs(*(uint16_t *) (header.getBody()));
+    }
+
+    char *getBody() {
+        return header.getBody() + sizeof(uint16_t);
+    }
+
+    size_t getBodySize(size_t totalSize) {
+        return totalSize - ((char *) getBody() - (char *) this);
+    }
+
+    RTP &getHeader() {
+        return header;
+    }
+
+    size_t normalizePacket(size_t totalSize, SSRC originalSSRC, uint8_t originalPayloadType) {
+        header.setSeqNumber(getOriginalSeqNo());
+        header.setSsrc(originalSSRC); // TODO Endianess
+        header.setPayloadType(originalPayloadType);
+        // TODO, the -12 is the size of the header (which is variable!)
+        memmove(header.getBody(), header.getBody() + 2, totalSize - 12 - sizeof(uint16_t));
+        return totalSize - sizeof(uint16_t);
+    }
+};
+
+
 #pragma pack(pop)
 };
 #endif //WEBRTC_SERVER_RTP_HPP

+ 2 - 0
include/rtc/track.hpp

@@ -46,6 +46,7 @@ public:
 	void close(void) override;
 	bool send(message_variant data) override;
 	bool send(const byte *data, size_t size);
+    bool sendControl(message_ptr msg);
 
 	bool isOpen(void) const override;
 	bool isClosed(void) const override;
@@ -74,6 +75,7 @@ private:
 	std::shared_ptr<RtcpHandler> mRtcpHandler;
 
 	friend class PeerConnection;
+
 };
 
 } // namespace rtc

+ 17 - 14
src/rtcp.cpp

@@ -21,6 +21,7 @@
 
 #include <cmath>
 #include <utility>
+#include "track.hpp"
 
 #ifdef _WIN32
 #include <winsock2.h>
@@ -31,10 +32,11 @@
 
 namespace rtc {
 
+rtc::message_ptr RtcpReceivingSession::outgoing(rtc::message_ptr ptr) {
+    return ptr;
+}
 
-void RtcpSession::onOutgoing(std::function<void(rtc::message_ptr)> cb) { mTxCallback = cb; }
-
-std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
+rtc::message_ptr RtcpReceivingSession::incoming(rtc::message_ptr ptr) {
 	if (ptr->type == rtc::Message::Type::Binary) {
 		auto rtp = reinterpret_cast<const RTP *>(ptr->data());
 
@@ -42,12 +44,12 @@ std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
 		if (rtp->version() != 2) {
 			PLOG_WARNING << "RTP packet is not version 2";
 
-			return std::nullopt;
+			return nullptr;
 		}
 		if (rtp->payloadType() == 201 || rtp->payloadType() == 200) {
 			PLOG_WARNING << "RTP packet has a payload type indicating RR/SR";
 
-			return std::nullopt;
+			return nullptr;
 		}
 
 		// TODO Implement the padding bit
@@ -55,7 +57,7 @@ std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
 			PLOG_WARNING << "Padding processing not implemented";
 		}
 
-		mSsrc = ntohl(rtp->ssrc);
+		mSsrc = rtp->ssrc();
 
 		uint32_t seqNo = rtp->seqNumber();
 		// uint32_t rtpTS = rtp->getTS();
@@ -85,17 +87,17 @@ std::optional<rtc::message_ptr> RtcpSession::incoming(rtc::message_ptr ptr) {
 		if (mRequestedBitrate > 0)
 			pushREMB(mRequestedBitrate);
 	}
-	return std::nullopt;
+	return nullptr;
 }
 
-void RtcpSession::requestBitrate(unsigned int newBitrate) {
+void RtcpReceivingSession::requestBitrate(unsigned int newBitrate) {
 	mRequestedBitrate = newBitrate;
 
 	PLOG_DEBUG << "[GOOG-REMB] Requesting bitrate: " << newBitrate << std::endl;
 	pushREMB(newBitrate);
 }
 
-void RtcpSession::pushREMB(unsigned int bitrate) {
+void RtcpReceivingSession::pushREMB(unsigned int bitrate) {
 	rtc::message_ptr msg =
 	    rtc::make_message(RTCP_REMB::sizeWithSSRCs(1), rtc::Message::Type::Control);
 	auto remb = reinterpret_cast<RTCP_REMB *>(msg->data());
@@ -103,10 +105,10 @@ void RtcpSession::pushREMB(unsigned int bitrate) {
 	remb->setSSRC(0, mSsrc);
 	remb->log();
 
-	tx(msg);
+    send(msg);
 }
 
-void RtcpSession::pushRR(unsigned int lastSR_delay) {
+void RtcpReceivingSession::pushRR(unsigned int lastSR_delay) {
 	auto msg = rtc::make_message(RTCP_RR::sizeWithReportBlocks(1), rtc::Message::Type::Control);
 	auto rr = reinterpret_cast<RTCP_RR *>(msg->data());
 	rr->preparePacket(mSsrc, 1);
@@ -114,15 +116,16 @@ void RtcpSession::pushRR(unsigned int lastSR_delay) {
 	                                     lastSR_delay);
 	rr->log();
 
-	tx(msg);
+    send(msg);
 }
 
-void RtcpSession::tx(message_ptr msg) {
+bool RtcpReceivingSession::send(message_ptr msg) {
 	try {
-		mTxCallback(msg);
+	    return track->sendControl(std::move(msg));
 	} catch (const std::exception &e) {
 		LOG_DEBUG << "RTCP tx failed: " << e.what();
 	}
+	return false;
 }
 
 } // namespace rtc

+ 31 - 32
src/track.cpp

@@ -77,25 +77,39 @@ void Track::open(shared_ptr<DtlsSrtpTransport> transport) {
 #endif
 
 bool Track::outgoing(message_ptr message) {
-	auto direction = mMediaDescription.direction();
-	if (direction == Description::Direction::RecvOnly ||
-	    direction == Description::Direction::Inactive)
-		throw std::runtime_error("Track media direction does not allow sending");
 
-	if (mIsClosed)
-		throw std::runtime_error("Track is closed");
+    if (mRtcpHandler) {
+        message = mRtcpHandler->outgoing(message);
+        if (!message)
+            return false;
+    }
+
+    auto direction = mMediaDescription.direction();
+    if ((direction == Description::Direction::SendOnly ||
+         direction == Description::Direction::Inactive) &&
+        message->type != Message::Control) {
+        PLOG_WARNING << "Track media direction does not allow reception, dropping";
+    }
+
+	return sendControl(message);
+}
 
-	if (message->size() > maxMessageSize())
-		throw std::runtime_error("Message size exceeds limit");
+bool Track::sendControl(message_ptr message) {
+
+    if (mIsClosed)
+        throw std::runtime_error("Track is closed");
+
+    if (message->size() > maxMessageSize())
+        throw std::runtime_error("Message size exceeds limit");
 
 #if RTC_ENABLE_MEDIA
-	auto transport = mDtlsSrtpTransport.lock();
-	if (!transport)
-		throw std::runtime_error("Track transport is not open");
+    auto transport = mDtlsSrtpTransport.lock();
+    if (!transport)
+        throw std::runtime_error("Track transport is not open");
 
-	return transport->sendMedia(message);
+    return transport->sendMedia(message);
 #else
-	PLOG_WARNING << "Ignoring track send (not compiled with SRTP support)";
+    PLOG_WARNING << "Ignoring track send (not compiled with SRTP support)";
 	return false;
 #endif
 }
@@ -105,12 +119,10 @@ void Track::incoming(message_ptr message) {
 		return;
 
 	if (mRtcpHandler) {
-		auto opt = mRtcpHandler->incoming(message);
-		if (!opt)
-			return;
-
-		message = *opt;
-	}
+        message = mRtcpHandler->incoming(message);
+        if (!message)
+            return ;
+    }
 
 	auto direction = mMediaDescription.direction();
 	if ((direction == Description::Direction::SendOnly ||
@@ -128,20 +140,7 @@ void Track::incoming(message_ptr message) {
 }
 
 void Track::setRtcpHandler(std::shared_ptr<RtcpHandler> handler) {
-	if (mRtcpHandler)
-		mRtcpHandler->onOutgoing(nullptr);
-
 	mRtcpHandler = std::move(handler);
-	if (mRtcpHandler) {
-		mRtcpHandler->onOutgoing([&]([[maybe_unused]] message_ptr message) {
-#if RTC_ENABLE_MEDIA
-			if (auto transport = mDtlsSrtpTransport.lock())
-				transport->sendMedia(message);
-#else
-			PLOG_WARNING << "Ignoring RTCP send (not compiled with SRTP support)";
-#endif
-		});
-	}
 }
 
 } // namespace rtc