Browse Source

Add `intrinsics.simd_masked_load` and `intrinsics.simd_masked_store`

gingerBill 1 year ago
parent
commit
84ac56f778

+ 5 - 2
base/intrinsics/intrinsics.odin

@@ -283,8 +283,11 @@ simd_reduce_any         :: proc(a: #simd[N]T) -> T where type_is_boolean(T) ---
 simd_reduce_all         :: proc(a: #simd[N]T) -> T where type_is_boolean(T) ---
 
 
-simd_gather  :: proc(ptr: #simd[N]rawptr, val: #simd[N]T, mask: #simd[N]U) -> #simd[N]T where type_is_integer(U) || type_is_boolean(U) ---
-simd_scatter :: proc(ptr: #simd[N]rawptr, val: #simd[N]T, mask: #simd[N]U)              where type_is_integer(U) || type_is_boolean(U) ---
+simd_gather       :: proc(ptr: #simd[N]rawptr, val: #simd[N]T, mask: #simd[N]U) -> #simd[N]T where type_is_integer(U) || type_is_boolean(U) ---
+simd_scatter      :: proc(ptr: #simd[N]rawptr, val: #simd[N]T, mask: #simd[N]U)              where type_is_integer(U) || type_is_boolean(U) ---
+
+simd_masked_load  :: proc(ptr: rawptr, val: #simd[N]T, mask: #simd[N]U) -> #simd[N]T where type_is_integer(U) || type_is_boolean(U) ---
+simd_masked_store :: proc(ptr: rawptr, val: #simd[N]T, mask: #simd[N]U)              where type_is_integer(U) || type_is_boolean(U) ---
 
 
 simd_shuffle :: proc(a, b: #simd[N]T, indices: ..int) -> #simd[len(indices)]T ---

+ 3 - 0
core/simd/simd.odin

@@ -106,6 +106,9 @@ lanes_ge :: intrinsics.simd_lanes_ge
 // Gather and Scatter intrinsics
 gather  :: intrinsics.simd_gather
 scatter :: intrinsics.simd_scatter
+masked_load  :: intrinsics.simd_gather
+masked_store :: intrinsics.simd_scatter
+
 
 // extract :: proc(a: #simd[N]T, idx: uint) -> T
 extract :: intrinsics.simd_extract

+ 44 - 18
src/check_builtin.cpp

@@ -665,26 +665,40 @@ gb_internal bool check_builtin_simd_operation(CheckerContext *c, Operand *operan
 
 	case BuiltinProc_simd_gather:
 	case BuiltinProc_simd_scatter:
+	case BuiltinProc_simd_masked_load:
+	case BuiltinProc_simd_masked_store:
 		{
 			// gather (ptr: #simd[N]rawptr, values: #simd[N]T, mask: #simd[N]int_or_bool) -> #simd[N]T
 			// scatter(ptr: #simd[N]rawptr, values: #simd[N]T, mask: #simd[N]int_or_bool)
 
+			// masked_load (ptr: rawptr, values: #simd[N]T, mask: #simd[N]int_or_bool) -> #simd[N]T
+			// masked_store(ptr: rawptr, values: #simd[N]T, mask: #simd[N]int_or_bool)
+
 			Operand ptr    = {};
 			Operand values = {};
 			Operand mask   = {};
 			check_expr(c, &ptr,    ce->args[0]); if (ptr.mode    == Addressing_Invalid) return false;
 			check_expr(c, &values, ce->args[1]); if (values.mode == Addressing_Invalid) return false;
 			check_expr(c, &mask,   ce->args[2]); if (mask.mode   == Addressing_Invalid) return false;
-			if (!is_type_simd_vector(ptr.type))    { error(ptr.expr,    "'%.*s' expected a simd vector type", LIT(builtin_name)); return false; }
 			if (!is_type_simd_vector(values.type)) { error(values.expr, "'%.*s' expected a simd vector type", LIT(builtin_name)); return false; }
 			if (!is_type_simd_vector(mask.type))   { error(mask.expr,   "'%.*s' expected a simd vector type", LIT(builtin_name)); return false; }
 
-			Type *ptr_elem = base_array_type(ptr.type);
-			if (!is_type_rawptr(ptr_elem)) {
-				gbString s = type_to_string(ptr.type);
-				error(ptr.expr, "Expected a simd vector of 'rawptr' for the addresses, got %s", s);
-				gb_string_free(s);
-				return false;
+			if (id == BuiltinProc_simd_gather || id == BuiltinProc_simd_scatter) {
+				if (!is_type_simd_vector(ptr.type))    { error(ptr.expr,    "'%.*s' expected a simd vector type", LIT(builtin_name)); return false; }
+				Type *ptr_elem = base_array_type(ptr.type);
+				if (!is_type_rawptr(ptr_elem)) {
+					gbString s = type_to_string(ptr.type);
+					error(ptr.expr, "Expected a simd vector of 'rawptr' for the addresses, got %s", s);
+					gb_string_free(s);
+					return false;
+				}
+			} else {
+				if (!is_type_pointer(ptr.type)) {
+					gbString s = type_to_string(ptr.type);
+					error(ptr.expr, "Expected a pointer type for the address, got %s", s);
+					gb_string_free(s);
+					return false;
+				}
 			}
 			Type *mask_elem = base_array_type(mask.type);
 
@@ -695,19 +709,31 @@ gb_internal bool check_builtin_simd_operation(CheckerContext *c, Operand *operan
 				return false;
 			}
 
-			i64 ptr_count    = get_array_type_count(ptr.type);
-			i64 values_count = get_array_type_count(values.type);
-			i64 mask_count   = get_array_type_count(mask.type);
-			if (ptr_count != values_count ||
-			    values_count != mask_count ||
-			    mask_count != ptr_count) {
-				gbString s = type_to_string(mask.type);
-				error(mask.expr, "All simd vectors must be of the same length, got %lld vs %lld vs %lld", cast(long long)ptr_count, cast(long long)values_count, cast(long long)mask_count);
-				gb_string_free(s);
-				return false;
+			if (id == BuiltinProc_simd_gather || id == BuiltinProc_simd_scatter) {
+				i64 ptr_count    = get_array_type_count(ptr.type);
+				i64 values_count = get_array_type_count(values.type);
+				i64 mask_count   = get_array_type_count(mask.type);
+				if (ptr_count != values_count ||
+				    values_count != mask_count ||
+				    mask_count != ptr_count) {
+					gbString s = type_to_string(mask.type);
+					error(mask.expr, "All simd vectors must be of the same length, got %lld vs %lld vs %lld", cast(long long)ptr_count, cast(long long)values_count, cast(long long)mask_count);
+					gb_string_free(s);
+					return false;
+				}
+			} else {
+				i64 values_count = get_array_type_count(values.type);
+				i64 mask_count   = get_array_type_count(mask.type);
+				if (values_count != mask_count) {
+					gbString s = type_to_string(mask.type);
+					error(mask.expr, "All simd vectors must be of the same length, got %lld vs %lld", cast(long long)values_count, cast(long long)mask_count);
+					gb_string_free(s);
+					return false;
+				}
 			}
 
-			if (id == BuiltinProc_simd_gather) {
+			if (id == BuiltinProc_simd_gather ||
+			    id == BuiltinProc_simd_masked_load) {
 				operand->mode = Addressing_Value;
 				operand->type = values.type;
 			} else {

+ 2 - 2
src/checker.cpp

@@ -1651,9 +1651,9 @@ gb_internal void add_type_and_value(CheckerContext *ctx, Ast *expr, AddressingMo
 
 		if (mode == Addressing_Constant || mode == Addressing_Invalid) {
 			expr->tav.value = value;
-		} else if (mode == Addressing_Value && is_type_typeid(type)) {
+		} else if (mode == Addressing_Value && type != nullptr && is_type_typeid(type)) {
 			expr->tav.value = value;
-		} else if (mode == Addressing_Value && is_type_proc(type)) {
+		} else if (mode == Addressing_Value && type != nullptr && is_type_proc(type)) {
 			expr->tav.value = value;
 		}
 

+ 6 - 3
src/checker_builtin_procs.hpp

@@ -193,7 +193,8 @@ BuiltinProc__simd_begin,
 
 	BuiltinProc_simd_gather,
 	BuiltinProc_simd_scatter,
-
+	BuiltinProc_simd_masked_load,
+	BuiltinProc_simd_masked_store,
 
 	// Platform specific SIMD intrinsics
 	BuiltinProc_simd_x86__MM_SHUFFLE,
@@ -525,8 +526,10 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 	{STR_LIT("simd_lanes_rotate_left"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 	{STR_LIT("simd_lanes_rotate_right"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 
-	{STR_LIT("simd_gather"),  3, false, Expr_Expr, BuiltinProcPkg_intrinsics},
-	{STR_LIT("simd_scatter"), 3, false, Expr_Stmt, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_gather"),       3, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_scatter"),      3, false, Expr_Stmt, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_masked_load"),  3, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_masked_store"), 3, false, Expr_Stmt, BuiltinProcPkg_intrinsics},
 
 	{STR_LIT("simd_x86__MM_SHUFFLE"), 4, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 

+ 15 - 3
src/llvm_backend_proc.cpp

@@ -1691,20 +1691,24 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
 
 	case BuiltinProc_simd_gather:
 	case BuiltinProc_simd_scatter:
+	case BuiltinProc_simd_masked_load:
+	case BuiltinProc_simd_masked_store:
 		{
 			LLVMValueRef ptr = arg0.value;
 			LLVMValueRef val = arg1.value;
 			LLVMValueRef mask = arg2.value;
 
-			unsigned count = cast(unsigned)get_array_type_count(arg0.type);
+			unsigned count = cast(unsigned)get_array_type_count(arg1.type);
 
 			LLVMTypeRef mask_type = LLVMVectorType(LLVMInt1TypeInContext(p->module->ctx), count);
 			mask = LLVMBuildTrunc(p->builder, mask, mask_type, "");
 
 			char const *name = nullptr;
 			switch (builtin_id) {
-			case BuiltinProc_simd_gather:  name = "llvm.masked.gather";  break;
-			case BuiltinProc_simd_scatter: name = "llvm.masked.scatter"; break;
+			case BuiltinProc_simd_gather:       name = "llvm.masked.gather";  break;
+			case BuiltinProc_simd_scatter:      name = "llvm.masked.scatter"; break;
+			case BuiltinProc_simd_masked_load:  name = "llvm.masked.load";    break;
+			case BuiltinProc_simd_masked_store: name = "llvm.masked.store";   break;
 			}
 			LLVMTypeRef types[2] = {
 				lb_type(p->module, arg1.type),
@@ -1716,12 +1720,20 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
 
 			LLVMValueRef args[4] = {};
 			switch (builtin_id) {
+			case BuiltinProc_simd_masked_load:
+				types[1] = lb_type(p->module, t_rawptr);
+				/*fallthrough*/
 			case BuiltinProc_simd_gather:
 				args[0] = ptr;
 				args[1] = align;
 				args[2] = mask;
 				args[3] = val;
+				// res.type = arg1.type;
 				break;
+
+			case BuiltinProc_simd_masked_store:
+				types[1] = lb_type(p->module, t_rawptr);
+				/*fallthrough*/
 			case BuiltinProc_simd_scatter:
 				args[0] = val;
 				args[1] = ptr;