Browse Source

Added WebSocket

Paul-Louis Ageneau 5 years ago
parent
commit
c1f91b2fff

+ 2 - 1
CMakeLists.txt

@@ -36,10 +36,11 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
 
+	${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp
-	${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
 )
 
 set(LIBDATACHANNEL_HEADERS

+ 2 - 2
include/rtc/peerconnection.hpp

@@ -96,8 +96,6 @@ public:
 	std::optional<std::chrono::milliseconds> rtt();
 
 private:
-	init_token mInitToken = Init::Token();
-
 	std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
 	std::shared_ptr<DtlsTransport> initDtlsTransport();
 	std::shared_ptr<SctpTransport> initSctpTransport();
@@ -128,6 +126,8 @@ private:
 	const Configuration mConfig;
 	const std::shared_ptr<Certificate> mCertificate;
 
+	init_token mInitToken = Init::Token();
+
 	std::optional<Description> mLocalDescription, mRemoteDescription;
 	mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;
 

+ 87 - 0
include/rtc/websocket.hpp

@@ -0,0 +1,87 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_WEBSOCKET_H
+#define RTC_WEBSOCKET_H
+
+#if ENABLE_WEBSOCKET
+
+#include "channel.hpp"
+#include "include.hpp"
+#include "init.hpp"
+#include "message.hpp"
+#include "queue.hpp"
+
+#include <atomic>
+#include <optional>
+#include <thread>
+#include <variant>
+
+namespace rtc {
+
+class TcpTransport;
+class TlsTransport;
+class WsTransport;
+
+class WebSocket final : public Channel, public std::enable_shared_from_this<WebSocket> {
+public:
+	WebSocket();
+	WebSocket(const string &url);
+	~WebSocket();
+
+	void open(const string &url);
+	void close() override;
+	bool send(const std::variant<binary, string> &data) override;
+
+	bool isOpen() const override;
+	bool isClosed() const override;
+	size_t maxMessageSize() const override;
+
+	// Extended API
+	std::optional<std::variant<binary, string>> receive() override;
+	size_t availableAmount() const override; // total size available to receive
+
+private:
+	void remoteClose();
+	bool outgoing(mutable_message_ptr message);
+
+	std::shared_ptr<TcpTransport> initTcpTransport();
+	std::shared_ptr<TlsTransport> initTlsTransport();
+	std::shared_ptr<WsTransport> initWsTransport();
+	void closeTransports();
+
+	init_token mInitToken = Init::Token();
+
+	std::shared_ptr<TcpTransport> mTcpTransport;
+	std::shared_ptr<TlsTransport> mTlsTransport;
+	std::shared_ptr<WsTransport> mWsTransport;
+	std::recursive_mutex mInitMutex;
+
+	string mScheme, mHost, mHostname, mService, mPath;
+
+	std::atomic<bool> mIsOpen = false;
+	std::atomic<bool> mIsClosed = false;
+
+	Queue<message_ptr> mRecvQueue;
+	std::atomic<size_t> mRecvAmount = 0;
+};
+} // namespace rtc
+
+#endif
+
+#endif // NET_WEBSOCKET_H

+ 11 - 28
src/dtlstransport.cpp

@@ -62,11 +62,9 @@ void DtlsTransport::Cleanup() {
 }
 
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
-                             verifier_callback verifierCallback,
-                             state_callback stateChangeCallback)
-    : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
-      mVerifierCallback(std::move(verifierCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)) {
+                             verifier_callback verifierCallback, state_callback stateChangeCallback)
+    : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
+      mVerifierCallback(std::move(verifierCallback)) {
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
@@ -113,8 +111,6 @@ DtlsTransport::~DtlsTransport() {
 	gnutls_deinit(mSession);
 }
 
-DtlsTransport::State DtlsTransport::state() const { return mState; }
-
 bool DtlsTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -126,7 +122,7 @@ bool DtlsTransport::stop() {
 }
 
 bool DtlsTransport::send(message_ptr message) {
-	if (!message || mState != State::Connected)
+	if (!message || state() != State::Connected)
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -152,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
-void DtlsTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(state);
-}
-
 void DtlsTransport::runRecvLoop() {
 	const size_t maxMtu = 4096;
 
@@ -362,9 +353,8 @@ void DtlsTransport::Cleanup() {
 
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
                              verifier_callback verifierCallback, state_callback stateChangeCallback)
-    : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
-      mVerifierCallback(std::move(verifierCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)) {
+    : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
+      mVerifierCallback(std::move(verifierCallback)) {
 
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
@@ -445,10 +435,8 @@ bool DtlsTransport::stop() {
 	return true;
 }
 
-DtlsTransport::State DtlsTransport::state() const { return mState; }
-
 bool DtlsTransport::send(message_ptr message) {
-	if (!message || mState != State::Connected)
+	if (!message || state() != State::Connected)
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -467,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
-void DtlsTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(state);
-}
-
 void DtlsTransport::runRecvLoop() {
 	const size_t maxMtu = 4096;
 	try {
@@ -490,7 +473,7 @@ void DtlsTransport::runRecvLoop() {
 				auto message = *mIncomingQueue.pop();
 				BIO_write(mInBio, message->data(), message->size());
 
-				if (mState == State::Connecting) {
+				if (state() == State::Connecting) {
 					// Continue the handshake
 					int ret = SSL_do_handshake(mSsl);
 					if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
@@ -515,7 +498,7 @@ void DtlsTransport::runRecvLoop() {
 
 			// No more messages pending, retransmit and rearm timeout if connecting
 			std::optional<milliseconds> duration;
-			if (mState == State::Connecting) {
+			if (state() == State::Connecting) {
 				// Warning: This function breaks the usual return value convention
 				int ret = DTLSv1_handle_timeout(mSsl);
 				if (ret < 0) {
@@ -525,7 +508,7 @@ void DtlsTransport::runRecvLoop() {
 				}
 
 				struct timeval timeout = {};
-				if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
+				if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
 					duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
 					// Also handle handshake timeout manually because OpenSSL actually doesn't...
 					// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
@@ -546,7 +529,7 @@ void DtlsTransport::runRecvLoop() {
 		PLOG_ERROR << "DTLS recv: " << e.what();
 	}
 
-	if (mState == State::Connected) {
+	if (state() == State::Connected) {
 		PLOG_INFO << "DTLS disconnected";
 		changeState(State::Disconnected);
 		recv(nullptr);

+ 0 - 8
src/dtlstransport.hpp

@@ -46,33 +46,25 @@ public:
 	static void Init();
 	static void Cleanup();
 
-	enum class State { Disconnected, Connecting, Connected, Failed };
-
 	using verifier_callback = std::function<bool(const std::string &fingerprint)>;
-	using state_callback = std::function<void(State state)>;
 
 	DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
 	              verifier_callback verifierCallback, state_callback stateChangeCallback);
 	~DtlsTransport();
 
-	State state() const;
-
 	bool stop() override;
 	bool send(message_ptr message) override; // false if dropped
 
 private:
 	void incoming(message_ptr message) override;
-	void changeState(State state);
 	void runRecvLoop();
 
 	const std::shared_ptr<Certificate> mCertificate;
 
 	Queue<message_ptr> mIncomingQueue;
-	std::atomic<State> mState;
 	std::thread mRecvThread;
 
 	verifier_callback mVerifierCallback;
-	state_callback mStateChangeCallback;
 
 #if USE_GNUTLS
 	gnutls_session_t mSession;

+ 42 - 24
src/icetransport.cpp

@@ -48,9 +48,8 @@ namespace rtc {
 IceTransport::IceTransport(const Configuration &config, Description::Role role,
                            candidate_callback candidateCallback, state_callback stateChangeCallback,
                            gathering_state_callback gatheringStateChangeCallback)
-    : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
-      mCandidateCallback(std::move(candidateCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)),
+    : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
+      mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
       mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
       mAgent(nullptr, nullptr) {
 
@@ -108,8 +107,6 @@ bool IceTransport::stop() {
 
 Description::Role IceTransport::role() const { return mRole; }
 
-IceTransport::State IceTransport::state() const { return mState; }
-
 Description IceTransport::getLocalDescription(Description::Type type) const {
 	char sdp[JUICE_MAX_SDP_STRING_LEN];
 	if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
@@ -161,7 +158,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
 }
 
 bool IceTransport::send(message_ptr message) {
-	if (!message || (mState != State::Connected && mState != State::Completed))
+	auto s = state();
+	if (!message || (s != State::Connected && s != State::Completed))
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -173,18 +171,29 @@ bool IceTransport::outgoing(message_ptr message) {
 	                  message->size()) >= 0;
 }
 
-void IceTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(mState);
-}
-
 void IceTransport::changeGatheringState(GatheringState state) {
 	if (mGatheringState.exchange(state) != state)
 		mGatheringStateChangeCallback(mGatheringState);
 }
 
 void IceTransport::processStateChange(unsigned int state) {
-	changeState(static_cast<State>(state));
+	switch (state) {
+	case JUICE_STATE_DISCONNECTED:
+		changeState(State::Disconnected);
+		break;
+	case JUICE_STATE_CONNECTING:
+		changeState(State::Connecting);
+		break;
+	case JUICE_STATE_CONNECTED:
+		changeState(State::Connected);
+		break;
+	case JUICE_STATE_COMPLETED:
+		changeState(State::Completed);
+		break;
+	case JUICE_STATE_FAILED:
+		changeState(State::Failed);
+		break;
+	};
 }
 
 void IceTransport::processCandidate(const string &candidate) {
@@ -263,9 +272,8 @@ namespace rtc {
 IceTransport::IceTransport(const Configuration &config, Description::Role role,
                            candidate_callback candidateCallback, state_callback stateChangeCallback,
                            gathering_state_callback gatheringStateChangeCallback)
-    : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
-      mCandidateCallback(std::move(candidateCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)),
+    : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
+      mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
       mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
       mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) {
 
@@ -457,8 +465,6 @@ bool IceTransport::stop() {
 
 Description::Role IceTransport::role() const { return mRole; }
 
-IceTransport::State IceTransport::state() const { return mState; }
-
 Description IceTransport::getLocalDescription(Description::Type type) const {
 	// RFC 8445: The initiating agent that started the ICE processing MUST take the controlling
 	// role, and the other MUST take the controlled role.
@@ -529,7 +535,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
 }
 
 bool IceTransport::send(message_ptr message) {
-	if (!message || (mState != State::Connected && mState != State::Completed))
+	auto s = state();
+	if (!message || (s != State::Connected && s != State::Completed))
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -541,11 +548,6 @@ bool IceTransport::outgoing(message_ptr message) {
 	                       reinterpret_cast<const char *>(message->data())) >= 0;
 }
 
-void IceTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(mState);
-}
-
 void IceTransport::changeGatheringState(GatheringState state) {
 	if (mGatheringState.exchange(state) != state)
 		mGatheringStateChangeCallback(mGatheringState);
@@ -576,7 +578,23 @@ void IceTransport::processStateChange(unsigned int state) {
 		mTimeoutId = 0;
 	}
 
-	changeState(static_cast<State>(state));
+	switch (state) {
+	case NICE_COMPONENT_STATE_DISCONNECTED:
+		changeState(State::Disconnected);
+		break;
+	case NICE_COMPONENT_STATE_CONNECTING:
+		changeState(State::Connecting);
+		break;
+	case NICE_COMPONENT_STATE_CONNECTED:
+		changeState(State::Connected);
+		break;
+	case NICE_COMPONENT_STATE_READY:
+		changeState(State::Completed);
+		break;
+	case NICE_COMPONENT_STATE_FAILED:
+		changeState(State::Failed);
+		break;
+	};
 }
 
 string IceTransport::AddressToString(const NiceAddress &addr) {

+ 4 - 24
src/icetransport.hpp

@@ -40,29 +40,9 @@ namespace rtc {
 
 class IceTransport : public Transport {
 public:
-#if USE_JUICE
-	enum class State : unsigned int{
-	    Disconnected = JUICE_STATE_DISCONNECTED,
-	    Connecting = JUICE_STATE_CONNECTING,
-	    Connected = JUICE_STATE_CONNECTED,
-	    Completed = JUICE_STATE_COMPLETED,
-	    Failed = JUICE_STATE_FAILED,
-	};
-#else
-	enum class State : unsigned int {
-		Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
-		Connecting = NICE_COMPONENT_STATE_CONNECTING,
-		Connected = NICE_COMPONENT_STATE_CONNECTED,
-		Completed = NICE_COMPONENT_STATE_READY,
-		Failed = NICE_COMPONENT_STATE_FAILED,
-	};
-
-	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
-#endif
 	enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
 
 	using candidate_callback = std::function<void(const Candidate &candidate)>;
-	using state_callback = std::function<void(State state)>;
 	using gathering_state_callback = std::function<void(GatheringState state)>;
 
 	IceTransport(const Configuration &config, Description::Role role,
@@ -71,7 +51,6 @@ public:
 	~IceTransport();
 
 	Description::Role role() const;
-	State state() const;
 	GatheringState gatheringState() const;
 	Description getLocalDescription(Description::Type type) const;
 	void setRemoteDescription(const Description &description);
@@ -84,10 +63,13 @@ public:
 	bool stop() override;
 	bool send(message_ptr message) override; // false if dropped
 
+#if !USE_JUICE
+	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
+#endif
+
 private:
 	bool outgoing(message_ptr message) override;
 
-	void changeState(State state);
 	void changeGatheringState(GatheringState state);
 
 	void processStateChange(unsigned int state);
@@ -98,11 +80,9 @@ private:
 	Description::Role mRole;
 	string mMid;
 	std::chrono::milliseconds mTrickleTimeout;
-	std::atomic<State> mState;
 	std::atomic<GatheringState> mGatheringState;
 
 	candidate_callback mCandidateCallback;
-	state_callback mStateChangeCallback;
 	gathering_state_callback mGatheringStateChangeCallback;
 
 #if USE_JUICE

+ 6 - 14
src/sctptransport.cpp

@@ -67,9 +67,8 @@ void SctpTransport::Cleanup() { usrsctp_finish(); }
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
                              message_callback recvCallback, amount_callback bufferedAmountCallback,
                              state_callback stateChangeCallback)
-    : Transport(lower), mPort(port), mSendQueue(0, message_size_func),
-      mBufferedAmountCallback(std::move(bufferedAmountCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
+    : Transport(lower, std::move(stateChangeCallback)), mPort(port),
+      mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
 	onRecv(recvCallback);
 
 	PLOG_DEBUG << "Initializing SCTP transport";
@@ -176,8 +175,6 @@ SctpTransport::~SctpTransport() {
 	usrsctp_deregister_address(this);
 }
 
-SctpTransport::State SctpTransport::state() const { return mState; }
-
 bool SctpTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -265,7 +262,7 @@ void SctpTransport::incoming(message_ptr message) {
 	// to be sent on our side (i.e. the local INIT) before proceeding.
 	{
 		std::unique_lock lock(mWriteMutex);
-		mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; });
+		mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; });
 	}
 
 	if (!message) {
@@ -279,11 +276,6 @@ void SctpTransport::incoming(message_ptr message) {
 	usrsctp_conninput(this, message->data(), message->size(), 0);
 }
 
-void SctpTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(state);
-}
-
 bool SctpTransport::trySendQueue() {
 	// Requires mSendMutex to be locked
 	while (auto next = mSendQueue.peek()) {
@@ -298,7 +290,7 @@ bool SctpTransport::trySendQueue() {
 
 bool SctpTransport::trySendMessage(message_ptr message) {
 	// Requires mSendMutex to be locked
-	if (!mSock || mState != State::Connected)
+	if (!mSock || state() != State::Connected)
 		return false;
 
 	uint32_t ppid;
@@ -410,7 +402,7 @@ void SctpTransport::sendReset(uint16_t streamId) {
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
 		std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
 		mWrittenCondition.wait_for(lock, 1000ms,
-		                           [&]() { return mWritten || mState != State::Connected; });
+		                           [&]() { return mWritten || state() != State::Connected; });
 	} else if (errno == EINVAL) {
 		PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
 	} else {
@@ -567,7 +559,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 			PLOG_INFO << "SCTP connected";
 			changeState(State::Connected);
 		} else {
-			if (mState == State::Connecting) {
+			if (state() == State::Connecting) {
 				PLOG_ERROR << "SCTP connection failed";
 				changeState(State::Failed);
 			} else {

+ 1 - 10
src/sctptransport.hpp

@@ -38,17 +38,12 @@ public:
 	static void Init();
 	static void Cleanup();
 
-	enum class State { Disconnected, Connecting, Connected, Failed };
-
 	using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
-	using state_callback = std::function<void(State state)>;
 
 	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
 	              amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
 	~SctpTransport();
 
-	State state() const;
-
 	bool stop() override;
 	bool send(message_ptr message) override; // false if buffered
 	void close(unsigned int stream);
@@ -76,7 +71,6 @@ private:
 	void connect();
 	void shutdown();
 	void incoming(message_ptr message) override;
-	void changeState(State state);
 
 	bool trySendQueue();
 	bool trySendMessage(message_ptr message);
@@ -105,14 +99,11 @@ private:
 	std::atomic<bool> mWritten = false; // written outside lock
 	bool mWrittenOnce = false;
 
-	state_callback mStateChangeCallback;
-	std::atomic<State> mState;
+	binary mPartialRecv, mPartialStringData, mPartialBinaryData;
 
 	// Stats
 	std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
 
-	binary mPartialRecv, mPartialStringData, mPartialBinaryData;
-
 	static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
 	                        struct sctp_rcvinfo recv_info, int flags, void *user_data);
 	static int SendCallback(struct socket *sock, uint32_t sb_free);

+ 2 - 2
src/tcptransport.cpp

@@ -24,8 +24,8 @@ namespace rtc {
 
 using std::to_string;
 
-TcpTransport::TcpTransport(const string &hostname, const string &service)
-    : mHostname(hostname), mService(service) {
+TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
+    : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
 	mThread = std::thread(&TcpTransport::runLoop, this);
 }
 

+ 1 - 1
src/tcptransport.hpp

@@ -34,7 +34,7 @@ namespace rtc {
 
 class TcpTransport : public Transport {
 public:
-	TcpTransport(const string &hostname, const string &service);
+	TcpTransport(const string &hostname, const string &service, state_callback callback);
 	~TcpTransport();
 
 	bool stop() override;

+ 5 - 3
src/tlstransport.cpp

@@ -61,7 +61,8 @@ void TlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
+    : Transport(lower, std::move(callback)) {
 
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 
@@ -82,6 +83,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transp
 		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
 
 		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+		registerIncoming();
 
 	} catch (...) {
 
@@ -271,10 +273,10 @@ void TlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
+    : Transport(lower, std::move(callback)) {
 
 	PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
-	GlobalInit();
 
 	if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
 		throw std::runtime_error("Failed to create SSL context");

+ 1 - 1
src/tlstransport.hpp

@@ -41,7 +41,7 @@ class TcpTransport;
 
 class TlsTransport : public Transport {
 public:
-	TlsTransport(std::shared_ptr<TcpTransport> lower, string host);
+	TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
 	~TlsTransport();
 
 	bool stop() override;

+ 15 - 1
src/transport.hpp

@@ -32,7 +32,13 @@ using namespace std::placeholders;
 
 class Transport {
 public:
-	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {}
+	enum class State { Disconnected, Connecting, Connected, Completed, Failed };
+	using state_callback = std::function<void(State state)>;
+
+	Transport(std::shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr)
+	    : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {
+	}
+
 	virtual ~Transport() {
 		stop();
 		if (mLower)
@@ -49,11 +55,16 @@ public:
 	}
 
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
+	State state() const { return mState; }
 
 	virtual bool send(message_ptr message) { return outgoing(message); }
 
 protected:
 	void recv(message_ptr message) { mRecvCallback(message); }
+	void changeState(State state) {
+		if (mState.exchange(state) != state)
+			mStateChangeCallback(state);
+	}
 
 	virtual void incoming(message_ptr message) { recv(message); }
 	virtual bool outgoing(message_ptr message) {
@@ -65,7 +76,10 @@ protected:
 
 private:
 	std::shared_ptr<Transport> mLower;
+	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<message_ptr> mRecvCallback;
+
+	std::atomic<State> mState = State::Disconnected;
 	std::atomic<bool> mShutdown = false;
 };
 

+ 212 - 74
src/websocket.cpp

@@ -1,100 +1,238 @@
-/*************************************************************************
- *   Copyright (C) 2017-2018 by Paul-Louis Ageneau                       *
- *   paul-louis (at) ageneau (dot) org                                   *
- *                                                                       *
- *   This file is part of Plateform.                                     *
- *                                                                       *
- *   Plateform is free software: you can redistribute it and/or modify   *
- *   it under the terms of the GNU Affero General Public License as      *
- *   published by the Free Software Foundation, either version 3 of      *
- *   the License, or (at your option) any later version.                 *
- *                                                                       *
- *   Plateform is distributed in the hope that it will be useful, but    *
- *   WITHOUT ANY WARRANTY; without even the implied warranty of          *
- *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the        *
- *   GNU Affero General Public License for more details.                 *
- *                                                                       *
- *   You should have received a copy of the GNU Affero General Public    *
- *   License along with Plateform.                                       *
- *   If not, see <http://www.gnu.org/licenses/>.                         *
- *************************************************************************/
-
-#include "net/websocket.hpp"
-
-#include <exception>
-#include <iostream>
-
-const size_t DEFAULT_MAX_PAYLOAD_SIZE = 16384; // 16 KB
-
-namespace net {
-
-WebSocket::WebSocket(void) : mMaxPayloadSize(DEFAULT_MAX_PAYLOAD_SIZE) {}
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if ENABLE_WEBSOCKET
+
+#include "include.hpp"
+#include "websocket.hpp"
+
+#include "tcptransport.hpp"
+#include "tlstransport.hpp"
+#include "wstransport.hpp"
+
+#include <regex>
+
+namespace rtc {
+
+WebSocket::WebSocket() {}
 
 WebSocket::WebSocket(const string &url) : WebSocket() { open(url); }
 
-WebSocket::~WebSocket(void) {}
+WebSocket::~WebSocket() { close(); }
 
 void WebSocket::open(const string &url) {
-	close();
+	static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)";
+	static std::regex regex(rs, std::regex::extended);
+
+	std::smatch match;
+	if (!std::regex_match(url, match, regex))
+		throw std::invalid_argument("Malformed WebSocket URL: " + url);
+
+	mScheme = match[2];
+	if (mScheme != "ws" && mScheme != "wss")
+		throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
+
+	mHost = match[4];
+	if (auto pos = mHost.find(':'); pos != string::npos) {
+		mHostname = mHost.substr(0, pos);
+		mService = mHost.substr(pos + 1);
+	} else {
+		mHostname = mHost;
+		mService = mScheme == "ws" ? "80" : "443";
+	}
+
+	mPath = match[5];
+	if (string query = match[7]; !query.empty())
+		mPath += "?" + query;
 
-	mUrl = url;
-	mThread = std::thread(&WebSocket::run, this);
+	initTcpTransport();
 }
 
-void WebSocket::close(void) {
-	mWebSocket.close();
-	if (mThread.joinable())
-		mThread.join();
-	mConnected = false;
+void WebSocket::close() {
+	resetCallbacks();
+	closeTransports();
 }
 
-bool WebSocket::isOpen(void) const { return mConnected; }
+void WebSocket::remoteClose() {
+	mIsOpen = false;
+	if (!mIsClosed.exchange(true))
+		triggerClosed();
+}
 
-bool WebSocket::isClosed(void) const { return !mThread.joinable(); }
+bool WebSocket::send(const std::variant<binary, string> &data) {
+	return std::visit(
+	    [&](const auto &d) {
+		    using T = std::decay_t<decltype(d)>;
+		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
+		    auto *b = reinterpret_cast<const byte *>(d.data());
+		    return outgoing(std::make_shared<Message>(b, b + d.size(), type));
+	    },
+	    data);
+}
 
-void WebSocket::setMaxPayloadSize(size_t size) { mMaxPayloadSize = size; }
+bool WebSocket::isOpen() const { return mIsOpen; }
 
-bool WebSocket::send(const std::variant<binary, string> &data) {
-	if (!std::holds_alternative<binary>(data))
-		throw std::runtime_error("WebSocket string messages are not supported");
+bool WebSocket::isClosed() const { return mIsClosed; }
 
-	mWebSocket.write(std::get<binary>(data));
-	return true;
-}
+size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
-std::optional<std::variant<binary, string>> WebSocket::receive() {
-	if (!mQueue.empty())
-		return mQueue.pop();
-	else
-		return std::nullopt;
-}
+std::optional<std::variant<binary, string>> WebSocket::receive() { return nullopt; }
 
-void WebSocket::run(void) {
-	if (mUrl.empty())
-		return;
+size_t WebSocket::availableAmount() const { return 0; }
 
-	try {
-		mWebSocket.connect(mUrl);
+bool WebSocket::outgoing(mutable_message_ptr message) {
+	if (mIsClosed || !mWsTransport)
+		throw std::runtime_error("WebSocket is closed");
 
-		mConnected = true;
-		triggerOpen();
+	if (message->size() > maxMessageSize())
+		throw std::runtime_error("Message size exceeds limit");
 
-		while (true) {
-			binary payload;
-			if (!mWebSocket.read(payload, mMaxPayloadSize))
+	return mWsTransport->send(message);
+}
+
+std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
+	using State = TcpTransport::State;
+	try {
+		std::lock_guard lock(mInitMutex);
+		if (auto transport = std::atomic_load(&mTcpTransport))
+			return transport;
+
+		auto transport = std::make_shared<TcpTransport>(mHostname, mService, [this](State state) {
+			switch (state) {
+			case State::Connected:
+				if (mScheme == "ws")
+					initWsTransport();
+				else
+					initTlsTransport();
+				break;
+			case State::Failed:
+				// TODO
+				break;
+			case State::Disconnected:
+				// TODO
 				break;
-			mQueue.push(std::move(payload));
-			triggerAvailable(mQueue.size());
-		}
+			default:
+				// Ignore
+				break;
+			}
+		});
+		std::atomic_store(&mTcpTransport, transport);
+		return transport;
 	} catch (const std::exception &e) {
-		triggerError(e.what());
+		PLOG_ERROR << e.what();
+		// TODO
+		throw std::runtime_error("TCP transport initialization failed");
 	}
+}
 
-	mWebSocket.close();
+std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
+	using State = TlsTransport::State;
+	try {
+		std::lock_guard lock(mInitMutex);
+		if (auto transport = std::atomic_load(&mTlsTransport))
+			return transport;
+
+		auto lower = std::atomic_load(&mTcpTransport);
+		auto transport = std::make_shared<TlsTransport>(lower, mHost, [this](State state) {
+			switch (state) {
+			case State::Connected:
+				initWsTransport();
+				break;
+			case State::Failed:
+				// TODO
+				break;
+			case State::Disconnected:
+				// TODO
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		});
+		std::atomic_store(&mTlsTransport, transport);
+		return transport;
+	} catch (const std::exception &e) {
+		PLOG_ERROR << e.what();
+		// TODO
+		throw std::runtime_error("TLS transport initialization failed");
+	}
+}
 
-	if (mConnected)
-		triggerClosed();
-	mConnected = false;
+std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
+	using State = WsTransport::State;
+	try {
+		std::lock_guard lock(mInitMutex);
+		if (auto transport = std::atomic_load(&mWsTransport))
+			return transport;
+
+		std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
+		if (!lower)
+			lower = std::atomic_load(&mTcpTransport);
+		auto transport = std::make_shared<WsTransport>(lower, mHost, mPath, [this](State state) {
+			switch (state) {
+			case State::Connected:
+				triggerOpen();
+				break;
+			case State::Failed:
+				// TODO
+				break;
+			case State::Disconnected:
+				// TODO
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		});
+		std::atomic_store(&mWsTransport, transport);
+		return transport;
+	} catch (const std::exception &e) {
+		PLOG_ERROR << e.what();
+		// TODO
+		throw std::runtime_error("WebSocket transport initialization failed");
+	}
 }
 
-} // namespace net
+void closeTransports() {
+	mIsOpen = false;
+	mIsClosed = true;
+
+	// Pass the references to a thread, allowing to terminate a transport from its own thread
+	auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
+	auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
+	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
+	if (ws || dtls || tcp) {
+		std::thread t([ws, dtls, tcp]() mutable {
+			if (ws)
+				ws->stop();
+			if (dtls)
+				dtls->stop();
+			if (tcp)
+				tcp->stop();
+
+			ws.reset();
+			dtls.reset();
+			tcp.reset();
+		});
+		t.detach();
+	}
+}
+
+} // namespace rtc
+
+#endif

+ 5 - 4
src/wstransport.cpp

@@ -53,11 +53,12 @@ using std::to_string;
 using random_bytes_engine =
     std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
 
-WsTransport::WsTransport(std::shared_ptr<TcpTransport> lower, string host, string path)
-    : Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {}
+WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
+                         state_callback callback)
+    : Transport(lower, std::move(callback)), mHost(std::move(host)), mPath(std::move(path)) {
 
-WsTransport::WsTransport(std::shared_ptr<TlsTransport> lower, string host, string path)
-    : Transport(lower), mHost(std::move(host)), mPath(std::move(path)) {}
+	registerIncoming();
+}
 
 WsTransport::~WsTransport() {}
 

+ 2 - 2
src/wstransport.hpp

@@ -31,8 +31,8 @@ class TlsTransport;
 
 class WsTransport : public Transport {
 public:
-	WsTransport(std::shared_ptr<TcpTransport> lower, string host, string path);
-	WsTransport(std::shared_ptr<TlsTransport> lower, string host, string path);
+	WsTransport(std::shared_ptr<Transport> lower, string host, string path,
+	            state_callback callback);
 	~WsTransport();
 
 	void stop() override;