Selaa lähdekoodia

Merge pull request #105453 from reduz/signals-thread-safe

Add thread safety to Object signals
Thaddeus Crews 4 kuukautta sitten
vanhempi
commit
c5c1cd4440
4 muutettua tiedostoa jossa 124 lisäystä ja 56 poistoa
  1. 117 55
      core/object/object.cpp
  2. 4 1
      core/object/object.h
  3. 1 0
      doc/classes/Object.xml
  4. 2 0
      scene/main/node.h

+ 117 - 55
core/object/object.cpp

@@ -65,6 +65,23 @@ struct _ObjectDebugLock {
 
 #endif
 
+struct _ObjectSignalLock {
+	Mutex *mutex;
+	_ObjectSignalLock(const Object *const p_obj) {
+		mutex = p_obj->signal_mutex;
+		if (mutex) {
+			mutex->lock();
+		}
+	}
+	~_ObjectSignalLock() {
+		if (mutex) {
+			mutex->unlock();
+		}
+	}
+};
+
+#define OBJ_SIGNAL_LOCK _ObjectSignalLock _signal_lock(this);
+
 PropertyInfo::operator Dictionary() const {
 	Dictionary d;
 	d["name"] = name;
@@ -255,6 +272,9 @@ void Object::_initialize() {
 }
 
 void Object::_postinitialize() {
+	if (_uses_signal_mutex()) {
+		signal_mutex = memnew(Mutex);
+	}
 	notification(NOTIFICATION_POSTINITIALIZE);
 }
 
@@ -1123,6 +1143,9 @@ void Object::get_meta_list(List<StringName> *p_list) const {
 void Object::add_user_signal(const MethodInfo &p_signal) {
 	ERR_FAIL_COND_MSG(p_signal.name.is_empty(), "Signal name cannot be empty.");
 	ERR_FAIL_COND_MSG(ClassDB::has_signal(get_class_name(), p_signal.name), vformat("User signal's name conflicts with a built-in signal of '%s'.", get_class_name()));
+
+	OBJ_SIGNAL_LOCK
+
 	ERR_FAIL_COND_MSG(signal_map.has(p_signal.name), vformat("Trying to add already existing signal '%s'.", p_signal.name));
 	SignalData s;
 	s.user = p_signal;
@@ -1130,6 +1153,8 @@ void Object::add_user_signal(const MethodInfo &p_signal) {
 }
 
 bool Object::_has_user_signal(const StringName &p_name) const {
+	OBJ_SIGNAL_LOCK
+
 	if (!signal_map.has(p_name)) {
 		return false;
 	}
@@ -1137,6 +1162,8 @@ bool Object::_has_user_signal(const StringName &p_name) const {
 }
 
 void Object::_remove_user_signal(const StringName &p_name) {
+	OBJ_SIGNAL_LOCK
+
 	SignalData *s = signal_map.getptr(p_name);
 	ERR_FAIL_NULL_MSG(s, "Provided signal does not exist.");
 	ERR_FAIL_COND_MSG(!s->removable, "Signal is not removable (not added with add_user_signal).");
@@ -1183,46 +1210,53 @@ Error Object::emit_signalp(const StringName &p_name, const Variant **p_args, int
 		return ERR_CANT_ACQUIRE_RESOURCE; //no emit, signals blocked
 	}
 
-	SignalData *s = signal_map.getptr(p_name);
-	if (!s) {
+	Callable *slot_callables = nullptr;
+	uint32_t *slot_flags = nullptr;
+	uint32_t slot_count = 0;
+
+	{
+		OBJ_SIGNAL_LOCK
+
+		SignalData *s = signal_map.getptr(p_name);
+		if (!s) {
 #ifdef DEBUG_ENABLED
-		bool signal_is_valid = ClassDB::has_signal(get_class_name(), p_name);
-		//check in script
-		ERR_FAIL_COND_V_MSG(!signal_is_valid && !script.is_null() && !Ref<Script>(script)->has_script_signal(p_name), ERR_UNAVAILABLE, vformat("Can't emit non-existing signal \"%s\".", p_name));
+			bool signal_is_valid = ClassDB::has_signal(get_class_name(), p_name);
+			//check in script
+			ERR_FAIL_COND_V_MSG(!signal_is_valid && !script.is_null() && !Ref<Script>(script)->has_script_signal(p_name), ERR_UNAVAILABLE, vformat("Can't emit non-existing signal \"%s\".", p_name));
 #endif
-		//not connected? just return
-		return ERR_UNAVAILABLE;
-	}
+			//not connected? just return
+			return ERR_UNAVAILABLE;
+		}
 
-	// If this is a ref-counted object, prevent it from being destroyed during signal emission,
-	// which is needed in certain edge cases; e.g., https://github.com/godotengine/godot/issues/73889.
-	Ref<RefCounted> rc = Ref<RefCounted>(Object::cast_to<RefCounted>(this));
+		// If this is a ref-counted object, prevent it from being destroyed during signal emission,
+		// which is needed in certain edge cases; e.g., https://github.com/godotengine/godot/issues/73889.
+		Ref<RefCounted> rc = Ref<RefCounted>(Object::cast_to<RefCounted>(this));
 
-	// Ensure that disconnecting the signal or even deleting the object
-	// will not affect the signal calling.
-	Callable *slot_callables = (Callable *)alloca(sizeof(Callable) * s->slot_map.size());
-	uint32_t *slot_flags = (uint32_t *)alloca(sizeof(uint32_t) * s->slot_map.size());
-	uint32_t slot_count = 0;
+		// Ensure that disconnecting the signal or even deleting the object
+		// will not affect the signal calling.
+		slot_callables = (Callable *)alloca(sizeof(Callable) * s->slot_map.size());
+		slot_flags = (uint32_t *)alloca(sizeof(uint32_t) * s->slot_map.size());
 
-	for (const KeyValue<Callable, SignalData::Slot> &slot_kv : s->slot_map) {
-		memnew_placement(&slot_callables[slot_count], Callable(slot_kv.value.conn.callable));
-		slot_flags[slot_count] = slot_kv.value.conn.flags;
-		++slot_count;
-	}
+		for (const KeyValue<Callable, SignalData::Slot> &slot_kv : s->slot_map) {
+			memnew_placement(&slot_callables[slot_count], Callable(slot_kv.value.conn.callable));
+			slot_flags[slot_count] = slot_kv.value.conn.flags;
+			++slot_count;
+		}
 
-	DEV_ASSERT(slot_count == s->slot_map.size());
+		DEV_ASSERT(slot_count == s->slot_map.size());
 
-	// Disconnect all one-shot connections before emitting to prevent recursion.
-	for (uint32_t i = 0; i < slot_count; ++i) {
-		bool disconnect = slot_flags[i] & CONNECT_ONE_SHOT;
+		// Disconnect all one-shot connections before emitting to prevent recursion.
+		for (uint32_t i = 0; i < slot_count; ++i) {
+			bool disconnect = slot_flags[i] & CONNECT_ONE_SHOT;
 #ifdef TOOLS_ENABLED
-		if (disconnect && (slot_flags[i] & CONNECT_PERSIST) && Engine::get_singleton()->is_editor_hint()) {
-			// This signal was connected from the editor, and is being edited. Just don't disconnect for now.
-			disconnect = false;
-		}
+			if (disconnect && (slot_flags[i] & CONNECT_PERSIST) && Engine::get_singleton()->is_editor_hint()) {
+				// This signal was connected from the editor, and is being edited. Just don't disconnect for now.
+				disconnect = false;
+			}
 #endif
-		if (disconnect) {
-			_disconnect(p_name, slot_callables[i]);
+			if (disconnect) {
+				_disconnect(p_name, slot_callables[i]);
+			}
 		}
 	}
 
@@ -1280,6 +1314,8 @@ void Object::_add_user_signal(const String &p_name, const Array &p_args) {
 	// without access to ADD_SIGNAL in bind_methods
 	// added events are per instance, as opposed to the other ones, which are global
 
+	OBJ_SIGNAL_LOCK
+
 	MethodInfo mi;
 	mi.name = p_name;
 
@@ -1360,6 +1396,8 @@ bool Object::has_signal(const StringName &p_name) const {
 }
 
 void Object::get_signal_list(List<MethodInfo> *p_signals) const {
+	OBJ_SIGNAL_LOCK
+
 	if (!script.is_null()) {
 		Ref<Script> scr = script;
 		if (scr.is_valid()) {
@@ -1379,6 +1417,8 @@ void Object::get_signal_list(List<MethodInfo> *p_signals) const {
 }
 
 void Object::get_all_signal_connections(List<Connection> *p_connections) const {
+	OBJ_SIGNAL_LOCK
+
 	for (const KeyValue<StringName, SignalData> &E : signal_map) {
 		const SignalData *s = &E.value;
 
@@ -1389,6 +1429,8 @@ void Object::get_all_signal_connections(List<Connection> *p_connections) const {
 }
 
 void Object::get_signal_connection_list(const StringName &p_signal, List<Connection> *p_connections) const {
+	OBJ_SIGNAL_LOCK
+
 	const SignalData *s = signal_map.getptr(p_signal);
 	if (!s) {
 		return; //nothing
@@ -1400,6 +1442,7 @@ void Object::get_signal_connection_list(const StringName &p_signal, List<Connect
 }
 
 int Object::get_persistent_signal_connection_count() const {
+	OBJ_SIGNAL_LOCK
 	int count = 0;
 
 	for (const KeyValue<StringName, SignalData> &E : signal_map) {
@@ -1416,6 +1459,8 @@ int Object::get_persistent_signal_connection_count() const {
 }
 
 void Object::get_signals_connected_to_this(List<Connection> *p_connections) const {
+	OBJ_SIGNAL_LOCK
+
 	for (const Connection &E : connections) {
 		p_connections->push_back(E);
 	}
@@ -1423,6 +1468,7 @@ void Object::get_signals_connected_to_this(List<Connection> *p_connections) cons
 
 Error Object::connect(const StringName &p_signal, const Callable &p_callable, uint32_t p_flags) {
 	ERR_FAIL_COND_V_MSG(p_callable.is_null(), ERR_INVALID_PARAMETER, vformat("Cannot connect to '%s': the provided callable is null.", p_signal));
+	OBJ_SIGNAL_LOCK
 
 	if (p_callable.is_standard()) {
 		// FIXME: This branch should probably removed in favor of the `is_valid()` branch, but there exist some classes
@@ -1491,6 +1537,8 @@ Error Object::connect(const StringName &p_signal, const Callable &p_callable, ui
 
 bool Object::is_connected(const StringName &p_signal, const Callable &p_callable) const {
 	ERR_FAIL_COND_V_MSG(p_callable.is_null(), false, vformat("Cannot determine if connected to '%s': the provided callable is null.", p_signal)); // Should use `is_null`, see note in `connect` about the use of `is_valid`.
+	OBJ_SIGNAL_LOCK
+
 	const SignalData *s = signal_map.getptr(p_signal);
 	if (!s) {
 		bool signal_is_valid = ClassDB::has_signal(get_class_name(), p_signal);
@@ -1509,6 +1557,8 @@ bool Object::is_connected(const StringName &p_signal, const Callable &p_callable
 }
 
 bool Object::has_connections(const StringName &p_signal) const {
+	OBJ_SIGNAL_LOCK
+
 	const SignalData *s = signal_map.getptr(p_signal);
 	if (!s) {
 		bool signal_is_valid = ClassDB::has_signal(get_class_name(), p_signal);
@@ -1532,6 +1582,7 @@ void Object::disconnect(const StringName &p_signal, const Callable &p_callable)
 
 bool Object::_disconnect(const StringName &p_signal, const Callable &p_callable, bool p_force) {
 	ERR_FAIL_COND_V_MSG(p_callable.is_null(), false, vformat("Cannot disconnect from '%s': the provided callable is null.", p_signal)); // Should use `is_null`, see note in `connect` about the use of `is_valid`.
+	OBJ_SIGNAL_LOCK
 
 	SignalData *s = signal_map.getptr(p_signal);
 	if (!s) {
@@ -1569,6 +1620,10 @@ bool Object::_disconnect(const StringName &p_signal, const Callable &p_callable,
 	return true;
 }
 
+bool Object::_uses_signal_mutex() const {
+	return true;
+}
+
 void Object::_set_bind(const StringName &p_set, const Variant &p_value) {
 	set(p_set, p_value);
 }
@@ -2204,33 +2259,36 @@ Object::~Object() {
 		ERR_PRINT(vformat("Object '%s' was freed or unreferenced while a signal is being emitted from it. Try connecting to the signal using 'CONNECT_DEFERRED' flag, or use queue_free() to free the object (if this object is a Node) to avoid this error and potential crashes.", to_string()));
 	}
 
-	// Drop all connections to the signals of this object.
-	while (signal_map.size()) {
-		// Avoid regular iteration so erasing is safe.
-		KeyValue<StringName, SignalData> &E = *signal_map.begin();
-		SignalData *s = &E.value;
-
-		for (const KeyValue<Callable, SignalData::Slot> &slot_kv : s->slot_map) {
-			Object *target = slot_kv.value.conn.callable.get_object();
-			if (likely(target)) {
-				target->connections.erase(slot_kv.value.cE);
+	{
+		OBJ_SIGNAL_LOCK
+		// Drop all connections to the signals of this object.
+		while (signal_map.size()) {
+			// Avoid regular iteration so erasing is safe.
+			KeyValue<StringName, SignalData> &E = *signal_map.begin();
+			SignalData *s = &E.value;
+
+			for (const KeyValue<Callable, SignalData::Slot> &slot_kv : s->slot_map) {
+				Object *target = slot_kv.value.conn.callable.get_object();
+				if (likely(target)) {
+					target->connections.erase(slot_kv.value.cE);
+				}
 			}
-		}
 
-		signal_map.erase(E.key);
-	}
-
-	// Disconnect signals that connect to this object.
-	while (connections.size()) {
-		Connection c = connections.front()->get();
-		Object *obj = c.callable.get_object();
-		bool disconnected = false;
-		if (likely(obj)) {
-			disconnected = c.signal.get_object()->_disconnect(c.signal.get_name(), c.callable, true);
+			signal_map.erase(E.key);
 		}
-		if (unlikely(!disconnected)) {
-			// If the disconnect has failed, abandon the connection to avoid getting trapped in an infinite loop here.
-			connections.pop_front();
+
+		// Disconnect signals that connect to this object.
+		while (connections.size()) {
+			Connection c = connections.front()->get();
+			Object *obj = c.callable.get_object();
+			bool disconnected = false;
+			if (likely(obj)) {
+				disconnected = c.signal.get_object()->_disconnect(c.signal.get_name(), c.callable, true);
+			}
+			if (unlikely(!disconnected)) {
+				// If the disconnect has failed, abandon the connection to avoid getting trapped in an infinite loop here.
+				connections.pop_front();
+			}
 		}
 	}
 
@@ -2248,6 +2306,10 @@ Object::~Object() {
 		}
 		memfree(_instance_bindings);
 	}
+
+	if (signal_mutex) {
+		memdelete(signal_mutex);
+	}
 }
 
 bool predelete_handler(Object *p_object) {

+ 4 - 1
core/object/object.h

@@ -608,7 +608,8 @@ private:
 		HashMap<Callable, Slot, HashableHasher<Callable>> slot_map;
 		bool removable = false;
 	};
-
+	friend struct _ObjectSignalLock;
+	mutable Mutex *signal_mutex = nullptr;
 	HashMap<StringName, SignalData> signal_map;
 	List<Connection> connections;
 #ifdef DEBUG_ENABLED
@@ -756,6 +757,8 @@ protected:
 
 	bool _disconnect(const StringName &p_signal, const Callable &p_callable, bool p_force = false);
 
+	virtual bool _uses_signal_mutex() const;
+
 #ifdef TOOLS_ENABLED
 	struct VirtualMethodTracker {
 		void **method;

+ 1 - 0
doc/classes/Object.xml

@@ -521,6 +521,7 @@
 				A signal can only be connected once to the same [Callable]. If the signal is already connected, this method returns [constant ERR_INVALID_PARAMETER] and generates an error, unless the signal is connected with [constant CONNECT_REFERENCE_COUNTED]. To prevent this, use [method is_connected] first to check for existing connections.
 				[b]Note:[/b] If the [param callable]'s object is freed, the connection will be lost.
 				[b]Note:[/b] In GDScript, it is generally recommended to connect signals with [method Signal.connect] instead.
+				[b]Note:[/b] This operation (and all other signal related operations) is thread-safe.
 			</description>
 		</method>
 		<method name="disconnect">

+ 2 - 0
scene/main/node.h

@@ -387,6 +387,8 @@ protected:
 	void _validate_property(PropertyInfo &p_property) const;
 
 protected:
+	virtual bool _uses_signal_mutex() const override { return false; } // Node uses thread guards instead.
+
 	virtual void input(const Ref<InputEvent> &p_event);
 	virtual void shortcut_input(const Ref<InputEvent> &p_key_event);
 	virtual void unhandled_input(const Ref<InputEvent> &p_event);