Przeglądaj źródła

Fixed state callback and revised synchronization and deletion

Paul-Louis Ageneau 5 lat temu
rodzic
commit
e04113f3f1

+ 2 - 0
include/rtc/channel.hpp

@@ -60,6 +60,8 @@ protected:
 	virtual void triggerAvailable(size_t count);
 	virtual void triggerBufferedAmount(size_t amount);
 
+	void resetCallbacks();
+
 private:
 	synchronized_callback<> mOpenCallback;
 	synchronized_callback<> mClosedCallback;

+ 3 - 3
include/rtc/datachannel.hpp

@@ -40,7 +40,7 @@ class DataChannel : public std::enable_shared_from_this<DataChannel>, public Cha
 public:
 	DataChannel(std::weak_ptr<PeerConnection> pc, unsigned int stream, string label,
 	            string protocol, Reliability reliability);
-	DataChannel(std::weak_ptr<PeerConnection> pc, std::shared_ptr<SctpTransport> transport,
+	DataChannel(std::weak_ptr<PeerConnection> pc, std::weak_ptr<SctpTransport> transport,
 	            unsigned int stream);
 	~DataChannel();
 
@@ -65,13 +65,13 @@ public:
 
 private:
 	void remoteClose();
-	void open(std::shared_ptr<SctpTransport> sctpTransport);
+	void open(std::shared_ptr<SctpTransport> transport);
 	bool outgoing(mutable_message_ptr message);
 	void incoming(message_ptr message);
 	void processOpenMessage(message_ptr message);
 
 	const std::weak_ptr<PeerConnection> mPeerConnection;
-	std::shared_ptr<SctpTransport> mSctpTransport;
+	std::weak_ptr<SctpTransport> mSctpTransport;
 
 	unsigned int mStream;
 	string mLabel;

+ 7 - 5
include/rtc/peerconnection.hpp

@@ -52,14 +52,13 @@ public:
 		Connected = RTC_CONNECTED,
 		Disconnected = RTC_DISCONNECTED,
 		Failed = RTC_FAILED,
-		Closed = RTC_CLOSED,
-		Destroying = RTC_DESTROYING
+		Closed = RTC_CLOSED
 	};
 
 	enum class GatheringState : int {
 		New = RTC_GATHERING_NEW,
 		InProgress = RTC_GATHERING_INPROGRESS,
-		Complete = RTC_GATHERING_COMPLETE,
+		Complete = RTC_GATHERING_COMPLETE
 	};
 
 	PeerConnection(void);
@@ -94,6 +93,7 @@ private:
 	std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
 	std::shared_ptr<DtlsTransport> initDtlsTransport();
 	std::shared_ptr<SctpTransport> initSctpTransport();
+	void closeTransports();
 
 	void endLocalCandidates();
 	bool checkFingerprint(const std::string &fingerprint) const;
@@ -112,8 +112,10 @@ private:
 	void processLocalDescription(Description description);
 	void processLocalCandidate(Candidate candidate);
 	void triggerDataChannel(std::weak_ptr<DataChannel> weakDataChannel);
-	void changeState(State state);
-	void changeGatheringState(GatheringState state);
+	bool changeState(State state);
+	bool changeGatheringState(GatheringState state);
+
+	void resetCallbacks();
 
 	const Configuration mConfig;
 	const std::shared_ptr<Certificate> mCertificate;

+ 1 - 2
include/rtc/rtc.h

@@ -31,8 +31,7 @@ typedef enum {
 	RTC_CONNECTED = 2,
 	RTC_DISCONNECTED = 3,
 	RTC_FAILED = 4,
-	RTC_CLOSED = 5,
-	RTC_DESTROYING = 6 // internal
+	RTC_CLOSED = 5
 } rtcState;
 
 typedef enum {

+ 9 - 0
src/channel.cpp

@@ -88,5 +88,14 @@ void Channel::triggerBufferedAmount(size_t amount) {
 		mBufferedAmountLowCallback();
 }
 
+void Channel::resetCallbacks() {
+	mOpenCallback = nullptr;
+	mClosedCallback = nullptr;
+	mErrorCallback = nullptr;
+	mMessageCallback = nullptr;
+	mAvailableCallback = nullptr;
+	mBufferedAmountLowCallback = nullptr;
+}
+
 } // namespace rtc
 

+ 20 - 9
src/datachannel.cpp

@@ -74,7 +74,7 @@ DataChannel::DataChannel(weak_ptr<PeerConnection> pc, unsigned int stream, strin
       mReliability(std::make_shared<Reliability>(std::move(reliability))),
       mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
 
-DataChannel::DataChannel(weak_ptr<PeerConnection> pc, shared_ptr<SctpTransport> transport,
+DataChannel::DataChannel(weak_ptr<PeerConnection> pc, weak_ptr<SctpTransport> transport,
                          unsigned int stream)
     : mPeerConnection(pc), mSctpTransport(transport), mStream(stream),
       mReliability(std::make_shared<Reliability>()),
@@ -93,10 +93,13 @@ string DataChannel::protocol() const { return mProtocol; }
 Reliability DataChannel::reliability() const { return *mReliability; }
 
 void DataChannel::close() {
-	if (mIsOpen.exchange(false) && mSctpTransport)
-		mSctpTransport->reset(mStream);
+	if (mIsOpen.exchange(false))
+		if (auto transport = mSctpTransport.lock())
+			transport->reset(mStream);
 	mIsClosed = true;
 	mSctpTransport.reset();
+
+	resetCallbacks();
 }
 
 void DataChannel::remoteClose() {
@@ -158,8 +161,8 @@ size_t DataChannel::maxMessageSize() const {
 
 size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 
-void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
-	mSctpTransport = sctpTransport;
+void DataChannel::open(shared_ptr<SctpTransport> transport) {
+	mSctpTransport = transport;
 
 	uint8_t channelType = static_cast<uint8_t>(mReliability->type);
 	if (mReliability->unordered)
@@ -186,20 +189,24 @@ void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
 	std::copy(mLabel.begin(), mLabel.end(), end);
 	std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size());
 
-	mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
+	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 }
 
 bool DataChannel::outgoing(mutable_message_ptr message) {
-	if (mIsClosed || !mSctpTransport)
+	if (mIsClosed)
 		throw std::runtime_error("DataChannel is closed");
 
 	if (message->size() > maxMessageSize())
 		throw std::runtime_error("Message size exceeds limit");
 
+	auto transport = mSctpTransport.lock();
+	if (!transport)
+		throw std::runtime_error("DataChannel has no transport");
+
 	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
 	message->reliability = mIsOpen ? mReliability : nullptr;
 	message->stream = mStream;
-	return mSctpTransport->send(message);
+	return transport->send(message);
 }
 
 void DataChannel::incoming(message_ptr message) {
@@ -238,6 +245,10 @@ void DataChannel::incoming(message_ptr message) {
 }
 
 void DataChannel::processOpenMessage(message_ptr message) {
+	auto transport = mSctpTransport.lock();
+	if (!transport)
+		throw std::runtime_error("DataChannel has no transport");
+
 	if (message->size() < sizeof(OpenMessage))
 		throw std::invalid_argument("DataChannel open message too small");
 
@@ -274,7 +285,7 @@ void DataChannel::processOpenMessage(message_ptr message) {
 	auto &ack = *reinterpret_cast<AckMessage *>(buffer.data());
 	ack.type = MESSAGE_ACK;
 
-	mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
+	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 
 	mIsOpen = true;
 	triggerOpen();

+ 71 - 28
src/peerconnection.cpp

@@ -24,6 +24,7 @@
 #include "sctptransport.hpp"
 
 #include <iostream>
+#include <thread>
 
 namespace rtc {
 
@@ -38,28 +39,11 @@ PeerConnection::PeerConnection() : PeerConnection(Configuration()) {
 PeerConnection::PeerConnection(const Configuration &config)
     : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
 
-PeerConnection::~PeerConnection() {
-	changeState(State::Destroying);
-	close();
-	mSctpTransport.reset();
-	mDtlsTransport.reset();
-	mIceTransport.reset();
-}
+PeerConnection::~PeerConnection() { close(); }
 
 void PeerConnection::close() {
-	// Close DataChannels
 	closeDataChannels();
-
-	// Close Transports
-	for (int i = 0; i < 2; ++i) { // Make sure a transport wasn't spawn behind our back
-		if (auto transport = std::atomic_load(&mSctpTransport))
-			transport->stop();
-		if (auto transport = std::atomic_load(&mDtlsTransport))
-			transport->stop();
-		if (auto transport = std::atomic_load(&mIceTransport))
-			transport->stop();
-	}
-	changeState(State::Closed);
+	closeTransports();
 }
 
 const Configuration *PeerConnection::config() const { return &mConfig; }
@@ -241,8 +225,15 @@ shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role
 				    break;
 			    }
 		    });
+
 		std::atomic_store(&mIceTransport, transport);
+		if (mState == State::Closed) {
+			mIceTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
 		return transport;
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		changeState(State::Failed);
@@ -274,8 +265,15 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 				    break;
 			    }
 		    });
+
 		std::atomic_store(&mDtlsTransport, transport);
+		if (mState == State::Closed) {
+			mDtlsTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
 		return transport;
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		changeState(State::Failed);
@@ -312,8 +310,15 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 				    break;
 			    }
 		    });
+
 		std::atomic_store(&mSctpTransport, transport);
+		if (mState == State::Closed) {
+			mSctpTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
 		return transport;
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		changeState(State::Failed);
@@ -321,6 +326,34 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 	}
 }
 
+void PeerConnection::closeTransports() {
+	// Change state to sink state Closed to block init methods
+	changeState(State::Closed);
+
+	// Reset callbacks now that state is changed
+	resetCallbacks();
+
+	// Pass the references to a thread, allowing to terminate a transport from its own thread
+	auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr));
+	auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
+	auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr));
+	if (sctp || dtls || ice) {
+		std::thread t([sctp, dtls, ice]() mutable {
+			if (sctp)
+				sctp->stop();
+			if (dtls)
+				dtls->stop();
+			if (ice)
+				ice->stop();
+
+			sctp.reset();
+			dtls.reset();
+			ice.reset();
+		});
+		t.detach();
+	}
+}
+
 void PeerConnection::endLocalCandidates() {
 	std::lock_guard lock(mLocalDescriptionMutex);
 	if (mLocalDescription)
@@ -467,21 +500,34 @@ void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
 	mDataChannelCallback(dataChannel);
 }
 
-void PeerConnection::changeState(State state) {
+bool PeerConnection::changeState(State state) {
 	State current;
 	do {
 		current = mState.load();
-		if (current == state || current == State::Destroying)
-			return;
+		if (current == state)
+			return true;
+		if (current == State::Closed)
+			return false;
+
 	} while (!mState.compare_exchange_weak(current, state));
 
-	if (state != State::Destroying)
-		mStateChangeCallback(state);
+	mStateChangeCallback(state);
+	return true;
 }
 
-void PeerConnection::changeGatheringState(GatheringState state) {
+bool PeerConnection::changeGatheringState(GatheringState state) {
 	if (mGatheringState.exchange(state) != state)
 		mGatheringStateChangeCallback(state);
+	return true;
+}
+
+void PeerConnection::resetCallbacks() {
+	// Unregister all callbacks
+	mDataChannelCallback = nullptr;
+	mLocalDescriptionCallback = nullptr;
+	mLocalCandidateCallback = nullptr;
+	mStateChangeCallback = nullptr;
+	mGatheringStateChangeCallback = nullptr;
 }
 
 } // namespace rtc
@@ -508,9 +554,6 @@ std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &st
 	case State::Closed:
 		str = "closed";
 		break;
-	case State::Destroying:
-		str = "destroying";
-		break;
 	default:
 		str = "unknown";
 		break;

+ 9 - 6
src/rtc.cpp

@@ -53,6 +53,14 @@ void *getUserPointer(int id) {
 	return it != userPointerMap.end() ? it->second : nullptr;
 }
 
+void setUserPointer(int i, void *ptr) {
+	std::lock_guard lock(mutex);
+	if (ptr)
+		userPointerMap.insert(std::make_pair(i, ptr));
+	else
+		userPointerMap.erase(i);
+}
+
 shared_ptr<PeerConnection> getPeerConnection(int id) {
 	std::lock_guard lock(mutex);
 	auto it = peerConnectionMap.find(id);
@@ -99,12 +107,7 @@ bool eraseDataChannel(int dc) {
 
 void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast<LogLevel>(level)); }
 
-void rtcSetUserPointer(int i, void *ptr) {
-	if (ptr)
-		userPointerMap.insert(std::make_pair(i, ptr));
-	else
-		userPointerMap.erase(i);
-}
+void rtcSetUserPointer(int i, void *ptr) { setUserPointer(i, ptr); }
 
 int rtcCreatePeerConnection(const rtcConfiguration *config) {
 	Configuration c;

+ 16 - 7
test/capi.cpp

@@ -157,6 +157,16 @@ int test_capi_main() {
 
 	sleep(3);
 
+	if (peer1->state != RTC_CONNECTED || peer2->state != RTC_CONNECTED) {
+		fprintf(stderr, "PeerConnection is not connected\n");
+		goto error;
+	}
+
+	if (!peer1->connected || !peer2->connected) {
+		fprintf(stderr, "DataChannel is not connected\n");
+		goto error;
+	}
+
 	char buffer[256];
 	if (rtcGetLocalAddress(peer1->pc, buffer, 256) >= 0)
 		printf("Local address 1:  %s\n", buffer);
@@ -167,13 +177,12 @@ int test_capi_main() {
 	if (rtcGetRemoteAddress(peer2->pc, buffer, 256) >= 0)
 		printf("Remote address 2: %s\n", buffer);
 
-	if (peer1->connected && peer2->connected) {
-		deletePeer(peer1);
-		deletePeer(peer2);
-		sleep(1);
-		printf("Success\n");
-		return 0;
-	}
+	deletePeer(peer1);
+	sleep(1);
+	deletePeer(peer2);
+
+	printf("Success\n");
+	return 0;
 
 error:
 	deletePeer(peer1);

+ 9 - 5
test/connectivity.cpp

@@ -108,6 +108,13 @@ void test_connectivity() {
 
 	this_thread::sleep_for(3s);
 
+	if (pc1->state() != PeerConnection::State::Connected &&
+	    pc2->state() != PeerConnection::State::Connected)
+		throw runtime_error("PeerConnection is not connected");
+
+	if (!dc1->isOpen() || !dc2->isOpen())
+		throw runtime_error("DataChannel is not open");
+
 	if (auto addr = pc1->localAddress())
 		cout << "Local address 1:  " << *addr << endl;
 	if (auto addr = pc1->remoteAddress())
@@ -117,13 +124,10 @@ void test_connectivity() {
 	if (auto addr = pc2->remoteAddress())
 		cout << "Remote address 2: " << *addr << endl;
 
-	if (!dc1->isOpen() || !dc2->isOpen())
-		throw runtime_error("DataChannel is not open");
-
+	// Delay close of peer 2 to check closing works properly
 	pc1->close();
-	pc2->close();
-
 	this_thread::sleep_for(1s);
+	pc2->close();
 
 	cout << "Success" << endl;
 }