Browse Source

Merge pull request #4023 from Feoramund/simd-index

Vectorize `index_byte`
Jeroen van Rijn 1 year ago
parent
commit
e226d37803

+ 151 - 5
core/bytes/bytes.odin

@@ -1,9 +1,38 @@
 package bytes
 
+import "base:intrinsics"
 import "core:mem"
 import "core:unicode"
 import "core:unicode/utf8"
 
+
+@private SIMD_SCAN_WIDTH :: 8 * size_of(uintptr)
+
+when SIMD_SCAN_WIDTH == 32 {
+	@(private, rodata)
+	simd_scanner_indices := #simd[SIMD_SCAN_WIDTH]u8 {
+		 0,  1,  2,  3,  4,  5,  6,  7,
+		 8,  9, 10, 11, 12, 13, 14, 15,
+		16, 17, 18, 19, 20, 21, 22, 23,
+		24, 25, 26, 27, 28, 29, 30, 31,
+	}
+} else when SIMD_SCAN_WIDTH == 64 {
+	@(private, rodata)
+	simd_scanner_indices := #simd[SIMD_SCAN_WIDTH]u8 {
+		 0,  1,  2,  3,  4,  5,  6,  7,
+		 8,  9, 10, 11, 12, 13, 14, 15,
+		16, 17, 18, 19, 20, 21, 22, 23,
+		24, 25, 26, 27, 28, 29, 30, 31,
+		32, 33, 34, 35, 36, 37, 38, 39,
+		40, 41, 42, 43, 44, 45, 46, 47,
+		48, 49, 50, 51, 52, 53, 54, 55,
+		56, 57, 58, 59, 60, 61, 62, 63,
+	}
+} else {
+	#panic("Invalid SIMD_SCAN_WIDTH. Must be 32 or 64.")
+}
+
+
 clone :: proc(s: []byte, allocator := context.allocator, loc := #caller_location) -> []byte {
 	c := make([]byte, len(s), allocator, loc)
 	copy(c, s)
@@ -293,23 +322,140 @@ split_after_iterator :: proc(s: ^[]byte, sep: []byte) -> ([]byte, bool) {
 	return _split_iterator(s, sep, len(sep))
 }
 
+/*
+Scan a slice of bytes for a specific byte.
+
+This procedure safely handles slices of any length, including empty slices.
+
+Inputs:
+- data: A slice of bytes.
+- c: The byte to search for.
+
+Returns:
+- index: The index of the byte `c`, or -1 if it was not found.
+*/
+index_byte :: proc(s: []byte, c: byte) -> (index: int) #no_bounds_check {
+	length := len(s)
+	i := 0
+
+	// Guard against small strings.
+	if length < SIMD_SCAN_WIDTH {
+		for /**/; i < length; i += 1 {
+			if s[i] == c {
+				return i
+			}
+		}
+		return -1
+	}
+
+	ptr := cast(int)cast(uintptr)raw_data(s)
+
+	alignment_start := (SIMD_SCAN_WIDTH - ptr % SIMD_SCAN_WIDTH) % SIMD_SCAN_WIDTH
 
-index_byte :: proc(s: []byte, c: byte) -> int {
-	for i := 0; i < len(s); i += 1 {
+	// Iterate as a scalar until the data is aligned on a `SIMD_SCAN_WIDTH` boundary.
+	//
+	// This way, every load in the vector loop will be aligned, which should be
+	// the fastest possible scenario.
+	for /**/; i < alignment_start; i += 1 {
 		if s[i] == c {
 			return i
 		}
 	}
+
+	// Iterate as a vector over every aligned chunk, evaluating each byte simultaneously at the CPU level.
+	scanner: #simd[SIMD_SCAN_WIDTH]u8 = c
+	tail := length - (length - alignment_start) % SIMD_SCAN_WIDTH
+
+	for /**/; i < tail; i += SIMD_SCAN_WIDTH {
+		load := (cast(^#simd[SIMD_SCAN_WIDTH]u8)(&s[i]))^
+		comparison := intrinsics.simd_lanes_eq(load, scanner)
+		match := intrinsics.simd_reduce_or(comparison)
+		if match > 0 {
+			sentinel: #simd[SIMD_SCAN_WIDTH]u8 = u8(0xFF)
+			index_select := intrinsics.simd_select(comparison, simd_scanner_indices, sentinel)
+			index_reduce := intrinsics.simd_reduce_min(index_select)
+			return i + cast(int)index_reduce
+		}
+	}
+
+	// Iterate as a scalar over the remaining unaligned portion.
+	for /**/; i < length; i += 1 {
+		if s[i] == c {
+			return i
+		}
+	}
+
 	return -1
 }
 
-// Returns -1 if c is not present
-last_index_byte :: proc(s: []byte, c: byte) -> int {
-	for i := len(s)-1; i >= 0; i -= 1 {
+/*
+Scan a slice of bytes for a specific byte, starting from the end and working
+backwards to the start.
+
+This procedure safely handles slices of any length, including empty slices.
+
+Inputs:
+- data: A slice of bytes.
+- c: The byte to search for.
+
+Returns:
+- index: The index of the byte `c`, or -1 if it was not found.
+*/
+last_index_byte :: proc(s: []byte, c: byte) -> int #no_bounds_check {
+	length := len(s)
+	i := length - 1
+
+	// Guard against small strings.
+	if length < SIMD_SCAN_WIDTH {
+		for /**/; i >= 0; i -= 1 {
+			if s[i] == c {
+				return i
+			}
+		}
+		return -1
+	}
+
+	ptr := cast(int)cast(uintptr)raw_data(s)
+
+	tail := length - (ptr + length) % SIMD_SCAN_WIDTH
+
+	// Iterate as a scalar until the data is aligned on a `SIMD_SCAN_WIDTH` boundary.
+	//
+	// This way, every load in the vector loop will be aligned, which should be
+	// the fastest possible scenario.
+	for /**/; i >= tail; i -= 1 {
 		if s[i] == c {
 			return i
 		}
 	}
+
+	// Iterate as a vector over every aligned chunk, evaluating each byte simultaneously at the CPU level.
+	scanner: #simd[SIMD_SCAN_WIDTH]u8 = c
+	alignment_start := (SIMD_SCAN_WIDTH - ptr % SIMD_SCAN_WIDTH) % SIMD_SCAN_WIDTH
+
+	i -= SIMD_SCAN_WIDTH - 1
+
+	for /**/; i >= alignment_start; i -= SIMD_SCAN_WIDTH {
+		load := (cast(^#simd[SIMD_SCAN_WIDTH]u8)(&s[i]))^
+		comparison := intrinsics.simd_lanes_eq(load, scanner)
+		match := intrinsics.simd_reduce_or(comparison)
+		if match > 0 {
+			sentinel: #simd[SIMD_SCAN_WIDTH]u8
+			index_select := intrinsics.simd_select(comparison, simd_scanner_indices, sentinel)
+			index_reduce := intrinsics.simd_reduce_max(index_select)
+			return i + cast(int)index_reduce
+		}
+	}
+
+	// Iterate as a scalar over the remaining unaligned portion.
+	i += SIMD_SCAN_WIDTH - 1
+
+	for /**/; i >= 0; i -= 1 {
+		if s[i] == c {
+			return i
+		}
+	}
+
 	return -1
 }
 

+ 4 - 12
core/strings/strings.odin

@@ -1,6 +1,8 @@
 // Procedures to manipulate UTF-8 encoded strings
 package strings
 
+import "base:intrinsics"
+import "core:bytes"
 import "core:io"
 import "core:mem"
 import "core:unicode"
@@ -1424,12 +1426,7 @@ Output:
 
 */
 index_byte :: proc(s: string, c: byte) -> (res: int) {
-	for i := 0; i < len(s); i += 1 {
-		if s[i] == c {
-			return i
-		}
-	}
-	return -1
+	return #force_inline bytes.index_byte(transmute([]u8)s, c)
 }
 /*
 Returns the byte offset of the last byte `c` in the string `s`, -1 when not found.
@@ -1464,12 +1461,7 @@ Output:
 
 */
 last_index_byte :: proc(s: string, c: byte) -> (res: int) {
-	for i := len(s)-1; i >= 0; i -= 1 {
-		if s[i] == c {
-			return i
-		}
-	}
-	return -1
+	return #force_inline bytes.last_index_byte(transmute([]u8)s, c)
 }
 /*
 Returns the byte offset of the first rune `r` in the string `s` it finds, -1 when not found.

+ 1 - 0
tests/benchmark/all.odin

@@ -1,4 +1,5 @@
 package benchmarks
 
+@(require) import "bytes"
 @(require) import "crypto"
 @(require) import "hash"

+ 116 - 0
tests/benchmark/bytes/benchmark_bytes.odin

@@ -0,0 +1,116 @@
+package benchmark_bytes
+
+import "core:bytes"
+import "core:fmt"
+import "core:log"
+import "core:testing"
+import "core:time"
+
+
+// These are the normal, unoptimized algorithms.
+
+plain_index_byte :: proc(s: []u8, c: byte) -> (res: int) #no_bounds_check {
+	for i := 0; i < len(s); i += 1 {
+		if s[i] == c {
+			return i
+		}
+	}
+	return -1
+}
+
+plain_last_index_byte :: proc(s: []u8, c: byte) -> (res: int) #no_bounds_check {
+	for i := len(s)-1; i >= 0; i -= 1 {
+		if s[i] == c {
+			return i
+		}
+	}
+	return -1
+}
+
+sizes := [?]int {
+	15, 16, 17,
+	31, 32, 33,
+	256,
+	512,
+	1024,
+	1024 * 1024,
+	1024 * 1024 * 1024,
+}
+
+run_trial_size :: proc(p: proc([]u8, byte) -> int, size: int, idx: int, warmup: int, runs: int) -> (timing: time.Duration) {
+	data := make([]u8, size)
+	defer delete(data)
+
+	for i in 0..<size {
+		data[i] = u8('0' + i % 10)
+	}
+	data[idx] = 'z'
+
+	accumulator: int
+
+	for _ in 0..<warmup {
+		accumulator += p(data, 'z')
+	}
+
+	for _ in 0..<runs {
+		start := time.now()
+		accumulator += p(data, 'z')
+		done := time.since(start)
+		timing += done
+	}
+
+	timing /= time.Duration(runs)
+
+	log.debug(accumulator)
+	return
+}
+
+HOT :: 3
+
+@test
+benchmark_plain_index_cold :: proc(t: ^testing.T) {
+	report: string
+	for size in sizes {
+		timing := run_trial_size(plain_index_byte, size, size - 1, 0, 1)
+		report = fmt.tprintf("%s\n        +++ % 8M | %v", report, size, timing)
+		timing = run_trial_size(plain_last_index_byte, size, 0, 0, 1)
+		report = fmt.tprintf("%s\n (last) +++ % 8M | %v", report, size, timing)
+	}
+	log.info(report)
+}
+
+@test
+benchmark_plain_index_hot :: proc(t: ^testing.T) {
+	report: string
+	for size in sizes {
+		timing := run_trial_size(plain_index_byte, size, size - 1, HOT, HOT)
+		report = fmt.tprintf("%s\n        +++ % 8M | %v", report, size, timing)
+		timing = run_trial_size(plain_last_index_byte, size, 0, HOT, HOT)
+		report = fmt.tprintf("%s\n (last) +++ % 8M | %v", report, size, timing)
+	}
+	log.info(report)
+}
+
+@test
+benchmark_simd_index_cold :: proc(t: ^testing.T) {
+	report: string
+	for size in sizes {
+		timing := run_trial_size(bytes.index_byte, size, size - 1, 0, 1)
+		report = fmt.tprintf("%s\n        +++ % 8M | %v", report, size, timing)
+		timing = run_trial_size(bytes.last_index_byte, size, 0, 0, 1)
+		report = fmt.tprintf("%s\n (last) +++ % 8M | %v", report, size, timing)
+	}
+	log.info(report)
+}
+
+@test
+benchmark_simd_index_hot :: proc(t: ^testing.T) {
+	report: string
+	for size in sizes {
+		timing := run_trial_size(bytes.index_byte, size, size - 1, HOT, HOT)
+		report = fmt.tprintf("%s\n        +++ % 8M | %v", report, size, timing)
+		timing = run_trial_size(bytes.last_index_byte, size, 0, HOT, HOT)
+		report = fmt.tprintf("%s\n (last) +++ % 8M | %v", report, size, timing)
+	}
+	log.info(report)
+}

+ 89 - 0
tests/core/bytes/test_core_bytes.odin

@@ -0,0 +1,89 @@
+package test_core_bytes
+
+import "core:bytes"
+import "core:slice"
+import "core:testing"
+
+@private SIMD_SCAN_WIDTH :: 8 * size_of(uintptr)
+
+@test
+test_index_byte_sanity :: proc(t: ^testing.T) {
+	// We must be able to find the byte at the correct index.
+	data := make([]u8, 2 * SIMD_SCAN_WIDTH)
+	defer delete(data)
+	slice.fill(data, '-')
+
+	INDEX_MAX :: SIMD_SCAN_WIDTH - 1
+
+	for offset in 0..<INDEX_MAX {
+		for idx in 0..<INDEX_MAX {
+			sub := data[offset:]
+			sub[idx] = 'o'
+			if !testing.expect_value(t, bytes.index_byte(sub, 'o'), idx) {
+				return
+			}
+			if !testing.expect_value(t, bytes.last_index_byte(sub, 'o'), idx) {
+				return
+			}
+			sub[idx] = '-'
+		}
+	}
+}
+
+@test
+test_index_byte_empty :: proc(t: ^testing.T) {
+	a: [1]u8
+	testing.expect_value(t, bytes.index_byte(a[0:0], 'o'), -1)
+	testing.expect_value(t, bytes.last_index_byte(a[0:0], 'o'), -1)
+}
+
+@test
+test_index_byte_multiple_hits :: proc(t: ^testing.T) {
+	for n in 5..<256 {
+		data := make([]u8, n)
+		defer delete(data)
+		slice.fill(data, '-')
+
+		data[n-1] = 'o'
+		data[n-3] = 'o'
+		data[n-5] = 'o'
+
+		// Find the first one.
+		if !testing.expect_value(t, bytes.index_byte(data, 'o'), n-5) {
+			return
+		}
+
+		// Find the last one.
+		if !testing.expect_value(t, bytes.last_index_byte(data, 'o'), n-1) {
+			return
+		}
+	}
+}
+
+@test
+test_index_byte_zero :: proc(t: ^testing.T) {
+	// This test protects against false positives in uninitialized memory.
+	for n in 1..<256 {
+		data := make([]u8, n + 64)
+		defer delete(data)
+		slice.fill(data, '-')
+
+		// Positive hit.
+		data[n-1] = 0
+		if !testing.expect_value(t, bytes.index_byte(data[:n], 0), n-1) {
+			return
+		}
+		if !testing.expect_value(t, bytes.last_index_byte(data[:n], 0), n-1) {
+			return
+		}
+
+		// Test for false positives.
+		data[n-1] = '-'
+		if !testing.expect_value(t, bytes.index_byte(data[:n], 0), -1) {
+			return
+		}
+		if !testing.expect_value(t, bytes.last_index_byte(data[:n], 0), -1) {
+			return
+		}
+	}
+}

+ 1 - 0
tests/core/normal.odin

@@ -9,6 +9,7 @@ download_assets :: proc() {
 	}
 }
 
+@(require) import "bytes"
 @(require) import "c/libc"
 @(require) import "compress"
 @(require) import "container"