Browse Source

Merge pull request #147 from paullouisageneau/transport-init

Redesign transport initialization
Paul-Louis Ageneau 5 years ago
parent
commit
672287bbcf

+ 13 - 14
src/dtlssrtptransport.cpp

@@ -45,6 +45,18 @@ DtlsSrtpTransport::DtlsSrtpTransport(std::shared_ptr<IceTransport> lower,
 
 	PLOG_DEBUG << "Initializing DTLS-SRTP transport";
 
+#if USE_GNUTLS
+	PLOG_DEBUG << "Setting SRTP profile (GnuTLS)";
+	gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80),
+	              "Failed to set SRTP profile");
+#else
+	PLOG_DEBUG << "Setting SRTP profile (OpenSSL)";
+	// returns 0 on success, 1 on error
+	if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"), "Failed to set SRTP profile")
+		throw std::runtime_error("Failed to set SRTP profile: " +
+		                         openssl::error_string(ERR_get_error()));
+#endif
+
 	if (srtp_err_status_t err = srtp_create(&mSrtpIn, nullptr)) {
 		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
 	}
@@ -55,7 +67,7 @@ DtlsSrtpTransport::DtlsSrtpTransport(std::shared_ptr<IceTransport> lower,
 }
 
 DtlsSrtpTransport::~DtlsSrtpTransport() {
-	stop();
+	stop(); // stop before deallocating
 
 	srtp_dealloc(mSrtpIn);
 	srtp_dealloc(mSrtpOut);
@@ -181,19 +193,6 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 	}
 }
 
-void DtlsSrtpTransport::postCreation() {
-#if USE_GNUTLS
-	PLOG_DEBUG << "Setting SRTP profile (GnuTLS)";
-	gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80),
-	              "Failed to set SRTP profile");
-#else
-	PLOG_DEBUG << "Setting SRTP profile (OpenSSL)";
-	// returns 0 on success, 1 on error
-	if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"), "Failed to set SRTP profile")
-		throw std::runtime_error("Failed to set SRTP profile: " + openssl::error_string(ERR_get_error()));
-#endif
-}
-
 void DtlsSrtpTransport::postHandshake() {
 	if (mInitDone)
 		return;

+ 0 - 1
src/dtlssrtptransport.hpp

@@ -42,7 +42,6 @@ public:
 
 private:
 	void incoming(message_ptr message) override;
-	void postCreation() override;
 	void postHandshake() override;
 
 	message_callback mSrtpRecvCallback;

+ 18 - 16
src/dtlstransport.cpp

@@ -85,11 +85,6 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		gnutls_transport_set_pull_function(mSession, ReadCallback);
 		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 
-		postCreation();
-
-		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
-		registerIncoming();
-
 	} catch (...) {
 		gnutls_deinit(mSession);
 		throw;
@@ -102,6 +97,15 @@ DtlsTransport::~DtlsTransport() {
 	gnutls_deinit(mSession);
 }
 
+void DtlsTransport::start() {
+	Transport::start();
+
+	registerIncoming();
+
+	PLOG_DEBUG << "Starting DTLS recv thread";
+	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+}
+
 bool DtlsTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -139,10 +143,6 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
-void DtlsTransport::postCreation() {
-	// Dummy
-}
-
 void DtlsTransport::postHandshake() {
 	// Dummy
 }
@@ -364,9 +364,6 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 		SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
 		SSL_set_tmp_ecdh(mSsl, ecdh.get());
 
-		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
-		registerIncoming();
-
 	} catch (...) {
 		if (mSsl)
 			SSL_free(mSsl);
@@ -383,6 +380,15 @@ DtlsTransport::~DtlsTransport() {
 	SSL_CTX_free(mCtx);
 }
 
+void DtlsTransport::start() {
+	Transport::start();
+
+	registerIncoming();
+
+	PLOG_DEBUG << "Starting DTLS recv thread";
+	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+}
+
 bool DtlsTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -414,10 +420,6 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
-void DtlsTransport::postCreation() {
-	// Dummy
-}
-
 void DtlsTransport::postHandshake() {
 	// Dummy
 }

+ 1 - 1
src/dtlstransport.hpp

@@ -47,12 +47,12 @@ public:
 	              verifier_callback verifierCallback, state_callback stateChangeCallback);
 	~DtlsTransport();
 
+	virtual void start() override;
 	virtual bool stop() override;
 	virtual bool send(message_ptr message) override; // false if dropped
 
 protected:
 	virtual void incoming(message_ptr message) override;
-	virtual void postCreation();
 	virtual void postHandshake();
 	void runRecvLoop();
 

+ 3 - 3
src/peerconnection.cpp

@@ -314,9 +314,9 @@ shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role
 		std::atomic_store(&mIceTransport, transport);
 		if (mState == State::Closed) {
 			mIceTransport.reset();
-			transport->stop();
 			throw std::runtime_error("Connection is closed");
 		}
+		transport->start();
 		return transport;
 
 	} catch (const std::exception &e) {
@@ -379,9 +379,9 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 		std::atomic_store(&mDtlsTransport, transport);
 		if (mState == State::Closed) {
 			mDtlsTransport.reset();
-			transport->stop();
 			throw std::runtime_error("Connection is closed");
 		}
+		transport->start();
 		return transport;
 
 	} catch (const std::exception &e) {
@@ -434,9 +434,9 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 		std::atomic_store(&mSctpTransport, transport);
 		if (mState == State::Closed) {
 			mSctpTransport.reset();
-			transport->stop();
 			throw std::runtime_error("Connection is closed");
 		}
+		transport->start();
 		return transport;
 
 	} catch (const std::exception &e) {

+ 8 - 4
src/sctptransport.cpp

@@ -187,9 +187,6 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 	if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_SNDBUF, &bufferSize, sizeof(bufferSize)))
 		throw std::runtime_error("Could not set SCTP send buffer size, errno=" +
 		                         std::to_string(errno));
-
-	registerIncoming();
-	connect();
 }
 
 SctpTransport::~SctpTransport() {
@@ -203,6 +200,13 @@ SctpTransport::~SctpTransport() {
 	}
 }
 
+void SctpTransport::start() {
+	Transport::start();
+
+	registerIncoming();
+	connect();
+}
+
 bool SctpTransport::stop() {
 	// Transport::stop() will unregister incoming() from the lower layer, therefore we need to make
 	// sure the thread from lower layers is not blocked in incoming() by the WrittenOnce condition.
@@ -230,7 +234,7 @@ void SctpTransport::connect() {
 	if (!mSock)
 		return;
 
-	PLOG_DEBUG << "SCTP connect";
+	PLOG_DEBUG << "SCTP connecting";
 	changeState(State::Connecting);
 
 	struct sockaddr_conn sconn = {};

+ 1 - 0
src/sctptransport.hpp

@@ -46,6 +46,7 @@ public:
 	              amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
 	~SctpTransport();
 
+	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override; // false if buffered
 	void closeStream(unsigned int stream);

+ 7 - 3
src/tcptransport.cpp

@@ -86,11 +86,15 @@ TcpTransport::TcpTransport(const string &hostname, const string &service, state_
     : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
 
 	PLOG_DEBUG << "Initializing TCP transport";
-	mThread = std::thread(&TcpTransport::runLoop, this);
 }
 
-TcpTransport::~TcpTransport() {
-	stop();
+TcpTransport::~TcpTransport() { stop(); }
+
+void TcpTransport::start() {
+	Transport::start();
+
+	PLOG_DEBUG << "Starting TCP recv thread";
+	mThread = std::thread(&TcpTransport::runLoop, this);
 }
 
 bool TcpTransport::stop() {

+ 1 - 0
src/tcptransport.hpp

@@ -56,6 +56,7 @@ public:
 	TcpTransport(const string &hostname, const string &service, state_callback callback);
 	~TcpTransport();
 
+	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override;
 

+ 18 - 18
src/tlstransport.cpp

@@ -71,11 +71,6 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		gnutls_transport_set_pull_function(mSession, ReadCallback);
 		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 
-		postCreation();
-
-		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
-		registerIncoming();
-
 	} catch (...) {
 		gnutls_deinit(mSession);
 		gnutls_certificate_free_credentials(mCreds);
@@ -90,6 +85,15 @@ TlsTransport::~TlsTransport() {
 	gnutls_certificate_free_credentials(mCreds);
 }
 
+void TlsTransport::start() {
+	Transport::start();
+
+	registerIncoming();
+
+	PLOG_DEBUG << "Starting TLS recv thread";
+	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+}
+
 bool TlsTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -124,10 +128,6 @@ void TlsTransport::incoming(message_ptr message) {
 		mIncomingQueue.stop();
 }
 
-void TlsTransport::postCreation() {
-	// Dummy
-}
-
 void TlsTransport::postHandshake() {
 	// Dummy
 }
@@ -309,11 +309,6 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
 		SSL_set_tmp_ecdh(mSsl, ecdh.get());
 
-		postCreation();
-
-		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
-		registerIncoming();
-
 	} catch (...) {
 		if (mSsl)
 			SSL_free(mSsl);
@@ -330,6 +325,15 @@ TlsTransport::~TlsTransport() {
 	SSL_CTX_free(mCtx);
 }
 
+void TlsTransport::start() {
+	Transport::start();
+
+	registerIncoming();
+
+	PLOG_DEBUG << "Starting TLS recv thread";
+	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+}
+
 bool TlsTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -369,10 +373,6 @@ void TlsTransport::incoming(message_ptr message) {
 		mIncomingQueue.stop();
 }
 
-void TlsTransport::postCreation() {
-	// Dummy
-}
-
 void TlsTransport::postHandshake() {
 	// Dummy
 }

+ 1 - 1
src/tlstransport.hpp

@@ -40,12 +40,12 @@ public:
 	TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
 	virtual ~TlsTransport();
 
+	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override;
 
 protected:
 	virtual void incoming(message_ptr message) override;
-	virtual void postCreation();
 	virtual void postHandshake();
 	void runRecvLoop();
 

+ 5 - 5
src/transport.hpp

@@ -39,12 +39,12 @@ public:
 	    : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {
 	}
 
-	virtual ~Transport() {
-		stop();
-	}
+	virtual ~Transport() { stop(); }
+
+	virtual void start() { mStopped = false; }
 
 	virtual bool stop() {
-		if (mShutdown.exchange(true))
+		if (mStopped.exchange(true))
 			return false;
 
 		// We don't want incoming() to be called by the lower layer anymore
@@ -95,7 +95,7 @@ private:
 	synchronized_callback<message_ptr> mRecvCallback;
 
 	std::atomic<State> mState = State::Disconnected;
-	std::atomic<bool> mShutdown = false;
+	std::atomic<bool> mStopped = true;
 };
 
 } // namespace rtc

+ 6 - 25
src/verifiedtlstransport.cpp

@@ -28,43 +28,24 @@ using std::weak_ptr;
 
 namespace rtc {
 
-#if USE_GNUTLS
 
 VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host,
                                            state_callback callback)
-    : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {}
-
-VerifiedTlsTransport::~VerifiedTlsTransport() {}
+    : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {
 
-void VerifiedTlsTransport::postCreation() {
+#if USE_GNUTLS
 	PLOG_DEBUG << "Setting up TLS certificate verification";
 	gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
-}
-
-void VerifiedTlsTransport::postHandshake() {
-	// Nothing to do
-}
-
-#else // USE_GNUTLS==0
-
-VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host,
-                                           state_callback callback)
-    : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {}
-
-VerifiedTlsTransport::~VerifiedTlsTransport() {}
-
-void VerifiedTlsTransport::postCreation() {
+#else
 	PLOG_DEBUG << "Setting up TLS certificate verification";
 	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);
 	SSL_set_verify_depth(mSsl, 4);
+#endif
 }
 
-void VerifiedTlsTransport::postHandshake() {
-	// Nothing to do
-}
-
-#endif
+VerifiedTlsTransport::~VerifiedTlsTransport() {}
 
 } // namespace rtc
 
 #endif
+

+ 0 - 4
src/verifiedtlstransport.hpp

@@ -29,10 +29,6 @@ class VerifiedTlsTransport final : public TlsTransport {
 public:
 	VerifiedTlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
 	~VerifiedTlsTransport();
-
-protected:
-	void postCreation() override;
-	void postHandshake() override;
 };
 
 } // namespace rtc

+ 5 - 3
src/websocket.cpp

@@ -189,10 +189,11 @@ shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
 		std::atomic_store(&mTcpTransport, transport);
 		if (mState == WebSocket::State::Closed) {
 			mTcpTransport.reset();
-			transport->stop();
 			throw std::runtime_error("Connection is closed");
 		}
+		transport->start();
 		return transport;
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		remoteClose();
@@ -245,10 +246,11 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 		std::atomic_store(&mTlsTransport, transport);
 		if (mState == WebSocket::State::Closed) {
 			mTlsTransport.reset();
-			transport->stop();
 			throw std::runtime_error("Connection is closed");
 		}
+		transport->start();
 		return transport;
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		remoteClose();
@@ -295,9 +297,9 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		std::atomic_store(&mWsTransport, transport);
 		if (mState == WebSocket::State::Closed) {
 			mWsTransport.reset();
-			transport->stop();
 			throw std::runtime_error("Connection is closed");
 		}
+		transport->start();
 		return transport;
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();

+ 7 - 2
src/wstransport.cpp

@@ -58,13 +58,17 @@ WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string p
 	onRecv(recvCallback);
 
 	PLOG_DEBUG << "Initializing WebSocket transport";
+}
+
+WsTransport::~WsTransport() { stop(); }
+
+void WsTransport::start() {
+	Transport::start();
 
 	registerIncoming();
 	sendHttpRequest();
 }
 
-WsTransport::~WsTransport() { stop(); }
-
 bool WsTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -143,6 +147,7 @@ void WsTransport::close() {
 }
 
 bool WsTransport::sendHttpRequest() {
+	PLOG_DEBUG << "Sending WebSocket HTTP request";
 	changeState(State::Connecting);
 
 	auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());

+ 1 - 0
src/wstransport.hpp

@@ -35,6 +35,7 @@ public:
 	            message_callback recvCallback, state_callback stateCallback);
 	~WsTransport();
 
+	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override;