Browse Source

Add `hadamard_product`

gingerBill 3 years ago
parent
commit
cee45c1b15
6 changed files with 84 additions and 13 deletions
  1. 55 1
      src/check_builtin.cpp
  2. 3 9
      src/check_type.cpp
  3. 2 0
      src/checker_builtin_procs.hpp
  4. 3 3
      src/llvm_backend_expr.cpp
  5. 10 0
      src/llvm_backend_proc.cpp
  6. 11 0
      src/types.cpp

+ 55 - 1
src/check_builtin.cpp

@@ -2056,6 +2056,14 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 			return false;
 		}
 		
+		Type *elem = xt->Array.elem;
+		
+		if (!is_type_valid_for_matrix_elems(elem)) {
+			gbString s = type_to_string(elem);
+			error(call, "Matrix elements types are limited to integers, floats, and complex, got %s", s);
+			gb_string_free(s);
+		}
+		
 		if (xt->Array.count == 0 || yt->Array.count == 0) {
 			gbString s1 = type_to_string(x.type);
 			gbString s2 = type_to_string(y.type);
@@ -2072,7 +2080,53 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 		}
 		
 		operand->mode = Addressing_Value;
-		operand->type = alloc_type_matrix(xt->Array.elem, xt->Array.count, yt->Array.count);	
+		operand->type = alloc_type_matrix(elem, xt->Array.count, yt->Array.count);	
+		operand->type = check_matrix_type_hint(operand->type, type_hint);
+		break;
+	}
+	
+	case BuiltinProc_hadamard_product: {
+		Operand x = {};
+		Operand y = {};
+		check_expr(c, &x, ce->args[0]);
+		if (x.mode == Addressing_Invalid) {
+			return false;
+		}
+		check_expr(c, &y, ce->args[1]);
+		if (y.mode == Addressing_Invalid) {
+			return false;
+		}
+		if (!is_operand_value(x) || !is_operand_value(y)) {
+			error(call, "'%.*s' expects a matrix or array types", LIT(builtin_name));
+			return false;
+		}
+		if (!is_type_matrix(x.type) && !is_type_array(y.type)) {
+			gbString s1 = type_to_string(x.type);
+			gbString s2 = type_to_string(y.type);
+			error(call, "'%.*s' expects matrix or array values, got %s and %s", LIT(builtin_name), s1, s2);
+			gb_string_free(s2);
+			gb_string_free(s1);
+			return false;
+		}
+		
+		if (!are_types_identical(x.type, y.type)) {
+			gbString s1 = type_to_string(x.type);
+			gbString s2 = type_to_string(y.type);
+			error(call, "'%.*s' values of the same type, got %s and %s", LIT(builtin_name), s1, s2);
+			gb_string_free(s2);
+			gb_string_free(s1);
+			return false;
+		}
+		
+		Type *elem = core_array_type(x.type);
+		if (!is_type_valid_for_matrix_elems(elem)) {
+			gbString s = type_to_string(elem);
+			error(call, "'%.*s' expects elements to be types are limited to integers, floats, and complex, got %s", LIT(builtin_name), s);
+			gb_string_free(s);
+		}
+		
+		operand->mode = Addressing_Value;
+		operand->type = x.type;
 		operand->type = check_matrix_type_hint(operand->type, type_hint);
 		break;
 	}

+ 3 - 9
src/check_type.cpp

@@ -997,8 +997,8 @@ void check_bit_set_type(CheckerContext *c, Type *type, Type *named_type, Ast *no
 
 				GB_ASSERT(lower <= upper);
 
-				i64 bits = MAX_BITS;
-				if (bs->underlying != nullptr) {
+				i64 bits = MAX_BITS
+;				if (bs->underlying != nullptr) {
 					Type *u = check_type(c, bs->underlying);
 					if (!is_type_integer(u)) {
 						gbString ts = type_to_string(u);
@@ -2239,13 +2239,7 @@ void check_matrix_type(CheckerContext *ctx, Type **type, Ast *node) {
 		error(column.expr, "Matrix types are limited to a maximum of %d elements, got %lld", MAX_MATRIX_ELEMENT_COUNT, cast(long long)element_count);
 	}
 	
-	if (is_type_integer(elem)) {
-		// okay
-	} else if (is_type_float(elem)) {
-		// okay
-	} else if (is_type_complex(elem)) {
-		// okay
-	} else {
+	if (!is_type_valid_for_matrix_elems(elem)) {
 		gbString s = type_to_string(elem);
 		error(column.expr, "Matrix elements types are limited to integers, floats, and complex, got %s", s);
 		gb_string_free(s);

+ 2 - 0
src/checker_builtin_procs.hpp

@@ -37,6 +37,7 @@ enum BuiltinProcId {
 	
 	BuiltinProc_transpose,
 	BuiltinProc_outer_product,
+	BuiltinProc_hadamard_product,
 
 	BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures
 
@@ -280,6 +281,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 	
 	{STR_LIT("transpose"),        1, false, Expr_Expr, BuiltinProcPkg_builtin},
 	{STR_LIT("outer_product"),    2, false, Expr_Expr, BuiltinProcPkg_builtin},
+	{STR_LIT("hadamard_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin},
 
 	{STR_LIT(""),                 0, true,  Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE
 

+ 3 - 3
src/llvm_backend_expr.cpp

@@ -672,13 +672,13 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 
 
 
-lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
+lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise=false) {
 	GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
 	
 	Type *xt = base_type(lhs.type);
 	Type *yt = base_type(rhs.type);
 	
-	if (op == Token_Mul) {
+	if (op == Token_Mul && !component_wise) {
 		if (xt->kind == Type_Matrix) {
 			if (yt->kind == Type_Matrix) {
 				return lb_emit_matrix_mul(p, lhs, rhs, type);
@@ -703,7 +703,7 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue
 		array_lhs.type = array_type; 
 		array_rhs.type = array_type;
 
-		lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, type);
+		lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, array_type);
 		array.type = type;
 		return array;
 	}

+ 10 - 0
src/llvm_backend_proc.cpp

@@ -1270,6 +1270,16 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
 			lbValue b = lb_build_expr(p, ce->args[1]);
 			return lb_emit_outer_product(p, a, b, tv.type);
 		}
+	case BuiltinProc_hadamard_product:
+		{
+			lbValue a = lb_build_expr(p, ce->args[0]);
+			lbValue b = lb_build_expr(p, ce->args[1]);
+			if (is_type_array(tv.type)) {
+				return lb_emit_arith(p, Token_Mul, a, b, tv.type);
+			}
+			GB_ASSERT(is_type_matrix(tv.type));
+			return lb_emit_arith_matrix(p, Token_Mul, a, b, tv.type, true);
+		}
 
 
 	// "Intrinsics"

+ 11 - 0
src/types.cpp

@@ -1333,6 +1333,17 @@ i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) {
 	return stride_elems*column_index + row_index;
 }
 
+bool is_type_valid_for_matrix_elems(Type *t) {
+	if (is_type_integer(t)) {
+		return true;
+	} else if (is_type_float(t)) {
+		return true;
+	} else if (is_type_complex(t)) {
+		return true;
+	}
+	return false;
+}
+
 bool is_type_dynamic_array(Type *t) {
 	t = base_type(t);
 	return t->kind == Type_DynamicArray;