Pārlūkot izejas kodu

Merge pull request #105629 from aaronp64/list_sort

Reuse and optimize sorting logic for `List`, `SelfList`, and `HashMap`
Thaddeus Crews 3 mēneši atpakaļ
vecāks
revīzija
0d88e17143

+ 12 - 39
core/templates/hash_map.h

@@ -33,6 +33,7 @@
 #include "core/os/memory.h"
 #include "core/templates/hashfuncs.h"
 #include "core/templates/pair.h"
+#include "core/templates/sort_list.h"
 
 #include <initializer_list>
 
@@ -60,8 +61,6 @@ struct HashMapElement {
 			data(p_key, p_value) {}
 };
 
-bool _hashmap_variant_less_than(const Variant &p_left, const Variant &p_right);
-
 template <typename TKey, typename TValue,
 		typename Hasher = HashMapHasherDefault,
 		typename Comparator = HashMapComparatorDefault<TKey>,
@@ -265,44 +264,18 @@ public:
 	}
 
 	void sort() {
-		if (elements == nullptr || num_elements < 2) {
-			return; // An empty or single element HashMap is already sorted.
-		}
-		// Use insertion sort because we want this operation to be fast for the
-		// common case where the input is already sorted or nearly sorted.
-		HashMapElement<TKey, TValue> *inserting = head_element->next;
-		while (inserting != nullptr) {
-			HashMapElement<TKey, TValue> *after = nullptr;
-			for (HashMapElement<TKey, TValue> *current = inserting->prev; current != nullptr; current = current->prev) {
-				if (_hashmap_variant_less_than(inserting->data.key, current->data.key)) {
-					after = current;
-				} else {
-					break;
-				}
-			}
-			HashMapElement<TKey, TValue> *next = inserting->next;
-			if (after != nullptr) {
-				// Modify the elements around `inserting` to remove it from its current position.
-				inserting->prev->next = next;
-				if (next == nullptr) {
-					tail_element = inserting->prev;
-				} else {
-					next->prev = inserting->prev;
-				}
-				// Modify `before` and `after` to insert `inserting` between them.
-				HashMapElement<TKey, TValue> *before = after->prev;
-				if (before == nullptr) {
-					head_element = inserting;
-				} else {
-					before->next = inserting;
-				}
-				after->prev = inserting;
-				// Point `inserting` to its new surroundings.
-				inserting->prev = before;
-				inserting->next = after;
-			}
-			inserting = next;
+		sort_custom<KeyValueSort<TKey, TValue>>();
+	}
+
+	template <typename C>
+	void sort_custom() {
+		if (size() < 2) {
+			return;
 		}
+
+		using E = HashMapElement<TKey, TValue>;
+		SortList<E, KeyValue<TKey, TValue>, &E::data, &E::prev, &E::next, C> sorter;
+		sorter.sort(head_element, tail_element);
 	}
 
 	TValue &get(const TKey &p_key) {

+ 4 - 90
core/templates/list.h

@@ -32,7 +32,7 @@
 
 #include "core/error/error_macros.h"
 #include "core/os/memory.h"
-#include "core/templates/sort_array.h"
+#include "core/templates/sort_list.h"
 
 #include <initializer_list>
 
@@ -656,104 +656,18 @@ public:
 		where->prev_ptr = value;
 	}
 
-	/**
-	 * simple insertion sort
-	 */
-
 	void sort() {
 		sort_custom<Comparator<T>>();
 	}
 
-	template <typename C>
-	void sort_custom_inplace() {
-		if (size() < 2) {
-			return;
-		}
-
-		Element *from = front();
-		Element *current = from;
-		Element *to = from;
-
-		while (current) {
-			Element *next = current->next_ptr;
-
-			if (from != current) {
-				current->prev_ptr = nullptr;
-				current->next_ptr = from;
-
-				Element *find = from;
-				C less;
-				while (find && less(find->value, current->value)) {
-					current->prev_ptr = find;
-					current->next_ptr = find->next_ptr;
-					find = find->next_ptr;
-				}
-
-				if (current->prev_ptr) {
-					current->prev_ptr->next_ptr = current;
-				} else {
-					from = current;
-				}
-
-				if (current->next_ptr) {
-					current->next_ptr->prev_ptr = current;
-				} else {
-					to = current;
-				}
-			} else {
-				current->prev_ptr = nullptr;
-				current->next_ptr = nullptr;
-			}
-
-			current = next;
-		}
-		_data->first = from;
-		_data->last = to;
-	}
-
-	template <typename C>
-	struct AuxiliaryComparator {
-		C compare;
-		_FORCE_INLINE_ bool operator()(const Element *a, const Element *b) const {
-			return compare(a->value, b->value);
-		}
-	};
-
 	template <typename C>
 	void sort_custom() {
-		//this version uses auxiliary memory for speed.
-		//if you don't want to use auxiliary memory, use the in_place version
-
-		int s = size();
-		if (s < 2) {
+		if (size() < 2) {
 			return;
 		}
 
-		Element **aux_buffer = memnew_arr(Element *, s);
-
-		int idx = 0;
-		for (Element *E = front(); E; E = E->next_ptr) {
-			aux_buffer[idx] = E;
-			idx++;
-		}
-
-		SortArray<Element *, AuxiliaryComparator<C>> sort;
-		sort.sort(aux_buffer, s);
-
-		_data->first = aux_buffer[0];
-		aux_buffer[0]->prev_ptr = nullptr;
-		aux_buffer[0]->next_ptr = aux_buffer[1];
-
-		_data->last = aux_buffer[s - 1];
-		aux_buffer[s - 1]->prev_ptr = aux_buffer[s - 2];
-		aux_buffer[s - 1]->next_ptr = nullptr;
-
-		for (int i = 1; i < s - 1; i++) {
-			aux_buffer[i]->prev_ptr = aux_buffer[i - 1];
-			aux_buffer[i]->next_ptr = aux_buffer[i + 1];
-		}
-
-		memdelete_arr(aux_buffer);
+		SortList<Element, T, &Element::value, &Element::prev_ptr, &Element::next_ptr, C> sorter;
+		sorter.sort(_data->first, _data->last);
 	}
 
 	const void *id() const {

+ 8 - 39
core/templates/self_list.h

@@ -31,6 +31,7 @@
 #pragma once
 
 #include "core/error/error_macros.h"
+#include "core/templates/sort_list.h"
 #include "core/typedefs.h"
 
 template <typename T>
@@ -114,45 +115,13 @@ public:
 				return;
 			}
 
-			SelfList<T> *from = _first;
-			SelfList<T> *current = from;
-			SelfList<T> *to = from;
-
-			while (current) {
-				SelfList<T> *next = current->_next;
-
-				if (from != current) {
-					current->_prev = nullptr;
-					current->_next = from;
-
-					SelfList<T> *find = from;
-					C less;
-					while (find && less(*find->_self, *current->_self)) {
-						current->_prev = find;
-						current->_next = find->_next;
-						find = find->_next;
-					}
-
-					if (current->_prev) {
-						current->_prev->_next = current;
-					} else {
-						from = current;
-					}
-
-					if (current->_next) {
-						current->_next->_prev = current;
-					} else {
-						to = current;
-					}
-				} else {
-					current->_prev = nullptr;
-					current->_next = nullptr;
-				}
-
-				current = next;
-			}
-			_first = from;
-			_last = to;
+			struct PtrComparator {
+				C compare;
+				_FORCE_INLINE_ bool operator()(const T *p_a, const T *p_b) const { return compare(*p_a, *p_b); }
+			};
+			using Element = SelfList<T>;
+			SortList<Element, T *, &Element::_self, &Element::_prev, &Element::_next, PtrComparator> sorter;
+			sorter.sort(_first, _last);
 		}
 
 		_FORCE_INLINE_ SelfList<T> *first() { return _first; }

+ 148 - 0
core/templates/sort_list.h

@@ -0,0 +1,148 @@
+/**************************************************************************/
+/*  sort_list.h                                                           */
+/**************************************************************************/
+/*                         This file is part of:                          */
+/*                             GODOT ENGINE                               */
+/*                        https://godotengine.org                         */
+/**************************************************************************/
+/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
+/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur.                  */
+/*                                                                        */
+/* Permission is hereby granted, free of charge, to any person obtaining  */
+/* a copy of this software and associated documentation files (the        */
+/* "Software"), to deal in the Software without restriction, including    */
+/* without limitation the rights to use, copy, modify, merge, publish,    */
+/* distribute, sublicense, and/or sell copies of the Software, and to     */
+/* permit persons to whom the Software is furnished to do so, subject to  */
+/* the following conditions:                                              */
+/*                                                                        */
+/* The above copyright notice and this permission notice shall be         */
+/* included in all copies or substantial portions of the Software.        */
+/*                                                                        */
+/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,        */
+/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF     */
+/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
+/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY   */
+/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,   */
+/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE      */
+/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.                 */
+/**************************************************************************/
+
+#pragma once
+
+#include "core/typedefs.h"
+
+template <typename Element, typename T, T Element::*value, Element *Element::*prev, Element *Element::*next, typename Comparator = Comparator<T>>
+class SortList {
+public:
+	Comparator compare;
+
+	void sort(Element *&r_head, Element *&r_tail) {
+		Element *sorted_until;
+		if (_is_sorted(r_head, r_tail, sorted_until)) {
+			return;
+		}
+
+		// In case we're sorting only part of a larger list.
+		Element *head_prev = r_head->*prev;
+		r_head->*prev = nullptr;
+		Element *tail_next = r_tail->*next;
+		r_tail->*next = nullptr;
+
+		// Sort unsorted section and merge.
+		Element *head2 = sorted_until->*next;
+		_split(sorted_until, head2);
+		_merge_sort(head2, r_tail);
+		_merge(r_head, sorted_until, head2, r_tail, r_head, r_tail);
+
+		// Reconnect to larger list if needed.
+		if (head_prev) {
+			_connect(head_prev, r_head);
+		}
+		if (tail_next) {
+			_connect(r_tail, tail_next);
+		}
+	}
+
+private:
+	bool _is_sorted(Element *p_head, Element *p_tail, Element *&r_sorted_until) {
+		r_sorted_until = p_head;
+		while (r_sorted_until != p_tail) {
+			if (compare(r_sorted_until->*next->*value, r_sorted_until->*value)) {
+				return false;
+			}
+
+			r_sorted_until = r_sorted_until->*next;
+		}
+
+		return true;
+	}
+
+	void _merge_sort(Element *&r_head, Element *&r_tail) {
+		if (r_head == r_tail) {
+			return;
+		}
+
+		Element *tail1 = _get_mid(r_head);
+		Element *head2 = tail1->*next;
+		_split(tail1, head2);
+
+		_merge_sort(r_head, tail1);
+		_merge_sort(head2, r_tail);
+		_merge(r_head, tail1, head2, r_tail, r_head, r_tail);
+	}
+
+	void _merge(
+			Element *p_head1, Element *p_tail1,
+			Element *p_head2, Element *p_tail2,
+			Element *&r_head, Element *&r_tail) {
+		if (compare(p_head2->*value, p_head1->*value)) {
+			r_head = p_head2;
+			p_head2 = p_head2->*next;
+		} else {
+			r_head = p_head1;
+			p_head1 = p_head1->*next;
+		}
+
+		Element *curr = r_head;
+		while (p_head1 && p_head2) {
+			if (compare(p_head2->*value, p_head1->*value)) {
+				_connect(curr, p_head2);
+				p_head2 = p_head2->*next;
+			} else {
+				_connect(curr, p_head1);
+				p_head1 = p_head1->*next;
+			}
+			curr = curr->*next;
+		}
+
+		if (p_head1) {
+			_connect(curr, p_head1);
+			r_tail = p_tail1;
+		} else {
+			_connect(curr, p_head2);
+			r_tail = p_tail2;
+		}
+	}
+
+	Element *_get_mid(Element *p_head) {
+		Element *end = p_head;
+		Element *mid = p_head;
+		while (end->*next && end->*next->*next) {
+			end = end->*next->*next;
+			mid = mid->*next;
+		}
+
+		return mid;
+	}
+
+	_FORCE_INLINE_ void _connect(Element *p_a, Element *p_b) {
+		p_a->*next = p_b;
+		p_b->*prev = p_a;
+	}
+
+	_FORCE_INLINE_ void _split(Element *p_a, Element *p_b) {
+		p_a->*next = nullptr;
+		p_b->*prev = nullptr;
+	}
+};

+ 13 - 1
core/variant/dictionary.cpp

@@ -304,9 +304,21 @@ void Dictionary::clear() {
 	_p->variant_map.clear();
 }
 
+struct _DictionaryVariantSort {
+	_FORCE_INLINE_ bool operator()(const KeyValue<Variant, Variant> &p_l, const KeyValue<Variant, Variant> &p_r) const {
+		bool valid = false;
+		Variant res;
+		Variant::evaluate(Variant::OP_LESS, p_l.key, p_r.key, res, valid);
+		if (!valid) {
+			res = false;
+		}
+		return res;
+	}
+};
+
 void Dictionary::sort() {
 	ERR_FAIL_COND_MSG(_p->read_only, "Dictionary is in read-only state.");
-	_p->variant_map.sort();
+	_p->variant_map.sort_custom<_DictionaryVariantSort>();
 }
 
 void Dictionary::merge(const Dictionary &p_dictionary, bool p_overwrite) {

+ 1 - 1
modules/mono/editor/bindings_generator.cpp

@@ -4009,7 +4009,7 @@ bool BindingsGenerator::_populate_object_type_interfaces() {
 
 		List<Pair<MethodInfo, uint32_t>> method_list_with_hashes;
 		ClassDB::get_method_list_with_compatibility(type_cname, &method_list_with_hashes, true);
-		method_list_with_hashes.sort_custom_inplace<SortMethodWithHashes>();
+		method_list_with_hashes.sort_custom<SortMethodWithHashes>();
 
 		List<MethodInterface> compat_methods;
 		for (const Pair<MethodInfo, uint32_t> &E : method_list_with_hashes) {

+ 28 - 0
tests/core/templates/test_hash_map.h

@@ -145,4 +145,32 @@ TEST_CASE("[HashMap] Const iteration") {
 		++idx;
 	}
 }
+
+TEST_CASE("[HashMap] Sort") {
+	HashMap<int, int> hashmap;
+	int shuffled_ints[]{ 6, 1, 9, 8, 3, 0, 4, 5, 7, 2 };
+
+	for (int i : shuffled_ints) {
+		hashmap[i] = i;
+	}
+	hashmap.sort();
+
+	int i = 0;
+	for (const KeyValue<int, int> &kv : hashmap) {
+		CHECK_EQ(kv.key, i);
+		i++;
+	}
+
+	struct ReverseSort {
+		bool operator()(const KeyValue<int, int> &p_a, const KeyValue<int, int> &p_b) {
+			return p_a.key > p_b.key;
+		}
+	};
+	hashmap.sort_custom<ReverseSort>();
+
+	for (const KeyValue<int, int> &kv : hashmap) {
+		i--;
+		CHECK_EQ(kv.key, i);
+	}
+}
 } // namespace TestHashMap

+ 57 - 12
tests/core/templates/test_list.h

@@ -310,19 +310,64 @@ TEST_CASE("[List] Move before") {
 	CHECK(list.front()->next()->get() == n[3]->get());
 }
 
-TEST_CASE("[List] Sort") {
-	List<String> list;
-	list.push_back("D");
-	list.push_back("B");
-	list.push_back("A");
-	list.push_back("C");
-
-	list.sort();
+template <typename T>
+static void compare_lists(const List<T> &p_result, const List<T> &p_expected) {
+	CHECK_EQ(p_result.size(), p_expected.size());
+	const typename List<T>::Element *result_it = p_result.front();
+	const typename List<T>::Element *expected_it = p_expected.front();
+	for (int i = 0; i < p_result.size(); i++) {
+		CHECK(result_it);
+		CHECK(expected_it);
+		CHECK_EQ(result_it->get(), expected_it->get());
+		result_it = result_it->next();
+		expected_it = expected_it->next();
+	}
+	CHECK(!result_it);
+	CHECK(!expected_it);
+
+	result_it = p_result.back();
+	expected_it = p_expected.back();
+	for (int i = 0; i < p_result.size(); i++) {
+		CHECK(result_it);
+		CHECK(expected_it);
+		CHECK_EQ(result_it->get(), expected_it->get());
+		result_it = result_it->prev();
+		expected_it = expected_it->prev();
+	}
+	CHECK(!result_it);
+	CHECK(!expected_it);
+}
 
-	CHECK(list.front()->get() == "A");
-	CHECK(list.front()->next()->get() == "B");
-	CHECK(list.back()->prev()->get() == "C");
-	CHECK(list.back()->get() == "D");
+TEST_CASE("[List] Sort") {
+	List<String> result{ "D", "B", "A", "C" };
+	result.sort();
+	List<String> expected{ "A", "B", "C", "D" };
+	compare_lists(result, expected);
+
+	List<int> empty_result{};
+	empty_result.sort();
+	List<int> empty_expected{};
+	compare_lists(empty_result, empty_expected);
+
+	List<int> one_result{ 1 };
+	one_result.sort();
+	List<int> one_expected{ 1 };
+	compare_lists(one_result, one_expected);
+
+	List<float> reversed_result{ 2.0, 1.5, 1.0 };
+	reversed_result.sort();
+	List<float> reversed_expected{ 1.0, 1.5, 2.0 };
+	compare_lists(reversed_result, reversed_expected);
+
+	List<int> already_sorted_result{ 1, 2, 3, 4, 5 };
+	already_sorted_result.sort();
+	List<int> already_sorted_expected{ 1, 2, 3, 4, 5 };
+	compare_lists(already_sorted_result, already_sorted_expected);
+
+	List<int> with_duplicates_result{ 1, 2, 3, 1, 2, 3 };
+	with_duplicates_result.sort();
+	List<int> with_duplicates_expected{ 1, 1, 2, 2, 3, 3 };
+	compare_lists(with_duplicates_result, with_duplicates_expected);
 }
 
 TEST_CASE("[List] Swap adjacent front and back") {

+ 39 - 10
core/templates/hash_map.cpp → tests/core/templates/test_self_list.h

@@ -1,5 +1,5 @@
 /**************************************************************************/
-/*  hash_map.cpp                                                          */
+/*  test_self_list.h                                                      */
 /**************************************************************************/
 /*                         This file is part of:                          */
 /*                             GODOT ENGINE                               */
@@ -28,16 +28,45 @@
 /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.                 */
 /**************************************************************************/
 
-#include "hash_map.h"
+#pragma once
 
-#include "core/variant/variant.h"
+#include "core/templates/self_list.h"
 
-bool _hashmap_variant_less_than(const Variant &p_left, const Variant &p_right) {
-	bool valid = false;
-	Variant res;
-	Variant::evaluate(Variant::OP_LESS, p_left, p_right, res, valid);
-	if (!valid) {
-		res = false;
+#include "tests/test_macros.h"
+
+namespace TestSelfList {
+
+TEST_CASE("[SelfList] Sort") {
+	const int SIZE = 5;
+	int numbers[SIZE]{ 3, 2, 5, 1, 4 };
+	SelfList<int> elements[SIZE]{
+		SelfList<int>(&numbers[0]),
+		SelfList<int>(&numbers[1]),
+		SelfList<int>(&numbers[2]),
+		SelfList<int>(&numbers[3]),
+		SelfList<int>(&numbers[4]),
+	};
+
+	SelfList<int>::List list;
+	for (int i = 0; i < SIZE; i++) {
+		list.add_last(&elements[i]);
+	}
+
+	SelfList<int> *it = list.first();
+	for (int i = 0; i < SIZE; i++) {
+		CHECK_EQ(numbers[i], *it->self());
+		it = it->next();
+	}
+
+	list.sort();
+	it = list.first();
+	for (int i = 1; i <= SIZE; i++) {
+		CHECK_EQ(i, *it->self());
+		it = it->next();
+	}
+
+	for (SelfList<int> &element : elements) {
+		element.remove_from_list();
 	}
-	return res;
 }
+} // namespace TestSelfList

+ 1 - 0
tests/test_main.cpp

@@ -106,6 +106,7 @@
 #include "tests/core/templates/test_oa_hash_map.h"
 #include "tests/core/templates/test_paged_array.h"
 #include "tests/core/templates/test_rid.h"
+#include "tests/core/templates/test_self_list.h"
 #include "tests/core/templates/test_span.h"
 #include "tests/core/templates/test_vector.h"
 #include "tests/core/templates/test_vset.h"