فهرست منبع

Merge pull request #5442 from jon-lipstate/table_lookup

table lookup simd intrinsic
gingerBill 1 ماه پیش
والد
کامیت
19a075211f
5فایلهای تغییر یافته به همراه375 افزوده شده و 0 حذف شده
  1. 1 0
      base/intrinsics/intrinsics.odin
  2. 51 0
      core/simd/simd.odin
  3. 52 0
      src/check_builtin.cpp
  4. 2 0
      src/checker_builtin_procs.hpp
  5. 269 0
      src/llvm_backend_proc.cpp

+ 1 - 0
base/intrinsics/intrinsics.odin

@@ -314,6 +314,7 @@ simd_indices :: proc($T: typeid/#simd[$N]$E) -> T where type_is_numeric(T) ---
 
 simd_shuffle :: proc(a, b: #simd[N]T, indices: ..int) -> #simd[len(indices)]T ---
 simd_select  :: proc(cond: #simd[N]boolean_or_integer, true, false: #simd[N]T) -> #simd[N]T ---
+simd_runtime_swizzle :: proc(table: #simd[N]T, indices: #simd[N]T) -> #simd[N]T where type_is_integer(T) ---
 
 // Lane-wise operations
 simd_ceil    :: proc(a: #simd[N]any_float) -> #simd[N]any_float ---

+ 51 - 0
core/simd/simd.odin

@@ -2440,6 +2440,57 @@ Graphically, the operation looks as follows. The `t` and `f` represent the
 */
 select :: intrinsics.simd_select
 
+/*
+Runtime Equivalent to Shuffle.
+
+Performs element-wise table lookups using runtime indices.
+Each element in the indices vector selects an element from the table vector.
+The indices are automatically masked to prevent out-of-bounds access.
+
+This operation is hardware-accelerated on most platforms when using 8-bit
+integer vectors. For other element types or unsupported vector sizes, it
+falls back to software emulation.
+
+Inputs:
+- `table`: The lookup table vector (should be power-of-2 size for correct masking).
+- `indices`: The indices vector (automatically masked to valid range).
+
+Returns:
+- A vector where `result[i] = table[indices[i] & (table_size-1)]`.
+
+Operation:
+
+	for i in 0 ..< len(indices) {
+		masked_index := indices[i] & (len(table) - 1)
+		result[i] = table[masked_index]
+	}
+	return result
+
+Implementation:
+
+	| Platform    | Lane Size                                 | Implementation      |
+	|-------------|-------------------------------------------|---------------------|
+	| x86-64      | pshufb (16B), vpshufb (32B), AVX512 (64B) | Single vector       |
+	| ARM64       | tbl1 (16B), tbl2 (32B), tbl4 (64B)        | Automatic splitting |
+	| ARM32       | vtbl1 (8B), vtbl2 (16B), vtbl4 (32B)      | Automatic splitting |
+	| WebAssembly | i8x16.swizzle (16B), Emulation (>16B)     | Mixed               |
+	| Other       | Emulation                                 | Software            |
+
+Example:
+
+	import "core:simd"
+	import "core:fmt"
+
+	runtime_swizzle_example :: proc() {
+		table := simd.u8x16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+		indices := simd.u8x16{15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}
+		result := simd.runtime_swizzle(table, indices)
+		fmt.println(result) // Expected: {15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}
+	}
+
+*/
+runtime_swizzle :: intrinsics.simd_runtime_swizzle
+
 /*
 Compute the square root of each lane in a SIMD vector.
 */

+ 52 - 0
src/check_builtin.cpp

@@ -1159,6 +1159,58 @@ gb_internal bool check_builtin_simd_operation(CheckerContext *c, Operand *operan
 			return true;
 		}
 
+	case BuiltinProc_simd_runtime_swizzle:
+		{
+			if (ce->args.count != 2) {
+				error(call, "'%.*s' expected 2 arguments, got %td", LIT(builtin_name), ce->args.count);
+				return false;
+			}
+			
+			Operand src = {};
+			Operand indices = {};
+			check_expr(c, &src, ce->args[0]); if (src.mode == Addressing_Invalid) return false;
+			check_expr_with_type_hint(c, &indices, ce->args[1], src.type); if (indices.mode == Addressing_Invalid) return false;
+			
+			if (!is_type_simd_vector(src.type)) {
+				error(src.expr, "'%.*s' expected first argument to be a simd vector", LIT(builtin_name));
+				return false;
+			}
+			if (!is_type_simd_vector(indices.type)) {
+				error(indices.expr, "'%.*s' expected second argument (indices) to be a simd vector", LIT(builtin_name));
+				return false;
+			}
+			
+			Type *src_elem = base_array_type(src.type);
+			Type *indices_elem = base_array_type(indices.type);
+			
+			if (!is_type_integer(src_elem)) {
+				gbString src_str = type_to_string(src.type);
+				error(src.expr, "'%.*s' expected first argument to be a simd vector of integers, got '%s'", LIT(builtin_name), src_str);
+				gb_string_free(src_str);
+				return false;
+			}
+			
+			if (!is_type_integer(indices_elem)) {
+				gbString indices_str = type_to_string(indices.type);
+				error(indices.expr, "'%.*s' expected indices to be a simd vector of integers, got '%s'", LIT(builtin_name), indices_str);
+				gb_string_free(indices_str);
+				return false;
+			}
+			
+			if (!are_types_identical(src.type, indices.type)) {
+				gbString src_str = type_to_string(src.type);
+				gbString indices_str = type_to_string(indices.type);
+				error(indices.expr, "'%.*s' expected both arguments to have the same type, got '%s' vs '%s'", LIT(builtin_name), src_str, indices_str);
+				gb_string_free(indices_str);
+				gb_string_free(src_str);
+				return false;
+			}
+			
+			operand->mode = Addressing_Value;
+			operand->type = src.type;
+			return true;
+		}
+
 	case BuiltinProc_simd_ceil:
 	case BuiltinProc_simd_floor:
 	case BuiltinProc_simd_trunc:

+ 2 - 0
src/checker_builtin_procs.hpp

@@ -191,6 +191,7 @@ BuiltinProc__simd_begin,
 
 	BuiltinProc_simd_shuffle,
 	BuiltinProc_simd_select,
+	BuiltinProc_simd_runtime_swizzle,
 
 	BuiltinProc_simd_ceil,
 	BuiltinProc_simd_floor,
@@ -552,6 +553,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 
 	{STR_LIT("simd_shuffle"), 2, true,  Expr_Expr, BuiltinProcPkg_intrinsics},
 	{STR_LIT("simd_select"),  3, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_runtime_swizzle"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 
 	{STR_LIT("simd_ceil") , 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 	{STR_LIT("simd_floor"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},

+ 269 - 0
src/llvm_backend_proc.cpp

@@ -1721,6 +1721,275 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
 			return res;
 		}
 
+	case BuiltinProc_simd_runtime_swizzle:
+		{
+			LLVMValueRef src = arg0.value;
+			LLVMValueRef indices = lb_build_expr(p, ce->args[1]).value;
+			
+			Type *vt = arg0.type;
+			GB_ASSERT(vt->kind == Type_SimdVector);
+			i64 count = vt->SimdVector.count;
+			Type *elem_type = vt->SimdVector.elem;
+			i64 elem_size = type_size_of(elem_type);
+			
+			// Determine strategy based on element size and target architecture
+			char const *intrinsic_name = nullptr;
+			bool use_hardware_runtime_swizzle = false;
+			
+			// 8-bit elements: Use dedicated table lookup instructions
+			if (elem_size == 1) {
+				use_hardware_runtime_swizzle = true;
+				
+				if (build_context.metrics.arch == TargetArch_amd64 || build_context.metrics.arch == TargetArch_i386) {
+					// x86/x86-64: Use pshufb intrinsics
+					switch (count) {
+					case 16:
+						intrinsic_name = "llvm.x86.ssse3.pshuf.b.128";
+						break;
+					case 32:
+						intrinsic_name = "llvm.x86.avx2.pshuf.b";
+						break;
+					case 64:
+						intrinsic_name = "llvm.x86.avx512.pshuf.b.512";
+						break;
+					default:
+						use_hardware_runtime_swizzle = false;
+						break;
+					}
+				} else if (build_context.metrics.arch == TargetArch_arm64) {
+					// ARM64: Use NEON tbl intrinsics with automatic table splitting
+					switch (count) {
+					case 16:
+						intrinsic_name = "llvm.aarch64.neon.tbl1";
+						break;
+					case 32:
+						intrinsic_name = "llvm.aarch64.neon.tbl2";
+						break;
+					case 48:
+						intrinsic_name = "llvm.aarch64.neon.tbl3";
+						break;
+					case 64:
+						intrinsic_name = "llvm.aarch64.neon.tbl4";
+						break;
+					default:
+						use_hardware_runtime_swizzle = false;
+						break;
+					}
+				} else if (build_context.metrics.arch == TargetArch_arm32) {
+					// ARM32: Use NEON vtbl intrinsics with automatic table splitting
+					switch (count) {
+					case 8:
+						intrinsic_name = "llvm.arm.neon.vtbl1";
+						break;
+					case 16:
+						intrinsic_name = "llvm.arm.neon.vtbl2";
+						break;
+					case 24:
+						intrinsic_name = "llvm.arm.neon.vtbl3";
+						break;
+					case 32:
+						intrinsic_name = "llvm.arm.neon.vtbl4";
+						break;
+					default:
+						use_hardware_runtime_swizzle = false;
+						break;
+					}
+				} else if (build_context.metrics.arch == TargetArch_wasm32 || build_context.metrics.arch == TargetArch_wasm64p32) {
+					// WebAssembly: Use swizzle (only supports 16-byte vectors)
+					if (count == 16) {
+						intrinsic_name = "llvm.wasm.swizzle";
+					} else {
+						use_hardware_runtime_swizzle = false;
+					}
+				} else {
+					use_hardware_runtime_swizzle = false;
+				}
+			}
+			
+			if (use_hardware_runtime_swizzle && intrinsic_name != nullptr) {
+				// Use dedicated hardware swizzle instruction
+				
+				// Check if required target features are enabled
+				bool features_enabled = true;
+				if (build_context.metrics.arch == TargetArch_amd64 || build_context.metrics.arch == TargetArch_i386) {
+					// x86/x86-64 feature checking
+					if (count == 16) {
+						// SSE/SSSE3 for 128-bit vectors
+						if (!check_target_feature_is_enabled(str_lit("ssse3"), nullptr)) {
+							features_enabled = false;
+						}
+					} else if (count == 32) {
+						// AVX2 requires ssse3 + avx2 features
+						if (!check_target_feature_is_enabled(str_lit("ssse3"), nullptr) || 
+							!check_target_feature_is_enabled(str_lit("avx2"), nullptr)) {
+							features_enabled = false;
+						}
+					} else if (count == 64) {
+						// AVX512 requires ssse3 + avx2 + avx512f + avx512bw features
+						if (!check_target_feature_is_enabled(str_lit("ssse3"), nullptr) ||
+							!check_target_feature_is_enabled(str_lit("avx2"), nullptr) ||
+							!check_target_feature_is_enabled(str_lit("avx512f"), nullptr) ||
+							!check_target_feature_is_enabled(str_lit("avx512bw"), nullptr)) {
+							features_enabled = false;
+						}
+					}
+				} else if (build_context.metrics.arch == TargetArch_arm64 || build_context.metrics.arch == TargetArch_arm32) {
+					// ARM/ARM64 feature checking - NEON is required for all table/swizzle ops
+					if (!check_target_feature_is_enabled(str_lit("neon"), nullptr)) {
+						features_enabled = false;
+					}
+				}
+				
+				if (features_enabled) {
+					// Add target features to function attributes for LLVM instruction selection
+					if (build_context.metrics.arch == TargetArch_amd64 || build_context.metrics.arch == TargetArch_i386) {
+						// x86/x86-64 function attributes
+						if (count == 16) {
+							// SSE/SSSE3 for 128-bit vectors
+							lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+ssse3"));
+							lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("min-legal-vector-width"), str_lit("128"));
+						} else if (count == 32) {
+							lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+avx,+avx2,+ssse3"));
+							lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("min-legal-vector-width"), str_lit("256"));
+						} else if (count == 64) {
+							lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+avx,+avx2,+avx512f,+avx512bw,+ssse3"));
+							lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("min-legal-vector-width"), str_lit("512"));
+						}
+					} else if (build_context.metrics.arch == TargetArch_arm64) {
+						// ARM64 function attributes - enable NEON for swizzle instructions
+						lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+neon"));
+						// Set appropriate vector width for multi-swizzle operations
+						if (count >= 32) {
+							lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("min-legal-vector-width"), str_lit("256"));
+						}
+					} else if (build_context.metrics.arch == TargetArch_arm32) {
+						// ARM32 function attributes - enable NEON for swizzle instructions
+						lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+neon"));
+					}
+					
+					// Handle ARM's multi-swizzle intrinsics by splitting the src vector
+					if (build_context.metrics.arch == TargetArch_arm64 && count > 16) {
+						// ARM64 TBL2/TBL3/TBL4: Split src into multiple 16-byte vectors
+						int num_tables = cast(int)(count / 16);
+						GB_ASSERT_MSG(count % 16 == 0, "ARM64 src size must be multiple of 16 bytes, got %lld bytes", count);
+						GB_ASSERT_MSG(num_tables <= 4, "ARM64 NEON supports maximum 4 tables (tbl4), got %d tables for %lld-byte vector", num_tables, count);
+						
+						LLVMValueRef src_parts[4]; // Max 4 tables for tbl4
+						for (int i = 0; i < num_tables; i++) {
+							// Extract 16-byte slice from the larger src
+							LLVMValueRef indices_for_extract[16];
+							for (int j = 0; j < 16; j++) {
+								indices_for_extract[j] = LLVMConstInt(LLVMInt32TypeInContext(p->module->ctx), i * 16 + j, false);
+							}
+							LLVMValueRef extract_mask = LLVMConstVector(indices_for_extract, 16);
+							src_parts[i] = LLVMBuildShuffleVector(p->builder, src, LLVMGetUndef(LLVMTypeOf(src)), extract_mask, "");
+						}
+						
+						// Call appropriate ARM64 tbl intrinsic
+						if (count == 32) {
+							LLVMValueRef args[3] = { src_parts[0], src_parts[1], indices };
+							res.value = lb_call_intrinsic(p, intrinsic_name, args, 3, nullptr, 0);
+						} else if (count == 48) {
+							LLVMValueRef args[4] = { src_parts[0], src_parts[1], src_parts[2], indices };
+							res.value = lb_call_intrinsic(p, intrinsic_name, args, 4, nullptr, 0);
+						} else if (count == 64) {
+							LLVMValueRef args[5] = { src_parts[0], src_parts[1], src_parts[2], src_parts[3], indices };
+							res.value = lb_call_intrinsic(p, intrinsic_name, args, 5, nullptr, 0);
+						}
+					} else if (build_context.metrics.arch == TargetArch_arm32 && count > 8) {
+						// ARM32 VTBL2/VTBL3/VTBL4: Split src into multiple 8-byte vectors
+						int num_tables = cast(int)count / 8;
+						GB_ASSERT_MSG(count % 8 == 0, "ARM32 src size must be multiple of 8 bytes, got %lld bytes", count);
+						GB_ASSERT_MSG(num_tables <= 4, "ARM32 NEON supports maximum 4 tables (vtbl4), got %d tables for %lld-byte vector", num_tables, count);
+						
+						LLVMValueRef src_parts[4]; // Max 4 tables for vtbl4
+						for (int i = 0; i < num_tables; i++) {
+							// Extract 8-byte slice from the larger src
+							LLVMValueRef indices_for_extract[8];
+							for (int j = 0; j < 8; j++) {
+								indices_for_extract[j] = LLVMConstInt(LLVMInt32TypeInContext(p->module->ctx), i * 8 + j, false);
+							}
+							LLVMValueRef extract_mask = LLVMConstVector(indices_for_extract, 8);
+							src_parts[i] = LLVMBuildShuffleVector(p->builder, src, LLVMGetUndef(LLVMTypeOf(src)), extract_mask, "");
+						}
+						
+						// Call appropriate ARM32 vtbl intrinsic
+						if (count == 16) {
+							LLVMValueRef args[3] = { src_parts[0], src_parts[1], indices };
+							res.value = lb_call_intrinsic(p, intrinsic_name, args, 3, nullptr, 0);
+						} else if (count == 24) {
+							LLVMValueRef args[4] = { src_parts[0], src_parts[1], src_parts[2], indices };
+							res.value = lb_call_intrinsic(p, intrinsic_name, args, 4, nullptr, 0);
+						} else if (count == 32) {
+							LLVMValueRef args[5] = { src_parts[0], src_parts[1], src_parts[2], src_parts[3], indices };
+							res.value = lb_call_intrinsic(p, intrinsic_name, args, 5, nullptr, 0);
+						}
+					} else {
+						// Single runtime swizzle case (x86, WebAssembly, ARM single-table)
+						LLVMValueRef args[2] = { src, indices };
+						res.value = lb_call_intrinsic(p, intrinsic_name, args, gb_count_of(args), nullptr, 0);
+					}
+					return res;
+				} else {
+					// Features not enabled, fall back to emulation
+					use_hardware_runtime_swizzle = false;
+				}
+			}
+			
+			// Fallback: Emulate with extracts and inserts for all element sizes
+			GB_ASSERT(count > 0 && count <= 64); // Sanity check
+			
+			LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, count);
+			LLVMTypeRef i32_type = LLVMInt32TypeInContext(p->module->ctx);
+			LLVMTypeRef elem_llvm_type = lb_type(p->module, elem_type);
+			
+			// Calculate mask based on element size and vector count
+			i64 max_index = count - 1;
+			LLVMValueRef index_mask;
+			
+			if (elem_size == 1) {
+				// 8-bit: mask to src size (like pshufb behavior)
+				index_mask = LLVMConstInt(elem_llvm_type, max_index, false);
+			} else if (elem_size == 2) {
+				// 16-bit: mask to src size 
+				index_mask = LLVMConstInt(elem_llvm_type, max_index, false);
+			} else if (elem_size == 4) {
+				// 32-bit: mask to src size
+				index_mask = LLVMConstInt(elem_llvm_type, max_index, false);
+			} else {
+				// 64-bit: mask to src size
+				index_mask = LLVMConstInt(elem_llvm_type, max_index, false);
+			}
+			
+			for (i64 i = 0; i < count; i++) {
+				LLVMValueRef idx_i = LLVMConstInt(i32_type, cast(unsigned)i, false);
+				LLVMValueRef index_elem = LLVMBuildExtractElement(p->builder, indices, idx_i, "");
+				
+				// Mask index to valid range
+				LLVMValueRef masked_index = LLVMBuildAnd(p->builder, index_elem, index_mask, "");
+				
+				// Convert to i32 for extractelement
+				LLVMValueRef index_i32;
+				if (LLVMGetIntTypeWidth(LLVMTypeOf(masked_index)) < 32) {
+					index_i32 = LLVMBuildZExt(p->builder, masked_index, i32_type, "");
+				} else if (LLVMGetIntTypeWidth(LLVMTypeOf(masked_index)) > 32) {
+					index_i32 = LLVMBuildTrunc(p->builder, masked_index, i32_type, "");
+				} else {
+					index_i32 = masked_index;
+				}
+				
+				values[i] = LLVMBuildExtractElement(p->builder, src, index_i32, "");
+			}
+			
+			// Build result vector
+			res.value = LLVMGetUndef(LLVMTypeOf(src));
+			for (i64 i = 0; i < count; i++) {
+				LLVMValueRef idx_i = LLVMConstInt(i32_type, cast(unsigned)i, false);
+				res.value = LLVMBuildInsertElement(p->builder, res.value, values[i], idx_i, "");
+			}
+			return res;
+		}
+
 	case BuiltinProc_simd_ceil:
 	case BuiltinProc_simd_floor:
 	case BuiltinProc_simd_trunc: