Browse Source

Merge pull request #70 from paullouisageneau/fix-incoming-race

Fix possible race condition on Transport::incoming()
Paul-Louis Ageneau 5 years ago
parent
commit
86b9bace53
5 changed files with 21 additions and 30 deletions
  1. 2 0
      src/dtlstransport.cpp
  2. 9 21
      src/icetransport.cpp
  3. 1 3
      src/icetransport.hpp
  4. 1 0
      src/sctptransport.cpp
  5. 8 6
      src/transport.hpp

+ 2 - 0
src/dtlstransport.cpp

@@ -100,6 +100,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 	gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 
 	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+	registerIncoming();
 }
 
 DtlsTransport::~DtlsTransport() {
@@ -410,6 +411,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 	SSL_set_tmp_ecdh(mSsl, ecdh.get());
 
 	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+	registerIncoming();
 }
 
 DtlsTransport::~DtlsTransport() {

+ 9 - 21
src/icetransport.cpp

@@ -103,7 +103,6 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
 IceTransport::~IceTransport() { stop(); }
 
 bool IceTransport::stop() {
-	onRecv(nullptr);
 	return Transport::stop();
 }
 
@@ -169,15 +168,6 @@ bool IceTransport::send(message_ptr message) {
 	return outgoing(message);
 }
 
-void IceTransport::incoming(message_ptr message) {
-	PLOG_VERBOSE << "Incoming size=" << message->size();
-	recv(message);
-}
-
-void IceTransport::incoming(const byte *data, int size) {
-	incoming(make_message(data, data + size));
-}
-
 bool IceTransport::outgoing(message_ptr message) {
 	return juice_send(mAgent.get(), reinterpret_cast<const char *>(message->data()),
 	                  message->size()) >= 0;
@@ -234,7 +224,9 @@ void IceTransport::RecvCallback(juice_agent_t *agent, const char *data, size_t s
                                 void *user_ptr) {
 	auto iceTransport = static_cast<rtc::IceTransport *>(user_ptr);
 	try {
-		iceTransport->incoming(reinterpret_cast<const byte *>(data), size);
+		PLOG_VERBOSE << "Incoming size=" << size;
+		auto b = reinterpret_cast<const byte *>(data);
+		iceTransport->incoming(make_message(b, b + size));
 	} catch (const std::exception &e) {
 		PLOG_WARNING << e.what();
 	}
@@ -455,6 +447,9 @@ bool IceTransport::stop() {
 		return false;
 
 	PLOG_DEBUG << "Stopping ICE thread";
+	nice_agent_attach_recv(mNiceAgent.get(), mStreamId, 1, g_main_loop_get_context(mMainLoop.get()),
+	                       NULL, NULL);
+	nice_agent_remove_stream(mNiceAgent.get(), mStreamId);
 	g_main_loop_quit(mMainLoop.get());
 	mMainLoopThread.join();
 	return true;
@@ -541,15 +536,6 @@ bool IceTransport::send(message_ptr message) {
 	return outgoing(message);
 }
 
-void IceTransport::incoming(message_ptr message) {
-	PLOG_VERBOSE << "Incoming size=" << message->size();
-	recv(message);
-}
-
-void IceTransport::incoming(const byte *data, int size) {
-	incoming(make_message(data, data + size));
-}
-
 bool IceTransport::outgoing(message_ptr message) {
 	return nice_agent_send(mNiceAgent.get(), mStreamId, 1, message->size(),
 	                       reinterpret_cast<const char *>(message->data())) >= 0;
@@ -637,7 +623,9 @@ void IceTransport::RecvCallback(NiceAgent *agent, guint streamId, guint componen
                                 gchar *buf, gpointer userData) {
 	auto iceTransport = static_cast<rtc::IceTransport *>(userData);
 	try {
-		iceTransport->incoming(reinterpret_cast<byte *>(buf), len);
+		PLOG_VERBOSE << "Incoming size=" << len;
+		auto b = reinterpret_cast<byte *>(buf);
+		iceTransport->incoming(make_message(b, b + len));
 	} catch (const std::exception &e) {
 		PLOG_WARNING << e.what();
 	}

+ 1 - 3
src/icetransport.hpp

@@ -37,7 +37,7 @@
 #include <thread>
 
 namespace rtc {
-	
+
 class IceTransport : public Transport {
 public:
 #if USE_JUICE
@@ -85,8 +85,6 @@ public:
 	bool send(message_ptr message) override; // false if dropped
 
 private:
-	void incoming(message_ptr message) override;
-	void incoming(const byte *data, int size);
 	bool outgoing(message_ptr message) override;
 
 	void changeState(State state);

+ 1 - 0
src/sctptransport.cpp

@@ -163,6 +163,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 		throw std::runtime_error("Could not set SCTP send buffer size, errno=" +
 		                         std::to_string(errno));
 
+	registerIncoming();
 	connect();
 }
 

+ 8 - 6
src/transport.hpp

@@ -32,10 +32,7 @@ using namespace std::placeholders;
 
 class Transport {
 public:
-	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {
-		if (mLower)
-			mLower->onRecv(std::bind(&Transport::incoming, this, _1));
-	}
+	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {}
 	virtual ~Transport() { stop(); }
 
 	virtual bool stop() {
@@ -44,14 +41,19 @@ public:
 		return !mShutdown.exchange(true);
 	}
 
-	virtual bool send(message_ptr message) = 0;
+	void registerIncoming() {
+		if (mLower)
+			mLower->onRecv(std::bind(&Transport::incoming, this, _1));
+	}
 
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
 
+	virtual bool send(message_ptr message) { return outgoing(message); }
+
 protected:
 	void recv(message_ptr message) { mRecvCallback(message); }
 
-	virtual void incoming(message_ptr message) = 0;
+	virtual void incoming(message_ptr message) { recv(message); }
 	virtual bool outgoing(message_ptr message) {
 		if (mLower)
 			return mLower->send(message);