Browse Source

Refactor ref-counting code and fix ref counted releasing before aquiring

rune-scape 1 year ago
parent
commit
cee0e6667a

+ 36 - 66
core/object/ref_counted.h

@@ -57,24 +57,30 @@ template <typename T>
 class Ref {
 	T *reference = nullptr;
 
-	void ref(const Ref &p_from) {
-		if (p_from.reference == reference) {
+	_FORCE_INLINE_ void ref(const Ref &p_from) {
+		ref_pointer<false>(p_from.reference);
+	}
+
+	template <bool Init>
+	_FORCE_INLINE_ void ref_pointer(T *p_refcounted) {
+		if (p_refcounted == reference) {
 			return;
 		}
 
-		unref();
-
-		reference = p_from.reference;
+		// This will go out of scope and get unref'd.
+		Ref cleanup_ref;
+		cleanup_ref.reference = reference;
+		reference = p_refcounted;
 		if (reference) {
-			reference->reference();
-		}
-	}
-
-	void ref_pointer(T *p_ref) {
-		ERR_FAIL_NULL(p_ref);
-
-		if (p_ref->init_ref()) {
-			reference = p_ref;
+			if constexpr (Init) {
+				if (!reference->init_ref()) {
+					reference = nullptr;
+				}
+			} else {
+				if (!reference->reference()) {
+					reference = nullptr;
+				}
+			}
 		}
 	}
 
@@ -124,15 +130,11 @@ public:
 
 	template <typename T_Other>
 	void operator=(const Ref<T_Other> &p_from) {
-		RefCounted *refb = const_cast<RefCounted *>(static_cast<const RefCounted *>(p_from.ptr()));
-		if (!refb) {
-			unref();
-			return;
-		}
-		Ref r;
-		r.reference = Object::cast_to<T>(refb);
-		ref(r);
-		r.reference = nullptr;
+		ref_pointer<false>(Object::cast_to<T>(p_from.ptr()));
+	}
+
+	void operator=(T *p_from) {
+		ref_pointer<true>(p_from);
 	}
 
 	void operator=(const Variant &p_variant) {
@@ -142,16 +144,7 @@ public:
 			return;
 		}
 
-		unref();
-
-		if (!object) {
-			return;
-		}
-
-		T *r = Object::cast_to<T>(object);
-		if (r && r->reference()) {
-			reference = r;
-		}
+		ref_pointer<false>(Object::cast_to<T>(object));
 	}
 
 	template <typename T_Other>
@@ -159,48 +152,25 @@ public:
 		if (reference == p_ptr) {
 			return;
 		}
-		unref();
 
-		T *r = Object::cast_to<T>(p_ptr);
-		if (r) {
-			ref_pointer(r);
-		}
+		ref_pointer<true>(Object::cast_to<T>(p_ptr));
 	}
 
 	Ref(const Ref &p_from) {
-		ref(p_from);
+		this->operator=(p_from);
 	}
 
 	template <typename T_Other>
 	Ref(const Ref<T_Other> &p_from) {
-		RefCounted *refb = const_cast<RefCounted *>(static_cast<const RefCounted *>(p_from.ptr()));
-		if (!refb) {
-			unref();
-			return;
-		}
-		Ref r;
-		r.reference = Object::cast_to<T>(refb);
-		ref(r);
-		r.reference = nullptr;
+		this->operator=(p_from);
 	}
 
-	Ref(T *p_reference) {
-		if (p_reference) {
-			ref_pointer(p_reference);
-		}
+	Ref(T *p_from) {
+		this->operator=(p_from);
 	}
 
-	Ref(const Variant &p_variant) {
-		Object *object = p_variant.get_validated_object();
-
-		if (!object) {
-			return;
-		}
-
-		T *r = Object::cast_to<T>(object);
-		if (r && r->reference()) {
-			reference = r;
-		}
+	Ref(const Variant &p_from) {
+		this->operator=(p_from);
 	}
 
 	inline bool is_valid() const { return reference != nullptr; }
@@ -222,7 +192,7 @@ public:
 		ref(memnew(T(p_params...)));
 	}
 
-	Ref() {}
+	Ref() = default;
 
 	~Ref() {
 		unref();
@@ -299,13 +269,13 @@ struct GetTypeInfo<const Ref<T> &> {
 template <typename T>
 struct VariantInternalAccessor<Ref<T>> {
 	static _FORCE_INLINE_ Ref<T> get(const Variant *v) { return Ref<T>(*VariantInternal::get_object(v)); }
-	static _FORCE_INLINE_ void set(Variant *v, const Ref<T> &p_ref) { VariantInternal::refcounted_object_assign(v, p_ref.ptr()); }
+	static _FORCE_INLINE_ void set(Variant *v, const Ref<T> &p_ref) { VariantInternal::object_assign(v, p_ref); }
 };
 
 template <typename T>
 struct VariantInternalAccessor<const Ref<T> &> {
 	static _FORCE_INLINE_ Ref<T> get(const Variant *v) { return Ref<T>(*VariantInternal::get_object(v)); }
-	static _FORCE_INLINE_ void set(Variant *v, const Ref<T> &p_ref) { VariantInternal::refcounted_object_assign(v, p_ref.ptr()); }
+	static _FORCE_INLINE_ void set(Variant *v, const Ref<T> &p_ref) { VariantInternal::object_assign(v, p_ref); }
 };
 
 #endif // REF_COUNTED_H

+ 10 - 9
core/variant/callable.cpp

@@ -315,31 +315,32 @@ bool Callable::operator<(const Callable &p_callable) const {
 }
 
 void Callable::operator=(const Callable &p_callable) {
+	CallableCustom *cleanup_ref = nullptr;
 	if (is_custom()) {
 		if (p_callable.is_custom()) {
 			if (custom == p_callable.custom) {
 				return;
 			}
 		}
-
-		if (custom->ref_count.unref()) {
-			memdelete(custom);
-			custom = nullptr;
-		}
+		cleanup_ref = custom;
+		custom = nullptr;
 	}
 
 	if (p_callable.is_custom()) {
 		method = StringName();
-		if (!p_callable.custom->ref_count.ref()) {
-			object = 0;
-		} else {
-			object = 0;
+		object = 0;
+		if (p_callable.custom->ref_count.ref()) {
 			custom = p_callable.custom;
 		}
 	} else {
 		method = p_callable.method;
 		object = p_callable.object;
 	}
+
+	if (cleanup_ref != nullptr && cleanup_ref->ref_count.unref()) {
+		memdelete(cleanup_ref);
+	}
+	cleanup_ref = nullptr;
 }
 
 Callable::operator String() const {

+ 65 - 67
core/variant/variant.cpp

@@ -1072,17 +1072,69 @@ bool Variant::is_null() const {
 	}
 }
 
+void Variant::ObjData::ref(const ObjData &p_from) {
+	// Mirrors Ref::ref in refcounted.h
+	if (p_from.id == id) {
+		return;
+	}
+
+	ObjData cleanup_ref = *this;
+
+	*this = p_from;
+	if (id.is_ref_counted()) {
+		RefCounted *reference = static_cast<RefCounted *>(obj);
+		// Assuming reference is not null because id.is_ref_counted() was true.
+		if (!reference->reference()) {
+			*this = ObjData();
+		}
+	}
+
+	cleanup_ref.unref();
+}
+
+void Variant::ObjData::ref_pointer(Object *p_object) {
+	// Mirrors Ref::ref_pointer in refcounted.h
+	if (p_object == obj) {
+		return;
+	}
+
+	ObjData cleanup_ref = *this;
+
+	if (p_object) {
+		*this = ObjData{ p_object->get_instance_id(), p_object };
+		if (p_object->is_ref_counted()) {
+			RefCounted *reference = static_cast<RefCounted *>(p_object);
+			if (!reference->init_ref()) {
+				*this = ObjData();
+			}
+		}
+	} else {
+		*this = ObjData();
+	}
+
+	cleanup_ref.unref();
+}
+
+void Variant::ObjData::unref() {
+	// Mirrors Ref::unref in refcounted.h
+	if (id.is_ref_counted()) {
+		RefCounted *reference = static_cast<RefCounted *>(obj);
+		// Assuming reference is not null because id.is_ref_counted() was true.
+		if (reference->unreference()) {
+			memdelete(reference);
+		}
+	}
+	*this = ObjData();
+}
+
 void Variant::reference(const Variant &p_variant) {
-	switch (type) {
-		case NIL:
-		case BOOL:
-		case INT:
-		case FLOAT:
-			break;
-		default:
-			clear();
+	if (type == OBJECT && p_variant.type == OBJECT) {
+		_get_obj().ref(p_variant._get_obj());
+		return;
 	}
 
+	clear();
+
 	type = p_variant.type;
 
 	switch (p_variant.type) {
@@ -1165,18 +1217,7 @@ void Variant::reference(const Variant &p_variant) {
 		} break;
 		case OBJECT: {
 			memnew_placement(_data._mem, ObjData);
-
-			if (p_variant._get_obj().obj && p_variant._get_obj().id.is_ref_counted()) {
-				RefCounted *ref_counted = static_cast<RefCounted *>(p_variant._get_obj().obj);
-				if (!ref_counted->reference()) {
-					_get_obj().obj = nullptr;
-					_get_obj().id = ObjectID();
-					break;
-				}
-			}
-
-			_get_obj().obj = const_cast<Object *>(p_variant._get_obj().obj);
-			_get_obj().id = p_variant._get_obj().id;
+			_get_obj().ref(p_variant._get_obj());
 		} break;
 		case CALLABLE: {
 			memnew_placement(_data._mem, Callable(*reinterpret_cast<const Callable *>(p_variant._data._mem)));
@@ -1375,15 +1416,7 @@ void Variant::_clear_internal() {
 			reinterpret_cast<NodePath *>(_data._mem)->~NodePath();
 		} break;
 		case OBJECT: {
-			if (_get_obj().id.is_ref_counted()) {
-				// We are safe that there is a reference here.
-				RefCounted *ref_counted = static_cast<RefCounted *>(_get_obj().obj);
-				if (ref_counted->unreference()) {
-					memdelete(ref_counted);
-				}
-			}
-			_get_obj().obj = nullptr;
-			_get_obj().id = ObjectID();
+			_get_obj().unref();
 		} break;
 		case RID: {
 			// Not much need probably.
@@ -2589,24 +2622,8 @@ Variant::Variant(const ::RID &p_rid) :
 
 Variant::Variant(const Object *p_object) :
 		type(OBJECT) {
-	memnew_placement(_data._mem, ObjData);
-
-	if (p_object) {
-		if (p_object->is_ref_counted()) {
-			RefCounted *ref_counted = const_cast<RefCounted *>(static_cast<const RefCounted *>(p_object));
-			if (!ref_counted->init_ref()) {
-				_get_obj().obj = nullptr;
-				_get_obj().id = ObjectID();
-				return;
-			}
-		}
-
-		_get_obj().obj = const_cast<Object *>(p_object);
-		_get_obj().id = p_object->get_instance_id();
-	} else {
-		_get_obj().obj = nullptr;
-		_get_obj().id = ObjectID();
-	}
+	_get_obj() = ObjData();
+	_get_obj().ref_pointer(const_cast<Object *>(p_object));
 }
 
 Variant::Variant(const Callable &p_callable) :
@@ -2828,26 +2845,7 @@ void Variant::operator=(const Variant &p_variant) {
 			*reinterpret_cast<::RID *>(_data._mem) = *reinterpret_cast<const ::RID *>(p_variant._data._mem);
 		} break;
 		case OBJECT: {
-			if (_get_obj().id.is_ref_counted()) {
-				//we are safe that there is a reference here
-				RefCounted *ref_counted = static_cast<RefCounted *>(_get_obj().obj);
-				if (ref_counted->unreference()) {
-					memdelete(ref_counted);
-				}
-			}
-
-			if (p_variant._get_obj().obj && p_variant._get_obj().id.is_ref_counted()) {
-				RefCounted *ref_counted = static_cast<RefCounted *>(p_variant._get_obj().obj);
-				if (!ref_counted->reference()) {
-					_get_obj().obj = nullptr;
-					_get_obj().id = ObjectID();
-					break;
-				}
-			}
-
-			_get_obj().obj = const_cast<Object *>(p_variant._get_obj().obj);
-			_get_obj().id = p_variant._get_obj().id;
-
+			_get_obj().ref(p_variant._get_obj());
 		} break;
 		case CALLABLE: {
 			*reinterpret_cast<Callable *>(_data._mem) = *reinterpret_cast<const Callable *>(p_variant._data._mem);

+ 18 - 0
core/variant/variant.h

@@ -62,6 +62,10 @@
 #include "core/variant/dictionary.h"
 
 class Object;
+class RefCounted;
+
+template <typename T>
+class Ref;
 
 struct PropertyInfo;
 struct MethodInfo;
@@ -175,6 +179,20 @@ private:
 	struct ObjData {
 		ObjectID id;
 		Object *obj = nullptr;
+
+		void ref(const ObjData &p_from);
+		void ref_pointer(Object *p_object);
+		void ref_pointer(RefCounted *p_object);
+		void unref();
+
+		template <typename T>
+		_ALWAYS_INLINE_ void ref(const Ref<T> &p_from) {
+			if (p_from.is_valid()) {
+				ref(ObjData{ p_from->get_instance_id(), p_from.ptr() });
+			} else {
+				unref();
+			}
+		}
 	};
 
 	/* array helpers */

+ 0 - 30
core/variant/variant_construct.cpp

@@ -323,36 +323,6 @@ String Variant::get_constructor_argument_name(Variant::Type p_type, int p_constr
 	return construct_data[p_type][p_constructor].arg_names[p_argument];
 }
 
-void VariantInternal::refcounted_object_assign(Variant *v, const RefCounted *rc) {
-	if (!rc || !const_cast<RefCounted *>(rc)->init_ref()) {
-		v->_get_obj().obj = nullptr;
-		v->_get_obj().id = ObjectID();
-		return;
-	}
-
-	v->_get_obj().obj = const_cast<RefCounted *>(rc);
-	v->_get_obj().id = rc->get_instance_id();
-}
-
-void VariantInternal::object_assign(Variant *v, const Object *o) {
-	if (o) {
-		if (o->is_ref_counted()) {
-			RefCounted *ref_counted = const_cast<RefCounted *>(static_cast<const RefCounted *>(o));
-			if (!ref_counted->init_ref()) {
-				v->_get_obj().obj = nullptr;
-				v->_get_obj().id = ObjectID();
-				return;
-			}
-		}
-
-		v->_get_obj().obj = const_cast<Object *>(o);
-		v->_get_obj().id = o->get_instance_id();
-	} else {
-		v->_get_obj().obj = nullptr;
-		v->_get_obj().id = ObjectID();
-	}
-}
-
 void Variant::get_constructor_list(Type p_type, List<MethodInfo> *r_list) {
 	ERR_FAIL_INDEX(p_type, Variant::VARIANT_MAX);
 

+ 4 - 5
core/variant/variant_construct.h

@@ -156,14 +156,14 @@ public:
 		if (p_args[0]->get_type() == Variant::NIL) {
 			VariantInternal::clear(&r_ret);
 			VariantTypeChanger<Object *>::change(&r_ret);
-			VariantInternal::object_assign_null(&r_ret);
+			VariantInternal::object_reset_data(&r_ret);
 			r_error.error = Callable::CallError::CALL_OK;
 		} else if (p_args[0]->get_type() == Variant::OBJECT) {
-			VariantInternal::clear(&r_ret);
 			VariantTypeChanger<Object *>::change(&r_ret);
 			VariantInternal::object_assign(&r_ret, p_args[0]);
 			r_error.error = Callable::CallError::CALL_OK;
 		} else {
+			VariantInternal::clear(&r_ret);
 			r_error.error = Callable::CallError::CALL_ERROR_INVALID_ARGUMENT;
 			r_error.argument = 0;
 			r_error.expected = Variant::OBJECT;
@@ -171,7 +171,6 @@ public:
 	}
 
 	static inline void validated_construct(Variant *r_ret, const Variant **p_args) {
-		VariantInternal::clear(r_ret);
 		VariantTypeChanger<Object *>::change(r_ret);
 		VariantInternal::object_assign(r_ret, p_args[0]);
 	}
@@ -203,13 +202,13 @@ public:
 
 		VariantInternal::clear(&r_ret);
 		VariantTypeChanger<Object *>::change(&r_ret);
-		VariantInternal::object_assign_null(&r_ret);
+		VariantInternal::object_reset_data(&r_ret);
 	}
 
 	static inline void validated_construct(Variant *r_ret, const Variant **p_args) {
 		VariantInternal::clear(r_ret);
 		VariantTypeChanger<Object *>::change(r_ret);
-		VariantInternal::object_assign_null(r_ret);
+		VariantInternal::object_reset_data(r_ret);
 	}
 	static void ptr_construct(void *base, const void **p_args) {
 		PtrConstruct<Object *>::construct(nullptr, base);

+ 19 - 10
core/variant/variant_internal.h

@@ -220,7 +220,7 @@ public:
 	// Should be in the same order as Variant::Type for consistency.
 	// Those primitive and vector types don't need an `init_` method:
 	// Nil, bool, float, Vector2/i, Rect2/i, Vector3/i, Plane, Quat, RID.
-	// Object is a special case, handled via `object_assign_null`.
+	// Object is a special case, handled via `object_reset_data`.
 	_FORCE_INLINE_ static void init_string(Variant *v) {
 		memnew_placement(v->_data._mem, String);
 		v->type = Variant::STRING;
@@ -319,7 +319,7 @@ public:
 		v->type = Variant::PACKED_VECTOR4_ARRAY;
 	}
 	_FORCE_INLINE_ static void init_object(Variant *v) {
-		object_assign_null(v);
+		object_reset_data(v);
 		v->type = Variant::OBJECT;
 	}
 
@@ -327,19 +327,28 @@ public:
 		v->clear();
 	}
 
-	static void object_assign(Variant *v, const Object *o); // Needs RefCounted, so it's implemented elsewhere.
-	static void refcounted_object_assign(Variant *v, const RefCounted *rc);
+	_FORCE_INLINE_ static void object_assign(Variant *v, const Variant *vo) {
+		v->_get_obj().ref(vo->_get_obj());
+	}
+
+	_FORCE_INLINE_ static void object_assign(Variant *v, Object *o) {
+		v->_get_obj().ref_pointer(o);
+	}
 
-	_FORCE_INLINE_ static void object_assign(Variant *v, const Variant *o) {
-		object_assign(v, o->_get_obj().obj);
+	_FORCE_INLINE_ static void object_assign(Variant *v, const Object *o) {
+		v->_get_obj().ref_pointer(const_cast<Object *>(o));
+	}
+
+	template <typename T>
+	_FORCE_INLINE_ static void object_assign(Variant *v, const Ref<T> &r) {
+		v->_get_obj().ref(r);
 	}
 
-	_FORCE_INLINE_ static void object_assign_null(Variant *v) {
-		v->_get_obj().obj = nullptr;
-		v->_get_obj().id = ObjectID();
+	_FORCE_INLINE_ static void object_reset_data(Variant *v) {
+		v->_get_obj() = Variant::ObjData();
 	}
 
-	static void update_object_id(Variant *v) {
+	_FORCE_INLINE_ static void update_object_id(Variant *v) {
 		const Object *o = v->_get_obj().obj;
 		if (o) {
 			v->_get_obj().id = o->get_instance_id();