Browse Source

Allow conversions between matrices of the same element count

gingerBill 4 years ago
parent
commit
48d277a3c4
4 changed files with 46 additions and 34 deletions
  1. 3 1
      src/check_expr.cpp
  2. 2 2
      src/llvm_backend_const.cpp
  3. 28 10
      src/llvm_backend_expr.cpp
  4. 13 21
      src/types.cpp

+ 3 - 1
src/check_expr.cpp

@@ -2469,7 +2469,9 @@ bool check_is_castable_to(CheckerContext *c, Operand *operand, Type *y) {
 		}
 		}
 		
 		
 		if (src->Matrix.row_count != src->Matrix.column_count) {
 		if (src->Matrix.row_count != src->Matrix.column_count) {
-			return false;
+			i64 src_count = src->Matrix.row_count*src->Matrix.column_count;
+			i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count;
+			return src_count == dst_count;
 		}
 		}
 		
 		
 		if (dst->Matrix.row_count != dst->Matrix.column_count) {
 		if (dst->Matrix.row_count != dst->Matrix.column_count) {

+ 2 - 2
src/llvm_backend_const.cpp

@@ -524,7 +524,7 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc
 		lbValue single_elem = lb_const_value(m, elem, value, allow_local);
 		lbValue single_elem = lb_const_value(m, elem, value, allow_local);
 		single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem));
 		single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem));
 				
 				
-		i64 total_elem_count = matrix_type_total_elems(type);
+		i64 total_elem_count = matrix_type_total_internal_elems(type);
 		LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count);		
 		LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count);		
 		for (i64 i = 0; i < row; i++) {
 		for (i64 i = 0; i < row; i++) {
 			elems[matrix_indices_to_offset(type, i, i)] = single_elem.value;
 			elems[matrix_indices_to_offset(type, i, i)] = single_elem.value;
@@ -990,7 +990,7 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc
 			}
 			}
 			
 			
 			i64 max_count = type->Matrix.row_count*type->Matrix.column_count;
 			i64 max_count = type->Matrix.row_count*type->Matrix.column_count;
-			i64 total_count = matrix_type_total_elems(type);
+			i64 total_count = matrix_type_total_internal_elems(type);
 			
 			
 			LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count);
 			LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count);
 			if (cl->elems[0]->kind == Ast_FieldValue) {
 			if (cl->elems[0]->kind == Ast_FieldValue) {

+ 28 - 10
src/llvm_backend_expr.cpp

@@ -508,7 +508,7 @@ LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
 	GB_ASSERT(mt->kind == Type_Matrix);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem);
 	LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem);
 	
 	
-	unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
+	unsigned total_count = cast(unsigned)matrix_type_total_internal_elems(mt);
 	LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
 	LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
 	
 	
 	LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value;
 	LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value;
@@ -948,7 +948,7 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue
 		// pretend it is an array
 		// pretend it is an array
 		lbValue array_lhs = lhs;
 		lbValue array_lhs = lhs;
 		lbValue array_rhs = rhs;
 		lbValue array_rhs = rhs;
-		Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_elems(xt));
+		Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_internal_elems(xt));
 		GB_ASSERT(type_size_of(array_type) == type_size_of(xt));
 		GB_ASSERT(type_size_of(array_type) == type_size_of(xt));
 		
 		
 		array_lhs.type = array_type; 
 		array_lhs.type = array_type; 
@@ -1941,15 +1941,33 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) {
 		GB_ASSERT(dst->kind == Type_Matrix);
 		GB_ASSERT(dst->kind == Type_Matrix);
 		GB_ASSERT(src->kind == Type_Matrix);
 		GB_ASSERT(src->kind == Type_Matrix);
 		lbAddr v = lb_add_local_generated(p, t, true);
 		lbAddr v = lb_add_local_generated(p, t, true);
-		for (i64 j = 0; j < dst->Matrix.column_count; j++) {
-			for (i64 i = 0; i < dst->Matrix.row_count; i++) {
-				if (i < src->Matrix.row_count && j < src->Matrix.column_count) {
-					lbValue d = lb_emit_matrix_epi(p, v.addr, i, j);
+		
+		if (is_matrix_square(dst) && is_matrix_square(dst)) {
+			for (i64 j = 0; j < dst->Matrix.column_count; j++) {
+				for (i64 i = 0; i < dst->Matrix.row_count; i++) {
+					if (i < src->Matrix.row_count && j < src->Matrix.column_count) {
+						lbValue d = lb_emit_matrix_epi(p, v.addr, i, j);
+						lbValue s = lb_emit_matrix_ev(p, value, i, j);
+						lb_emit_store(p, d, s);
+					} else if (i == j) {
+						lbValue d = lb_emit_matrix_epi(p, v.addr, i, j);
+						lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true);
+						lb_emit_store(p, d, s);
+					}
+				}
+			}
+		} else {
+			i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count;
+			i64 src_count = src->Matrix.row_count*src->Matrix.column_count;
+			GB_ASSERT(dst_count == src_count);
+			
+			for (i64 j = 0; j < src->Matrix.column_count; j++) {
+				for (i64 i = 0; i < src->Matrix.row_count; i++) {
 					lbValue s = lb_emit_matrix_ev(p, value, i, j);
 					lbValue s = lb_emit_matrix_ev(p, value, i, j);
-					lb_emit_store(p, d, s);
-				} else if (i == j) {
-					lbValue d = lb_emit_matrix_epi(p, v.addr, i, j);
-					lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true);
+					i64 index = i + j*src->Matrix.row_count;					
+					i64 dst_i = index%dst->Matrix.row_count;
+					i64 dst_j = index/dst->Matrix.row_count;
+					lbValue d = lb_emit_matrix_epi(p, v.addr, dst_i, dst_j);
 					lb_emit_store(p, d, s);
 					lb_emit_store(p, d, s);
 				}
 				}
 			}
 			}

+ 13 - 21
src/types.cpp

@@ -1293,7 +1293,7 @@ i64 matrix_type_stride_in_elems(Type *t) {
 }
 }
 
 
 
 
-i64 matrix_type_total_elems(Type *t) {
+i64 matrix_type_total_internal_elems(Type *t) {
 	t = base_type(t);
 	t = base_type(t);
 	GB_ASSERT(t->kind == Type_Matrix);
 	GB_ASSERT(t->kind == Type_Matrix);
 	i64 size = type_size_of(t);
 	i64 size = type_size_of(t);
@@ -1301,37 +1301,29 @@ i64 matrix_type_total_elems(Type *t) {
 	return size/gb_max(elem_size, 1);
 	return size/gb_max(elem_size, 1);
 }
 }
 
 
-void matrix_indices_from_index(Type *t, i64 index, i64 *row_index_, i64 *column_index_) {
+i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) {
 	t = base_type(t);
 	t = base_type(t);
 	GB_ASSERT(t->kind == Type_Matrix);
 	GB_ASSERT(t->kind == Type_Matrix);
-	i64 row_count = t->Matrix.row_count;
-	i64 column_count = t->Matrix.column_count;
-	GB_ASSERT(0 <= index && index < row_count*column_count);
-	
-	i64 row_index = index / column_count;
-	i64 column_index = index % column_count;
-	
-	if (row_index_)    *row_index_    = row_index;
-	if (column_index_) *column_index_ = column_index;
+	GB_ASSERT(0 <= row_index && row_index < t->Matrix.row_count);
+	GB_ASSERT(0 <= column_index && column_index < t->Matrix.column_count);
+	i64 stride_elems = matrix_type_stride_in_elems(t);
+	return stride_elems*column_index + row_index;
 }
 }
-
 i64 matrix_index_to_offset(Type *t, i64 index) {
 i64 matrix_index_to_offset(Type *t, i64 index) {
 	t = base_type(t);
 	t = base_type(t);
 	GB_ASSERT(t->kind == Type_Matrix);
 	GB_ASSERT(t->kind == Type_Matrix);
 	
 	
-	i64 row_index, column_index;
-	matrix_indices_from_index(t, index, &row_index, &column_index);
-	i64 stride_elems = matrix_type_stride_in_elems(t);
-	return stride_elems*column_index + row_index;
+	i64 row_index    = index%t->Matrix.row_count;
+	i64 column_index = index/t->Matrix.row_count;
+	return matrix_indices_to_offset(t, row_index, column_index);
 }
 }
 
 
-i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) {
+
+
+bool is_matrix_square(Type *t) {
 	t = base_type(t);
 	t = base_type(t);
 	GB_ASSERT(t->kind == Type_Matrix);
 	GB_ASSERT(t->kind == Type_Matrix);
-	GB_ASSERT(0 <= row_index && row_index < t->Matrix.row_count);
-	GB_ASSERT(0 <= column_index && column_index < t->Matrix.column_count);
-	i64 stride_elems = matrix_type_stride_in_elems(t);
-	return stride_elems*column_index + row_index;
+	return t->Matrix.row_count == t->Matrix.column_count;
 }
 }
 
 
 bool is_type_valid_for_matrix_elems(Type *t) {
 bool is_type_valid_for_matrix_elems(Type *t) {