Browse Source

Fix SortArray crashing with bad comparison functions

poke1024 7 years ago
parent
commit
9d27bd3c3b
2 changed files with 31 additions and 5 deletions
  1. 1 1
      core/array.cpp
  2. 30 4
      core/sort.h

+ 1 - 1
core/array.cpp

@@ -259,7 +259,7 @@ Array &Array::sort_custom(Object *p_obj, const StringName &p_function) {
 
 	ERR_FAIL_NULL_V(p_obj, *this);
 
-	SortArray<Variant, _ArrayVariantSortCustom> avs;
+	SortArray<Variant, _ArrayVariantSortCustom, true> avs;
 	avs.compare.obj = p_obj;
 	avs.compare.func = p_function;
 	avs.sort(_p->array.ptrw(), _p->array.size());

+ 30 - 4
core/sort.h

@@ -36,13 +36,25 @@
 	@author ,,, <red@lunatea>
 */
 
+#define ERR_BAD_COMPARE(cond)                                         \
+	if (unlikely(cond)) {                                             \
+		ERR_PRINT("bad comparison function; sorting will be broken"); \
+		break;                                                        \
+	}
+
 template <class T>
 struct _DefaultComparator {
 
-	inline bool operator()(const T &a, const T &b) const { return (a < b); }
+	_FORCE_INLINE_ bool operator()(const T &a, const T &b) const { return (a < b); }
 };
 
-template <class T, class Comparator = _DefaultComparator<T> >
+#ifdef DEBUG_ENABLED
+#define SORT_ARRAY_VALIDATE_ENABLED true
+#else
+#define SORT_ARRAY_VALIDATE_ENABLED false
+#endif
+
+template <class T, class Comparator = _DefaultComparator<T>, bool Validate = SORT_ARRAY_VALIDATE_ENABLED>
 class SortArray {
 
 	enum {
@@ -164,12 +176,23 @@ public:
 
 	inline int partitioner(int p_first, int p_last, T p_pivot, T *p_array) const {
 
+		const int unmodified_first = p_first;
+		const int unmodified_last = p_last;
+
 		while (true) {
-			while (compare(p_array[p_first], p_pivot))
+			while (compare(p_array[p_first], p_pivot)) {
+				if (Validate) {
+					ERR_BAD_COMPARE(p_first == unmodified_last - 1)
+				}
 				p_first++;
+			}
 			p_last--;
-			while (compare(p_pivot, p_array[p_last]))
+			while (compare(p_pivot, p_array[p_last])) {
+				if (Validate) {
+					ERR_BAD_COMPARE(p_last == unmodified_first)
+				}
 				p_last--;
+			}
 
 			if (!(p_first < p_last))
 				return p_first;
@@ -238,6 +261,9 @@ public:
 
 		int next = p_last - 1;
 		while (compare(p_value, p_array[next])) {
+			if (Validate) {
+				ERR_BAD_COMPARE(next == 0)
+			}
 			p_array[p_last] = p_array[next];
 			p_last = next;
 			next--;