Pārlūkot izejas kodu

Merge pull request #2979 from rope-hmg/master

Binary search improvements
Jeroen van Rijn 1 gadu atpakaļ
vecāks
revīzija
e8e3501443
4 mainītis faili ar 174 papildinājumiem un 37 dzēšanām
  1. 83 34
      core/slice/slice.odin
  2. 23 3
      tests/core/Makefile
  3. 5 0
      tests/core/build.bat
  4. 63 0
      tests/core/slice/test_core_slice.odin

+ 83 - 34
core/slice/slice.odin

@@ -49,7 +49,7 @@ to_bytes :: proc "contextless" (s: []$T) -> []byte {
 	```
 	```
 	small_items := []byte{1, 0, 0, 0, 0, 0, 0, 0,
-	                      2, 0, 0, 0}
+						  2, 0, 0, 0}
 	large_items := slice.reinterpret([]i64, small_items)
 	assert(len(large_items) == 1) // only enough bytes to make 1 x i64; two would need at least 8 bytes.
 	```
@@ -78,7 +78,7 @@ swap_between :: proc(a, b: $T/[]$E) {
 	n := builtin.min(len(a), len(b))
 	if n >= 0 {
 		ptr_swap_overlapping(&a[0], &b[0], size_of(E)*n)
-	}	
+	}
 }
 
 
@@ -117,46 +117,95 @@ linear_search_proc :: proc(array: $A/[]$T, f: proc(T) -> bool) -> (index: int, f
 	return -1, false
 }
 
-@(require_results)
-binary_search :: proc(array: $A/[]$T, key: T) -> (index: int, found: bool)
-	where intrinsics.type_is_ordered(T) #no_bounds_check {
+/*
+	Binary search searches the given slice for the given element.
+	If the slice is not sorted, the returned index is unspecified and meaningless.
 
-	n := len(array)
-	switch n {
-	case 0:
-		return -1, false
-	case 1:
-		if array[0] == key {
-			return 0, true
-		}
-		return -1, false
-	}
+	If the value is found then the returned int is the index of the matching element.
+	If there are multiple matches, then any one of the matches could be returned.
 
-	lo, hi := 0, n-1
+	If the value is not found then the returned int is the index where a matching
+	element could be inserted while maintaining sorted order.
 
-	for array[hi] != array[lo] && key >= array[lo] && key <= array[hi] {
-		when intrinsics.type_is_ordered_numeric(T) {
-			// NOTE(bill): This is technically interpolation search
-			m := lo + int((key - array[lo]) * T(hi - lo) / (array[hi] - array[lo]))
-		} else {
-			m := lo + (hi - lo)/2
-		}
+	# Examples
+
+	Looks up a series of four elements. The first is found, with a
+	uniquely determined position; the second and third are not
+	found; the fourth could match any position in `[1, 4]`.
+
+	```
+	index: int
+	found: bool
+
+	s := []i32{0, 1, 1, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55}
+
+	index, found = slice.binary_search(s, 13)
+	assert(index == 9 && found == true)
+
+	index, found = slice.binary_search(s, 4)
+	assert(index == 7 && found == false)
+
+	index, found = slice.binary_search(s, 100)
+	assert(index == 13 && found == false)
+
+	index, found = slice.binary_search(s, 1)
+	assert(index >= 1 && index <= 4 && found == true)
+	```
+
+	For slices of more complex types see: binary_search_by
+*/
+@(require_results)
+binary_search :: proc(array: $A/[]$T, key: T) -> (index: int, found: bool)
+	where intrinsics.type_is_ordered(T) #no_bounds_check
+{
+	// I would like to use binary_search_by(array, key, cmp) here, but it doesn't like it:
+	// Cannot assign value 'cmp' of type 'proc($E, $E) -> Ordering' to 'proc(i32, i32) -> Ordering' in argument
+	return binary_search_by(array, key, proc(key: T, element: T) -> Ordering {
 		switch {
-		case array[m] < key:
-			lo = m + 1
-		case key < array[m]:
-			hi = m - 1
-		case:
-			return m, true
+			case element < key: return .Less
+			case element > key: return .Greater
+			case:               return .Equal
+		}
+	})
+}
+
+@(require_results)
+binary_search_by :: proc(array: $A/[]$T, key: T, f: proc(T, T) -> Ordering) -> (index: int, found: bool)
+	where intrinsics.type_is_ordered(T) #no_bounds_check
+{
+	// INVARIANTS:
+	// - 0 <= left <= (left + size = right) <= len(array)
+	// - f returns .Less    for everything in array[:left]
+	// - f returns .Greater for everything in array[right:]
+	size  := len(array)
+	left  := 0
+	right := size
+
+	for left < right {
+		mid := left + size / 2
+
+		// Steps to verify this is in-bounds:
+		// 1. We note that `size` is strictly positive due to the loop condition
+		// 2. Therefore `size/2 < size`
+		// 3. Adding `left` to both sides yields `(left + size/2) < (left + size)`
+		// 4. We know from the invariant that `left + size <= len(array)`
+		// 5. Therefore `left + size/2 < self.len()`
+		cmp := f(key, array[mid])
+
+		left  = mid + 1 if cmp == .Less    else left
+		right = mid     if cmp == .Greater else right
+
+		switch cmp {
+			case .Equal:   return mid, true
+			case .Less:    left  = mid + 1
+			case .Greater: right = mid
 		}
-	}
 
-	if key == array[lo] {
-		return lo, true
+		size = right - left
 	}
-	return -1, false
-}
 
+	return left, false
+}
 
 @(require_results)
 equal :: proc(a, b: $T/[]$E) -> bool where intrinsics.type_is_comparable(E) {

+ 23 - 3
tests/core/Makefile

@@ -1,9 +1,26 @@
 ODIN=../../odin
 PYTHON=$(shell which python3)
 
-all: download_test_assets image_test compress_test strings_test hash_test crypto_test noise_test encoding_test \
-	 math_test linalg_glsl_math_test filepath_test reflect_test os_exit_test i18n_test match_test c_libc_test net_test \
-	 fmt_test thread_test
+all: c_libc_test \
+	 compress_test \
+	 crypto_test \
+	 download_test_assets \
+	 encoding_test \
+	 filepath_test \
+	 fmt_test \
+	 hash_test \
+	 i18n_test \
+	 image_test \
+	 linalg_glsl_math_test \
+	 match_test \
+	 math_test \
+	 net_test \
+	 noise_test \
+	 os_exit_test \
+	 reflect_test \
+	 slice_test \
+	 strings_test \
+	 thread_test
 
 download_test_assets:
 	$(PYTHON) download_assets.py
@@ -44,6 +61,9 @@ filepath_test:
 reflect_test:
 	$(ODIN) run reflect/test_core_reflect.odin -file -collection:tests=.. -out:test_core_reflect
 
+slice_test:
+	$(ODIN) run slice/test_core_slice.odin -file -out:test_core_slice
+
 os_exit_test:
 	$(ODIN) run os/test_core_os_exit.odin -file -out:test_core_os_exit && exit 1 || exit 0
 

+ 5 - 0
tests/core/build.bat

@@ -66,6 +66,11 @@ echo Running core:reflect tests
 echo ---
 %PATH_TO_ODIN% run reflect %COMMON% %COLLECTION% -out:test_core_reflect.exe || exit /b
 
+echo ---
+echo Running core:slice tests
+echo ---
+%PATH_TO_ODIN% run slice %COMMON% -out:test_core_slice.exe || exit /b
+
 echo ---
 echo Running core:text/i18n tests
 echo ---

+ 63 - 0
tests/core/slice/test_core_slice.odin

@@ -1,6 +1,7 @@
 package test_core_slice
 
 import "core:slice"
+import "core:strings"
 import "core:testing"
 import "core:fmt"
 import "core:os"
@@ -30,6 +31,7 @@ when ODIN_TEST {
 main :: proc() {
 	t := testing.T{}
 	test_sort_with_indices(&t)
+	test_binary_search(&t)
 
 	fmt.printf("%v/%v tests successful.\n", TEST_count - TEST_fail, TEST_count)
 	if TEST_fail > 0 {
@@ -180,3 +182,64 @@ test_sort_by_indices :: proc(t: ^testing.T) {
 		}
 	}
 }
+
+@test
+test_binary_search :: proc(t: ^testing.T) {
+	builder := strings.Builder{}
+	defer strings.builder_destroy(&builder)
+
+	test_search :: proc(t: ^testing.T, b: ^strings.Builder, s: []i32, v: i32) -> (int, bool) {
+		log(t, fmt.sbprintf(b, "Searching for %v in %v", v, s))
+		strings.builder_reset(b)
+		index, found := slice.binary_search(s, v)
+		log(t, fmt.sbprintf(b, "index: %v, found: %v", index, found))
+		strings.builder_reset(b	)
+
+		return index, found
+	}
+
+	index: int
+	found: bool
+
+	s := []i32{0, 1, 1, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55}
+
+	index, found = test_search(t, &builder, s, 13)
+	expect(t, index == 9, "Expected index to be 9.")
+	expect(t, found == true, "Expected found to be true.")
+
+	index, found = test_search(t, &builder, s, 4)
+	expect(t, index == 7, "Expected index to be 7.")
+	expect(t, found == false, "Expected found to be false.")
+
+	index, found = test_search(t, &builder, s, 100)
+	expect(t, index == 13, "Expected index to be 13.")
+	expect(t, found == false, "Expected found to be false.")
+
+	index, found = test_search(t, &builder, s, 1)
+	expect(t, index >= 1 && index <= 4, "Expected index to be 1, 2, 3, or 4.")
+	expect(t, found == true, "Expected found to be true.")
+
+	index, found = test_search(t, &builder, s, -1)
+	expect(t, index == 0, "Expected index to be 0.")
+	expect(t, found == false, "Expected found to be false.")
+
+	a := []i32{}
+
+	index, found = test_search(t, &builder, a, 13)
+	expect(t, index == 0, "Expected index to be 0.")
+	expect(t, found == false, "Expected found to be false.")
+
+	b := []i32{1}
+
+	index, found = test_search(t, &builder, b, 13)
+	expect(t, index == 1, "Expected index to be 1.")
+	expect(t, found == false, "Expected found to be false.")
+
+	index, found = test_search(t, &builder, b, 1)
+	expect(t, index == 0, "Expected index to be 0.")
+	expect(t, found == true, "Expected found to be true.")
+
+	index, found = test_search(t, &builder, b, 0)
+	expect(t, index == 0, "Expected index to be 0.")
+	expect(t, found == false, "Expected found to be false.")
+}