Browse Source

Fixing transport link

eric.gressman 2 years ago
parent
commit
9248db44b5

+ 3 - 1
src/impl/tcpproxytransport.cpp

@@ -14,12 +14,12 @@
 
 namespace rtc::impl {
 
-using std::to_integer;
 using std::to_string;
 using std::chrono::system_clock;
 
 TcpProxyTransport::TcpProxyTransport(shared_ptr<TcpTransport> lower, std::string hostname, std::string service, state_callback stateCallback)
     : Transport(lower, std::move(stateCallback))
+	, mIsActive( lower->isActive() )
 	, mHostname( std::move(hostname) )
 	, mService( std::move(service) )
 {
@@ -49,6 +49,8 @@ bool TcpProxyTransport::send(message_ptr message) {
 	return outgoing(message);
 }
 
+bool TcpProxyTransport::isActive() const { return mIsActive; }
+
 void TcpProxyTransport::incoming(message_ptr message) {
 	auto s = state();
 	if (s != State::Connecting && s != State::Connected)

+ 3 - 3
src/impl/tcpproxytransport.hpp

@@ -15,12 +15,9 @@
 
 #if RTC_ENABLE_WEBSOCKET
 
-#include <atomic>
-
 namespace rtc::impl {
 
 class TcpTransport;
-class TlsTransport;
 
 class TcpProxyTransport final : public Transport, public std::enable_shared_from_this<TcpProxyTransport> {
 public:
@@ -32,12 +29,15 @@ public:
 	void stop() override;
 	bool send(message_ptr message) override;
 
+	bool isActive() const;
+
 private:
 	void incoming(message_ptr message) override;
 	bool sendHttpRequest();
 	std::string generateHttpRequest();
 	size_t parseHttpResponse( std::byte* buffer, size_t size );
 
+	const bool mIsActive;
 	std::string mHostname;
 	std::string mService;
 	binary mBuffer;

+ 17 - 6
src/impl/tlstransport.cpp

@@ -8,6 +8,7 @@
 
 #include "tlstransport.hpp"
 #include "tcptransport.hpp"
+#include "tcpproxytransport.hpp"
 #include "threadpool.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
@@ -58,10 +59,15 @@ void TlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
+TlsTransport::TlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TcpProxyTransport>> lower, optional<string> host,
                            certificate_ptr certificate, state_callback callback)
-    : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()),
-      mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) {
+    : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
+			    std::move(callback)), mHost(std::move(host))
+	, mIsClient(
+		std::visit(rtc::overloaded{[](shared_ptr<TcpTransport> l) { return l->isActive(); },
+                                   [](shared_ptr<TcpProxyTransport> l) { return l->isClient(); }},
+                   lower))
+	, mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) {
 
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 
@@ -308,10 +314,15 @@ void TlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
+TlsTransport::TlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TcpProxyTransport>> lower, optional<string> host,
                            certificate_ptr certificate, state_callback callback)
-    : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()),
-      mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) {
+    : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
+				std::move(callback)), mHost(std::move(host))
+	, mIsClient(
+          std::visit(rtc::overloaded{[](shared_ptr<TcpTransport> l) { return l->isActive(); },
+                                     [](shared_ptr<TcpProxyTransport> l) { return l->isActive(); }},
+                     lower))
+	, mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) {
 
 	PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
 

+ 2 - 1
src/impl/tlstransport.hpp

@@ -23,13 +23,14 @@
 namespace rtc::impl {
 
 class TcpTransport;
+class TcpProxyTransport;
 
 class TlsTransport : public Transport, public std::enable_shared_from_this<TlsTransport> {
 public:
 	static void Init();
 	static void Cleanup();
 
-	TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host, certificate_ptr certificate,
+	TlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TcpProxyTransport>> lower, optional<string> host, certificate_ptr certificate,
 	             state_callback callback);
 	virtual ~TlsTransport();
 

+ 1 - 1
src/impl/verifiedtlstransport.cpp

@@ -13,7 +13,7 @@
 
 namespace rtc::impl {
 
-VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host,
+VerifiedTlsTransport::VerifiedTlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TcpProxyTransport>> lower, string host,
                                            certificate_ptr certificate, state_callback callback)
     : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) {
 

+ 1 - 1
src/impl/verifiedtlstransport.hpp

@@ -17,7 +17,7 @@ namespace rtc::impl {
 
 class VerifiedTlsTransport final : public TlsTransport {
 public:
-	VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host, certificate_ptr certificate,
+	VerifiedTlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TcpProxyTransport>> lower, string host, certificate_ptr certificate,
 	                     state_callback callback);
 	~VerifiedTlsTransport();
 };

+ 14 - 3
src/impl/websocket.cpp

@@ -316,9 +316,20 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 		if (auto transport = std::atomic_load(&mTlsTransport))
 			return transport;
 
-		auto lower = std::atomic_load(&mTcpTransport);
-		if (!lower)
-			throw std::logic_error("No underlying TCP transport for TLS transport");
+		variant<shared_ptr<TcpTransport>, shared_ptr<TcpProxyTransport>> lower;
+		if (mIsProxied) {
+			auto transport = std::atomic_load(&mProxyTransport);
+			if (!transport)
+				throw std::logic_error("No underlying TLS transport for WebSocket transport");
+
+			lower = transport;
+		} else {
+			auto transport = std::atomic_load(&mTcpTransport);
+			if (!transport)
+				throw std::logic_error("No underlying TCP transport for WebSocket transport");
+
+			lower = transport;
+		}
 
 		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
 			auto shared_this = weak_this.lock();