소스 검색

Prevent message copy with move semantics

Paul-Louis Ageneau 5 년 전
부모
커밋
3c971e05dd
10개의 변경된 파일73개의 추가작업 그리고 84개의 파일을 삭제
  1. 10 9
      include/rtc/channel.hpp
  2. 2 2
      include/rtc/datachannel.hpp
  3. 27 0
      include/rtc/message.hpp
  4. 3 3
      include/rtc/peerconnection.hpp
  5. 2 2
      include/rtc/websocket.hpp
  6. 7 9
      src/channel.cpp
  7. 7 24
      src/datachannel.cpp
  8. 4 6
      src/peerconnection.cpp
  9. 5 5
      src/rtc.cpp
  10. 6 24
      src/websocket.cpp

+ 10 - 9
include/rtc/channel.hpp

@@ -20,6 +20,7 @@
 #define RTC_CHANNEL_H
 #define RTC_CHANNEL_H
 
 
 #include "include.hpp"
 #include "include.hpp"
+#include "message.hpp"
 
 
 #include <atomic>
 #include <atomic>
 #include <functional>
 #include <functional>
@@ -33,7 +34,7 @@ public:
 	virtual ~Channel() = default;
 	virtual ~Channel() = default;
 
 
 	virtual void close() = 0;
 	virtual void close() = 0;
-	virtual bool send(const std::variant<binary, string> &data) = 0; // returns false if buffered
+	virtual bool send(message_variant data) = 0; // returns false if buffered
 
 
 	virtual bool isOpen() const = 0;
 	virtual bool isOpen() const = 0;
 	virtual bool isClosed() const = 0;
 	virtual bool isClosed() const = 0;
@@ -42,24 +43,24 @@ public:
 
 
 	void onOpen(std::function<void()> callback);
 	void onOpen(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
-	void onError(std::function<void(const string &error)> callback);
+	void onError(std::function<void(string error)> callback);
 
 
-	void onMessage(std::function<void(const std::variant<binary, string> &data)> callback);
-	void onMessage(std::function<void(const binary &data)> binaryCallback,
-	               std::function<void(const string &data)> stringCallback);
+	void onMessage(std::function<void(message_variant data)> callback);
+	void onMessage(std::function<void(binary data)> binaryCallback,
+	               std::function<void(string data)> stringCallback);
 
 
 	void onBufferedAmountLow(std::function<void()> callback);
 	void onBufferedAmountLow(std::function<void()> callback);
 	void setBufferedAmountLowThreshold(size_t amount);
 	void setBufferedAmountLowThreshold(size_t amount);
 
 
 	// Extended API
 	// Extended API
-	virtual std::optional<std::variant<binary, string>> receive() = 0; // only if onMessage unset
+	virtual std::optional<message_variant> receive() = 0; // only if onMessage unset
 	virtual size_t availableAmount() const; // total size available to receive
 	virtual size_t availableAmount() const; // total size available to receive
 	void onAvailable(std::function<void()> callback);
 	void onAvailable(std::function<void()> callback);
 
 
 protected:
 protected:
 	virtual void triggerOpen();
 	virtual void triggerOpen();
 	virtual void triggerClosed();
 	virtual void triggerClosed();
-	virtual void triggerError(const string &error);
+	virtual void triggerError(string error);
 	virtual void triggerAvailable(size_t count);
 	virtual void triggerAvailable(size_t count);
 	virtual void triggerBufferedAmount(size_t amount);
 	virtual void triggerBufferedAmount(size_t amount);
 
 
@@ -68,8 +69,8 @@ protected:
 private:
 private:
 	synchronized_callback<> mOpenCallback;
 	synchronized_callback<> mOpenCallback;
 	synchronized_callback<> mClosedCallback;
 	synchronized_callback<> mClosedCallback;
-	synchronized_callback<const string &> mErrorCallback;
-	synchronized_callback<const std::variant<binary, string> &> mMessageCallback;
+	synchronized_callback<string> mErrorCallback;
+	synchronized_callback<message_variant> mMessageCallback;
 	synchronized_callback<> mAvailableCallback;
 	synchronized_callback<> mAvailableCallback;
 	synchronized_callback<> mBufferedAmountLowCallback;
 	synchronized_callback<> mBufferedAmountLowCallback;
 
 

+ 2 - 2
include/rtc/datachannel.hpp

@@ -50,7 +50,7 @@ public:
 	Reliability reliability() const;
 	Reliability reliability() const;
 
 
 	void close(void) override;
 	void close(void) override;
-	bool send(const std::variant<binary, string> &data) override;
+	bool send(message_variant data) override;
 	bool send(const byte *data, size_t size);
 	bool send(const byte *data, size_t size);
 	template <typename Buffer> bool sendBuffer(const Buffer &buf);
 	template <typename Buffer> bool sendBuffer(const Buffer &buf);
 	template <typename Iterator> bool sendBuffer(Iterator first, Iterator last);
 	template <typename Iterator> bool sendBuffer(Iterator first, Iterator last);
@@ -61,7 +61,7 @@ public:
 
 
 	// Extended API
 	// Extended API
 	size_t availableAmount() const override;
 	size_t availableAmount() const override;
-	std::optional<std::variant<binary, string>> receive() override;
+	std::optional<message_variant> receive() override;
 
 
 private:
 private:
 	void remoteClose();
 	void remoteClose();

+ 27 - 0
include/rtc/message.hpp

@@ -24,6 +24,8 @@
 
 
 #include <functional>
 #include <functional>
 #include <memory>
 #include <memory>
+#include <optional>
+#include <variant>
 
 
 namespace rtc {
 namespace rtc {
 
 
@@ -46,6 +48,7 @@ struct Message : binary {
 
 
 using message_ptr = std::shared_ptr<Message>;
 using message_ptr = std::shared_ptr<Message>;
 using message_callback = std::function<void(message_ptr message)>;
 using message_callback = std::function<void(message_ptr message)>;
+using message_variant = std::variant<binary, string>;
 
 
 constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
 constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
 	return m->type == Message::Binary || m->type == Message::String ? m->size() : 0;
 	return m->type == Message::Binary || m->type == Message::String ? m->size() : 0;
@@ -79,6 +82,30 @@ inline message_ptr make_message(binary &&data, Message::Type type = Message::Bin
 	return message;
 	return message;
 }
 }
 
 
+inline message_ptr make_message(message_variant data) {
+	return std::visit( //
+	    overloaded{
+	        [&](binary data) { return make_message(std::move(data), Message::Binary); },
+	        [&](string data) {
+		        auto b = reinterpret_cast<const byte *>(data.data());
+		        return make_message(b, b + data.size(), Message::String);
+	        },
+	    },
+	    std::move(data));
+}
+
+inline std::optional<message_variant> to_variant(Message &&message) {
+	switch (message.type) {
+	case Message::String:
+		return std::make_optional(
+		    string(reinterpret_cast<const char *>(message.data()), message.size()));
+	case Message::Binary:
+		return std::make_optional(std::move(message));
+	default:
+		return nullopt;
+	}
+}
+
 } // namespace rtc
 } // namespace rtc
 
 
 #endif
 #endif

+ 3 - 3
include/rtc/peerconnection.hpp

@@ -102,10 +102,10 @@ public:
 
 
 	// Media
 	// Media
 	bool hasMedia() const;
 	bool hasMedia() const;
-	void sendMedia(const binary &packet);
+	void sendMedia(binary packet);
 	void sendMedia(const byte *packet, size_t size);
 	void sendMedia(const byte *packet, size_t size);
 
 
-	void onMedia(std::function<void(const binary &packet)> callback);
+	void onMedia(std::function<void(binary)> callback);
 
 
 	// libnice only
 	// libnice only
 	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
 	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
@@ -164,7 +164,7 @@ private:
 	synchronized_callback<const Candidate &> mLocalCandidateCallback;
 	synchronized_callback<const Candidate &> mLocalCandidateCallback;
 	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<GatheringState> mGatheringStateChangeCallback;
 	synchronized_callback<GatheringState> mGatheringStateChangeCallback;
-	synchronized_callback<const binary &> mMediaCallback;
+	synchronized_callback<binary> mMediaCallback;
 };
 };
 
 
 } // namespace rtc
 } // namespace rtc

+ 2 - 2
include/rtc/websocket.hpp

@@ -58,14 +58,14 @@ public:
 
 
 	void open(const string &url);
 	void open(const string &url);
 	void close() override;
 	void close() override;
-	bool send(const std::variant<binary, string> &data) override;
+	bool send(const message_variant data) override;
 
 
 	bool isOpen() const override;
 	bool isOpen() const override;
 	bool isClosed() const override;
 	bool isClosed() const override;
 	size_t maxMessageSize() const override;
 	size_t maxMessageSize() const override;
 
 
 	// Extended API
 	// Extended API
-	std::optional<std::variant<binary, string>> receive() override;
+	std::optional<message_variant> receive() override;
 	size_t availableAmount() const override; // total size available to receive
 	size_t availableAmount() const override; // total size available to receive
 
 
 private:
 private:

+ 7 - 9
src/channel.cpp

@@ -34,11 +34,9 @@ void Channel::onClosed(std::function<void()> callback) {
 	mClosedCallback = callback;
 	mClosedCallback = callback;
 }
 }
 
 
-void Channel::onError(std::function<void(const string &error)> callback) {
-	mErrorCallback = callback;
-}
+void Channel::onError(std::function<void(string error)> callback) { mErrorCallback = callback; }
 
 
-void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) {
+void Channel::onMessage(std::function<void(message_variant data)> callback) {
 	mMessageCallback = callback;
 	mMessageCallback = callback;
 
 
 	// Pass pending messages
 	// Pass pending messages
@@ -46,10 +44,10 @@ void Channel::onMessage(std::function<void(const std::variant<binary, string> &d
 		mMessageCallback(*message);
 		mMessageCallback(*message);
 }
 }
 
 
-void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
-                        std::function<void(const string &data)> stringCallback) {
-	onMessage([binaryCallback, stringCallback](const std::variant<binary, string> &data) {
-		std::visit(overloaded{binaryCallback, stringCallback}, data);
+void Channel::onMessage(std::function<void(binary data)> binaryCallback,
+                        std::function<void(string data)> stringCallback) {
+	onMessage([binaryCallback, stringCallback](std::variant<binary, string> data) {
+		std::visit(overloaded{binaryCallback, stringCallback}, std::move(data));
 	});
 	});
 }
 }
 
 
@@ -67,7 +65,7 @@ void Channel::triggerOpen() { mOpenCallback(); }
 
 
 void Channel::triggerClosed() { mClosedCallback(); }
 void Channel::triggerClosed() { mClosedCallback(); }
 
 
-void Channel::triggerError(const string &error) { mErrorCallback(error); }
+void Channel::triggerError(string error) { mErrorCallback(error); }
 
 
 void Channel::triggerAvailable(size_t count) {
 void Channel::triggerAvailable(size_t count) {
 	if (count == 1)
 	if (count == 1)

+ 7 - 24
src/datachannel.cpp

@@ -117,40 +117,23 @@ void DataChannel::remoteClose() {
 	mSctpTransport.reset();
 	mSctpTransport.reset();
 }
 }
 
 
-bool DataChannel::send(const std::variant<binary, string> &data) {
-	return std::visit(
-	    [&](const auto &d) {
-		    using T = std::decay_t<decltype(d)>;
-		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
-		    auto *b = reinterpret_cast<const byte *>(d.data());
-		    return outgoing(std::make_shared<Message>(b, b + d.size(), type));
-	    },
-	    data);
-}
+bool DataChannel::send(message_variant data) { return outgoing(make_message(std::move(data))); }
 
 
 bool DataChannel::send(const byte *data, size_t size) {
 bool DataChannel::send(const byte *data, size_t size) {
 	return outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
 	return outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
 }
 }
 
 
-std::optional<std::variant<binary, string>> DataChannel::receive() {
+std::optional<message_variant> DataChannel::receive() {
 	while (!mRecvQueue.empty()) {
 	while (!mRecvQueue.empty()) {
 		auto message = *mRecvQueue.pop();
 		auto message = *mRecvQueue.pop();
-		switch (message->type) {
-		case Message::Control: {
+		if (message->type == Message::Control) {
 			auto raw = reinterpret_cast<const uint8_t *>(message->data());
 			auto raw = reinterpret_cast<const uint8_t *>(message->data());
-			if (raw[0] == MESSAGE_CLOSE)
+			if (!message->empty() && raw[0] == MESSAGE_CLOSE)
 				remoteClose();
 				remoteClose();
-			break;
-		}
-		case Message::String:
-			return std::make_optional(
-			    string(reinterpret_cast<const char *>(message->data()), message->size()));
-		case Message::Binary:
-			return std::make_optional(std::move(*message));
-		default:
-			// Ignore
-			break;
+			continue;
 		}
 		}
+		if (auto variant = to_variant(std::move(*message)))
+			return variant;
 	}
 	}
 
 
 	return nullopt;
 	return nullopt;

+ 4 - 6
src/peerconnection.cpp

@@ -240,17 +240,15 @@ bool PeerConnection::hasMedia() const {
 	return (local && local->hasMedia()) || (remote && remote->hasMedia());
 	return (local && local->hasMedia()) || (remote && remote->hasMedia());
 }
 }
 
 
-void PeerConnection::sendMedia(const binary &packet) {
-	outgoingMedia(make_message(packet.begin(), packet.end(), Message::Binary));
+void PeerConnection::sendMedia(binary packet) {
+	outgoingMedia(make_message(std::move(packet), Message::Binary));
 }
 }
 
 
 void PeerConnection::sendMedia(const byte *packet, size_t size) {
 void PeerConnection::sendMedia(const byte *packet, size_t size) {
 	outgoingMedia(make_message(packet, packet + size, Message::Binary));
 	outgoingMedia(make_message(packet, packet + size, Message::Binary));
 }
 }
 
 
-void PeerConnection::onMedia(std::function<void(const binary &packet)> callback) {
-	mMediaCallback = callback;
-}
+void PeerConnection::onMedia(std::function<void(binary)> callback) { mMediaCallback = callback; }
 
 
 void PeerConnection::outgoingMedia([[maybe_unused]] message_ptr message) {
 void PeerConnection::outgoingMedia([[maybe_unused]] message_ptr message) {
 	if (!hasMedia())
 	if (!hasMedia())
@@ -529,7 +527,7 @@ void PeerConnection::forwardMessage(message_ptr message) {
 
 
 void PeerConnection::forwardMedia(message_ptr message) {
 void PeerConnection::forwardMedia(message_ptr message) {
 	if (message)
 	if (message)
-		mMediaCallback(*message);
+		mMediaCallback(std::move(*message));
 }
 }
 
 
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {

+ 5 - 5
src/rtc.cpp

@@ -542,7 +542,7 @@ int rtcSetErrorCallback(int id, rtcErrorCallbackFunc cb) {
 	return WRAP({
 	return WRAP({
 		auto channel = getChannel(id);
 		auto channel = getChannel(id);
 		if (cb)
 		if (cb)
-			channel->onError([id, cb](const string &error) {
+			channel->onError([id, cb](string error) {
 				if (auto ptr = getUserPointer(id))
 				if (auto ptr = getUserPointer(id))
 					cb(error.c_str(), *ptr);
 					cb(error.c_str(), *ptr);
 			});
 			});
@@ -556,11 +556,11 @@ int rtcSetMessageCallback(int id, rtcMessageCallbackFunc cb) {
 		auto channel = getChannel(id);
 		auto channel = getChannel(id);
 		if (cb)
 		if (cb)
 			channel->onMessage(
 			channel->onMessage(
-			    [id, cb](const binary &b) {
+			    [id, cb](binary b) {
 				    if (auto ptr = getUserPointer(id))
 				    if (auto ptr = getUserPointer(id))
 					    cb(reinterpret_cast<const char *>(b.data()), int(b.size()), *ptr);
 					    cb(reinterpret_cast<const char *>(b.data()), int(b.size()), *ptr);
 			    },
 			    },
-			    [id, cb](const string &s) {
+			    [id, cb](string s) {
 				    if (auto ptr = getUserPointer(id))
 				    if (auto ptr = getUserPointer(id))
 					    cb(s.c_str(), -int(s.size() + 1), *ptr);
 					    cb(s.c_str(), -int(s.size() + 1), *ptr);
 			    });
 			    });
@@ -643,13 +643,13 @@ int rtcReceiveMessage(int id, char *buffer, int *size) {
 		if (auto message = channel->receive())
 		if (auto message = channel->receive())
 			return std::visit( //
 			return std::visit( //
 			    overloaded{    //
 			    overloaded{    //
-			               [&](const binary &b) {
+			               [&](binary b) {
 				               *size = std::min(*size, int(b.size()));
 				               *size = std::min(*size, int(b.size()));
 				               auto data = reinterpret_cast<const char *>(b.data());
 				               auto data = reinterpret_cast<const char *>(b.data());
 				               std::copy(data, data + *size, buffer);
 				               std::copy(data, data + *size, buffer);
 				               return 1;
 				               return 1;
 			               },
 			               },
-			               [&](const string &s) {
+			               [&](string s) {
 				               int len = std::min(*size - 1, int(s.size()));
 				               int len = std::min(*size - 1, int(s.size()));
 				               if (len >= 0) {
 				               if (len >= 0) {
 					               std::copy(s.data(), s.data() + len, buffer);
 					               std::copy(s.data(), s.data() + len, buffer);

+ 6 - 24
src/websocket.cpp

@@ -100,16 +100,7 @@ void WebSocket::remoteClose() {
 	}
 	}
 }
 }
 
 
-bool WebSocket::send(const std::variant<binary, string> &data) {
-	return std::visit(
-	    [&](const auto &d) {
-		    using T = std::decay_t<decltype(d)>;
-		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
-		    auto *b = reinterpret_cast<const byte *>(d.data());
-		    return outgoing(std::make_shared<Message>(b, b + d.size(), type));
-	    },
-	    data);
-}
+bool WebSocket::send(message_variant data) { return outgoing(make_message(std::move(data))); }
 
 
 bool WebSocket::isOpen() const { return mState == State::Open; }
 bool WebSocket::isOpen() const { return mState == State::Open; }
 
 
@@ -117,20 +108,11 @@ bool WebSocket::isClosed() const { return mState == State::Closed; }
 
 
 size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
 
-std::optional<std::variant<binary, string>> WebSocket::receive() {
-	while (!mRecvQueue.empty()) {
-		auto message = *mRecvQueue.pop();
-		switch (message->type) {
-		case Message::String:
-			return std::make_optional(
-			    string(reinterpret_cast<const char *>(message->data()), message->size()));
-		case Message::Binary:
-			return std::make_optional(std::move(*message));
-		default:
-			// Ignore
-			break;
-		}
-	}
+std::optional<message_variant> WebSocket::receive() {
+	while (!mRecvQueue.empty())
+		if (auto variant = to_variant(std::move(**mRecvQueue.pop())))
+			return variant;
+
 	return nullopt;
 	return nullopt;
 }
 }