Selaa lähdekoodia

Move bisect to Span and deduplicate code.

Co-authored-by: Lukas Tenbrink <[email protected]>
Yufeng Ying 2 kuukautta sitten
vanhempi
commit
3bf400ffae

+ 31 - 0
core/templates/span.h

@@ -86,6 +86,10 @@ public:
 	constexpr int64_t rfind(const T &p_val, uint64_t p_from) const;
 	_FORCE_INLINE_ constexpr int64_t rfind(const T &p_val) const { return rfind(p_val, size() - 1); }
 	constexpr uint64_t count(const T &p_val) const;
+	/// Find the index of the given value using binary search.
+	/// Note: Assumes that elements in the span are sorted. Otherwise, use find() instead.
+	template <typename Comparator = Comparator<T>>
+	constexpr uint64_t bisect(const T &p_value, bool p_before, Comparator compare = Comparator()) const;
 };
 
 template <typename T>
@@ -119,6 +123,33 @@ constexpr uint64_t Span<T>::count(const T &p_val) const {
 	return amount;
 }
 
+template <typename T>
+template <typename Comparator>
+constexpr uint64_t Span<T>::bisect(const T &p_value, bool p_before, Comparator compare) const {
+	uint64_t lo = 0;
+	uint64_t hi = size();
+	if (p_before) {
+		while (lo < hi) {
+			const uint64_t mid = (lo + hi) / 2;
+			if (compare(ptr()[mid], p_value)) {
+				lo = mid + 1;
+			} else {
+				hi = mid;
+			}
+		}
+	} else {
+		while (lo < hi) {
+			const uint64_t mid = (lo + hi) / 2;
+			if (compare(p_value, ptr()[mid])) {
+				hi = mid;
+			} else {
+				lo = mid + 1;
+			}
+		}
+	}
+	return lo;
+}
+
 // Zero-constructing Span initializes _ptr and _len to 0 (and thus empty).
 template <typename T>
 struct is_zero_constructible<Span<T>> : std::true_type {};

+ 1 - 3
core/templates/vector.h

@@ -40,7 +40,6 @@
 
 #include "core/error/error_macros.h"
 #include "core/templates/cowdata.h"
-#include "core/templates/search_array.h"
 #include "core/templates/sort_array.h"
 
 #include <initializer_list>
@@ -152,8 +151,7 @@ public:
 
 	template <typename Comparator, typename Value, typename... Args>
 	Size bsearch_custom(const Value &p_value, bool p_before, Args &&...args) {
-		SearchArray<T, Comparator> search{ args... };
-		return search.bisect(ptrw(), size(), p_value, p_before);
+		return span().bisect(p_value, p_before, Comparator{ args... });
 	}
 
 	Vector<T> duplicate() {

+ 9 - 43
core/templates/vset.h

@@ -37,41 +37,19 @@ template <typename T>
 class VSet {
 	Vector<T> _data;
 
+protected:
 	_FORCE_INLINE_ int _find(const T &p_val, bool &r_exact) const {
 		r_exact = false;
 		if (_data.is_empty()) {
 			return 0;
 		}
 
-		int low = 0;
-		int high = _data.size() - 1;
-		const T *a = &_data[0];
-		int middle = 0;
+		int64_t pos = _data.span().bisect(p_val, true);
 
-#ifdef DEBUG_ENABLED
-		if (low > high) {
-			ERR_PRINT("low > high, this may be a bug");
+		if (pos < _data.size() && !(p_val < _data[pos]) && !(_data[pos] < p_val)) {
+			r_exact = true;
 		}
-#endif
-
-		while (low <= high) {
-			middle = (low + high) / 2;
-
-			if (p_val < a[middle]) {
-				high = middle - 1; //search low end of array
-			} else if (a[middle] < p_val) {
-				low = middle + 1; //search high end of array
-			} else {
-				r_exact = true;
-				return middle;
-			}
-		}
-
-		//return the position where this would be inserted
-		if (a[middle] < p_val) {
-			middle++;
-		}
-		return middle;
+		return pos;
 	}
 
 	_FORCE_INLINE_ int _find_exact(const T &p_val) const {
@@ -79,23 +57,11 @@ class VSet {
 			return -1;
 		}
 
-		int low = 0;
-		int high = _data.size() - 1;
-		int middle;
-		const T *a = &_data[0];
-
-		while (low <= high) {
-			middle = (low + high) / 2;
-
-			if (p_val < a[middle]) {
-				high = middle - 1; //search low end of array
-			} else if (a[middle] < p_val) {
-				low = middle + 1; //search high end of array
-			} else {
-				return middle;
-			}
-		}
+		int64_t pos = _data.span().bisect(p_val, true);
 
+		if (pos < _data.size() && !(p_val < _data[pos]) && !(_data[pos] < p_val)) {
+			return pos;
+		}
 		return -1;
 	}
 

+ 1 - 3
core/variant/array.cpp

@@ -34,7 +34,6 @@
 #include "core/math/math_funcs.h"
 #include "core/object/script_language.h"
 #include "core/templates/hashfuncs.h"
-#include "core/templates/search_array.h"
 #include "core/templates/vector.h"
 #include "core/variant/callable.h"
 #include "core/variant/dictionary.h"
@@ -737,8 +736,7 @@ void Array::shuffle() {
 int Array::bsearch(const Variant &p_value, bool p_before) const {
 	Variant value = p_value;
 	ERR_FAIL_COND_V(!_p->typed.validate(value, "binary search"), -1);
-	SearchArray<Variant, _ArrayVariantSort> avs;
-	return avs.bisect(_p->array.ptr(), _p->array.size(), value, p_before);
+	return _p->array.span().bisect<_ArrayVariantSort>(value, p_before);
 }
 
 int Array::bsearch_custom(const Variant &p_value, const Callable &p_callable, bool p_before) const {

+ 1 - 2
drivers/metal/metal_objects.mm

@@ -1571,14 +1571,13 @@ BoundUniformSet &MDUniformSet::bound_uniform_set(MDShader *p_shader, id<MTLDevic
 		}
 	}
 
-	SearchArray<__unsafe_unretained id<MTLResource>> search;
 	ResourceUsageMap usage_to_resources;
 	for (KeyValue<id<MTLResource>, StageResourceUsage> const &keyval : bound_resources) {
 		ResourceVector *resources = usage_to_resources.getptr(keyval.value);
 		if (resources == nullptr) {
 			resources = &usage_to_resources.insert(keyval.value, ResourceVector())->value;
 		}
-		int64_t pos = search.bisect(resources->ptr(), resources->size(), keyval.key, true);
+		int64_t pos = resources->span().bisect(keyval.key, true);
 		if (pos == resources->size() || (*resources)[pos] != keyval.key) {
 			resources->insert(pos, keyval.key);
 		}

+ 5 - 15
editor/doc_tools.cpp

@@ -178,10 +178,8 @@ static void merge_methods(Vector<DocData::MethodDoc> &p_to, const Vector<DocData
 	DocData::MethodDoc *to_ptrw = p_to.ptrw();
 	int64_t to_size = p_to.size();
 
-	SearchArray<DocData::MethodDoc, MethodCompare> search_array;
-
 	for (const DocData::MethodDoc &from : p_from) {
-		int64_t found = search_array.bisect(to_ptrw, to_size, from, true);
+		int64_t found = p_to.span().bisect<MethodCompare>(from, true);
 
 		if (found >= to_size) {
 			continue;
@@ -206,10 +204,8 @@ static void merge_constants(Vector<DocData::ConstantDoc> &p_to, const Vector<Doc
 	const DocData::ConstantDoc *from_ptr = p_from.ptr();
 	int64_t from_size = p_from.size();
 
-	SearchArray<DocData::ConstantDoc> search_array;
-
 	for (DocData::ConstantDoc &to : p_to) {
-		int64_t found = search_array.bisect(from_ptr, from_size, to, true);
+		int64_t found = p_from.span().bisect(to, true);
 
 		if (found >= from_size) {
 			continue;
@@ -234,10 +230,8 @@ static void merge_properties(Vector<DocData::PropertyDoc> &p_to, const Vector<Do
 	DocData::PropertyDoc *to_ptrw = p_to.ptrw();
 	int64_t to_size = p_to.size();
 
-	SearchArray<DocData::PropertyDoc> search_array;
-
 	for (const DocData::PropertyDoc &from : p_from) {
-		int64_t found = search_array.bisect(to_ptrw, to_size, from, true);
+		int64_t found = p_to.span().bisect(from, true);
 
 		if (found >= to_size) {
 			continue;
@@ -262,10 +256,8 @@ static void merge_theme_properties(Vector<DocData::ThemeItemDoc> &p_to, const Ve
 	DocData::ThemeItemDoc *to_ptrw = p_to.ptrw();
 	int64_t to_size = p_to.size();
 
-	SearchArray<DocData::ThemeItemDoc> search_array;
-
 	for (const DocData::ThemeItemDoc &from : p_from) {
-		int64_t found = search_array.bisect(to_ptrw, to_size, from, true);
+		int64_t found = p_to.span().bisect(from, true);
 
 		if (found >= to_size) {
 			continue;
@@ -290,10 +282,8 @@ static void merge_operators(Vector<DocData::MethodDoc> &p_to, const Vector<DocDa
 	DocData::MethodDoc *to_ptrw = p_to.ptrw();
 	int64_t to_size = p_to.size();
 
-	SearchArray<DocData::MethodDoc, OperatorCompare> search_array;
-
 	for (const DocData::MethodDoc &from : p_from) {
-		int64_t found = search_array.bisect(to_ptrw, to_size, from, true);
+		int64_t found = p_to.span().bisect(from, true);
 
 		if (found >= to_size) {
 			continue;

+ 1 - 14
scene/gui/item_list.cpp

@@ -1414,20 +1414,7 @@ void ItemList::_notification(int p_what) {
 			const Rect2 clip(-base_ofs, size);
 
 			// Do a binary search to find the first separator that is below clip_position.y.
-			int first_visible_separator = 0;
-			{
-				int lo = 0;
-				int hi = separators.size();
-				while (lo < hi) {
-					const int mid = (lo + hi) / 2;
-					if (separators[mid] < clip.position.y) {
-						lo = mid + 1;
-					} else {
-						hi = mid;
-					}
-				}
-				first_visible_separator = lo;
-			}
+			int64_t first_visible_separator = separators.span().bisect(clip.position.y, true);
 
 			// If not in thumbnails mode, draw visible separators.
 			if (icon_mode != ICON_MODE_TOP) {

+ 64 - 29
core/templates/search_array.h → tests/core/templates/test_vset.h

@@ -1,5 +1,5 @@
 /**************************************************************************/
-/*  search_array.h                                                        */
+/*  test_vset.h                                                           */
 /**************************************************************************/
 /*                         This file is part of:                          */
 /*                             GODOT ENGINE                               */
@@ -30,35 +30,70 @@
 
 #pragma once
 
-#include "core/typedefs.h"
+#include "core/templates/vset.h"
 
-template <typename T, typename Comparator = Comparator<T>>
-class SearchArray {
+#include "tests/test_macros.h"
+
+namespace TestVSet {
+
+template <typename T>
+class TestClass : public VSet<T> {
 public:
-	Comparator compare;
-
-	inline int64_t bisect(const T *p_array, int64_t p_len, const T &p_value, bool p_before) const {
-		int64_t lo = 0;
-		int64_t hi = p_len;
-		if (p_before) {
-			while (lo < hi) {
-				const int64_t mid = (lo + hi) / 2;
-				if (compare(p_array[mid], p_value)) {
-					lo = mid + 1;
-				} else {
-					hi = mid;
-				}
-			}
-		} else {
-			while (lo < hi) {
-				const int64_t mid = (lo + hi) / 2;
-				if (compare(p_value, p_array[mid])) {
-					hi = mid;
-				} else {
-					lo = mid + 1;
-				}
-			}
-		}
-		return lo;
+	int _find(const T &p_val, bool &r_exact) const {
+		return VSet<T>::_find(p_val, r_exact);
 	}
 };
+
+TEST_CASE("[VSet] _find and _find_exact correctness.") {
+	TestClass<int> set;
+
+	// insert some values
+	set.insert(10);
+	set.insert(20);
+	set.insert(30);
+	set.insert(40);
+	set.insert(50);
+
+	// data should be sorted
+	CHECK(set.size() == 5);
+	CHECK(set[0] == 10);
+	CHECK(set[1] == 20);
+	CHECK(set[2] == 30);
+	CHECK(set[3] == 40);
+	CHECK(set[4] == 50);
+
+	// _find_exact return exact position for existing elements
+	CHECK(set.find(10) == 0);
+	CHECK(set.find(30) == 2);
+	CHECK(set.find(50) == 4);
+
+	// _find_exact return -1 for non-existing elements
+	CHECK(set.find(15) == -1);
+	CHECK(set.find(0) == -1);
+	CHECK(set.find(60) == -1);
+
+	// test _find
+	bool exact;
+
+	// existing elements
+	CHECK(set._find(10, exact) == 0);
+	CHECK(exact == true);
+
+	CHECK(set._find(30, exact) == 2);
+	CHECK(exact == true);
+
+	// non-existing elements
+	CHECK(set._find(25, exact) == 2);
+	CHECK(exact == false);
+
+	CHECK(set._find(35, exact) == 3);
+	CHECK(exact == false);
+
+	CHECK(set._find(5, exact) == 0);
+	CHECK(exact == false);
+
+	CHECK(set._find(60, exact) == 5);
+	CHECK(exact == false);
+}
+
+} // namespace TestVSet

+ 1 - 0
tests/test_main.cpp

@@ -106,6 +106,7 @@
 #include "tests/core/templates/test_rid.h"
 #include "tests/core/templates/test_span.h"
 #include "tests/core/templates/test_vector.h"
+#include "tests/core/templates/test_vset.h"
 #include "tests/core/test_crypto.h"
 #include "tests/core/test_hashing_context.h"
 #include "tests/core/test_time.h"