Browse Source

Merge pull request #5249 from Kelimion/simd_prefix

Vectorize `strings.prefix_length`.
Jeroen van Rijn 3 months ago
parent
commit
c80f3db3a6

+ 68 - 0
base/runtime/internal.odin

@@ -405,6 +405,74 @@ memory_compare_zero :: proc "contextless" (a: rawptr, n: int) -> int #no_bounds_
 	return 0
 }
 
+memory_prefix_length :: proc "contextless" (x, y: rawptr, n: int) -> (idx: int) #no_bounds_check {
+	switch {
+	case x == y:   return n
+	case x == nil: return 0
+	case y == nil: return 0
+	}
+	a, b := cast([^]byte)x, cast([^]byte)y
+
+	n := uint(n)
+	i := uint(0)
+	m := uint(0)
+
+	when HAS_HARDWARE_SIMD {
+		when ODIN_ARCH == .amd64 && intrinsics.has_target_feature("avx2") {
+			m = n / 32 * 32
+			for /**/; i < m; i += 32 {
+				load_a := intrinsics.unaligned_load(cast(^#simd[32]u8)&a[i])
+				load_b := intrinsics.unaligned_load(cast(^#simd[32]u8)&b[i])
+				comparison := intrinsics.simd_lanes_ne(load_a, load_b)
+				if intrinsics.simd_reduce_or(comparison) != 0 {
+					sentinel: #simd[32]u8 = u8(0xFF)
+					indices := intrinsics.simd_indices(#simd[32]u8)
+					index_select := intrinsics.simd_select(comparison, indices, sentinel)
+					index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
+					return i + index_reduce
+				}
+			}
+		}
+	}
+
+	m = (n-i) / 16 * 16
+	for /**/; i < m; i += 16 {
+		load_a := intrinsics.unaligned_load(cast(^#simd[16]u8)&a[i])
+		load_b := intrinsics.unaligned_load(cast(^#simd[16]u8)&b[i])
+		comparison := intrinsics.simd_lanes_ne(load_a, load_b)
+		if intrinsics.simd_reduce_or(comparison) != 0 {
+			sentinel: #simd[16]u8 = u8(0xFF)
+			indices := intrinsics.simd_indices(#simd[16]u8)
+			index_select := intrinsics.simd_select(comparison, indices, sentinel)
+			index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
+			return int(i + index_reduce)
+		}
+	}
+
+	// 64-bit SIMD is faster than using a `uintptr` to detect a difference then
+	// re-iterating with the byte-by-byte loop, at least on AMD64.
+	m = (n-i) / 8 * 8
+	for /**/; i < m; i += 8 {
+		load_a := intrinsics.unaligned_load(cast(^#simd[8]u8)&a[i])
+		load_b := intrinsics.unaligned_load(cast(^#simd[8]u8)&b[i])
+		comparison := intrinsics.simd_lanes_ne(load_a, load_b)
+		if intrinsics.simd_reduce_or(comparison) != 0 {
+			sentinel: #simd[8]u8 = u8(0xFF)
+			indices := intrinsics.simd_indices(#simd[8]u8)
+			index_select := intrinsics.simd_select(comparison, indices, sentinel)
+			index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
+			return int(i + index_reduce)
+		}
+	}
+
+	for /**/; i < n; i += 1 {
+		if a[i] ~ b[i] != 0 {
+			return int(i)
+		}
+	}
+	return int(n)
+}
+
 string_eq :: proc "contextless" (lhs, rhs: string) -> bool {
 	x := transmute(Raw_String)lhs
 	y := transmute(Raw_String)rhs

+ 46 - 17
core/strings/strings.odin

@@ -2,6 +2,7 @@
 package strings
 
 import "base:intrinsics"
+import "base:runtime"
 import "core:bytes"
 import "core:io"
 import "core:mem"
@@ -458,6 +459,7 @@ equal_fold :: proc(u, v: string) -> (res: bool) {
 
 	return s == t
 }
+
 /*
 Returns the prefix length common between strings `a` and `b`
 
@@ -488,30 +490,57 @@ Output:
 	0
 
 */
-prefix_length :: proc(a, b: string) -> (n: int) {
-	_len := min(len(a), len(b))
-
-	// Scan for matches including partial codepoints.
-	#no_bounds_check for n < _len && a[n] == b[n] {
-		n += 1
-	}
+prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
+	RUNE_ERROR :: '\ufffd'
+	RUNE_SELF  :: 0x80
+	UTF_MAX    :: 4
 
-	// Now scan to ignore partial codepoints.
-	if n > 0 {
-		s := a[:n]
-		n = 0
-		for {
-			r0, w := utf8.decode_rune(s[n:])
-			if r0 != utf8.RUNE_ERROR {
-				n += w
-			} else {
-				break
+	n    = runtime.memory_prefix_length(raw_data(a), raw_data(b), min(len(a), len(b)))
+	lim := max(n - UTF_MAX + 1, 0)
+	for l := n; l > lim; l -= 1 {
+		r, _ := runtime.string_decode_rune(a[l - 1:])
+		if r != RUNE_ERROR {
+			if l > 0 && (a[l - 1] & 0xc0 == 0xc0) {
+				return l - 1
 			}
+			return l
 		}
 	}
 	return
 }
 /*
+Returns the common prefix between strings `a` and `b`
+
+Inputs:
+- a: The first input string
+- b: The second input string
+
+Returns:
+- n: The string prefix common between strings `a` and `b`
+
+Example:
+
+	import "core:fmt"
+	import "core:strings"
+
+	common_prefix_example :: proc() {
+		fmt.println(strings.common_prefix("testing", "test"))
+		fmt.println(strings.common_prefix("testing", "te"))
+		fmt.println(strings.common_prefix("telephone", "te"))
+	}
+
+Output:
+
+	test
+	te
+	te
+
+
+*/
+common_prefix :: proc(a, b: string) -> string {
+	return a[:prefix_length(a, b)]
+}
+/*
 Determines if a string `s` starts with a given `prefix`
 
 Inputs:

+ 1 - 0
tests/benchmark/all.odin

@@ -4,3 +4,4 @@ package benchmarks
 @(require) import "crypto"
 @(require) import "hash"
 @(require) import "text/regex"
+@(require) import "strings"

+ 131 - 0
tests/benchmark/strings/benchmark_strings.odin

@@ -0,0 +1,131 @@
+package benchmark_strings
+
+import "base:intrinsics"
+import "core:fmt"
+import "core:log"
+import "core:testing"
+import "core:strings"
+import "core:text/table"
+import "core:time"
+import "core:unicode/utf8"
+
+RUNS_PER_SIZE :: 2500
+
+sizes := [?]int {
+	7, 8, 9,
+	15, 16, 17,
+	31, 32, 33,
+	63, 64, 65,
+	95, 96, 97,
+	128,
+	256,
+	512,
+	1024,
+	4096,
+}
+
+// These are the normal, unoptimized algorithms.
+
+plain_prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
+	_len := min(len(a), len(b))
+
+	// Scan for matches including partial codepoints.
+	#no_bounds_check for n < _len && a[n] == b[n] {
+		n += 1
+	}
+
+	// Now scan to ignore partial codepoints.
+	if n > 0 {
+		s := a[:n]
+		n = 0
+		for {
+			r0, w := utf8.decode_rune(s[n:])
+			if r0 != utf8.RUNE_ERROR {
+				n += w
+			} else {
+				break
+			}
+		}
+	}
+	return
+}
+
+run_trial_size_prefix :: proc(p: proc "contextless" (string, string) -> $R, suffix: string, size: int, idx: int, runs: int, loc := #caller_location) -> (timing: time.Duration) {
+	left  := make([]u8, size)
+	right := make([]u8, size)
+	defer {
+		delete(left)
+		delete(right)
+	}
+
+	if len(suffix) > 0 {
+		copy(left [idx:], suffix)
+		copy(right[idx:], suffix)
+
+	} else {
+		right[idx] = 'A'
+	}
+
+	accumulator: int
+
+	watch: time.Stopwatch
+
+	time.stopwatch_start(&watch)
+	for _ in 0..<runs {
+		result := p(string(left[:size]), string(right[:size]))
+		accumulator += result
+	}
+	time.stopwatch_stop(&watch)
+	timing = time.stopwatch_duration(watch)
+
+	log.debug(accumulator)
+	return
+}
+
+run_trial_size :: proc {
+	run_trial_size_prefix,
+}
+
+bench_table_size :: proc(algo_name: string, plain, simd: $P, suffix := "") {
+	string_buffer := strings.builder_make()
+	defer strings.builder_destroy(&string_buffer)
+
+	tbl: table.Table
+	table.init(&tbl)
+	defer table.destroy(&tbl)
+
+	table.aligned_header_of_values(&tbl, .Right, "Algorithm", "Size", "Iterations", "Scalar", "SIMD", "SIMD Relative (%)", "SIMD Relative (x)")
+
+	for size in sizes {
+		// Place the non-zero byte somewhere in the middle.
+		needle_index := size / 2
+
+		plain_timing := run_trial_size(plain, suffix, size, needle_index, RUNS_PER_SIZE)
+		simd_timing  := run_trial_size(simd,  suffix, size, needle_index, RUNS_PER_SIZE)
+
+		_plain := fmt.tprintf("%8M",  plain_timing)
+		_simd  := fmt.tprintf("%8M",  simd_timing)
+		_relp  := fmt.tprintf("%.3f %%", f64(simd_timing) / f64(plain_timing) * 100.0)
+		_relx  := fmt.tprintf("%.3f x",  1 / (f64(simd_timing) / f64(plain_timing)))
+
+		table.aligned_row_of_values(
+			&tbl,
+			.Right,
+			algo_name,
+			size, RUNS_PER_SIZE, _plain, _simd, _relp, _relx)
+	}
+
+	builder_writer := strings.to_writer(&string_buffer)
+
+	fmt.sbprintln(&string_buffer)
+	table.write_plain_table(builder_writer, &tbl)
+
+	my_table_string := strings.to_string(string_buffer)
+	log.info(my_table_string)
+}
+
+@test
+benchmark_memory_procs :: proc(t: ^testing.T) {
+	bench_table_size("prefix_length ascii",   plain_prefix_length, strings.prefix_length)
+	bench_table_size("prefix_length unicode", plain_prefix_length, strings.prefix_length, "🦉")
+}

+ 50 - 1
tests/core/strings/test_core_strings.odin

@@ -1,9 +1,10 @@
 package test_core_strings
 
+import "base:runtime"
 import "core:mem"
 import "core:strings"
 import "core:testing"
-import "base:runtime"
+import "core:unicode/utf8"
 
 @test
 test_index_any_small_string_not_found :: proc(t: ^testing.T) {
@@ -218,3 +219,51 @@ test_builder_to_cstring :: proc(t: ^testing.T) {
 		testing.expect(t, err == .Out_Of_Memory)
 	}
 }
+
+@test
+test_prefix_length :: proc(t: ^testing.T) {
+	prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
+		_len := min(len(a), len(b))
+
+		// Scan for matches including partial codepoints.
+		#no_bounds_check for n < _len && a[n] == b[n] {
+			n += 1
+		}
+
+		// Now scan to ignore partial codepoints.
+		if n > 0 {
+			s := a[:n]
+			n = 0
+			for {
+				r0, w := utf8.decode_rune(s[n:])
+				if r0 != utf8.RUNE_ERROR {
+					n += w
+				} else {
+					break
+				}
+			}
+		}
+		return
+	}
+
+	cases := [][2]string{
+		{"Hellope, there!", "Hellope, world!"},
+		{"Hellope, there!", "Foozle"},
+		{"Hellope, there!", "Hell"},
+		{"Hellope! 🦉",     "Hellope! 🦉"},
+	}
+
+	for v in cases {
+		p_scalar := prefix_length(v[0], v[1])
+		p_simd   := strings.prefix_length(v[0], v[1])
+		testing.expect_value(t, p_simd, p_scalar)
+
+		s := v[0]
+		for len(s) > 0 {
+			p_scalar = prefix_length(v[0], s)
+			p_simd   = strings.prefix_length(v[0], s)
+			testing.expect_value(t, p_simd, p_scalar)
+			s = s[:len(s) - 1]
+		}
+	}
+}