소스 검색

Fixed rtcReceiveMessage() and refactor C API buffer handling

Paul-Louis Ageneau 4 년 전
부모
커밋
2bcdab027c
9개의 변경된 파일134개의 추가작업 그리고 164개의 파일을 삭제
  1. 2 2
      include/rtc/channel.hpp
  2. 1 0
      include/rtc/datachannel.hpp
  3. 4 2
      include/rtc/rtc.h
  4. 1 0
      include/rtc/track.hpp
  5. 1 0
      include/rtc/websocket.hpp
  6. 83 152
      src/capi.cpp
  7. 23 7
      src/datachannel.cpp
  8. 7 0
      src/track.cpp
  9. 12 1
      src/websocket.cpp

+ 2 - 2
include/rtc/channel.hpp

@@ -54,7 +54,8 @@ public:
 
 	// Extended API
 	virtual std::optional<message_variant> receive() = 0; // only if onMessage unset
-	virtual size_t availableAmount() const; // total size available to receive
+	virtual std::optional<message_variant> peek() = 0;    // only if onMessage unset
+	virtual size_t availableAmount() const;               // total size available to receive
 	void onAvailable(std::function<void()> callback);
 
 protected:
@@ -81,4 +82,3 @@ private:
 } // namespace rtc
 
 #endif // RTC_CHANNEL_H
-

+ 1 - 0
include/rtc/datachannel.hpp

@@ -62,6 +62,7 @@ public:
 	// Extended API
 	size_t availableAmount() const override;
 	std::optional<message_variant> receive() override;
+	std::optional<message_variant> peek() override;
 
 private:
 	void remoteClose();

+ 4 - 2
include/rtc/rtc.h

@@ -78,8 +78,10 @@ typedef enum { // Don't change, it must match plog severity
 } rtcLogLevel;
 
 #define RTC_ERR_SUCCESS 0
-#define RTC_ERR_INVALID -1 // invalid argument
-#define RTC_ERR_FAILURE -2 // runtime error
+#define RTC_ERR_INVALID -1   // invalid argument
+#define RTC_ERR_FAILURE -2   // runtime error
+#define RTC_ERR_NOT_AVAIL -3 // element not available
+#define RTC_ERR_TOO_SMALL -4 // buffer too small
 
 typedef struct {
 	const char **iceServers;

+ 1 - 0
include/rtc/track.hpp

@@ -56,6 +56,7 @@ public:
 	// Extended API
 	size_t availableAmount() const override;
 	std::optional<message_variant> receive() override;
+	std::optional<message_variant> peek() override;
 
 	// RTCP handler
 	void setRtcpHandler(std::shared_ptr<RtcpHandler> handler);

+ 1 - 0
include/rtc/websocket.hpp

@@ -66,6 +66,7 @@ public:
 
 	// Extended API
 	std::optional<message_variant> receive() override;
+	std::optional<message_variant> peek() override;
 	size_t availableAmount() const override; // total size available to receive
 
 private:

+ 83 - 152
src/capi.cpp

@@ -195,6 +195,31 @@ template <typename F> int wrap(F func) {
 		return RTC_ERR_SUCCESS;                                                                    \
 	})
 
+int copyAndReturn(string s, char *buffer, int size) {
+	if (!buffer)
+		return int(s.size() + 1);
+
+	if (size < int(s.size()))
+		return RTC_ERR_TOO_SMALL;
+
+	std::copy(s.begin(), s.end(), buffer);
+	buffer[s.size()] = '\0';
+	return int(s.size() + 1);
+}
+
+int copyAndReturn(binary b, char *buffer, int size) {
+	if (!buffer)
+		return int(b.size());
+
+	if (size < int(b.size()))
+		return RTC_ERR_TOO_SMALL;
+
+	auto data = reinterpret_cast<const char *>(b.data());
+	std::copy(data, data + b.size(), buffer);
+	buffer[b.size()] = '\0';
+	return int(b.size());
+}
+
 class plogAppender : public plog::IAppender {
 public:
 	plogAppender(rtcLogCallbackFunc cb = nullptr) { setCallback(cb); }
@@ -366,19 +391,7 @@ int rtcDeleteTrack(int tr) {
 int rtcGetTrackDescription(int tr, char *buffer, int size) {
 	return WRAP({
 		auto track = getTrack(tr);
-
-		if (size <= 0)
-			return 0;
-
-		if (!buffer)
-			throw std::invalid_argument("Unexpected null pointer for buffer");
-
-		string description(track->description());
-		const char *data = description.data();
-		size = std::min(size - 1, int(description.size()));
-		std::copy(data, data + size, buffer);
-		buffer[size] = '\0';
-		return int(size + 1);
+		return copyAndReturn(track->description(), buffer, size);
 	});
 }
 
@@ -528,7 +541,7 @@ int rtcSetRemoteDescription(int pc, const char *sdp, const char *type) {
 		if (!sdp)
 			throw std::invalid_argument("Unexpected null pointer for remote description");
 
-		peerConnection->setRemoteDescription({string(sdp), type ? string(type) : "" });
+		peerConnection->setRemoteDescription({string(sdp), type ? string(type) : ""});
 	});
 }
 
@@ -547,22 +560,10 @@ int rtcGetLocalDescription(int pc, char *buffer, int size) {
 	return WRAP({
 		auto peerConnection = getPeerConnection(pc);
 
-		if (size <= 0)
-			return 0;
-
-		if (!buffer)
-			throw std::invalid_argument("Unexpected null pointer for buffer");
-
-		if (auto desc = peerConnection->localDescription()) {
-			auto sdp = string(*desc);
-			const char *data = sdp.data();
-			size = std::min(size - 1, int(sdp.size()));
-			std::copy(data, data + size, buffer);
-			buffer[size] = '\0';
-			return size + 1;
-		}
-
-		return RTC_ERR_FAILURE;
+		if (auto desc = peerConnection->localDescription())
+			return copyAndReturn(string(*desc), buffer, size);
+		else
+			return RTC_ERR_NOT_AVAIL;
 	});
 }
 
@@ -570,22 +571,10 @@ int rtcGetRemoteDescription(int pc, char *buffer, int size) {
 	return WRAP({
 		auto peerConnection = getPeerConnection(pc);
 
-		if (size <= 0)
-			return 0;
-
-		if (!buffer)
-			throw std::invalid_argument("Unexpected null pointer for buffer");
-
-		if (auto desc = peerConnection->remoteDescription()) {
-			auto sdp = string(*desc);
-			const char *data = sdp.data();
-			size = std::min(size - 1, int(sdp.size()));
-			std::copy(data, data + size, buffer);
-			buffer[size] = '\0';
-			return size + 1;
-		}
-
-		return RTC_ERR_FAILURE;
+		if (auto desc = peerConnection->remoteDescription())
+			return copyAndReturn(string(*desc), buffer, size);
+		else
+			return RTC_ERR_NOT_AVAIL;
 	});
 }
 
@@ -593,21 +582,10 @@ int rtcGetLocalAddress(int pc, char *buffer, int size) {
 	return WRAP({
 		auto peerConnection = getPeerConnection(pc);
 
-		if (size <= 0)
-			return 0;
-
-		if (!buffer)
-			throw std::invalid_argument("Unexpected null pointer for buffer");
-
-		if (auto addr = peerConnection->localAddress()) {
-			const char *data = addr->data();
-			size = std::min(size - 1, int(addr->size()));
-			std::copy(data, data + size, buffer);
-			buffer[size] = '\0';
-			return size + 1;
-		}
-
-		return RTC_ERR_FAILURE;
+		if (auto addr = peerConnection->localAddress())
+			return copyAndReturn(std::move(*addr), buffer, size);
+		else
+			return RTC_ERR_NOT_AVAIL;
 	});
 }
 
@@ -615,21 +593,10 @@ int rtcGetRemoteAddress(int pc, char *buffer, int size) {
 	return WRAP({
 		auto peerConnection = getPeerConnection(pc);
 
-		if (size <= 0)
-			return 0;
-
-		if (!buffer)
-			throw std::invalid_argument("Unexpected null pointer for buffer");
-
-		if (auto addr = peerConnection->remoteAddress()) {
-			const char *data = addr->data();
-			size = std::min(size - 1, int(addr->size()));
-			std::copy(data, data + size, buffer);
-			buffer[size] = '\0';
-			return int(size + 1);
-		}
-
-		return RTC_ERR_FAILURE;
+		if (auto addr = peerConnection->remoteAddress())
+			return copyAndReturn(std::move(*addr), buffer, size);
+		else
+			return RTC_ERR_NOT_AVAIL;
 	});
 }
 
@@ -637,68 +604,28 @@ int rtcGetSelectedCandidatePair(int pc, char *local, int localSize, char *remote
 	return WRAP({
 		auto peerConnection = getPeerConnection(pc);
 
-		if (!local)
-			localSize = 0;
-		if (!remote)
-			remoteSize = 0;
-
 		Candidate localCand;
 		Candidate remoteCand;
-		if (peerConnection->getSelectedCandidatePair(&localCand, &remoteCand)) {
-			if (localSize > 0) {
-				string localSdp = string(localCand);
-				localSize = std::min(localSize - 1, int(localSdp.size()));
-				std::copy(localSdp.begin(), localSdp.begin() + localSize, local);
-				local[localSize] = '\0';
-			}
-			if (remoteSize > 0) {
-				string remoteSdp = string(remoteCand);
-				remoteSize = std::min(remoteSize - 1, int(remoteSdp.size()));
-				std::copy(remoteSdp.begin(), remoteSdp.begin() + remoteSize, remote);
-				remote[remoteSize] = '\0';
-			}
-			return localSize + remoteSize;
-		}
+		if (!peerConnection->getSelectedCandidatePair(&localCand, &remoteCand))
+			return RTC_ERR_NOT_AVAIL;
 
-		return RTC_ERR_FAILURE;
+		copyAndReturn(string(localCand), local, localSize);
+		copyAndReturn(string(remoteCand), remote, remoteSize);
+		return RTC_ERR_SUCCESS;
 	});
 }
 
 int rtcGetDataChannelLabel(int dc, char *buffer, int size) {
 	return WRAP({
 		auto dataChannel = getDataChannel(dc);
-
-		if (size <= 0)
-			return 0;
-
-		if (!buffer)
-			throw std::invalid_argument("Unexpected null pointer for buffer");
-
-		string label = dataChannel->label();
-		const char *data = label.data();
-		size = std::min(size - 1, int(label.size()));
-		std::copy(data, data + size, buffer);
-		buffer[size] = '\0';
-		return int(size + 1);
+		return copyAndReturn(dataChannel->label(), buffer, size);
 	});
 }
 
 int rtcGetDataChannelProtocol(int dc, char *buffer, int size) {
 	return WRAP({
 		auto dataChannel = getDataChannel(dc);
-
-		if (size <= 0)
-			return 0;
-
-		if (!buffer)
-			throw std::invalid_argument("Unexpected null pointer for buffer");
-
-		string protocol = dataChannel->protocol();
-		const char *data = protocol.data();
-		size = std::min(size - 1, int(protocol.size()));
-		std::copy(data, data + size, buffer);
-		buffer[size] = '\0';
-		return int(size + 1);
+		return copyAndReturn(dataChannel->protocol(), buffer, size);
 	});
 }
 
@@ -721,7 +648,7 @@ int rtcGetDataChannelReliability(int dc, rtcReliability *reliability) {
 		} else {
 			reliability->unreliable = false;
 		}
-		return 0;
+		return RTC_ERR_SUCCESS;
 	});
 }
 
@@ -853,34 +780,38 @@ int rtcReceiveMessage(int id, char *buffer, int *size) {
 		if (!size)
 			throw std::invalid_argument("Unexpected null pointer for size");
 
-		if (!buffer && *size != 0)
-			throw std::invalid_argument("Unexpected null pointer for buffer");
-
-		if (auto message = channel->receive())
-			return std::visit( //
-			    overloaded{    //
-			               [&](binary b) {
-				               if (*size > 0) {
-					               *size = std::min(*size, int(b.size()));
-					               auto data = reinterpret_cast<const char *>(b.data());
-					               std::copy(data, data + *size, buffer);
-				               }
-				               return 1;
-			               },
-			               [&](string s) {
-				               if (*size > 0) {
-					               int len = std::min(*size - 1, int(s.size()));
-					               if (len >= 0) {
-						               std::copy(s.data(), s.data() + len, buffer);
-						               buffer[len] = '\0';
-					               }
-					               *size = -(len + 1);
-				               }
-				               return 1;
-			               }},
-			    *message);
-		else
-			return 0;
+		*size = std::abs(*size);
+
+		auto message = channel->peek();
+		if (!message)
+			return RTC_ERR_NOT_AVAIL;
+
+		return std::visit( //
+		    overloaded{
+		        [&](binary b) {
+			        int ret = copyAndReturn(std::move(b), buffer, *size);
+			        if (ret >= 0) {
+				        channel->receive(); // discard
+				        *size = ret;
+				        return RTC_ERR_SUCCESS;
+			        } else {
+				        *size = int(b.size());
+				        return ret;
+			        }
+		        },
+		        [&](string s) {
+			        int ret = copyAndReturn(std::move(s), buffer, *size);
+			        if (ret >= 0) {
+				        channel->receive(); // discard
+				        *size = -ret;
+				        return RTC_ERR_SUCCESS;
+			        } else {
+				        *size = -int(s.size() + 1);
+				        return ret;
+			        }
+		        },
+		    },
+		    *message);
 	});
 }
 

+ 23 - 7
src/datachannel.cpp

@@ -123,19 +123,35 @@ bool DataChannel::send(const byte *data, size_t size) {
 
 std::optional<message_variant> DataChannel::receive() {
 	while (auto next = mRecvQueue.tryPop()) {
-		message_ptr message = std::move(*next);
-		if (message->type == Message::Control) {
-			auto raw = reinterpret_cast<const uint8_t *>(message->data());
-			if (!message->empty() && raw[0] == MESSAGE_CLOSE)
-				remoteClose();
-		} else {
+		message_ptr message = *next;
+		if (message->type != Message::Control)
 			return to_variant(std::move(*message));
-		}
+
+		auto raw = reinterpret_cast<const uint8_t *>(message->data());
+		if (!message->empty() && raw[0] == MESSAGE_CLOSE)
+			remoteClose();
 	}
 
 	return nullopt;
 }
 
+std::optional<message_variant> DataChannel::peek() {
+	while (auto next = mRecvQueue.peek()) {
+		message_ptr message = *next;
+		if (message->type != Message::Control)
+			return to_variant(std::move(*message));
+
+		auto raw = reinterpret_cast<const uint8_t *>(message->data());
+		if (!message->empty() && raw[0] == MESSAGE_CLOSE)
+			remoteClose();
+
+		mRecvQueue.tryPop();
+	}
+
+	return nullopt;
+}
+
+
 bool DataChannel::isOpen(void) const { return mIsOpen; }
 
 bool DataChannel::isClosed(void) const { return mIsClosed; }

+ 7 - 0
src/track.cpp

@@ -58,6 +58,13 @@ std::optional<message_variant> Track::receive() {
 	return nullopt;
 }
 
+std::optional<message_variant> Track::peek() {
+	if (auto next = mRecvQueue.peek())
+		return to_variant(std::move(**next));
+
+	return nullopt;
+}
+
 bool Track::isOpen(void) const {
 #if RTC_ENABLE_MEDIA
 	return !mIsClosed && mDtlsSrtpTransport.lock();

+ 12 - 1
src/websocket.cpp

@@ -125,13 +125,24 @@ size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
 std::optional<message_variant> WebSocket::receive() {
 	while (auto next = mRecvQueue.tryPop()) {
-		message_ptr message = std::move(*next);
+		message_ptr message = *next;
 		if (message->type != Message::Control)
 			return to_variant(std::move(*message));
 	}
 	return nullopt;
 }
 
+std::optional<message_variant> WebSocket::peek() {
+	while (auto next = mRecvQueue.peek()) {
+		message_ptr message = *next;
+		if (message->type != Message::Control)
+			return to_variant(std::move(*message));
+
+		mRecvQueue.tryPop();
+	}
+	return nullopt;
+}
+
 size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
 
 bool WebSocket::changeState(State state) { return mState.exchange(state) != state; }