Procházet zdrojové kódy

Implement backend for simd intrinsics

gingerBill před 3 roky
rodič
revize
81dd727f75
2 změnil soubory, kde provedl 255 přidání a 4 odebrání
  1. 54 4
      src/check_builtin.cpp
  2. 201 0
      src/llvm_backend_proc.cpp

+ 54 - 4
src/check_builtin.cpp

@@ -464,10 +464,6 @@ bool check_builtin_simd_operation(CheckerContext *c, Operand *operand, Ast *call
 
 	// Integer only
 	case BuiltinProc_simd_rem:
-	case BuiltinProc_simd_shl:
-	case BuiltinProc_simd_shr:
-	case BuiltinProc_simd_shl_masked:
-	case BuiltinProc_simd_shr_masked:
 	case BuiltinProc_simd_and:
 	case BuiltinProc_simd_or:
 	case BuiltinProc_simd_xor:
@@ -510,6 +506,60 @@ bool check_builtin_simd_operation(CheckerContext *c, Operand *operand, Ast *call
 			operand->type = x.type;
 			return true;
 		}
+
+	case BuiltinProc_simd_shl:
+	case BuiltinProc_simd_shr:
+	case BuiltinProc_simd_shl_masked:
+	case BuiltinProc_simd_shr_masked:
+		{
+			Operand x = {};
+			Operand y = {};
+			check_expr(c, &x, ce->args[0]);
+			check_expr(c, &y, ce->args[1]);
+			if (x.mode == Addressing_Invalid) {
+				return false;
+			}
+			if (y.mode == Addressing_Invalid) {
+				return false;
+			}
+			if (!is_type_simd_vector(x.type)) {
+				error(x.expr, "'%.*s' expected a simd vector type", LIT(builtin_name));
+				return false;
+			}
+			if (!is_type_simd_vector(y.type)) {
+				error(y.expr, "'%.*s' expected a simd vector type", LIT(builtin_name));
+				return false;
+			}
+			GB_ASSERT(x.type->kind == Type_SimdVector);
+			GB_ASSERT(y.type->kind == Type_SimdVector);
+			Type *xt = x.type;
+			Type *yt = y.type;
+
+			if (xt->SimdVector.count != yt->SimdVector.count) {
+				error(x.expr, "'%.*s' mismatched simd vector lengths, got '%lld' vs '%lld'",
+				      LIT(builtin_name),
+				      cast(long long)xt->SimdVector.count,
+				      cast(long long)yt->SimdVector.count);
+				return false;
+			}
+			if (!is_type_integer(base_array_type(x.type))) {
+				gbString xs = type_to_string(x.type);
+				error(x.expr, "'%.*s' expected a #simd type with an integer element, got '%s'", LIT(builtin_name), xs);
+				gb_string_free(xs);
+				return false;
+			}
+			if (!is_type_unsigned(base_array_type(y.type))) {
+				gbString ys = type_to_string(y.type);
+				error(y.expr, "'%.*s' expected a #simd type with an unsigned integer element as the shifting operand, got '%s'", LIT(builtin_name), ys);
+				gb_string_free(ys);
+				return false;
+			}
+
+			operand->mode = Addressing_Value;
+			operand->type = x.type;
+			return true;
+		}
+
 	// Unary
 	case BuiltinProc_simd_neg:
 	case BuiltinProc_simd_abs:

+ 201 - 0
src/llvm_backend_proc.cpp

@@ -981,10 +981,211 @@ lbValue lb_emit_call(lbProcedure *p, lbValue value, Array<lbValue> const &args,
 	return result;
 }
 
+lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, BuiltinProcId id) {
+	ast_node(ce, CallExpr, expr);
+
+	lbModule *m = p->module;
+
+	lbValue res = {};
+	res.type = tv.type;
+
+	lbValue arg0 = lb_build_expr(p, ce->args[0]);
+	lbValue arg1 = {};
+
+	Type *elem = base_array_type(arg0.type);
+
+	bool is_float = is_type_float(elem);
+	bool is_signed = !is_type_unsigned(elem);
+
+	LLVMOpcode op_code = cast(LLVMOpcode)0;
+
+	switch (id) {
+	case BuiltinProc_simd_add:
+	case BuiltinProc_simd_sub:
+	case BuiltinProc_simd_mul:
+	case BuiltinProc_simd_div:
+	case BuiltinProc_simd_rem:
+		arg1 = lb_build_expr(p, ce->args[1]);
+		if (is_float) {
+			switch (id) {
+			case BuiltinProc_simd_add: op_code = LLVMFAdd; break;
+			case BuiltinProc_simd_sub: op_code = LLVMFSub; break;
+			case BuiltinProc_simd_mul: op_code = LLVMFMul; break;
+			case BuiltinProc_simd_div: op_code = LLVMFDiv; break;
+			}
+		} else {
+			switch (id) {
+			case BuiltinProc_simd_add: op_code = LLVMAdd; break;
+			case BuiltinProc_simd_sub: op_code = LLVMSub; break;
+			case BuiltinProc_simd_mul: op_code = LLVMMul; break;
+			case BuiltinProc_simd_div:
+				if (is_signed) {
+					op_code = LLVMSDiv;
+				} else {
+					op_code = LLVMUDiv;
+				}
+				break;
+			case BuiltinProc_simd_rem:
+				if (is_signed) {
+					op_code = LLVMSRem;
+				} else {
+					op_code = LLVMURem;
+				}
+				break;
+			}
+		}
+		if (op_code) {
+			res.value = LLVMBuildBinOp(p->builder, op_code, arg0.value, arg1.value, "");
+			return res;
+		}
+		break;
+	case BuiltinProc_simd_shl: // Odin logic
+	case BuiltinProc_simd_shr: // Odin logic
+	case BuiltinProc_simd_shl_masked: // C logic
+	case BuiltinProc_simd_shr_masked: // C logic
+		arg1 = lb_build_expr(p, ce->args[1]);
+		{
+			i64 sz = type_size_of(elem);
+			GB_ASSERT(arg0.type->kind == Type_SimdVector);
+
+			i64 count = arg0.type->SimdVector.count;
+			Type *elem1 = base_array_type(arg1.type);
+
+			bool is_masked = false;
+			switch (id) {
+			case BuiltinProc_simd_shl:        op_code = LLVMShl;                         is_masked = false; break;
+			case BuiltinProc_simd_shr:        op_code = is_signed ? LLVMAShr : LLVMLShr; is_masked = false; break;
+			case BuiltinProc_simd_shl_masked: op_code = LLVMShl;                         is_masked = true;  break;
+			case BuiltinProc_simd_shr_masked: op_code = is_signed ? LLVMAShr : LLVMLShr; is_masked = true;  break;
+			}
+			if (op_code) {
+				LLVMValueRef bit_value = lb_const_int(m, elem1, sz*8 - 1).value;
+				LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, count);
+				for (i64 i = 0; i < count; i++) {
+					values[i] = bit_value;
+				}
+				LLVMValueRef bits = LLVMConstVector(values, cast(unsigned)count);
+				if (is_masked) {
+					// C logic
+					LLVMValueRef shift = LLVMBuildAnd(p->builder, arg1.value, bits, "");
+					res.value = LLVMBuildBinOp(p->builder, op_code, arg0.value, shift, "");
+				} else {
+					// Odin logic
+					LLVMValueRef zero = lb_const_nil(m, arg1.type).value;
+					LLVMValueRef mask = LLVMBuildICmp(p->builder, LLVMIntULE, arg1.value, bits, "");
+					LLVMValueRef shift = LLVMBuildBinOp(p->builder, op_code, arg0.value, arg1.value, "");
+					res.value = LLVMBuildSelect(p->builder, mask, shift, zero, "");
+				}
+
+				return res;
+			}
+		}
+		break;
+	case BuiltinProc_simd_and:
+	case BuiltinProc_simd_or:
+	case BuiltinProc_simd_xor:
+		arg1 = lb_build_expr(p, ce->args[1]);
+		switch (id) {
+		case BuiltinProc_simd_and: op_code = LLVMAnd; break;
+		case BuiltinProc_simd_or:  op_code = LLVMOr;  break;
+		case BuiltinProc_simd_xor: op_code = LLVMXor; break;
+		}
+		if (op_code) {
+			res.value = LLVMBuildBinOp(p->builder, op_code, arg0.value, arg1.value, "");
+			return res;
+		}
+		break;
+	case BuiltinProc_simd_neg:
+		if (is_float) {
+			res.value = LLVMBuildFNeg(p->builder, arg0.value, "");
+		} else {
+			res.value = LLVMBuildNeg(p->builder, arg0.value, "");
+		}
+		return res;
+	case BuiltinProc_simd_abs:
+		if (is_float) {
+			LLVMValueRef pos = arg0.value;
+			LLVMValueRef neg = LLVMBuildFNeg(p->builder, pos, "");
+			LLVMValueRef cond = LLVMBuildFCmp(p->builder, LLVMRealOGT, pos, neg, "");
+			res.value = LLVMBuildSelect(p->builder, cond, pos, neg, "");
+		} else {
+			LLVMValueRef pos = arg0.value;
+			LLVMValueRef neg = LLVMBuildNeg(p->builder, pos, "");
+			LLVMValueRef cond = LLVMBuildICmp(p->builder, is_signed ? LLVMIntSGT : LLVMIntUGT, pos, neg, "");
+			res.value = LLVMBuildSelect(p->builder, cond, pos, neg, "");
+		}
+		return res;
+	case BuiltinProc_simd_min:
+		if (is_float) {
+			LLVMValueRef cond = LLVMBuildFCmp(p->builder, LLVMRealOLT, arg0.value, arg1.value, "");
+			res.value = LLVMBuildSelect(p->builder, cond, arg0.value, arg1.value, "");
+		} else {
+			LLVMValueRef cond = LLVMBuildICmp(p->builder, is_signed ? LLVMIntSLT : LLVMIntULT, arg0.value, arg1.value, "");
+			res.value = LLVMBuildSelect(p->builder, cond, arg0.value, arg1.value, "");
+		}
+		return res;
+	case BuiltinProc_simd_max:
+		arg1 = lb_build_expr(p, ce->args[1]);
+		if (is_float) {
+			LLVMValueRef cond = LLVMBuildFCmp(p->builder, LLVMRealOGT, arg0.value, arg1.value, "");
+			res.value = LLVMBuildSelect(p->builder, cond, arg0.value, arg1.value, "");
+		} else {
+			LLVMValueRef cond = LLVMBuildICmp(p->builder, is_signed ? LLVMIntSGT : LLVMIntUGT, arg0.value, arg1.value, "");
+			res.value = LLVMBuildSelect(p->builder, cond, arg0.value, arg1.value, "");
+		}
+		return res;
+	case BuiltinProc_simd_eq:
+	case BuiltinProc_simd_ne:
+	case BuiltinProc_simd_lt:
+	case BuiltinProc_simd_le:
+	case BuiltinProc_simd_gt:
+	case BuiltinProc_simd_ge:
+		arg1 = lb_build_expr(p, ce->args[1]);
+		if (is_float) {
+			LLVMRealPredicate pred = cast(LLVMRealPredicate)0;
+			switch (id) {
+			case BuiltinProc_simd_eq: pred = LLVMRealOEQ; break;
+			case BuiltinProc_simd_ne: pred = LLVMRealONE; break;
+			case BuiltinProc_simd_lt: pred = LLVMRealOLT; break;
+			case BuiltinProc_simd_le: pred = LLVMRealOLE; break;
+			case BuiltinProc_simd_gt: pred = LLVMRealOGT; break;
+			case BuiltinProc_simd_ge: pred = LLVMRealOGE; break;
+			}
+			if (pred) {
+				res.value = LLVMBuildFCmp(p->builder, pred, arg0.value, arg1.value, "");
+				res.value = LLVMBuildSExtOrBitCast(p->builder, res.value, lb_type(m, tv.type), "");
+				return res;
+			}
+		} else {
+			LLVMIntPredicate pred = cast(LLVMIntPredicate)0;
+			switch (id) {
+			case BuiltinProc_simd_eq: pred = LLVMIntEQ; break;
+			case BuiltinProc_simd_ne: pred = LLVMIntNE; break;
+			case BuiltinProc_simd_lt: pred = is_signed ? LLVMIntSLT :LLVMIntULT; break;
+			case BuiltinProc_simd_le: pred = is_signed ? LLVMIntSLE :LLVMIntULE; break;
+			case BuiltinProc_simd_gt: pred = is_signed ? LLVMIntSGT :LLVMIntUGT; break;
+			case BuiltinProc_simd_ge: pred = is_signed ? LLVMIntSGE :LLVMIntUGE; break;
+			}
+			if (pred) {
+				res.value = LLVMBuildICmp(p->builder, pred, arg0.value, arg1.value, "");
+				res.value = LLVMBuildSExtOrBitCast(p->builder, res.value, lb_type(m, tv.type), "");
+				return res;
+			}
+		}
+		break;
+	}
+	GB_PANIC("Unhandled simd intrinsic: '%.*s'", LIT(builtin_procs[id].name));
+	return {};
+}
+
 
 lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, BuiltinProcId id) {
 	ast_node(ce, CallExpr, expr);
 
+	if (BuiltinProc__simd_begin < id && id < BuiltinProc__simd_end) {
+		return lb_build_builtin_simd_proc(p, expr, tv, id);
+	}
+
 	switch (id) {
 	case BuiltinProc_DIRECTIVE: {
 		ast_node(bd, BasicDirective, ce->proc);