Browse Source

Implemented buffered amount for WebSocket

Paul-Louis Ageneau 3 years ago
parent
commit
80329260e6

+ 4 - 0
src/impl/sctptransport.cpp

@@ -333,6 +333,10 @@ SctpTransport::~SctpTransport() {
 	Instances->erase(this);
 }
 
+void SctpTransport::onBufferedAmount(amount_callback callback) {
+	mBufferedAmountCallback = std::move(callback);
+}
+
 void SctpTransport::start() {
 	Transport::start();
 

+ 2 - 4
src/impl/sctptransport.hpp

@@ -53,6 +53,8 @@ public:
 	              state_callback stateChangeCallback);
 	~SctpTransport();
 
+	void onBufferedAmount(amount_callback callback);
+
 	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override; // false if buffered
@@ -61,10 +63,6 @@ public:
 
 	unsigned int maxStream() const;
 
-	void onBufferedAmount(amount_callback callback) {
-		mBufferedAmountCallback = std::move(callback);
-	}
-
 	// Stats
 	void clearStats();
 	size_t bytesSent();

+ 35 - 3
src/impl/tcptransport.cpp

@@ -71,6 +71,14 @@ TcpTransport::TcpTransport(socket_t sock, state_callback callback)
 
 TcpTransport::~TcpTransport() { stop(); }
 
+void TcpTransport::onBufferedAmount(amount_callback callback) {
+	mBufferedAmountCallback = std::move(callback);
+}
+
+void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) {
+	mReadTimeout = readTimeout;
+}
+
 void TcpTransport::start() {
 	Transport::start();
 
@@ -117,14 +125,13 @@ bool TcpTransport::outgoing(message_ptr message) {
 		return true;
 
 	mSendQueue.push(message);
+	updateBufferedAmount(ptrdiff_t(message->size()));
 	setPoll(PollService::Direction::Both);
 	return false;
 }
 
 string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
 
-void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) { mReadTimeout = readTimeout; }
-
 void TcpTransport::connect() {
 	PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
 	changeState(State::Connecting);
@@ -269,11 +276,15 @@ bool TcpTransport::trySendQueue() {
 	// mSendMutex must be locked
 	while (auto next = mSendQueue.peek()) {
 		message_ptr message = std::move(*next);
-		if (!trySendMessage(message)) {
+		size_t size = message->size();
+		if (!trySendMessage(message)) { // replaces message
 			mSendQueue.exchange(message);
+			updateBufferedAmount(-ptrdiff_t(size) + ptrdiff_t(message->size()));
 			return false;
 		}
+
 		mSendQueue.pop();
+		updateBufferedAmount(-ptrdiff_t(size));
 	}
 
 	return true;
@@ -281,6 +292,7 @@ bool TcpTransport::trySendQueue() {
 
 bool TcpTransport::trySendMessage(message_ptr &message) {
 	// mSendMutex must be locked
+
 	auto data = reinterpret_cast<const char *>(message->data());
 	auto size = message->size();
 	while (size) {
@@ -307,6 +319,26 @@ bool TcpTransport::trySendMessage(message_ptr &message) {
 	return true;
 }
 
+void TcpTransport::updateBufferedAmount(ptrdiff_t delta) {
+	// Requires mSendMutex to be locked
+
+	if (delta == 0)
+		return;
+
+	mBufferedAmount = size_t(std::max(ptrdiff_t(mBufferedAmount) + delta, ptrdiff_t(0)));
+
+	// Synchronously call the buffered amount callback
+	triggerBufferedAmount(mBufferedAmount);
+}
+
+void TcpTransport::triggerBufferedAmount(size_t amount) {
+	try {
+		mBufferedAmountCallback(amount);
+	} catch (const std::exception &e) {
+		PLOG_WARNING << "TCP buffered amount callback: " << e.what();
+	}
+}
+
 void TcpTransport::process(PollService::Event event) {
 	try {
 		switch (event) {

+ 9 - 2
src/impl/tcptransport.hpp

@@ -34,10 +34,15 @@ namespace rtc::impl {
 
 class TcpTransport : public Transport {
 public:
+	using amount_callback = std::function<void(size_t amount)>;
+
 	TcpTransport(string hostname, string service, state_callback callback); // active
 	TcpTransport(socket_t sock, state_callback callback);                   // passive
 	~TcpTransport();
 
+	void onBufferedAmount(amount_callback callback);
+	void setReadTimeout(std::chrono::milliseconds readTimeout);
+
 	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override;
@@ -48,8 +53,6 @@ public:
 	bool isActive() const { return mIsActive; }
 	string remoteAddress() const;
 
-	void setReadTimeout(std::chrono::milliseconds readTimeout);
-
 private:
 	void connect();
 	void prepare(const sockaddr *addr, socklen_t addrlen);
@@ -58,15 +61,19 @@ private:
 
 	bool trySendQueue();
 	bool trySendMessage(message_ptr &message);
+	void updateBufferedAmount(ptrdiff_t delta);
+	void triggerBufferedAmount(size_t amount);
 
 	void process(PollService::Event event);
 
 	const bool mIsActive;
 	string mHostname, mService;
+	amount_callback mBufferedAmountCallback;
 	optional<std::chrono::milliseconds> mReadTimeout;
 
 	socket_t mSock;
 	Queue<message_ptr> mSendQueue;
+	size_t mBufferedAmount = 0;
 	std::mutex mSendMutex;
 };
 

+ 5 - 0
src/impl/websocket.cpp

@@ -221,6 +221,8 @@ shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> tra
 		if (std::atomic_load(&mTcpTransport))
 			throw std::logic_error("TCP transport is already set");
 
+		transport->onBufferedAmount(weak_bind(&WebSocket::triggerBufferedAmount, this, _1));
+
 		transport->onStateChange([this, weak_this = weak_from_this()](State transportState) {
 			auto shared_this = weak_this.lock();
 			if (!shared_this)
@@ -410,6 +412,9 @@ void WebSocket::closeTransports() {
 	if (ws)
 		ws->onRecv(nullptr);
 
+	if (tcp)
+		tcp->onBufferedAmount(nullptr);
+
 	using array = std::array<shared_ptr<Transport>, 3>;
 	array transports{std::move(ws), std::move(tls), std::move(tcp)};