Browse Source

Initial implementation of a proxy for the websocket

eric.gressman 2 years ago
parent
commit
a609e3a9df

+ 2 - 0
CMakeLists.txt

@@ -121,6 +121,7 @@ set(LIBDATACHANNEL_IMPL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollinterrupter.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollinterrupter.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollservice.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollservice.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpproxytransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.cpp
@@ -153,6 +154,7 @@ set(LIBDATACHANNEL_IMPL_HEADERS
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollinterrupter.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollinterrupter.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollservice.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollservice.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpproxytransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.hpp

+ 1 - 1
include/rtc/websocket.hpp

@@ -34,7 +34,7 @@ public:
 
 
 	struct Configuration {
 	struct Configuration {
 		bool disableTlsVerification = false; // if true, don't verify the TLS certificate
 		bool disableTlsVerification = false; // if true, don't verify the TLS certificate
-		optional<ProxyServer> proxyServer;   // unsupported for now
+		optional<ProxyServer> proxyServer;
 		std::vector<string> protocols;
 		std::vector<string> protocols;
 		optional<std::chrono::milliseconds> pingInterval; // zero to disable
 		optional<std::chrono::milliseconds> pingInterval; // zero to disable
 		optional<int> maxOutstandingPings;
 		optional<int> maxOutstandingPings;

+ 172 - 0
src/impl/tcpproxytransport.cpp

@@ -0,0 +1,172 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/.
+ */
+
+#include "tcpproxytransport.hpp"
+#include "tcptransport.hpp"
+#include "utils.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+// #include <algorithm>
+// #include <chrono>
+// #include <iostream>
+// #include <numeric>
+// #include <random>
+// #include <regex>
+// #include <sstream>
+
+// #ifdef _WIN32
+// #include <winsock2.h>
+// #else
+// #include <arpa/inet.h>
+// #endif
+
+#ifndef htonll
+#define htonll(x)                                                                                  \
+	((uint64_t)(((uint64_t)htonl((uint32_t)(x))) << 32) | (uint64_t)htonl((uint32_t)((x) >> 32)))
+#endif
+#ifndef ntohll
+#define ntohll(x) htonll(x)
+#endif
+
+namespace rtc::impl {
+
+using std::to_integer;
+using std::to_string;
+using std::chrono::system_clock;
+
+TcpProxyTransport::TcpProxyTransport(shared_ptr<TcpTransport> lower, std::string hostname, std::string service, state_callback stateCallback)
+    : Transport(lower, std::move(stateCallback))
+	, mHostname( std::move(hostname) )
+	, mService( std::move(service) )
+{
+	PLOG_DEBUG << "Initializing TCP Proxy transport";
+}
+
+TcpProxyTransport::~TcpProxyTransport() { unregisterIncoming(); }
+
+void TcpProxyTransport::start() {
+	registerIncoming();
+
+	changeState(State::Connecting);
+	sendHttpRequest();
+}
+
+void TcpProxyTransport::stop() {
+	unregisterIncoming();
+}
+
+bool TcpProxyTransport::send(message_ptr message) {
+	std::lock_guard lock(mSendMutex);
+
+	if (state() != State::Connected)
+		throw std::runtime_error("Tcp proxy connection is not open");
+
+	PLOG_VERBOSE << "Send size=" << message->size();
+	return outgoing(message);
+}
+
+void TcpProxyTransport::incoming(message_ptr message) {
+	auto s = state();
+	if (s != State::Connecting && s != State::Connected)
+		return; // Drop
+
+	if (message) {
+		PLOG_VERBOSE << "Incoming size=" << message->size();
+
+		try {
+			mBuffer.insert(mBuffer.end(), message->begin(), message->end());
+
+			if (state() == State::Connecting) {
+				if (size_t len = parseHttpResponse(mBuffer.data(), mBuffer.size())) {
+					PLOG_INFO << "Tcp proxy connection open";
+					changeState(State::Connected);
+					mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				}
+			}
+
+			return;
+		} catch (const std::exception &e) {
+			PLOG_ERROR << e.what();
+		}
+	}
+
+	if (state() == State::Connected) {
+		PLOG_INFO << "TCP Proxy disconnected";
+		changeState(State::Disconnected);
+		recv(nullptr);
+	} else {
+		PLOG_ERROR << "TCP Proxy failed";
+		changeState(State::Failed);
+	}
+}
+
+bool TcpProxyTransport::sendHttpRequest() {
+	PLOG_DEBUG << "Sending TcpProxy HTTP request";
+
+	const string request = generateHttpRequest();
+	auto data = reinterpret_cast<const byte *>(request.data());
+	return outgoing(make_message(data, data + request.size()));
+}
+
+std::string TcpProxyTransport::generateHttpRequest()
+{
+	std::string out =
+		"CONNECT " +
+		mHostname + ":" + mService +
+		" HTTP/1.1\r\nHost: " +
+		mHostname + "\r\n\r\n";
+	return out;
+}
+
+//TODO move to utils?
+size_t parseHttpLines(const byte *buffer, size_t size, std::list<string> &lines) {
+	lines.clear();
+	auto begin = reinterpret_cast<const char *>(buffer);
+	auto end = begin + size;
+	auto cur = begin;
+	while (true) {
+		auto last = cur;
+		cur = std::find(cur, end, '\n');
+		if (cur == end)
+			return 0;
+		string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
+		if (line.empty())
+			break;
+		lines.emplace_back(std::move(line));
+	}
+
+	return cur - begin;
+}
+
+size_t TcpProxyTransport::parseHttpResponse( std::byte* buffer, size_t size )
+{
+	std::list<string> lines;
+	size_t length = parseHttpLines(buffer, size, lines);
+	if (length == 0)
+		return 0;
+
+	if (lines.empty())
+		throw std::runtime_error("Invalid HTTP request for Tcp Proxy");
+
+	std::istringstream status(std::move(lines.front()));
+	lines.pop_front();
+
+	string protocol;
+	unsigned int code = 0;
+	status >> protocol >> code;
+
+	if (code != 200)
+		throw std::runtime_error("Unexpected response code " + to_string(code) + " for Tcp Proxy");
+
+	return length;
+}
+
+} // namespace rtc::impl
+
+#endif

+ 51 - 0
src/impl/tcpproxytransport.hpp

@@ -0,0 +1,51 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/.
+ */
+
+#ifndef RTC_IMPL_TCP_PROXY_TRANSPORT_H
+#define RTC_IMPL_TCP_PROXY_TRANSPORT_H
+
+#include "common.hpp"
+#include "transport.hpp"
+#include "wshandshake.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <atomic>
+
+namespace rtc::impl {
+
+class TcpTransport;
+class TlsTransport;
+
+class TcpProxyTransport final : public Transport, public std::enable_shared_from_this<TcpProxyTransport> {
+public:
+	TcpProxyTransport(shared_ptr<TcpTransport> lower, std::string hostname, std::string service,
+				state_callback stateCallback);
+	~TcpProxyTransport();
+
+	void start() override;
+	void stop() override;
+	bool send(message_ptr message) override;
+
+private:
+	void incoming(message_ptr message) override;
+	bool sendHttpRequest();
+	std::string generateHttpRequest();
+	size_t parseHttpResponse( std::byte* buffer, size_t size );
+
+	std::string mHostname;
+	std::string mService;
+	binary mBuffer;
+	std::mutex mSendMutex;
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif

+ 63 - 3
src/impl/websocket.cpp

@@ -15,6 +15,7 @@
 #include "utils.hpp"
 #include "utils.hpp"
 
 
 #include "tcptransport.hpp"
 #include "tcptransport.hpp"
+#include "tcpproxytransport.hpp"
 #include "tlstransport.hpp"
 #include "tlstransport.hpp"
 #include "verifiedtlstransport.hpp"
 #include "verifiedtlstransport.hpp"
 #include "wstransport.hpp"
 #include "wstransport.hpp"
@@ -48,7 +49,7 @@ void WebSocket::open(const string &url) {
 		throw std::logic_error("WebSocket must be closed before opening");
 		throw std::logic_error("WebSocket must be closed before opening");
 
 
 	if (config.proxyServer) {
 	if (config.proxyServer) {
-		PLOG_WARNING << "Proxy server support for WebSocket is not implemented";
+		mIsProxied = true;
 	}
 	}
 
 
 	// Modified regex from RFC 3986, see https://www.rfc-editor.org/rfc/rfc3986.html#appendix-B
 	// Modified regex from RFC 3986, see https://www.rfc-editor.org/rfc/rfc3986.html#appendix-B
@@ -102,10 +103,20 @@ void WebSocket::open(const string &url) {
 		path += "?" + query;
 		path += "?" + query;
 
 
 	mHostname = hostname; // for TLS SNI
 	mHostname = hostname; // for TLS SNI
+	mService = service; //For proxy
 	std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
 	std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
 
 
 	changeState(State::Connecting);
 	changeState(State::Connecting);
-	setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
+
+	if (mIsProxied)
+	{
+		//TODO catch bad convert
+		setTcpTransport(std::make_shared<TcpTransport>(mProxy.value().hostname, std::to_string(mProxy.value().port), nullptr));
+	}
+	else
+	{
+		setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
+	}
 }
 }
 
 
 void WebSocket::close() {
 void WebSocket::close() {
@@ -218,7 +229,9 @@ shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> tra
 				return;
 				return;
 			switch (transportState) {
 			switch (transportState) {
 			case State::Connected:
 			case State::Connected:
-				if (mIsSecure)
+				if (mIsProxied)
+					initProxyTransport();
+				else if (mIsSecure)
 					initTlsTransport();
 					initTlsTransport();
 				else
 				else
 					initWsTransport();
 					initWsTransport();
@@ -250,6 +263,53 @@ shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> tra
 	}
 	}
 }
 }
 
 
+shared_ptr<TcpProxyTransport> WebSocket::initProxyTransport() {
+	PLOG_VERBOSE << "Starting Tcp Proxy transport";
+	using State = TcpProxyTransport::State;
+	try {
+		if (auto transport = std::atomic_load(&mProxyTransport))
+			return transport;
+
+		auto lower = std::atomic_load(&mTcpTransport);
+		if (!lower)
+			throw std::logic_error("No underlying TCP transport for Proxy transport");
+
+		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
+			auto shared_this = weak_this.lock();
+			if (!shared_this)
+				return;
+			switch (transportState) {
+			case State::Connected:
+				if (mIsSecure)
+					initTlsTransport();
+				else
+					initWsTransport();
+				break;
+			case State::Failed:
+				triggerError("Proxy connection failed");
+				remoteClose();
+				break;
+			case State::Disconnected:
+				remoteClose();
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		};
+
+		//TODO check optionals?
+		auto transport = std::make_shared<TcpProxyTransport>( lower, mHostname.value(), mService.value(), stateChangeCallback );
+
+		return emplaceTransport(this, &mProxyTransport, std::move(transport));
+
+	} catch (const std::exception &e) {
+		PLOG_ERROR << e.what();
+		remoteClose();
+		throw std::runtime_error("Tcp Proxy transport initialization failed");
+	}
+}
+
 shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 	PLOG_VERBOSE << "Starting TLS transport";
 	PLOG_VERBOSE << "Starting TLS transport";
 	using State = TlsTransport::State;
 	using State = TlsTransport::State;

+ 6 - 0
src/impl/websocket.hpp

@@ -17,6 +17,7 @@
 #include "message.hpp"
 #include "message.hpp"
 #include "queue.hpp"
 #include "queue.hpp"
 #include "tcptransport.hpp"
 #include "tcptransport.hpp"
+#include "tcpproxytransport.hpp"
 #include "tlstransport.hpp"
 #include "tlstransport.hpp"
 #include "wstransport.hpp"
 #include "wstransport.hpp"
 
 
@@ -51,6 +52,7 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	bool changeState(State state);
 	bool changeState(State state);
 
 
 	shared_ptr<TcpTransport> setTcpTransport(shared_ptr<TcpTransport> transport);
 	shared_ptr<TcpTransport> setTcpTransport(shared_ptr<TcpTransport> transport);
+	shared_ptr<TcpProxyTransport> initProxyTransport();
 	shared_ptr<TlsTransport> initTlsTransport();
 	shared_ptr<TlsTransport> initTlsTransport();
 	shared_ptr<WsTransport> initWsTransport();
 	shared_ptr<WsTransport> initWsTransport();
 	shared_ptr<TcpTransport> getTcpTransport() const;
 	shared_ptr<TcpTransport> getTcpTransport() const;
@@ -69,10 +71,14 @@ private:
 
 
 	const certificate_ptr mCertificate;
 	const certificate_ptr mCertificate;
 	bool mIsSecure;
 	bool mIsSecure;
+	bool mIsProxied{false};
 
 
+	optional<ProxyServer> mProxy;
 	optional<string> mHostname; // for TLS SNI
 	optional<string> mHostname; // for TLS SNI
+	optional<string> mService; // for Proxy
 
 
 	shared_ptr<TcpTransport> mTcpTransport;
 	shared_ptr<TcpTransport> mTcpTransport;
+	shared_ptr<TcpProxyTransport> mProxyTransport;
 	shared_ptr<TlsTransport> mTlsTransport;
 	shared_ptr<TlsTransport> mTlsTransport;
 	shared_ptr<WsTransport> mWsTransport;
 	shared_ptr<WsTransport> mWsTransport;
 	shared_ptr<WsHandshake> mWsHandshake;
 	shared_ptr<WsHandshake> mWsHandshake;