Parcourir la source

Merge pull request #28 from paullouisageneau/ice-timeout

Add trickle ICE timeout
Paul-Louis Ageneau il y a 5 ans
Parent
commit
72a0e2fe07

+ 1 - 0
include/rtc/description.hpp

@@ -45,6 +45,7 @@ public:
 	std::optional<string> fingerprint() const;
 	std::optional<string> fingerprint() const;
 	std::optional<uint16_t> sctpPort() const;
 	std::optional<uint16_t> sctpPort() const;
 	std::optional<size_t> maxMessageSize() const;
 	std::optional<size_t> maxMessageSize() const;
+	bool trickleEnabled() const;
 
 
 	void setFingerprint(string fingerprint);
 	void setFingerprint(string fingerprint);
 	void setSctpPort(uint16_t port);
 	void setSctpPort(uint16_t port);

+ 2 - 0
src/description.cpp

@@ -107,6 +107,8 @@ std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; }
 
 
 std::optional<size_t> Description::maxMessageSize() const { return mMaxMessageSize; }
 std::optional<size_t> Description::maxMessageSize() const { return mMaxMessageSize; }
 
 
+bool Description::trickleEnabled() const { return mTrickle; }
+
 void Description::setFingerprint(string fingerprint) {
 void Description::setFingerprint(string fingerprint) {
 	mFingerprint.emplace(std::move(fingerprint));
 	mFingerprint.emplace(std::move(fingerprint));
 }
 }

+ 20 - 7
src/dtlstransport.cpp

@@ -37,8 +37,11 @@ namespace {
 
 
 static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
 static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
 	if (ret < 0) {
 	if (ret < 0) {
-		if (!gnutls_error_is_fatal(ret))
+		if (!gnutls_error_is_fatal(ret)) {
+			PLOG_INFO << gnutls_strerror(ret);
 			return false;
 			return false;
+		}
+		PLOG_ERROR << gnutls_strerror(ret);
 		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
 		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
 	}
 	}
 	return true;
 	return true;
@@ -54,6 +57,9 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
     : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
     : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
       mVerifierCallback(std::move(verifierCallback)),
       mVerifierCallback(std::move(verifierCallback)),
       mStateChangeCallback(std::move(stateChangeCallback)) {
       mStateChangeCallback(std::move(stateChangeCallback)) {
+
+	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
+
 	gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
 	gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
 
 
 	bool active = lower->role() == Description::Role::Active;
 	bool active = lower->role() == Description::Role::Active;
@@ -272,8 +278,10 @@ string openssl_error_string(unsigned long err) {
 bool check_openssl(int success, const string &message = "OpenSSL error") {
 bool check_openssl(int success, const string &message = "OpenSSL error") {
 	if (success)
 	if (success)
 		return true;
 		return true;
-	else
-		throw std::runtime_error(message + ": " + openssl_error_string(ERR_get_error()));
+
+	string str = openssl_error_string(ERR_get_error());
+	PLOG_ERROR << str;
+	throw std::runtime_error(message + ": " + str);
 }
 }
 
 
 bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
 bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
@@ -281,12 +289,16 @@ bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error
 		return true;
 		return true;
 
 
 	unsigned long err = SSL_get_error(ssl, ret);
 	unsigned long err = SSL_get_error(ssl, ret);
-	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
+	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
 		return true;
 		return true;
-	else if (err == SSL_ERROR_ZERO_RETURN)
+	}
+	if (err == SSL_ERROR_ZERO_RETURN) {
+		PLOG_INFO << "The TLS connection has been cleanly closed";
 		return false;
 		return false;
-	else
-		throw std::runtime_error(message + ": " + openssl_error_string(err));
+	}
+	string str = openssl_error_string(err);
+	PLOG_ERROR << str;
+	throw std::runtime_error(message + ": " + str);
 }
 }
 
 
 } // namespace
 } // namespace
@@ -309,6 +321,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
       mVerifierCallback(std::move(verifierCallback)),
       mVerifierCallback(std::move(verifierCallback)),
       mStateChangeCallback(std::move(stateChangeCallback)) {
       mStateChangeCallback(std::move(stateChangeCallback)) {
 
 
+	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 	GlobalInit();
 	GlobalInit();
 
 
 	if (!(mCtx = SSL_CTX_new(DTLS_method())))
 	if (!(mCtx = SSL_CTX_new(DTLS_method())))

+ 34 - 4
src/icetransport.cpp

@@ -23,13 +23,14 @@
 #include <sys/socket.h>
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <sys/types.h>
 
 
-#include <chrono>
 #include <iostream>
 #include <iostream>
 #include <random>
 #include <random>
 #include <sstream>
 #include <sstream>
 
 
 namespace rtc {
 namespace rtc {
 
 
+using namespace std::chrono_literals;
+
 using std::shared_ptr;
 using std::shared_ptr;
 using std::weak_ptr;
 using std::weak_ptr;
 
 
@@ -42,6 +43,8 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
       mStateChangeCallback(std::move(stateChangeCallback)),
       mStateChangeCallback(std::move(stateChangeCallback)),
       mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)) {
       mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)) {
 
 
+	PLOG_DEBUG << "Initializing ICE transport";
+
 	g_log_set_handler("libnice", G_LOG_LEVEL_MASK, LogCallback, this);
 	g_log_set_handler("libnice", G_LOG_LEVEL_MASK, LogCallback, this);
 
 
 	IF_PLOG(plog::verbose) {
 	IF_PLOG(plog::verbose) {
@@ -177,6 +180,10 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
 IceTransport::~IceTransport() { stop(); }
 IceTransport::~IceTransport() { stop(); }
 
 
 void IceTransport::stop() {
 void IceTransport::stop() {
+	if (mTimeoutId) {
+		g_source_remove(mTimeoutId);
+		mTimeoutId = 0;
+	}
 	if (mMainLoopThread.joinable()) {
 	if (mMainLoopThread.joinable()) {
 		g_main_loop_quit(mMainLoop.get());
 		g_main_loop_quit(mMainLoop.get());
 		mMainLoopThread.join();
 		mMainLoopThread.join();
@@ -202,6 +209,7 @@ void IceTransport::setRemoteDescription(const Description &description) {
 	mRole = description.role() == Description::Role::Active ? Description::Role::Passive
 	mRole = description.role() == Description::Role::Active ? Description::Role::Passive
 	                                                        : Description::Role::Active;
 	                                                        : Description::Role::Active;
 	mMid = description.mid();
 	mMid = description.mid();
+	mTrickleTimeout = description.trickleEnabled() ? 30s : 0s;
 
 
 	if (nice_agent_parse_remote_sdp(mNiceAgent.get(), string(description).c_str()) < 0)
 	if (nice_agent_parse_remote_sdp(mNiceAgent.get(), string(description).c_str()) < 0)
 		throw std::runtime_error("Failed to parse remote SDP");
 		throw std::runtime_error("Failed to parse remote SDP");
@@ -278,6 +286,8 @@ void IceTransport::changeState(State state) {
 		mStateChangeCallback(mState);
 		mStateChangeCallback(mState);
 }
 }
 
 
+void IceTransport::processTimeout() { changeState(State::Failed); }
+
 void IceTransport::changeGatheringState(GatheringState state) {
 void IceTransport::changeGatheringState(GatheringState state) {
 	mGatheringState = state;
 	mGatheringState = state;
 	mGatheringStateChangeCallback(mGatheringState);
 	mGatheringStateChangeCallback(mGatheringState);
@@ -290,8 +300,19 @@ void IceTransport::processCandidate(const string &candidate) {
 void IceTransport::processGatheringDone() { changeGatheringState(GatheringState::Complete); }
 void IceTransport::processGatheringDone() { changeGatheringState(GatheringState::Complete); }
 
 
 void IceTransport::processStateChange(uint32_t state) {
 void IceTransport::processStateChange(uint32_t state) {
-	if (state != NICE_COMPONENT_STATE_GATHERING)
-		changeState(static_cast<State>(state));
+	if (state == NICE_COMPONENT_STATE_FAILED && mTrickleTimeout.count() > 0) {
+		if (mTimeoutId)
+			g_source_remove(mTimeoutId);
+		mTimeoutId = g_timeout_add(mTrickleTimeout.count() /* ms */, TimeoutCallback, this);
+		return;
+	}
+
+	if (state == NICE_COMPONENT_STATE_CONNECTED && mTimeoutId) {
+		g_source_remove(mTimeoutId);
+		mTimeoutId = 0;
+	}
+
+	changeState(static_cast<State>(state));
 }
 }
 
 
 string IceTransport::AddressToString(const NiceAddress &addr) {
 string IceTransport::AddressToString(const NiceAddress &addr) {
@@ -344,9 +365,18 @@ void IceTransport::RecvCallback(NiceAgent *agent, guint streamId, guint componen
 	}
 	}
 }
 }
 
 
+gboolean IceTransport::TimeoutCallback(gpointer userData) {
+	auto iceTransport = static_cast<rtc::IceTransport *>(userData);
+	try {
+		iceTransport->processTimeout();
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+	}
+	return FALSE;
+}
+
 void IceTransport::LogCallback(const gchar *logDomain, GLogLevelFlags logLevel,
 void IceTransport::LogCallback(const gchar *logDomain, GLogLevelFlags logLevel,
                                const gchar *message, gpointer userData) {
                                const gchar *message, gpointer userData) {
-
 	plog::Severity severity;
 	plog::Severity severity;
 	unsigned int flags = logLevel & G_LOG_LEVEL_MASK;
 	unsigned int flags = logLevel & G_LOG_LEVEL_MASK;
 	if (flags & G_LOG_LEVEL_ERROR)
 	if (flags & G_LOG_LEVEL_ERROR)

+ 5 - 0
src/icetransport.hpp

@@ -31,6 +31,7 @@ extern "C" {
 }
 }
 
 
 #include <atomic>
 #include <atomic>
+#include <chrono>
 #include <thread>
 #include <thread>
 
 
 namespace rtc {
 namespace rtc {
@@ -81,9 +82,11 @@ private:
 	void processCandidate(const string &candidate);
 	void processCandidate(const string &candidate);
 	void processGatheringDone();
 	void processGatheringDone();
 	void processStateChange(uint32_t state);
 	void processStateChange(uint32_t state);
+	void processTimeout();
 
 
 	Description::Role mRole;
 	Description::Role mRole;
 	string mMid;
 	string mMid;
+	std::chrono::milliseconds mTrickleTimeout;
 	std::atomic<State> mState;
 	std::atomic<State> mState;
 	std::atomic<GatheringState> mGatheringState;
 	std::atomic<GatheringState> mGatheringState;
 
 
@@ -91,6 +94,7 @@ private:
 	std::unique_ptr<NiceAgent, void (*)(gpointer)> mNiceAgent;
 	std::unique_ptr<NiceAgent, void (*)(gpointer)> mNiceAgent;
 	std::unique_ptr<GMainLoop, void (*)(GMainLoop *)> mMainLoop;
 	std::unique_ptr<GMainLoop, void (*)(GMainLoop *)> mMainLoop;
 	std::thread mMainLoopThread;
 	std::thread mMainLoopThread;
+	guint mTimeoutId = 0;
 
 
 	candidate_callback mCandidateCallback;
 	candidate_callback mCandidateCallback;
 	state_callback mStateChangeCallback;
 	state_callback mStateChangeCallback;
@@ -104,6 +108,7 @@ private:
 	                                 guint state, gpointer userData);
 	                                 guint state, gpointer userData);
 	static void RecvCallback(NiceAgent *agent, guint stream_id, guint component_id, guint len,
 	static void RecvCallback(NiceAgent *agent, guint stream_id, guint component_id, guint len,
 	                         gchar *buf, gpointer userData);
 	                         gchar *buf, gpointer userData);
+	static gboolean TimeoutCallback(gpointer userData);
 	static void LogCallback(const gchar *log_domain, GLogLevelFlags log_level, const gchar *message,
 	static void LogCallback(const gchar *log_domain, GLogLevelFlags log_level, const gchar *message,
 	                        gpointer user_data);
 	                        gpointer user_data);
 };
 };

+ 1 - 0
src/sctptransport.cpp

@@ -55,6 +55,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
       mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
       mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
 	onRecv(recvCallback);
 	onRecv(recvCallback);
 
 
+	PLOG_DEBUG << "Initializing SCTP transport";
 	GlobalInit();
 	GlobalInit();
 
 
 	usrsctp_register_address(this);
 	usrsctp_register_address(this);