Browse Source

Merge pull request #152 from paullouisageneau/prevent-message-copy

Prevent message copy
Paul-Louis Ageneau 5 years ago
parent
commit
09db03ba02

+ 1 - 0
CMakeLists.txt

@@ -49,6 +49,7 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/icetransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/init.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/log.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/message.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/peerconnection.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp

+ 4 - 4
README.md

@@ -118,12 +118,12 @@ config.iceServers.emplace_back("mystunserver.org:3478");
 
 auto pc = make_shared<rtc::PeerConnection>(config);
 
-pc->onLocalDescription([](const rtc::Description &sdp) {
+pc->onLocalDescription([](rtc::Description sdp) {
     // Send the SDP to the remote peer
     MY_SEND_DESCRIPTION_TO_REMOTE(string(sdp));
 });
 
-pc->onLocalCandidate([](const rtc::Candidate &candidate) {
+pc->onLocalCandidate([](rtc::Candidate candidate) {
     // Send the candidate to the remote peer
     MY_SEND_CANDIDATE_TO_REMOTE(candidate.candidate(), candidate.mid());
 });
@@ -159,7 +159,7 @@ dc->onOpen([]() {
     cout << "Open" << endl;
 });
 
-dc->onMessage([](const variant<binary, string> &message) {
+dc->onMessage([](variant<binary, string> message) {
     if (holds_alternative<string>(message)) {
         cout << "Received: " << get<string>(message) << endl;
     }
@@ -186,7 +186,7 @@ ws->onOpen([]() {
 	cout << "WebSocket open" << endl;
 });
 
-ws->onMessage([](const variant<binary, string> &message) {
+ws->onMessage([](variant<binary, string> message) {
     if (holds_alternative<string>(message)) {
         cout << "WebSocket received: " << get<string>(message) << endl;
     }

+ 3 - 3
examples/client/main.cpp

@@ -65,7 +65,7 @@ int main(int argc, char **argv) {
 
 	ws->onError([](const string &error) { cout << "WebSocket failed: " << error << endl; });
 
-	ws->onMessage([&](const variant<binary, string> &data) {
+	ws->onMessage([&](variant<binary, string> data) {
 		if (!holds_alternative<string>(data))
 			return;
 
@@ -166,7 +166,7 @@ shared_ptr<PeerConnection> createPeerConnection(const Configuration &config,
 	pc->onGatheringStateChange(
 	    [](PeerConnection::GatheringState state) { cout << "Gathering State: " << state << endl; });
 
-	pc->onLocalDescription([wws, id](const Description &description) {
+	pc->onLocalDescription([wws, id](Description description) {
 		json message = {
 		    {"id", id}, {"type", description.typeString()}, {"description", string(description)}};
 
@@ -174,7 +174,7 @@ shared_ptr<PeerConnection> createPeerConnection(const Configuration &config,
 			ws->send(message.dump());
 	});
 
-	pc->onLocalCandidate([wws, id](const Candidate &candidate) {
+	pc->onLocalCandidate([wws, id](Candidate candidate) {
 		json message = {{"id", id},
 		                {"type", "candidate"},
 		                {"candidate", string(candidate)},

+ 3 - 3
examples/copy-paste/answerer.cpp

@@ -36,12 +36,12 @@ int main(int argc, char **argv) {
 
 	auto pc = std::make_shared<PeerConnection>(config);
 
-	pc->onLocalDescription([](const Description &description) {
+	pc->onLocalDescription([](Description description) {
 		cout << "Local Description (Paste this to the other peer):" << endl;
 		cout << string(description) << endl;
 	});
 
-	pc->onLocalCandidate([](const Candidate &candidate) {
+	pc->onLocalCandidate([](Candidate candidate) {
 		cout << "Local Candidate (Paste this to the other peer after the local description):"
 		     << endl;
 		cout << string(candidate) << endl << endl;
@@ -60,7 +60,7 @@ int main(int argc, char **argv) {
 
 		dc->onClosed([&]() { cout << "[DataChannel closed: " << dc->label() << "]" << endl; });
 
-		dc->onMessage([](const variant<binary, string> &message) {
+		dc->onMessage([](variant<binary, string> message) {
 			if (holds_alternative<string>(message)) {
 				cout << "[Received message: " << get<string>(message) << "]" << endl;
 			}

+ 3 - 3
examples/copy-paste/offerer.cpp

@@ -36,12 +36,12 @@ int main(int argc, char **argv) {
 
 	auto pc = std::make_shared<PeerConnection>(config);
 
-	pc->onLocalDescription([](const Description &description) {
+	pc->onLocalDescription([](Description description) {
 		cout << "Local Description (Paste this to the other peer):" << endl;
 		cout << string(description) << endl;
 	});
 
-	pc->onLocalCandidate([](const Candidate &candidate) {
+	pc->onLocalCandidate([](Candidate candidate) {
 		cout << "Local Candidate (Paste this to the other peer after the local description):"
 		     << endl;
 		cout << string(candidate) << endl << endl;
@@ -60,7 +60,7 @@ int main(int argc, char **argv) {
 
 	dc->onClosed([&]() { cout << "[DataChannel closed: " << dc->label() << "]" << endl; });
 
-	dc->onMessage([](const variant<binary, string> &message) {
+	dc->onMessage([](variant<binary, string> message) {
 		if (holds_alternative<string>(message)) {
 			cout << "[Received: " << get<string>(message) << "]" << endl;
 		}

+ 10 - 9
include/rtc/channel.hpp

@@ -20,6 +20,7 @@
 #define RTC_CHANNEL_H
 
 #include "include.hpp"
+#include "message.hpp"
 
 #include <atomic>
 #include <functional>
@@ -33,7 +34,7 @@ public:
 	virtual ~Channel() = default;
 
 	virtual void close() = 0;
-	virtual bool send(const std::variant<binary, string> &data) = 0; // returns false if buffered
+	virtual bool send(message_variant data) = 0; // returns false if buffered
 
 	virtual bool isOpen() const = 0;
 	virtual bool isClosed() const = 0;
@@ -42,24 +43,24 @@ public:
 
 	void onOpen(std::function<void()> callback);
 	void onClosed(std::function<void()> callback);
-	void onError(std::function<void(const string &error)> callback);
+	void onError(std::function<void(string error)> callback);
 
-	void onMessage(std::function<void(const std::variant<binary, string> &data)> callback);
-	void onMessage(std::function<void(const binary &data)> binaryCallback,
-	               std::function<void(const string &data)> stringCallback);
+	void onMessage(std::function<void(message_variant data)> callback);
+	void onMessage(std::function<void(binary data)> binaryCallback,
+	               std::function<void(string data)> stringCallback);
 
 	void onBufferedAmountLow(std::function<void()> callback);
 	void setBufferedAmountLowThreshold(size_t amount);
 
 	// Extended API
-	virtual std::optional<std::variant<binary, string>> receive() = 0; // only if onMessage unset
+	virtual std::optional<message_variant> receive() = 0; // only if onMessage unset
 	virtual size_t availableAmount() const; // total size available to receive
 	void onAvailable(std::function<void()> callback);
 
 protected:
 	virtual void triggerOpen();
 	virtual void triggerClosed();
-	virtual void triggerError(const string &error);
+	virtual void triggerError(string error);
 	virtual void triggerAvailable(size_t count);
 	virtual void triggerBufferedAmount(size_t amount);
 
@@ -68,8 +69,8 @@ protected:
 private:
 	synchronized_callback<> mOpenCallback;
 	synchronized_callback<> mClosedCallback;
-	synchronized_callback<const string &> mErrorCallback;
-	synchronized_callback<const std::variant<binary, string> &> mMessageCallback;
+	synchronized_callback<string> mErrorCallback;
+	synchronized_callback<message_variant> mMessageCallback;
 	synchronized_callback<> mAvailableCallback;
 	synchronized_callback<> mBufferedAmountLowCallback;
 

+ 2 - 2
include/rtc/datachannel.hpp

@@ -50,7 +50,7 @@ public:
 	Reliability reliability() const;
 
 	void close(void) override;
-	bool send(const std::variant<binary, string> &data) override;
+	bool send(message_variant data) override;
 	bool send(const byte *data, size_t size);
 	template <typename Buffer> bool sendBuffer(const Buffer &buf);
 	template <typename Iterator> bool sendBuffer(Iterator first, Iterator last);
@@ -61,7 +61,7 @@ public:
 
 	// Extended API
 	size_t availableAmount() const override;
-	std::optional<std::variant<binary, string>> receive() override;
+	std::optional<message_variant> receive() override;
 
 private:
 	void remoteClose();

+ 14 - 18
include/rtc/message.hpp

@@ -1,5 +1,5 @@
 /**
- * Copyright (c) 2019 Paul-Louis Ageneau
+ * Copyright (c) 2019-2020 Paul-Louis Ageneau
  *
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
@@ -24,6 +24,8 @@
 
 #include <functional>
 #include <memory>
+#include <optional>
+#include <variant>
 
 namespace rtc {
 
@@ -46,8 +48,9 @@ struct Message : binary {
 
 using message_ptr = std::shared_ptr<Message>;
 using message_callback = std::function<void(message_ptr message)>;
+using message_variant = std::variant<binary, string>;
 
-constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
+inline size_t message_size_func(const message_ptr &m) {
 	return m->type == Message::Binary || m->type == Message::String ? m->size() : 0;
 };
 
@@ -61,23 +64,16 @@ message_ptr make_message(Iterator begin, Iterator end, Message::Type type = Mess
 	return message;
 }
 
-inline message_ptr make_message(size_t size, Message::Type type = Message::Binary,
-                                unsigned int stream = 0,
-                                std::shared_ptr<Reliability> reliability = nullptr) {
-	auto message = std::make_shared<Message>(size, type);
-	message->stream = stream;
-	message->reliability = reliability;
-	return message;
-}
+message_ptr make_message(size_t size, Message::Type type = Message::Binary, unsigned int stream = 0,
+                         std::shared_ptr<Reliability> reliability = nullptr);
 
-inline message_ptr make_message(binary &&data, Message::Type type = Message::Binary,
-                                unsigned int stream = 0,
-                                std::shared_ptr<Reliability> reliability = nullptr) {
-	auto message = std::make_shared<Message>(std::move(data), type);
-	message->stream = stream;
-	message->reliability = reliability;
-	return message;
-}
+message_ptr make_message(binary &&data, Message::Type type = Message::Binary,
+                         unsigned int stream = 0,
+                         std::shared_ptr<Reliability> reliability = nullptr);
+
+message_ptr make_message(message_variant data);
+
+std::optional<message_variant> to_variant(Message &&message);
 
 } // namespace rtc
 

+ 7 - 7
include/rtc/peerconnection.hpp

@@ -89,8 +89,8 @@ public:
 	                                               const Reliability &reliability = {});
 
 	void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
-	void onLocalDescription(std::function<void(const Description &description)> callback);
-	void onLocalCandidate(std::function<void(const Candidate &candidate)> callback);
+	void onLocalDescription(std::function<void(Description description)> callback);
+	void onLocalCandidate(std::function<void(Candidate candidate)> callback);
 	void onStateChange(std::function<void(State state)> callback);
 	void onGatheringStateChange(std::function<void(GatheringState state)> callback);
 
@@ -102,10 +102,10 @@ public:
 
 	// Media
 	bool hasMedia() const;
-	void sendMedia(const binary &packet);
+	void sendMedia(binary packet);
 	void sendMedia(const byte *packet, size_t size);
 
-	void onMedia(std::function<void(const binary &packet)> callback);
+	void onMedia(std::function<void(binary)> callback);
 
 	// libnice only
 	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
@@ -160,11 +160,11 @@ private:
 	std::atomic<GatheringState> mGatheringState;
 
 	synchronized_callback<std::shared_ptr<DataChannel>> mDataChannelCallback;
-	synchronized_callback<const Description &> mLocalDescriptionCallback;
-	synchronized_callback<const Candidate &> mLocalCandidateCallback;
+	synchronized_callback<Description> mLocalDescriptionCallback;
+	synchronized_callback<Candidate> mLocalCandidateCallback;
 	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<GatheringState> mGatheringStateChangeCallback;
-	synchronized_callback<const binary &> mMediaCallback;
+	synchronized_callback<binary> mMediaCallback;
 };
 
 } // namespace rtc

+ 2 - 2
include/rtc/websocket.hpp

@@ -58,14 +58,14 @@ public:
 
 	void open(const string &url);
 	void close() override;
-	bool send(const std::variant<binary, string> &data) override;
+	bool send(const message_variant data) override;
 
 	bool isOpen() const override;
 	bool isClosed() const override;
 	size_t maxMessageSize() const override;
 
 	// Extended API
-	std::optional<std::variant<binary, string>> receive() override;
+	std::optional<message_variant> receive() override;
 	size_t availableAmount() const override; // total size available to receive
 
 private:

+ 7 - 9
src/channel.cpp

@@ -34,11 +34,9 @@ void Channel::onClosed(std::function<void()> callback) {
 	mClosedCallback = callback;
 }
 
-void Channel::onError(std::function<void(const string &error)> callback) {
-	mErrorCallback = callback;
-}
+void Channel::onError(std::function<void(string error)> callback) { mErrorCallback = callback; }
 
-void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) {
+void Channel::onMessage(std::function<void(message_variant data)> callback) {
 	mMessageCallback = callback;
 
 	// Pass pending messages
@@ -46,10 +44,10 @@ void Channel::onMessage(std::function<void(const std::variant<binary, string> &d
 		mMessageCallback(*message);
 }
 
-void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
-                        std::function<void(const string &data)> stringCallback) {
-	onMessage([binaryCallback, stringCallback](const std::variant<binary, string> &data) {
-		std::visit(overloaded{binaryCallback, stringCallback}, data);
+void Channel::onMessage(std::function<void(binary data)> binaryCallback,
+                        std::function<void(string data)> stringCallback) {
+	onMessage([binaryCallback, stringCallback](std::variant<binary, string> data) {
+		std::visit(overloaded{binaryCallback, stringCallback}, std::move(data));
 	});
 }
 
@@ -67,7 +65,7 @@ void Channel::triggerOpen() { mOpenCallback(); }
 
 void Channel::triggerClosed() { mClosedCallback(); }
 
-void Channel::triggerError(const string &error) { mErrorCallback(error); }
+void Channel::triggerError(string error) { mErrorCallback(error); }
 
 void Channel::triggerAvailable(size_t count) {
 	if (count == 1)

+ 7 - 24
src/datachannel.cpp

@@ -117,40 +117,23 @@ void DataChannel::remoteClose() {
 	mSctpTransport.reset();
 }
 
-bool DataChannel::send(const std::variant<binary, string> &data) {
-	return std::visit(
-	    [&](const auto &d) {
-		    using T = std::decay_t<decltype(d)>;
-		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
-		    auto *b = reinterpret_cast<const byte *>(d.data());
-		    return outgoing(std::make_shared<Message>(b, b + d.size(), type));
-	    },
-	    data);
-}
+bool DataChannel::send(message_variant data) { return outgoing(make_message(std::move(data))); }
 
 bool DataChannel::send(const byte *data, size_t size) {
 	return outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
 }
 
-std::optional<std::variant<binary, string>> DataChannel::receive() {
+std::optional<message_variant> DataChannel::receive() {
 	while (!mRecvQueue.empty()) {
 		auto message = *mRecvQueue.pop();
-		switch (message->type) {
-		case Message::Control: {
+		if (message->type == Message::Control) {
 			auto raw = reinterpret_cast<const uint8_t *>(message->data());
-			if (raw[0] == MESSAGE_CLOSE)
+			if (!message->empty() && raw[0] == MESSAGE_CLOSE)
 				remoteClose();
-			break;
-		}
-		case Message::String:
-			return std::make_optional(
-			    string(reinterpret_cast<const char *>(message->data()), message->size()));
-		case Message::Binary:
-			return std::make_optional(std::move(*message));
-		default:
-			// Ignore
-			break;
+			continue;
 		}
+		if (auto variant = to_variant(std::move(*message)))
+			return variant;
 	}
 
 	return nullopt;

+ 63 - 0
src/message.cpp

@@ -0,0 +1,63 @@
+/**
+ * Copyright (c) 2019-2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "message.hpp"
+
+namespace rtc {
+
+message_ptr make_message(size_t size, Message::Type type, unsigned int stream,
+                         std::shared_ptr<Reliability> reliability) {
+	auto message = std::make_shared<Message>(size, type);
+	message->stream = stream;
+	message->reliability = reliability;
+	return message;
+}
+
+message_ptr make_message(binary &&data, Message::Type type, unsigned int stream,
+                         std::shared_ptr<Reliability> reliability) {
+	auto message = std::make_shared<Message>(std::move(data), type);
+	message->stream = stream;
+	message->reliability = reliability;
+	return message;
+}
+
+message_ptr make_message(message_variant data) {
+	return std::visit( //
+	    overloaded{
+	        [&](binary data) { return make_message(std::move(data), Message::Binary); },
+	        [&](string data) {
+		        auto b = reinterpret_cast<const byte *>(data.data());
+		        return make_message(b, b + data.size(), Message::String);
+	        },
+	    },
+	    std::move(data));
+}
+
+std::optional<message_variant> to_variant(Message &&message) {
+	switch (message.type) {
+	case Message::String:
+		return std::make_optional(
+		    string(reinterpret_cast<const char *>(message.data()), message.size()));
+	case Message::Binary:
+		return std::make_optional(std::move(message));
+	default:
+		return nullopt;
+	}
+}
+
+} // namespace rtc

+ 12 - 12
src/peerconnection.cpp

@@ -217,12 +217,11 @@ void PeerConnection::onDataChannel(
 	mDataChannelCallback = callback;
 }
 
-void PeerConnection::onLocalDescription(
-    std::function<void(const Description &description)> callback) {
+void PeerConnection::onLocalDescription(std::function<void(Description description)> callback) {
 	mLocalDescriptionCallback = callback;
 }
 
-void PeerConnection::onLocalCandidate(std::function<void(const Candidate &candidate)> callback) {
+void PeerConnection::onLocalCandidate(std::function<void(Candidate candidate)> callback) {
 	mLocalCandidateCallback = callback;
 }
 
@@ -240,17 +239,15 @@ bool PeerConnection::hasMedia() const {
 	return (local && local->hasMedia()) || (remote && remote->hasMedia());
 }
 
-void PeerConnection::sendMedia(const binary &packet) {
-	outgoingMedia(make_message(packet.begin(), packet.end(), Message::Binary));
+void PeerConnection::sendMedia(binary packet) {
+	outgoingMedia(make_message(std::move(packet), Message::Binary));
 }
 
 void PeerConnection::sendMedia(const byte *packet, size_t size) {
 	outgoingMedia(make_message(packet, packet + size, Message::Binary));
 }
 
-void PeerConnection::onMedia(std::function<void(const binary &packet)> callback) {
-	mMediaCallback = callback;
-}
+void PeerConnection::onMedia(std::function<void(binary)> callback) { mMediaCallback = callback; }
 
 void PeerConnection::outgoingMedia([[maybe_unused]] message_ptr message) {
 	if (!hasMedia())
@@ -529,7 +526,7 @@ void PeerConnection::forwardMessage(message_ptr message) {
 
 void PeerConnection::forwardMedia(message_ptr message) {
 	if (message)
-		mMediaCallback(*message);
+		mMediaCallback(std::move(*message));
 }
 
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
@@ -630,7 +627,9 @@ void PeerConnection::processLocalDescription(Description description) {
 		mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);
 	}
 
-	mProcessor->enqueue([this]() { mLocalDescriptionCallback(*mLocalDescription); });
+	mProcessor->enqueue([this, description = *mLocalDescription]() {
+		mLocalDescriptionCallback(std::move(description));
+	});
 }
 
 void PeerConnection::processLocalCandidate(Candidate candidate) {
@@ -640,8 +639,9 @@ void PeerConnection::processLocalCandidate(Candidate candidate) {
 
 	mLocalDescription->addCandidate(candidate);
 
-	mProcessor->enqueue(
-	    [this, candidate = std::move(candidate)]() { mLocalCandidateCallback(candidate); });
+	mProcessor->enqueue([this, candidate = std::move(candidate)]() {
+		mLocalCandidateCallback(std::move(candidate));
+	});
 }
 
 void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {

+ 7 - 7
src/rtc.cpp

@@ -341,7 +341,7 @@ int rtcSetLocalDescriptionCallback(int pc, rtcDescriptionCallbackFunc cb) {
 	return WRAP({
 		auto peerConnection = getPeerConnection(pc);
 		if (cb)
-			peerConnection->onLocalDescription([pc, cb](const Description &desc) {
+			peerConnection->onLocalDescription([pc, cb](Description desc) {
 				if (auto ptr = getUserPointer(pc))
 					cb(string(desc).c_str(), desc.typeString().c_str(), *ptr);
 			});
@@ -354,7 +354,7 @@ int rtcSetLocalCandidateCallback(int pc, rtcCandidateCallbackFunc cb) {
 	return WRAP({
 		auto peerConnection = getPeerConnection(pc);
 		if (cb)
-			peerConnection->onLocalCandidate([pc, cb](const Candidate &cand) {
+			peerConnection->onLocalCandidate([pc, cb](Candidate cand) {
 				if (auto ptr = getUserPointer(pc))
 					cb(cand.candidate().c_str(), cand.mid().c_str(), *ptr);
 			});
@@ -542,7 +542,7 @@ int rtcSetErrorCallback(int id, rtcErrorCallbackFunc cb) {
 	return WRAP({
 		auto channel = getChannel(id);
 		if (cb)
-			channel->onError([id, cb](const string &error) {
+			channel->onError([id, cb](string error) {
 				if (auto ptr = getUserPointer(id))
 					cb(error.c_str(), *ptr);
 			});
@@ -556,11 +556,11 @@ int rtcSetMessageCallback(int id, rtcMessageCallbackFunc cb) {
 		auto channel = getChannel(id);
 		if (cb)
 			channel->onMessage(
-			    [id, cb](const binary &b) {
+			    [id, cb](binary b) {
 				    if (auto ptr = getUserPointer(id))
 					    cb(reinterpret_cast<const char *>(b.data()), int(b.size()), *ptr);
 			    },
-			    [id, cb](const string &s) {
+			    [id, cb](string s) {
 				    if (auto ptr = getUserPointer(id))
 					    cb(s.c_str(), -int(s.size() + 1), *ptr);
 			    });
@@ -643,13 +643,13 @@ int rtcReceiveMessage(int id, char *buffer, int *size) {
 		if (auto message = channel->receive())
 			return std::visit( //
 			    overloaded{    //
-			               [&](const binary &b) {
+			               [&](binary b) {
 				               *size = std::min(*size, int(b.size()));
 				               auto data = reinterpret_cast<const char *>(b.data());
 				               std::copy(data, data + *size, buffer);
 				               return 1;
 			               },
-			               [&](const string &s) {
+			               [&](string s) {
 				               int len = std::min(*size - 1, int(s.size()));
 				               if (len >= 0) {
 					               std::copy(s.data(), s.data() + len, buffer);

+ 6 - 24
src/websocket.cpp

@@ -100,16 +100,7 @@ void WebSocket::remoteClose() {
 	}
 }
 
-bool WebSocket::send(const std::variant<binary, string> &data) {
-	return std::visit(
-	    [&](const auto &d) {
-		    using T = std::decay_t<decltype(d)>;
-		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
-		    auto *b = reinterpret_cast<const byte *>(d.data());
-		    return outgoing(std::make_shared<Message>(b, b + d.size(), type));
-	    },
-	    data);
-}
+bool WebSocket::send(message_variant data) { return outgoing(make_message(std::move(data))); }
 
 bool WebSocket::isOpen() const { return mState == State::Open; }
 
@@ -117,20 +108,11 @@ bool WebSocket::isClosed() const { return mState == State::Closed; }
 
 size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
-std::optional<std::variant<binary, string>> WebSocket::receive() {
-	while (!mRecvQueue.empty()) {
-		auto message = *mRecvQueue.pop();
-		switch (message->type) {
-		case Message::String:
-			return std::make_optional(
-			    string(reinterpret_cast<const char *>(message->data()), message->size()));
-		case Message::Binary:
-			return std::make_optional(std::move(*message));
-		default:
-			// Ignore
-			break;
-		}
-	}
+std::optional<message_variant> WebSocket::receive() {
+	while (!mRecvQueue.empty())
+		if (auto variant = to_variant(std::move(**mRecvQueue.pop())))
+			return variant;
+
 	return nullopt;
 }
 

+ 22 - 23
test/benchmark.cpp

@@ -48,20 +48,20 @@ size_t benchmark(milliseconds duration) {
 
 	auto pc2 = std::make_shared<PeerConnection>(config2);
 
-	pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](const Description &sdp) {
+	pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](Description sdp) {
 		auto pc2 = wpc2.lock();
 		if (!pc2)
 			return;
 		cout << "Description 1: " << sdp << endl;
-		pc2->setRemoteDescription(sdp);
+		pc2->setRemoteDescription(std::move(sdp));
 	});
 
-	pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](const Candidate &candidate) {
+	pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](Candidate candidate) {
 		auto pc2 = wpc2.lock();
 		if (!pc2)
 			return;
 		cout << "Candidate 1: " << candidate << endl;
-		pc2->addRemoteCandidate(candidate);
+		pc2->addRemoteCandidate(std::move(candidate));
 	});
 
 	pc1->onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; });
@@ -69,20 +69,20 @@ size_t benchmark(milliseconds duration) {
 		cout << "Gathering state 1: " << state << endl;
 	});
 
-	pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](const Description &sdp) {
+	pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](Description sdp) {
 		auto pc1 = wpc1.lock();
 		if (!pc1)
 			return;
 		cout << "Description 2: " << sdp << endl;
-		pc1->setRemoteDescription(sdp);
+		pc1->setRemoteDescription(std::move(sdp));
 	});
 
-	pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](const Candidate &candidate) {
+	pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](Candidate candidate) {
 		auto pc1 = wpc1.lock();
 		if (!pc1)
 			return;
 		cout << "Candidate 2: " << candidate << endl;
-		pc1->addRemoteCandidate(candidate);
+		pc1->addRemoteCandidate(std::move(candidate));
 	});
 
 	pc2->onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; });
@@ -99,21 +99,20 @@ size_t benchmark(milliseconds duration) {
 	steady_clock::time_point startTime, openTime, receivedTime, endTime;
 
 	shared_ptr<DataChannel> dc2;
-	pc2->onDataChannel(
-	    [&dc2, &receivedSize, &receivedTime](shared_ptr<DataChannel> dc) {
-		    dc->onMessage([&receivedTime, &receivedSize](const variant<binary, string> &message) {
-			    if (holds_alternative<binary>(message)) {
-				    const auto &bin = get<binary>(message);
-				    if (receivedSize == 0)
-					    receivedTime = steady_clock::now();
-				    receivedSize += bin.size();
-			    }
-		    });
-
-		    dc->onClosed([]() { cout << "DataChannel closed." << endl; });
-
-		    std::atomic_store(&dc2, dc);
-	    });
+	pc2->onDataChannel([&dc2, &receivedSize, &receivedTime](shared_ptr<DataChannel> dc) {
+		dc->onMessage([&receivedTime, &receivedSize](variant<binary, string> message) {
+			if (holds_alternative<binary>(message)) {
+				const auto &bin = get<binary>(message);
+				if (receivedSize == 0)
+					receivedTime = steady_clock::now();
+				receivedSize += bin.size();
+			}
+		});
+
+		dc->onClosed([]() { cout << "DataChannel closed." << endl; });
+
+		std::atomic_store(&dc2, dc);
+	});
 
 	startTime = steady_clock::now();
 	auto dc1 = pc1->createDataChannel("benchmark");

+ 9 - 9
test/connectivity.cpp

@@ -47,20 +47,20 @@ void test_connectivity() {
 
 	auto pc2 = std::make_shared<PeerConnection>(config2);
 
-	pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](const Description &sdp) {
+	pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](Description sdp) {
 		auto pc2 = wpc2.lock();
 		if (!pc2)
 			return;
 		cout << "Description 1: " << sdp << endl;
-		pc2->setRemoteDescription(sdp);
+		pc2->setRemoteDescription(std::move(sdp));
 	});
 
-	pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](const Candidate &candidate) {
+	pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](Candidate candidate) {
 		auto pc2 = wpc2.lock();
 		if (!pc2)
 			return;
 		cout << "Candidate 1: " << candidate << endl;
-		pc2->addRemoteCandidate(candidate);
+		pc2->addRemoteCandidate(std::move(candidate));
 	});
 
 	pc1->onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; });
@@ -69,20 +69,20 @@ void test_connectivity() {
 		cout << "Gathering state 1: " << state << endl;
 	});
 
-	pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](const Description &sdp) {
+	pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](Description sdp) {
 		auto pc1 = wpc1.lock();
 		if (!pc1)
 			return;
 		cout << "Description 2: " << sdp << endl;
-		pc1->setRemoteDescription(sdp);
+		pc1->setRemoteDescription(std::move(sdp));
 	});
 
-	pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](const Candidate &candidate) {
+	pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](Candidate candidate) {
 		auto pc1 = wpc1.lock();
 		if (!pc1)
 			return;
 		cout << "Candidate 2: " << candidate << endl;
-		pc1->addRemoteCandidate(candidate);
+		pc1->addRemoteCandidate(std::move(candidate));
 	});
 
 	pc2->onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; });
@@ -95,7 +95,7 @@ void test_connectivity() {
 	pc2->onDataChannel([&dc2](shared_ptr<DataChannel> dc) {
 		cout << "DataChannel 2: Received with label \"" << dc->label() << "\"" << endl;
 
-		dc->onMessage([](const variant<binary, string> &message) {
+		dc->onMessage([](variant<binary, string> message) {
 			if (holds_alternative<string>(message)) {
 				cout << "Message 2: " << get<string>(message) << endl;
 			}

+ 2 - 2
test/websocket.cpp

@@ -53,9 +53,9 @@ void test_websocket() {
 	ws->onClosed([]() { cout << "WebSocket: Closed" << endl; });
 
 	std::atomic<bool> received = false;
-	ws->onMessage([&received, &myMessage](const variant<binary, string> &message) {
+	ws->onMessage([&received, &myMessage](variant<binary, string> message) {
 		if (holds_alternative<string>(message)) {
-			string str = get<string>(message);
+			string str = std::move(get<string>(message));
 			if((received = (str == myMessage)))
 				cout << "WebSocket: Received expected message" << endl;
 			else