Browse Source

Unify logic for `slice.sort*` related procedures

gingerBill 3 years ago
parent
commit
1a9ec776cb
2 changed files with 182 additions and 488 deletions
  1. 5 488
      core/slice/sort.odin
  2. 177 0
      core/slice/sort_private.odin

+ 5 - 488
core/slice/sort.odin

@@ -1,10 +1,5 @@
 package slice
 package slice
 
 
-import "core:intrinsics"
-_ :: intrinsics
-
-ORD :: intrinsics.type_is_ordered
-
 Ordering :: enum {
 Ordering :: enum {
 	Less    = -1,
 	Less    = -1,
 	Equal   =  0,
 	Equal   =  0,
@@ -38,7 +33,7 @@ cmp_proc :: proc($E: typeid) -> (proc(E, E) -> Ordering) where ORD(E) {
 sort :: proc(data: $T/[]$E) where ORD(E) {
 sort :: proc(data: $T/[]$E) where ORD(E) {
 	when size_of(E) != 0 {
 	when size_of(E) != 0 {
 		if n := len(data); n > 1 {
 		if n := len(data); n > 1 {
-			_quick_sort(data, 0, n, _max_depth(n))
+			_quick_sort_general(data, 0, n, _max_depth(n), struct{}{}, .Ordered)
 		}
 		}
 	}
 	}
 }
 }
@@ -48,7 +43,7 @@ sort :: proc(data: $T/[]$E) where ORD(E) {
 sort_by :: proc(data: $T/[]$E, less: proc(i, j: E) -> bool) {
 sort_by :: proc(data: $T/[]$E, less: proc(i, j: E) -> bool) {
 	when size_of(E) != 0 {
 	when size_of(E) != 0 {
 		if n := len(data); n > 1 {
 		if n := len(data); n > 1 {
-			_quick_sort_less(data, 0, n, _max_depth(n), less)
+			_quick_sort_general(data, 0, n, _max_depth(n), less, .Less)
 		}
 		}
 	}
 	}
 }
 }
@@ -56,7 +51,7 @@ sort_by :: proc(data: $T/[]$E, less: proc(i, j: E) -> bool) {
 sort_by_cmp :: proc(data: $T/[]$E, cmp: proc(i, j: E) -> Ordering) {
 sort_by_cmp :: proc(data: $T/[]$E, cmp: proc(i, j: E) -> Ordering) {
 	when size_of(E) != 0 {
 	when size_of(E) != 0 {
 		if n := len(data); n > 1 {
 		if n := len(data); n > 1 {
-			_quick_sort_cmp(data, 0, n, _max_depth(n), cmp)
+			_quick_sort_general(data, 0, n, _max_depth(n), cmp, .Cmp)
 		}
 		}
 	}
 	}
 }
 }
@@ -79,6 +74,7 @@ is_sorted_by :: proc(array: $T/[]$E, less: proc(i, j: E) -> bool) -> bool {
 	return true
 	return true
 }
 }
 
 
+is_sorted_by_cmp :: is_sorted_cmp
 is_sorted_cmp :: proc(array: $T/[]$E, cmp: proc(i, j: E) -> Ordering) -> bool {
 is_sorted_cmp :: proc(array: $T/[]$E, cmp: proc(i, j: E) -> Ordering) -> bool {
 	for i := len(array)-1; i > 0; i -= 1 {
 	for i := len(array)-1; i > 0; i -= 1 {
 		if cmp(array[i], array[i-1]) == .Equal {
 		if cmp(array[i], array[i-1]) == .Equal {
@@ -140,489 +136,10 @@ is_sorted_by_key :: proc(array: $T/[]$E, key: proc(E) -> $K) -> bool where ORD(K
 	return true
 	return true
 }
 }
 
 
-
-
 @(private)
 @(private)
-_max_depth :: proc(n: int) -> int { // 2*ceil(log2(n+1))
-	depth: int
+_max_depth :: proc(n: int) -> (depth: int) { // 2*ceil(log2(n+1))
 	for i := n; i > 0; i >>= 1 {
 	for i := n; i > 0; i >>= 1 {
 		depth += 1
 		depth += 1
 	}
 	}
 	return depth * 2
 	return depth * 2
 }
 }
-
-@(private)
-_quick_sort :: proc(data: $T/[]$E, a, b, max_depth: int) where ORD(E) #no_bounds_check {
-	median3 :: proc(data: T, m1, m0, m2: int) #no_bounds_check {
-		if data[m1] < data[m0] {
-			swap(data, m1, m0)
-		}
-		if data[m2] < data[m1] {
-			swap(data, m2, m1)
-			if data[m1] < data[m0] {
-				swap(data, m1, m0)
-			}
-		}
-	}
-
-	do_pivot :: proc(data: T, lo, hi: int) -> (midlo, midhi: int) #no_bounds_check {
-		m := int(uint(lo+hi)>>1)
-		if hi-lo > 40 {
-			s := (hi-lo)/8
-			median3(data, lo, lo+s, lo+s*2)
-			median3(data, m, m-s, m+s)
-			median3(data, hi-1, hi-1-s, hi-1-s*2)
-		}
-		median3(data, lo, m, hi-1)
-
-
-		pivot := lo
-		a, c := lo+1, hi-1
-
-		for ; a < c && data[a] < data[pivot]; a += 1 {
-		}
-		b := a
-
-		for {
-			for ; b < c && !(data[pivot] < data[b]); b += 1 { // data[b] <= pivot
-			}
-			for ; b < c && data[pivot] < data[c-1]; c -=1 { // data[c-1] > pivot
-			}
-			if b >= c {
-				break
-			}
-
-			swap(data, b, c-1)
-			b += 1
-			c -= 1
-		}
-
-		protect := hi-c < 5
-		if !protect && hi-c < (hi-lo)/4 {
-			dups := 0
-			if !(data[pivot] < data[hi-1]) {
-				swap(data, c, hi-1)
-				c += 1
-				dups += 1
-			}
-			if !(data[b-1] < data[pivot]) {
-				b -= 1
-				dups += 1
-			}
-
-			if !(data[m] < data[pivot]) {
-				swap(data, m, b-1)
-				b -= 1
-				dups += 1
-			}
-			protect = dups > 1
-		}
-		if protect {
-			for {
-				for ; a < b && !(data[b-1] < data[pivot]); b -= 1 {
-				}
-				for ; a < b && data[a] < data[pivot]; a += 1 {
-				}
-				if a >= b {
-					break
-				}
-				swap(data, a, b-1)
-				a += 1
-				b -= 1
-			}
-		}
-		swap(data, pivot, b-1)
-		return b-1, c
-	}
-
-
-	a, b, max_depth := a, b, max_depth
-
-	if b-a > 12 { // only use shell sort for lengths <= 12
-		if max_depth == 0 {
-			_heap_sort(data, a, b)
-			return
-		}
-		max_depth -= 1
-		mlo, mhi := do_pivot(data, a, b)
-		if mlo-a < b-mhi {
-			_quick_sort(data, a, mlo, max_depth)
-			a = mhi
-		} else {
-			_quick_sort(data, mhi, b, max_depth)
-			b = mlo
-		}
-	}
-	if b-a > 1 {
-		// Shell short with gap 6
-		for i in a+6..<b {
-			if data[i] < data[i-6] {
-				swap(data, i, i-6)
-			}
-		}
-		_insertion_sort(data, a, b)
-	}
-}
-
-@(private)
-_insertion_sort :: proc(data: $T/[]$E, a, b: int) where ORD(E) #no_bounds_check {
-	for i in a+1..<b {
-		for j := i; j > a && data[j] < data[j-1]; j -= 1 {
-			swap(data, j, j-1)
-		}
-	}
-}
-
-@(private)
-_heap_sort :: proc(data: $T/[]$E, a, b: int) where ORD(E) #no_bounds_check {
-	sift_down :: proc(data: T, lo, hi, first: int) #no_bounds_check {
-		root := lo
-		for {
-			child := 2*root + 1
-			if child >= hi {
-				break
-			}
-			if child+1 < hi && data[first+child] < data[first+child+1] {
-				child += 1
-			}
-			if !(data[first+root] < data[first+child]) {
-				return
-			}
-			swap(data, first+root, first+child)
-			root = child
-		}
-	}
-
-
-	first, lo, hi := a, 0, b-a
-
-	for i := (hi-1)/2; i >= 0; i -= 1 {
-		sift_down(data, i, hi, first)
-	}
-
-	for i := hi-1; i >= 0; i -= 1 {
-		swap(data, first, first+i)
-		sift_down(data, lo, i, first)
-	}
-}
-
-
-
-
-
-
-@(private)
-_quick_sort_less :: proc(data: $T/[]$E, a, b, max_depth: int, less: proc(i, j: E) -> bool) #no_bounds_check {
-	median3 :: proc(data: T, m1, m0, m2: int, less: proc(i, j: E) -> bool) #no_bounds_check {
-		if less(data[m1], data[m0]) {
-			swap(data, m1, m0)
-		}
-		if less(data[m2], data[m1]) {
-			swap(data, m2, m1)
-			if less(data[m1], data[m0]) {
-				swap(data, m1, m0)
-			}
-		}
-	}
-
-	do_pivot :: proc(data: T, lo, hi: int, less: proc(i, j: E) -> bool) -> (midlo, midhi: int) #no_bounds_check {
-		m := int(uint(lo+hi)>>1)
-		if hi-lo > 40 {
-			s := (hi-lo)/8
-			median3(data, lo, lo+s, lo+s*2, less)
-			median3(data, m, m-s, m+s, less)
-			median3(data, hi-1, hi-1-s, hi-1-s*2, less)
-		}
-		median3(data, lo, m, hi-1, less)
-
-		pivot := lo
-		a, c := lo+1, hi-1
-
-		for ; a < c && less(data[a], data[pivot]); a += 1 {
-		}
-		b := a
-
-		for {
-			for ; b < c && !less(data[pivot], data[b]); b += 1 { // data[b] <= pivot
-			}
-			for ; b < c && less(data[pivot], data[c-1]); c -=1 { // data[c-1] > pivot
-			}
-			if b >= c {
-				break
-			}
-
-			swap(data, b, c-1)
-			b += 1
-			c -= 1
-		}
-
-		protect := hi-c < 5
-		if !protect && hi-c < (hi-lo)/4 {
-			dups := 0
-			if !less(data[pivot], data[hi-1]) {
-				swap(data, c, hi-1)
-				c += 1
-				dups += 1
-			}
-			if !less(data[b-1], data[pivot]) {
-				b -= 1
-				dups += 1
-			}
-
-			if !less(data[m], data[pivot]) {
-				swap(data, m, b-1)
-				b -= 1
-				dups += 1
-			}
-			protect = dups > 1
-		}
-		if protect {
-			for {
-				for ; a < b && !less(data[b-1], data[pivot]); b -= 1 {
-				}
-				for ; a < b && less(data[a], data[pivot]); a += 1 {
-				}
-				if a >= b {
-					break
-				}
-				swap(data, a, b-1)
-				a += 1
-				b -= 1
-			}
-		}
-		swap(data, pivot, b-1)
-		return b-1, c
-	}
-
-
-	a, b, max_depth := a, b, max_depth
-
-	if b-a > 12 { // only use shell sort for lengths <= 12
-		if max_depth == 0 {
-			_heap_sort_less(data, a, b, less)
-			return
-		}
-		max_depth -= 1
-		mlo, mhi := do_pivot(data, a, b, less)
-		if mlo-a < b-mhi {
-			_quick_sort_less(data, a, mlo, max_depth, less)
-			a = mhi
-		} else {
-			_quick_sort_less(data, mhi, b, max_depth, less)
-			b = mlo
-		}
-	}
-	if b-a > 1 {
-		// Shell short with gap 6
-		for i in a+6..<b {
-			if less(data[i], data[i-6]) {
-				swap(data, i, i-6)
-			}
-		}
-		_insertion_sort_less(data, a, b, less)
-	}
-}
-
-@(private)
-_insertion_sort_less :: proc(data: $T/[]$E, a, b: int, less: proc(i, j: E) -> bool) #no_bounds_check {
-	for i in a+1..<b {
-		for j := i; j > a && less(data[j], data[j-1]); j -= 1 {
-			swap(data, j, j-1)
-		}
-	}
-}
-
-@(private)
-_heap_sort_less :: proc(data: $T/[]$E, a, b: int, less: proc(i, j: E) -> bool) #no_bounds_check {
-	sift_down :: proc(data: T, lo, hi, first: int, less: proc(i, j: E) -> bool) #no_bounds_check {
-		root := lo
-		for {
-			child := 2*root + 1
-			if child >= hi {
-				break
-			}
-			if child+1 < hi && less(data[first+child], data[first+child+1]) {
-				child += 1
-			}
-			if !less(data[first+root], data[first+child]) {
-				return
-			}
-			swap(data, first+root, first+child)
-			root = child
-		}
-	}
-
-
-	first, lo, hi := a, 0, b-a
-
-	for i := (hi-1)/2; i >= 0; i -= 1 {
-		sift_down(data, i, hi, first, less)
-	}
-
-	for i := hi-1; i >= 0; i -= 1 {
-		swap(data, first, first+i)
-		sift_down(data, lo, i, first, less)
-	}
-}
-
-
-
-
-
-
-@(private)
-_quick_sort_cmp :: proc(data: $T/[]$E, a, b, max_depth: int, cmp: proc(i, j: E) -> Ordering) #no_bounds_check {
-	median3 :: proc(data: T, m1, m0, m2: int, cmp: proc(i, j: E) -> Ordering) #no_bounds_check {
-		if cmp(data[m1], data[m0]) == .Less {
-			swap(data, m1, m0)
-		}
-		if cmp(data[m2], data[m1]) == .Less {
-			swap(data, m2, m1)
-			if cmp(data[m1], data[m0]) == .Less {
-				swap(data, m1, m0)
-			}
-		}
-	}
-
-	do_pivot :: proc(data: T, lo, hi: int, cmp: proc(i, j: E) -> Ordering) -> (midlo, midhi: int) #no_bounds_check {
-		m := int(uint(lo+hi)>>1)
-		if hi-lo > 40 {
-			s := (hi-lo)/8
-			median3(data, lo, lo+s, lo+s*2, cmp)
-			median3(data, m, m-s, m+s, cmp)
-			median3(data, hi-1, hi-1-s, hi-1-s*2, cmp)
-		}
-		median3(data, lo, m, hi-1, cmp)
-
-		pivot := lo
-		a, c := lo+1, hi-1
-
-		for ; a < c && cmp(data[a], data[pivot]) == .Less; a += 1 {
-		}
-		b := a
-
-		for {
-			for ; b < c && cmp(data[pivot], data[b]) >= .Equal; b += 1 { // data[b] <= pivot
-			}
-			for ; b < c && cmp(data[pivot], data[c-1]) == .Less; c -=1 { // data[c-1] > pivot
-			}
-			if b >= c {
-				break
-			}
-
-			swap(data, b, c-1)
-			b += 1
-			c -= 1
-		}
-
-		protect := hi-c < 5
-		if !protect && hi-c < (hi-lo)/4 {
-			dups := 0
-			if cmp(data[pivot], data[hi-1]) != .Less {
-				swap(data, c, hi-1)
-				c += 1
-				dups += 1
-			}
-			if cmp(data[b-1], data[pivot]) != .Less {
-				b -= 1
-				dups += 1
-			}
-
-			if cmp(data[m], data[pivot]) != .Less {
-				swap(data, m, b-1)
-				b -= 1
-				dups += 1
-			}
-			protect = dups > 1
-		}
-		if protect {
-			for {
-				for ; a < b && cmp(data[b-1], data[pivot]) >= .Equal; b -= 1 {
-				}
-				for ; a < b && cmp(data[a], data[pivot]) == .Less; a += 1 {
-				}
-				if a >= b {
-					break
-				}
-				swap(data, a, b-1)
-				a += 1
-				b -= 1
-			}
-		}
-		swap(data, pivot, b-1)
-		return b-1, c
-	}
-
-
-	a, b, max_depth := a, b, max_depth
-
-	if b-a > 12 { // only use shell sort for lengths <= 12
-		if max_depth == 0 {
-			_heap_sort_cmp(data, a, b, cmp)
-			return
-		}
-		max_depth -= 1
-		mlo, mhi := do_pivot(data, a, b, cmp)
-		if mlo-a < b-mhi {
-			_quick_sort_cmp(data, a, mlo, max_depth, cmp)
-			a = mhi
-		} else {
-			_quick_sort_cmp(data, mhi, b, max_depth, cmp)
-			b = mlo
-		}
-	}
-	if b-a > 1 {
-		// Shell short with gap 6
-		for i in a+6..<b {
-			if cmp(data[i], data[i-6]) == .Less {
-				swap(data, i, i-6)
-			}
-		}
-		_insertion_sort_cmp(data, a, b, cmp)
-	}
-}
-
-@(private)
-_insertion_sort_cmp :: proc(data: $T/[]$E, a, b: int, cmp: proc(i, j: E) -> Ordering) #no_bounds_check {
-	for i in a+1..<b {
-		for j := i; j > a && cmp(data[j], data[j-1]) == .Less; j -= 1 {
-			swap(data, j, j-1)
-		}
-	}
-}
-
-@(private)
-_heap_sort_cmp :: proc(data: $T/[]$E, a, b: int, cmp: proc(i, j: E) -> Ordering) #no_bounds_check {
-	sift_down :: proc(data: T, lo, hi, first: int, cmp: proc(i, j: E) -> Ordering) #no_bounds_check {
-		root := lo
-		for {
-			child := 2*root + 1
-			if child >= hi {
-				break
-			}
-			if child+1 < hi && cmp(data[first+child], data[first+child+1]) == .Less {
-				child += 1
-			}
-			if cmp(data[first+root], data[first+child]) >= .Equal {
-				return
-			}
-			swap(data, first+root, first+child)
-			root = child
-		}
-	}
-
-
-	first, lo, hi := a, 0, b-a
-
-	for i := (hi-1)/2; i >= 0; i -= 1 {
-		sift_down(data, i, hi, first, cmp)
-	}
-
-	for i := hi-1; i >= 0; i -= 1 {
-		swap(data, first, first+i)
-		sift_down(data, lo, i, first, cmp)
-	}
-}
-
-
-

+ 177 - 0
core/slice/sort_private.odin

@@ -0,0 +1,177 @@
+//+private
+package slice
+
+import "core:intrinsics"
+_ :: intrinsics
+
+ORD :: intrinsics.type_is_ordered
+
+Sort_Kind :: enum {
+	Ordered,
+	Less,
+	Cmp,
+}
+
+_quick_sort_general :: proc(data: $T/[]$E, a, b, max_depth: int, call: $P, $KIND: Sort_Kind) where (ORD(E) && KIND == .Ordered) || (KIND != .Ordered) #no_bounds_check {
+	less :: #force_inline proc(a, b: $E, call: $P) -> bool {
+		when KIND == .Ordered {
+			return a < b
+		} else when KIND == .Less {
+			return call(a, b)
+		} else when KIND == .Cmp {
+			return call(a, b) == .Less
+		} else {
+			#panic("unhandled Sort_Kind")
+		}
+	}
+
+	insertion_sort :: proc(data: $T/[]$E, a, b: int, call: P) #no_bounds_check {
+		for i in a+1..<b {
+			for j := i; j > a && less(data[j], data[j-1], call); j -= 1 {
+				swap(data, j, j-1)
+			}
+		}
+	}
+
+	heap_sort :: proc(data: $T/[]$E, a, b: int, call: P) #no_bounds_check {
+		sift_down :: proc(data: T, lo, hi, first: int, call: P) #no_bounds_check {
+			root := lo
+			for {
+				child := 2*root + 1
+				if child >= hi {
+					break
+				}
+				if child+1 < hi && less(data[first+child], data[first+child+1], call) {
+					child += 1
+				}
+				if !less(data[first+root], data[first+child], call) {
+					return
+				}
+				swap(data, first+root, first+child)
+				root = child
+			}
+		}
+
+
+		first, lo, hi := a, 0, b-a
+
+		for i := (hi-1)/2; i >= 0; i -= 1 {
+			sift_down(data, i, hi, first, call)
+		}
+
+		for i := hi-1; i >= 0; i -= 1 {
+			swap(data, first, first+i)
+			sift_down(data, lo, i, first, call)
+		}
+	}
+
+	median3 :: proc(data: T, m1, m0, m2: int, call: P) #no_bounds_check {
+		if less(data[m1], data[m0], call) {
+			swap(data, m1, m0)
+		}
+		if less(data[m2], data[m1], call) {
+			swap(data, m2, m1)
+			if less(data[m1], data[m0], call) {
+				swap(data, m1, m0)
+			}
+		}
+	}
+
+	do_pivot :: proc(data: T, lo, hi: int, call: P) -> (midlo, midhi: int) #no_bounds_check {
+		m := int(uint(lo+hi)>>1)
+		if hi-lo > 40 {
+			s := (hi-lo)/8
+			median3(data, lo, lo+s, lo+s*2, call)
+			median3(data, m, m-s, m+s, call)
+			median3(data, hi-1, hi-1-s, hi-1-s*2, call)
+		}
+		median3(data, lo, m, hi-1, call)
+
+		pivot := lo
+		a, c := lo+1, hi-1
+
+
+		for ; a < c && less(data[a], data[pivot], call); a += 1 {
+		}
+		b := a
+
+		for {
+			for ; b < c && !less(data[pivot], data[b], call); b += 1 { // data[b] <= pivot
+			}
+			for ; b < c && less(data[pivot], data[c-1], call); c -=1 { // data[c-1] > pivot
+			}
+			if b >= c {
+				break
+			}
+
+			swap(data, b, c-1)
+			b += 1
+			c -= 1
+		}
+
+		protect := hi-c < 5
+		if !protect && hi-c < (hi-lo)/4 {
+			dups := 0
+			if !less(data[pivot], data[hi-1], call) {
+				swap(data, c, hi-1)
+				c += 1
+				dups += 1
+			}
+			if !less(data[b-1], data[pivot], call) {
+				b -= 1
+				dups += 1
+			}
+
+			if !less(data[m], data[pivot], call) {
+				swap(data, m, b-1)
+				b -= 1
+				dups += 1
+			}
+			protect = dups > 1
+		}
+		if protect {
+			for {
+				for ; a < b && !less(data[b-1], data[pivot], call); b -= 1 {
+				}
+				for ; a < b && less(data[a], data[pivot], call); a += 1 {
+				}
+				if a >= b {
+					break
+				}
+				swap(data, a, b-1)
+				a += 1
+				b -= 1
+			}
+		}
+		swap(data, pivot, b-1)
+		return b-1, c
+	}
+
+
+	a, b, max_depth := a, b, max_depth
+
+	if b-a > 12 { // only use shell sort for lengths <= 12
+		if max_depth == 0 {
+			heap_sort(data, a, b, call)
+			return
+		}
+		max_depth -= 1
+		mlo, mhi := do_pivot(data, a, b, call)
+		if mlo-a < b-mhi {
+			_quick_sort_general(data, a, mlo, max_depth, call, KIND)
+			a = mhi
+		} else {
+			_quick_sort_general(data, mhi, b, max_depth, call, KIND)
+			b = mlo
+		}
+	}
+	if b-a > 1 {
+		// Shell short with gap 6
+		for i in a+6..<b {
+			if less(data[i], data[i-6], call) {
+				swap(data, i, i-6)
+			}
+		}
+		insertion_sort(data, a, b, call)
+	}
+}