2
0
Эх сурвалжийг харах

Added separate PeerConnection gathering state

Paul-Louis Ageneau 5 жил өмнө
parent
commit
cd47c31f3f

+ 14 - 4
include/rtc/peerconnection.hpp

@@ -43,8 +43,6 @@ class PeerConnection {
 public:
 	enum class State : int {
 		New = RTC_NEW,
-		Gathering = RTC_GATHERING,
-		Finished = RTC_FINISHED,
 		Connecting = RTC_CONNECTING,
 		Connected = RTC_CONNECTED,
 		Disconnected = RTC_DISCONNECTED,
@@ -52,12 +50,19 @@ public:
 		Closed = RTC_CLOSED
 	};
 
+	enum class GatheringState : int {
+		New = RTC_GATHERING_NEW,
+		InProgress = RTC_GATHERING_INPROGRESS,
+		Complete = RTC_GATHERING_COMPLETE,
+	};
+
 	PeerConnection(void);
 	PeerConnection(const Configuration &config);
 	~PeerConnection();
 
 	const Configuration *config() const;
 	State state() const;
+	GatheringState gatheringState() const;
 	std::optional<Description> localDescription() const;
 	std::optional<Description> remoteDescription() const;
 
@@ -70,7 +75,8 @@ public:
 	void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
 	void onLocalDescription(std::function<void(const Description &description)> callback);
 	void onLocalCandidate(std::function<void(const Candidate &candidate)> callback);
-	void onStateChanged(std::function<void(State state)> callback);
+	void onStateChange(std::function<void(State state)> callback);
+	void onGatheringStateChange(std::function<void(GatheringState state)> callback);
 
 private:
 	void initIceTransport(Description::Role role);
@@ -87,6 +93,7 @@ private:
 	void processLocalCandidate(Candidate candidate);
 	void triggerDataChannel(std::shared_ptr<DataChannel> dataChannel);
 	void changeState(State state);
+	void changeGatheringState(GatheringState state);
 
 	const Configuration mConfig;
 	const std::shared_ptr<Certificate> mCertificate;
@@ -101,15 +108,18 @@ private:
 	std::unordered_map<unsigned int, std::weak_ptr<DataChannel>> mDataChannels;
 
 	std::atomic<State> mState;
+	std::atomic<GatheringState> mGatheringState;
 
 	std::function<void(std::shared_ptr<DataChannel> dataChannel)> mDataChannelCallback;
 	std::function<void(const Description &description)> mLocalDescriptionCallback;
 	std::function<void(const Candidate &candidate)> mLocalCandidateCallback;
-	std::function<void(State state)> mStateChangedCallback;
+	std::function<void(State state)> mStateChangeCallback;
+	std::function<void(GatheringState state)> mGatheringStateChangeCallback;
 };
 
 } // namespace rtc
 
 std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &state);
+std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::GatheringState &state);
 
 #endif

+ 15 - 8
include/rtc/rtc.h

@@ -27,15 +27,19 @@ extern "C" {
 
 typedef enum {
 	RTC_NEW = 0,
-	RTC_GATHERING = 1,
-	RTC_FINISHED = 2,
-	RTC_CONNECTING = 3,
-	RTC_CONNECTED = 4,
-	RTC_DISCONNECTED = 5,
-	RTC_FAILED = 6,
-	RTC_CLOSED = 7
+	RTC_CONNECTING = 1,
+	RTC_CONNECTED = 2,
+	RTC_DISCONNECTED = 3,
+	RTC_FAILED = 4,
+	RTC_CLOSED = 5
 } rtc_state_t;
 
+typedef enum {
+	RTC_GATHERING_NEW = 0,
+	RTC_GATHERING_INPROGRESS = 1,
+	RTC_GATHERING_COMPLETE = 2
+} rtc_gathering_state_t;
+
 int rtcCreatePeerConnection(const char **iceServers, int iceServersCount);
 void rtcDeletePeerConnection(int pc);
 int rtcCreateDataChannel(int pc, const char *label);
@@ -45,7 +49,10 @@ void rtcSetLocalDescriptionCallback(int pc, void (*descriptionCallback)(const ch
                                                                         void *));
 void rtcSetLocalCandidateCallback(int pc,
                                   void (*candidateCallback)(const char *, const char *, void *));
-void rtcSetStateChangedCallback(int pc, void (*stateCallback)(rtc_state_t state, void *));
+void rtcSetStateChangeCallback(int pc, void (*stateCallback)(rtc_state_t state, void *));
+void rtcSetGatheringStateChangeCallback(int pc,
+                                         void (*gatheringStateCallback)(rtc_gathering_state_t state,
+                                                                        void *));
 void rtcSetRemoteDescription(int pc, const char *sdp, const char *type);
 void rtcAddRemoteCandidate(int pc, const char *candidate, const char *mid);
 int rtcGetDataChannelLabel(int dc, char *data, int size);

+ 3 - 3
src/dtlstransport.cpp

@@ -48,10 +48,10 @@ using std::shared_ptr;
 
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
                              verifier_callback verifierCallback,
-                             state_callback stateChangedCallback)
+                             state_callback stateChangeCallback)
     : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
       mVerifierCallback(std::move(verifierCallback)),
-      mStateChangedCallback(std::move(stateChangedCallback)) {
+      mStateChangeCallback(std::move(stateChangeCallback)) {
 	gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
 
 	bool active = lower->role() == Description::Role::Active;
@@ -102,7 +102,7 @@ void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message)
 
 void DtlsTransport::changeState(State state) {
 	mState = state;
-	mStateChangedCallback(state);
+	mStateChangeCallback(state);
 }
 
 void DtlsTransport::runRecvLoop() {

+ 2 - 2
src/dtlstransport.hpp

@@ -44,7 +44,7 @@ public:
 	using state_callback = std::function<void(State state)>;
 
 	DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
-	              verifier_callback verifierCallback, state_callback stateChangedCallback);
+	              verifier_callback verifierCallback, state_callback stateChangeCallback);
 	~DtlsTransport();
 
 	State state() const;
@@ -64,7 +64,7 @@ private:
 	std::thread mRecvThread;
 
 	verifier_callback mVerifierCallback;
-	state_callback mStateChangedCallback;
+	state_callback mStateChangeCallback;
 
 	static int CertificateCallback(gnutls_session_t session);
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);

+ 29 - 16
src/icetransport.cpp

@@ -35,10 +35,13 @@ using std::weak_ptr;
 
 IceTransport::IceTransport(const Configuration &config, Description::Role role,
                            candidate_callback candidateCallback,
-                           state_callback stateChangedCallback)
-    : mRole(role), mMid("0"), mState(State::Disconnected), mNiceAgent(nullptr, nullptr),
-      mMainLoop(nullptr, nullptr), mCandidateCallback(std::move(candidateCallback)),
-      mStateChangedCallback(std::move(stateChangedCallback)) {
+                           state_callback stateChangeCallback,
+                           gathering_state_callback gatheringStateChangeCallback)
+    : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
+      mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr),
+      mCandidateCallback(std::move(candidateCallback)),
+      mStateChangeCallback(std::move(stateChangeCallback)),
+      mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)) {
 
 	auto logLevelFlags = GLogLevelFlags(G_LOG_LEVEL_MASK | G_LOG_FLAG_FATAL | G_LOG_FLAG_RECURSION);
 	g_log_set_handler(nullptr, logLevelFlags, LogCallback, this);
@@ -103,7 +106,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
 	}
 
 	g_signal_connect(G_OBJECT(mNiceAgent.get()), "component-state-changed",
-	                 G_CALLBACK(StateChangedCallback), this);
+	                 G_CALLBACK(StateChangeCallback), this);
 	g_signal_connect(G_OBJECT(mNiceAgent.get()), "new-candidate-full",
 	                 G_CALLBACK(CandidateCallback), this);
 	g_signal_connect(G_OBJECT(mNiceAgent.get()), "candidate-gathering-done",
@@ -147,8 +150,12 @@ void IceTransport::setRemoteDescription(const Description &description) {
 }
 
 void IceTransport::gatherLocalCandidates() {
-	if (!nice_agent_gather_candidates(mNiceAgent.get(), mStreamId))
+	// Change state now as candidates calls can be synchronous
+	changeGatheringState(GatheringState::InProgress);
+
+	if (!nice_agent_gather_candidates(mNiceAgent.get(), mStreamId)) {
 		throw std::runtime_error("Failed to gather local ICE candidates");
+	}
 }
 
 bool IceTransport::addRemoteCandidate(const Candidate &candidate) {
@@ -186,19 +193,25 @@ void IceTransport::outgoing(message_ptr message) {
 	                reinterpret_cast<const char *>(message->data()));
 }
 
+void IceTransport::changeState(State state) {
+	mState = state;
+	mStateChangeCallback(mState);
+}
+
+void IceTransport::changeGatheringState(GatheringState state) {
+	mGatheringState = state;
+	mGatheringStateChangeCallback(mGatheringState);
+}
+
 void IceTransport::processCandidate(const string &candidate) {
 	mCandidateCallback(Candidate(candidate, mMid));
 }
 
-void IceTransport::processGatheringDone() {
-	if (mState == State::Gathering) {
-		mState = State::Finished;
-	}
-}
+void IceTransport::processGatheringDone() { changeGatheringState(GatheringState::Complete); }
 
-void IceTransport::changeState(uint32_t state) {
-	mState = static_cast<State>(state);
-	mStateChangedCallback(mState);
+void IceTransport::processStateChange(uint32_t state) {
+	if (state != NICE_COMPONENT_STATE_GATHERING)
+		changeState(static_cast<State>(state));
 }
 
 void IceTransport::CandidateCallback(NiceAgent *agent, NiceCandidate *candidate,
@@ -222,11 +235,11 @@ void IceTransport::GatheringDoneCallback(NiceAgent *agent, guint streamId, gpoin
 	}
 }
 
-void IceTransport::StateChangedCallback(NiceAgent *agent, guint streamId, guint componentId,
+void IceTransport::StateChangeCallback(NiceAgent *agent, guint streamId, guint componentId,
                                         guint state, gpointer userData) {
 	auto iceTransport = static_cast<rtc::IceTransport *>(userData);
 	try {
-		iceTransport->changeState(state);
+		iceTransport->processStateChange(state);
 	} catch (const std::exception &e) {
 		std::cerr << "ICE change state: " << e.what() << std::endl;
 	}

+ 14 - 6
src/icetransport.hpp

@@ -39,23 +39,26 @@ class IceTransport : public Transport {
 public:
 	enum class State : uint32_t {
 		Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
-		Gathering = NICE_COMPONENT_STATE_GATHERING,
-		Finished = static_cast<uint32_t>(NICE_COMPONENT_STATE_LAST) + 1,
 		Connecting = NICE_COMPONENT_STATE_CONNECTING,
 		Connected = NICE_COMPONENT_STATE_CONNECTED,
 		Ready = NICE_COMPONENT_STATE_READY,
 		Failed = NICE_COMPONENT_STATE_FAILED
 	};
 
+	enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
+
 	using candidate_callback = std::function<void(const Candidate &candidate)>;
 	using state_callback = std::function<void(State state)>;
+	using gathering_state_callback = std::function<void(GatheringState state)>;
 
 	IceTransport(const Configuration &config, Description::Role role,
-	             candidate_callback candidateCallback, state_callback stateChangedCallback);
+	             candidate_callback candidateCallback, state_callback stateChangeCallback,
+	             gathering_state_callback gatheringStateChangeCallback);
 	~IceTransport();
 
 	Description::Role role() const;
 	State state() const;
+	GatheringState gyyatheringState() const;
 	Description getLocalDescription(Description::Type type) const;
 	void setRemoteDescription(const Description &description);
 	void gatherLocalCandidates();
@@ -68,13 +71,17 @@ private:
 	void incoming(const byte *data, int size);
 	void outgoing(message_ptr message);
 
-	void changeState(uint32_t state);
+	void changeState(State state);
+	void changeGatheringState(GatheringState state);
+
 	void processCandidate(const string &candidate);
 	void processGatheringDone();
+	void processStateChange(uint32_t state);
 
 	Description::Role mRole;
 	string mMid;
 	std::atomic<State> mState;
+	std::atomic<GatheringState> mGatheringState;
 
 	uint32_t mStreamId = 0;
 	std::unique_ptr<NiceAgent, void (*)(gpointer)> mNiceAgent;
@@ -82,11 +89,12 @@ private:
 	std::thread mMainLoopThread;
 
 	candidate_callback mCandidateCallback;
-	state_callback mStateChangedCallback;
+	state_callback mStateChangeCallback;
+	gathering_state_callback mGatheringStateChangeCallback;
 
 	static void CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, gpointer userData);
 	static void GatheringDoneCallback(NiceAgent *agent, guint streamId, gpointer userData);
-	static void StateChangedCallback(NiceAgent *agent, guint streamId, guint componentId,
+	static void StateChangeCallback(NiceAgent *agent, guint streamId, guint componentId,
 	                                 guint state, gpointer userData);
 	static void RecvCallback(NiceAgent *agent, guint stream_id, guint component_id, guint len,
 	                         gchar *buf, gpointer userData);

+ 51 - 18
src/peerconnection.cpp

@@ -42,6 +42,8 @@ const Configuration *PeerConnection::config() const { return &mConfig; }
 
 PeerConnection::State PeerConnection::state() const { return mState; }
 
+PeerConnection::GatheringState PeerConnection::gatheringState() const { return mGatheringState; }
+
 std::optional<Description> PeerConnection::localDescription() const { return mLocalDescription; }
 
 std::optional<Description> PeerConnection::remoteDescription() const { return mRemoteDescription; }
@@ -108,8 +110,12 @@ void PeerConnection::onLocalCandidate(std::function<void(const Candidate &candid
 	mLocalCandidateCallback = callback;
 }
 
-void PeerConnection::onStateChanged(std::function<void(State state)> callback) {
-	mStateChangedCallback = callback;
+void PeerConnection::onStateChange(std::function<void(State state)> callback) {
+	mStateChangeCallback = callback;
+}
+
+void PeerConnection::onGatheringStateChange(std::function<void(GatheringState state)> callback) {
+	mGatheringStateChangeCallback = callback;
 }
 
 void PeerConnection::initIceTransport(Description::Role role) {
@@ -117,14 +123,6 @@ void PeerConnection::initIceTransport(Description::Role role) {
 	    mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1),
 	    [this](IceTransport::State state) {
 		    switch (state) {
-		    case IceTransport::State::Gathering:
-			    changeState(State::Gathering);
-			    break;
-		    case IceTransport::State::Finished:
-			    if (mLocalDescription)
-				    mLocalDescription->endCandidates();
-			    changeState(State::Finished);
-			    break;
 		    case IceTransport::State::Connecting:
 			    changeState(State::Connecting);
 			    break;
@@ -138,6 +136,21 @@ void PeerConnection::initIceTransport(Description::Role role) {
 			    // Ignore
 			    break;
 		    }
+	    },
+	    [this](IceTransport::GatheringState state) {
+		    switch (state) {
+		    case IceTransport::GatheringState::InProgress:
+			    changeGatheringState(GatheringState::InProgress);
+			    break;
+		    case IceTransport::GatheringState::Complete:
+			    if (mLocalDescription)
+				    mLocalDescription->endCandidates();
+			    changeGatheringState(GatheringState::Complete);
+			    break;
+		    default:
+			    // Ignore
+			    break;
+		    }
 	    });
 }
 
@@ -276,8 +289,14 @@ void PeerConnection::triggerDataChannel(std::shared_ptr<DataChannel> dataChannel
 
 void PeerConnection::changeState(State state) {
 	mState = state;
-	if (mStateChangedCallback)
-		mStateChangedCallback(state);
+	if (mStateChangeCallback)
+		mStateChangeCallback(state);
+}
+
+void PeerConnection::changeGatheringState(GatheringState state) {
+	mGatheringState = state;
+	if (mGatheringStateChangeCallback)
+		mGatheringStateChangeCallback(state);
 }
 
 } // namespace rtc
@@ -289,12 +308,6 @@ std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &st
 	case State::New:
 		str = "new";
 		break;
-	case State::Gathering:
-		str = "gathering";
-		break;
-	case State::Finished:
-		str = "finished";
-		break;
 	case State::Connecting:
 		str = "connecting";
 		break;
@@ -314,3 +327,23 @@ std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::State &st
 	return out << str;
 }
 
+std::ostream &operator<<(std::ostream &out, const rtc::PeerConnection::GatheringState &state) {
+	using GatheringState = rtc::PeerConnection::GatheringState;
+	std::string str;
+	switch (state) {
+	case GatheringState::New:
+		str = "new";
+		break;
+	case GatheringState::InProgress:
+		str = "in_progress";
+		break;
+	case GatheringState::Complete:
+		str = "complete";
+		break;
+	default:
+		str = "unknown";
+		break;
+	}
+	return out << str;
+}
+

+ 15 - 2
src/rtc.cpp

@@ -101,16 +101,29 @@ void rtcSetLocalCandidateCallback(int pc,
 	});
 }
 
-void rtcSetStateChangedCallback(int pc, void (*stateCallback)(rtc_state_t state, void *)) {
+void rtcSetStateChangeCallback(int pc, void (*stateCallback)(rtc_state_t state, void *)) {
 	auto it = peerConnectionMap.find(pc);
 	if (it == peerConnectionMap.end())
 		return;
 
-	it->second->onStateChanged([pc, stateCallback](PeerConnection::State state) {
+	it->second->onStateChange([pc, stateCallback](PeerConnection::State state) {
 		stateCallback(static_cast<rtc_state_t>(state), getUserPointer(pc));
 	});
 }
 
+void rtcSetGatheringStateChangeCallback(int pc,
+                                         void (*gatheringStateCallback)(rtc_gathering_state_t state,
+                                                                        void *)) {
+	auto it = peerConnectionMap.find(pc);
+	if (it == peerConnectionMap.end())
+		return;
+
+	it->second->onGatheringStateChange(
+	    [pc, gatheringStateCallback](PeerConnection::GatheringState state) {
+		    gatheringStateCallback(static_cast<rtc_gathering_state_t>(state), getUserPointer(pc));
+	    });
+}
+
 void rtcSetRemoteDescription(int pc, const char *sdp, const char *type) {
 	auto it = peerConnectionMap.find(pc);
 	if (it == peerConnectionMap.end())

+ 3 - 3
src/sctptransport.cpp

@@ -48,9 +48,9 @@ void SctpTransport::GlobalCleanup() {
 }
 
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
-                             state_callback stateChangedCallback)
+                             state_callback stateChangeCallback)
     : Transport(lower), mPort(port), mState(State::Disconnected),
-      mStateChangedCallback(std::move(stateChangedCallback)) {
+      mStateChangeCallback(std::move(stateChangeCallback)) {
 
 	onRecv(recv);
 
@@ -227,7 +227,7 @@ void SctpTransport::incoming(message_ptr message) {
 
 void SctpTransport::changeState(State state) {
 	mState = state;
-	mStateChangedCallback(state);
+	mStateChangeCallback(state);
 }
 
 void SctpTransport::runConnect() {

+ 2 - 2
src/sctptransport.hpp

@@ -41,7 +41,7 @@ public:
 	using state_callback = std::function<void(State state)>;
 
 	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv,
-	              state_callback stateChangedCallback);
+	              state_callback stateChangeCallback);
 	~SctpTransport();
 
 	State state() const;
@@ -81,7 +81,7 @@ private:
 
 	std::atomic<State> mState;
 
-	state_callback mStateChangedCallback;
+	state_callback mStateChangeCallback;
 
 	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
 	static int ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,

+ 9 - 3
test/main.cpp

@@ -28,7 +28,7 @@ using namespace std;
 
 int main(int argc, char **argv) {
 	rtc::Configuration config;
-	config.iceServers.emplace_back("stun.l.google.com:19302");
+	// config.iceServers.emplace_back("stun.l.google.com:19302");
 
 	auto pc1 = std::make_shared<PeerConnection>(config);
 	auto pc2 = std::make_shared<PeerConnection>(config);
@@ -43,7 +43,10 @@ int main(int argc, char **argv) {
 		pc2->addRemoteCandidate(candidate);
 	});
 
-	pc1->onStateChanged([](PeerConnection::State state) { cout << "State 1: " << state << endl; });
+	pc1->onStateChange([](PeerConnection::State state) { cout << "State 1: " << state << endl; });
+	pc1->onGatheringStateChange([](PeerConnection::GatheringState state) {
+		cout << "Gathering state 1: " << state << endl;
+	});
 
 	pc2->onLocalDescription([pc1](const Description &sdp) {
 		cout << "Description 2: " << sdp << endl;
@@ -55,7 +58,10 @@ int main(int argc, char **argv) {
 		pc1->addRemoteCandidate(candidate);
 	});
 
-	pc2->onStateChanged([](PeerConnection::State state) { cout << "State 2: " << state << endl; });
+	pc2->onStateChange([](PeerConnection::State state) { cout << "State 2: " << state << endl; });
+	pc2->onGatheringStateChange([](PeerConnection::GatheringState state) {
+		cout << "Gathering state 2: " << state << endl;
+	});
 
 	shared_ptr<DataChannel> dc2;
 	pc2->onDataChannel([&dc2](shared_ptr<DataChannel> dc) {