Paul-Louis Ageneau 5 years ago
parent
commit
0d8aedfa01

+ 0 - 1
include/rtc/datachannel.hpp

@@ -82,7 +82,6 @@ private:
 	std::atomic<bool> mIsClosed = false;
 
 	Queue<message_ptr> mRecvQueue;
-	std::atomic<size_t> mRecvAmount = 0;
 
 	friend class PeerConnection;
 };

+ 13 - 5
include/rtc/websocket.hpp

@@ -40,10 +40,19 @@ class WsTransport;
 
 class WebSocket final : public Channel, public std::enable_shared_from_this<WebSocket> {
 public:
+	enum class State : int {
+		Connecting = 0,
+		Open = 1,
+		Closing = 2,
+		Closed = 3,
+	};
+
 	WebSocket();
 	WebSocket(const string &url);
 	~WebSocket();
 
+	State readyState() const;
+
 	void open(const string &url);
 	void close() override;
 	bool send(const std::variant<binary, string> &data) override;
@@ -57,8 +66,10 @@ public:
 	size_t availableAmount() const override; // total size available to receive
 
 private:
+	bool changeState(State state);
 	void remoteClose();
 	bool outgoing(mutable_message_ptr message);
+	void incoming(message_ptr message);
 
 	std::shared_ptr<TcpTransport> initTcpTransport();
 	std::shared_ptr<TlsTransport> initTlsTransport();
@@ -73,15 +84,12 @@ private:
 	std::recursive_mutex mInitMutex;
 
 	string mScheme, mHost, mHostname, mService, mPath;
-
-	std::atomic<bool> mIsOpen = false;
-	std::atomic<bool> mIsClosed = false;
+	std::atomic<State> mState = State::Closed;
 
 	Queue<message_ptr> mRecvQueue;
-	std::atomic<size_t> mRecvAmount = 0;
 };
 } // namespace rtc
 
 #endif
 
-#endif // NET_WEBSOCKET_H
+#endif // RTC_WEBSOCKET_H

+ 3 - 0
src/datachannel.cpp

@@ -214,6 +214,9 @@ bool DataChannel::outgoing(mutable_message_ptr message) {
 }
 
 void DataChannel::incoming(message_ptr message) {
+	if (!message)
+		return;
+
 	switch (message->type) {
 	case Message::Control: {
 		auto raw = reinterpret_cast<const uint8_t *>(message->data());

+ 1 - 1
src/init.cpp

@@ -76,7 +76,7 @@ Init::Init() {
 	SctpTransport::Init();
 	DtlsTransport::Init();
 #if RTC_ENABLE_WEBSOCKET
-	TlsTransport::Cleanup();
+	TlsTransport::Init();
 #endif
 }
 

+ 0 - 1
src/peerconnection.cpp

@@ -23,7 +23,6 @@
 #include "include.hpp"
 #include "sctptransport.hpp"
 
-#include <iostream>
 #include <thread>
 
 namespace rtc {

+ 14 - 2
src/tcptransport.cpp

@@ -26,6 +26,8 @@ using std::to_string;
 
 TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
     : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
+
+	PLOG_DEBUG << "Initializing TCP transport";
 	mThread = std::thread(&TcpTransport::runLoop, this);
 }
 
@@ -114,7 +116,10 @@ void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
 			throw std::runtime_error("Connection failed");
 
 	} catch (...) {
-		close();
+		if (mSock != INVALID_SOCKET) {
+			::closesocket(mSock);
+			mSock = INVALID_SOCKET;
+		}
 		throw;
 	}
 }
@@ -124,6 +129,7 @@ void TcpTransport::close() {
 		::closesocket(mSock);
 		mSock = INVALID_SOCKET;
 	}
+	changeState(State::Disconnected);
 }
 
 bool TcpTransport::trySendQueue() {
@@ -160,15 +166,20 @@ bool TcpTransport::trySendMessage(message_ptr &message) {
 void TcpTransport::runLoop() {
 	const size_t bufferSize = 4096;
 
+	changeState(State::Connecting);
+
 	// Connect
 	try {
 		connect(mHostname, mService);
 
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TCP connect: " << e.what();
+		changeState(State::Failed);
 		return;
 	}
 
+	changeState(State::Connected);
+
 	// Receive loop
 	try {
 		while (true) {
@@ -188,7 +199,7 @@ void TcpTransport::runLoop() {
 					break; // clean close
 
 				auto *b = reinterpret_cast<byte *>(buffer);
-				incoming(make_message(b, b + ret));
+				incoming(make_message(b, b + len));
 			}
 
 			if (FD_ISSET(mSock, &writefds))
@@ -199,6 +210,7 @@ void TcpTransport::runLoop() {
 	}
 
 	PLOG_INFO << "TCP disconnected";
+	changeState(State::Disconnected);
 	recv(nullptr);
 }
 

+ 15 - 6
src/tlstransport.cpp

@@ -97,7 +97,7 @@ TlsTransport::~TlsTransport() {
 	gnutls_deinit(mSession);
 }
 
-bool DtlsTransport::stop() {
+bool TlsTransport::stop() {
 	if (!Transport::stop())
 		return false;
 
@@ -129,6 +129,8 @@ void TlsTransport::incoming(message_ptr message) {
 void TlsTransport::runRecvLoop() {
 	const size_t bufferSize = 4096;
 
+	changeState(State::Connecting);
+
 	// Handshake loop
 	try {
 		int ret;
@@ -139,9 +141,12 @@ void TlsTransport::runRecvLoop() {
 
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TLS handshake: " << e.what();
+		changeState(State::Failed);
 		return;
 	}
 
+	changeState(State::Connected);
+
 	// Receive loop
 	try {
 		while (true) {
@@ -167,7 +172,6 @@ void TlsTransport::runRecvLoop() {
 				recv(make_message(b, b + ret));
 			}
 		}
-
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TLS recv: " << e.what();
 	}
@@ -175,6 +179,7 @@ void TlsTransport::runRecvLoop() {
 	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 
 	PLOG_INFO << "TLS disconnected";
+	changeState(State::Disconnected);
 	recv(nullptr);
 }
 
@@ -368,13 +373,17 @@ void TlsTransport::runRecvLoop() {
 			outgoing(make_message(buffer, buffer + len));
 
 		while (auto next = mIncomingQueue.pop()) {
-			auto message = *next;
+			message_ptr message = *next;
+			message_ptr decrypted;
+
 			BIO_write(mInBio, message->data(), message->size());
+
 			int ret = SSL_read(mSsl, buffer, bufferSize);
 			if (!check_openssl_ret(mSsl, ret))
 				break;
 
-			auto received = ret > 0 ? make_message(buffer, buffer + ret) : nullptr;
+			if (ret > 0)
+				decrypted = make_message(buffer, buffer + ret);
 
 			while (int len = BIO_read(mOutBio, buffer, bufferSize))
 				outgoing(make_message(buffer, buffer + len));
@@ -382,8 +391,8 @@ void TlsTransport::runRecvLoop() {
 			if (!initFinished && SSL_is_init_finished(mSsl))
 				initFinished = true;
 
-			if (received)
-				recv(received);
+			if (decrypted)
+				recv(decrypted);
 		}
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TLS recv: " << e.what();

+ 3 - 1
src/tlstransport.hpp

@@ -41,6 +41,9 @@ class TcpTransport;
 
 class TlsTransport : public Transport {
 public:
+	static void Init();
+	static void Cleanup();
+
 	TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
 	~TlsTransport();
 
@@ -68,7 +71,6 @@ protected:
 
 	static int TransportExIndex;
 
-	static void GlobalInit();
 	static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
 	static void InfoCallback(const SSL *ssl, int where, int ret);
 #endif

+ 105 - 44
src/websocket.cpp

@@ -27,15 +27,24 @@
 
 #include <regex>
 
+#ifdef _WIN32
+#include <winsock2.h>
+#endif
+
 namespace rtc {
 
 WebSocket::WebSocket() {}
 
 WebSocket::WebSocket(const string &url) : WebSocket() { open(url); }
 
-WebSocket::~WebSocket() { close(); }
+WebSocket::~WebSocket() { remoteClose(); }
+
+WebSocket::State WebSocket::readyState() const { return mState; }
 
 void WebSocket::open(const string &url) {
+	if (mState != State::Closed)
+		throw std::runtime_error("WebSocket must be closed before opening");
+
 	static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)";
 	static std::regex regex(rs, std::regex::extended);
 
@@ -60,18 +69,24 @@ void WebSocket::open(const string &url) {
 	if (string query = match[7]; !query.empty())
 		mPath += "?" + query;
 
+	changeState(State::Connecting);
 	initTcpTransport();
 }
 
 void WebSocket::close() {
-	resetCallbacks();
-	closeTransports();
+	auto state = mState.load();
+	if (state == State::Connecting || state == State::Open) {
+		changeState(State::Closing);
+		if (auto transport = std::atomic_load(&mWsTransport))
+			transport->close();
+		else
+			changeState(State::Closed);
+	}
 }
 
 void WebSocket::remoteClose() {
-	mIsOpen = false;
-	if (!mIsClosed.exchange(true))
-		triggerClosed();
+	close();
+	closeTransports();
 }
 
 bool WebSocket::send(const std::variant<binary, string> &data) {
@@ -85,19 +100,36 @@ bool WebSocket::send(const std::variant<binary, string> &data) {
 	    data);
 }
 
-bool WebSocket::isOpen() const { return mIsOpen; }
+bool WebSocket::isOpen() const { return mState == State::Open; }
 
-bool WebSocket::isClosed() const { return mIsClosed; }
+bool WebSocket::isClosed() const { return mState == State::Closed; }
 
 size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
-std::optional<std::variant<binary, string>> WebSocket::receive() { return nullopt; }
+std::optional<std::variant<binary, string>> WebSocket::receive() {
+	while (!mRecvQueue.empty()) {
+		auto message = *mRecvQueue.pop();
+		switch (message->type) {
+		case Message::String:
+			return std::make_optional(
+			    string(reinterpret_cast<const char *>(message->data()), message->size()));
+		case Message::Binary:
+			return std::make_optional(std::move(*message));
+		default:
+			// Ignore
+			break;
+		}
+	}
+	return nullopt;
+}
+
+size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
 
-size_t WebSocket::availableAmount() const { return 0; }
+bool WebSocket::changeState(State state) { return mState.exchange(state) != state; }
 
 bool WebSocket::outgoing(mutable_message_ptr message) {
-	if (mIsClosed || !mWsTransport)
-		throw std::runtime_error("WebSocket is closed");
+	if (mState != State::Open || !mWsTransport)
+		throw std::runtime_error("WebSocket is not open");
 
 	if (message->size() > maxMessageSize())
 		throw std::runtime_error("Message size exceeds limit");
@@ -105,6 +137,13 @@ bool WebSocket::outgoing(mutable_message_ptr message) {
 	return mWsTransport->send(message);
 }
 
+void WebSocket::incoming(message_ptr message) {
+	if (message->type == Message::String || message->type == Message::Binary) {
+		mRecvQueue.push(message);
+		triggerAvailable(mRecvQueue.size());
+	}
+}
+
 std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
 	using State = TcpTransport::State;
 	try {
@@ -121,10 +160,11 @@ std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
 					initTlsTransport();
 				break;
 			case State::Failed:
-				// TODO
+				triggerError("TCP connection failed");
+				remoteClose();
 				break;
 			case State::Disconnected:
-				// TODO
+				remoteClose();
 				break;
 			default:
 				// Ignore
@@ -132,10 +172,15 @@ std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
 			}
 		});
 		std::atomic_store(&mTcpTransport, transport);
+		if (mState == WebSocket::State::Closed) {
+			mTcpTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
 		return transport;
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
-		// TODO
+		remoteClose();
 		throw std::runtime_error("TCP transport initialization failed");
 	}
 }
@@ -154,10 +199,11 @@ std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 				initWsTransport();
 				break;
 			case State::Failed:
-				// TODO
+				triggerError("TCP connection failed");
+				remoteClose();
 				break;
 			case State::Disconnected:
-				// TODO
+				remoteClose();
 				break;
 			default:
 				// Ignore
@@ -165,10 +211,15 @@ std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 			}
 		});
 		std::atomic_store(&mTlsTransport, transport);
+		if (mState == WebSocket::State::Closed) {
+			mTlsTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
 		return transport;
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
-		// TODO
+		remoteClose();
 		throw std::runtime_error("TLS transport initialization failed");
 	}
 }
@@ -183,50 +234,60 @@ std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
 		if (!lower)
 			lower = std::atomic_load(&mTcpTransport);
-		auto transport = std::make_shared<WsTransport>(lower, mHost, mPath, [this](State state) {
-			switch (state) {
-			case State::Connected:
-				triggerOpen();
-				break;
-			case State::Failed:
-				// TODO
-				break;
-			case State::Disconnected:
-				// TODO
-				break;
-			default:
-				// Ignore
-				break;
-			}
-		});
+		auto transport = std::make_shared<WsTransport>(
+		    lower, mHost, mPath, std::bind(&WebSocket::incoming, this, _1), [this](State state) {
+			    switch (state) {
+			    case State::Connected:
+				    if (mState == WebSocket::State::Connecting) {
+					    PLOG_DEBUG << "WebSocket open";
+					    changeState(WebSocket::State::Open);
+					    triggerOpen();
+				    }
+				    break;
+			    case State::Failed:
+				    triggerError("WebSocket connection failed");
+				    remoteClose();
+				    break;
+			    case State::Disconnected:
+				    remoteClose();
+				    break;
+			    default:
+				    // Ignore
+				    break;
+			    }
+		    });
 		std::atomic_store(&mWsTransport, transport);
+		if (mState == WebSocket::State::Closed) {
+			mWsTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
 		return transport;
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
-		// TODO
+		remoteClose();
 		throw std::runtime_error("WebSocket transport initialization failed");
 	}
 }
 
-void closeTransports() {
-	mIsOpen = false;
-	mIsClosed = true;
+void WebSocket::closeTransports() {
+	changeState(State::Closed);
 
 	// Pass the references to a thread, allowing to terminate a transport from its own thread
 	auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
-	auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
+	auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
 	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
-	if (ws || dtls || tcp) {
-		std::thread t([ws, dtls, tcp]() mutable {
+	if (ws || tls || tcp) {
+		std::thread t([ws, tls, tcp]() mutable {
 			if (ws)
 				ws->stop();
-			if (dtls)
-				dtls->stop();
+			if (tls)
+				tls->stop();
 			if (tcp)
 				tcp->stop();
 
 			ws.reset();
-			dtls.reset();
+			tls.reset();
 			tcp.reset();
 		});
 		t.detach();

+ 50 - 24
src/wstransport.cpp

@@ -54,15 +54,25 @@ using random_bytes_engine =
     std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
 
 WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
-                         state_callback callback)
-    : Transport(lower, std::move(callback)), mHost(std::move(host)), mPath(std::move(path)) {
+                         message_callback recvCallback, state_callback stateCallback)
+    : Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)) {
+	onRecv(recvCallback);
+
+	PLOG_DEBUG << "Initializing WebSocket transport";
 
 	registerIncoming();
+	sendHttpRequest();
 }
 
-WsTransport::~WsTransport() {}
+WsTransport::~WsTransport() { stop(); }
+
+bool WsTransport::stop() {
+	if (!Transport::stop())
+		return false;
 
-void WsTransport::stop() {}
+	close();
+	return true;
+}
 
 bool WsTransport::send(message_ptr message) {
 	if (!message)
@@ -73,7 +83,7 @@ bool WsTransport::send(message_ptr message) {
 }
 
 bool WsTransport::send(mutable_message_ptr message) {
-	if (!message)
+	if (!message || state() != State::Connected)
 		return false;
 
 	return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
@@ -81,32 +91,39 @@ bool WsTransport::send(mutable_message_ptr message) {
 }
 
 void WsTransport::incoming(message_ptr message) {
-	mBuffer.insert(mBuffer.end(), message->begin(), message->end());
-
-	if (!mHandshakeDone) {
-		if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
-			mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
-			mHandshakeDone = true;
+	try {
+		mBuffer.insert(mBuffer.end(), message->begin(), message->end());
+
+		if (state() == State::Connecting) {
+			if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
+				mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				changeState(State::Connected);
+			}
 		}
-	}
 
-	if (mHandshakeDone) {
-		Frame frame = {};
-		while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
-			mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
-			recvFrame(frame);
+		if (state() == State::Connected) {
+			Frame frame = {};
+			while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
+				mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				recvFrame(frame);
+			}
 		}
+	} catch (const std::exception &e) {
+		PLOG_ERROR << e.what();
+		changeState(State::Failed);
 	}
 }
 
-void WsTransport::connect() { sendHttpRequest(); }
-
 void WsTransport::close() {
-	if (mHandshakeDone)
+	if (state() == State::Connected) {
 		sendFrame({CLOSE, NULL, 0, true, true});
+		changeState(State::Completed);
+	}
 }
 
 bool WsTransport::sendHttpRequest() {
+	changeState(State::Connecting);
+
 	auto seed = system_clock::now().time_since_epoch().count();
 	random_bytes_engine generator(seed);
 
@@ -133,19 +150,25 @@ bool WsTransport::sendHttpRequest() {
 }
 
 size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
-
 	std::list<string> lines;
 	auto begin = reinterpret_cast<const char *>(buffer);
 	auto end = begin + size;
 	auto cur = begin;
-	while ((cur = std::find(cur, end, '\n')) != end) {
-		string line(begin, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
+	while (true) {
+		auto last = cur;
+		cur = std::find(cur, end, '\n');
+		if (cur == end)
+			return 0;
+		string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
 		if (line.empty())
 			break;
 		lines.emplace_back(std::move(line));
 	}
 	size_t length = cur - begin;
 
+	if (lines.empty())
+		throw std::runtime_error("Invalid HTTP response for WebSocket");
+
 	string status = std::move(lines.front());
 	lines.pop_front();
 
@@ -153,6 +176,7 @@ size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
 	string protocol;
 	unsigned int code = 0;
 	ss >> protocol >> code;
+	PLOG_DEBUG << "WebSocket response code: " << code;
 	if (code != 101)
 		throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code));
 
@@ -174,7 +198,8 @@ size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
 		throw std::runtime_error("WebSocket update header missing or mismatching");
 
 	h = headers.find("sec-websocket-accept");
-	throw std::runtime_error("WebSocket accept header missing");
+	if (h == headers.end())
+		throw std::runtime_error("WebSocket accept header missing");
 
 	// TODO: Verify Sec-WebSocket-Accept
 
@@ -284,6 +309,7 @@ void WsTransport::recvFrame(const Frame &frame) {
 	}
 	case CLOSE: {
 		close();
+		changeState(State::Disconnected);
 		break;
 	}
 	default: {

+ 5 - 7
src/wstransport.hpp

@@ -32,15 +32,17 @@ class TlsTransport;
 class WsTransport : public Transport {
 public:
 	WsTransport(std::shared_ptr<Transport> lower, string host, string path,
-	            state_callback callback);
+	            message_callback recvCallback, state_callback stateCallback);
 	~WsTransport();
 
-	void stop() override;
+	bool stop() override;
 	bool send(message_ptr message) override;
 	bool send(mutable_message_ptr message);
 
 	void incoming(message_ptr message) override;
 
+	void close();
+
 private:
 	enum Opcode : uint8_t {
 		CONTINUATION = 0,
@@ -59,9 +61,6 @@ private:
 		bool mask = true;
 	};
 
-	void connect();
-	void close();
-
 	bool sendHttpRequest();
 	size_t readHttpResponse(const byte *buffer, size_t size);
 
@@ -72,11 +71,10 @@ private:
 	const string mHost;
 	const string mPath;
 
-	bool mHandshakeDone = false;
 	binary mBuffer;
 	binary mPartial;
 	Opcode mPartialOpcode;
-	};
+};
 
 } // namespace rtc