Browse Source

Minor clean up for `lb_matrix_trimmed_vector_mask`

gingerBill 3 years ago
parent
commit
30c141ceb9
1 changed files with 8 additions and 6 deletions
  1. 8 6
      src/llvm_backend_expr.cpp

+ 8 - 6
src/llvm_backend_expr.cpp

@@ -517,18 +517,14 @@ LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
 	return matrix_vector;
 	return matrix_vector;
 }
 }
 
 
-LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
-	Type *mt = base_type(m.type);
+LLVMValueRef lb_matrix_trimmed_vector_mask(lbProcedure *p, Type *mt) {
+	mt = base_type(mt);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	
 	
 	unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 	unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 	unsigned row_count = cast(unsigned)mt->Matrix.row_count;
 	unsigned row_count = cast(unsigned)mt->Matrix.row_count;
 	unsigned column_count = cast(unsigned)mt->Matrix.column_count;
 	unsigned column_count = cast(unsigned)mt->Matrix.column_count;
 	
 	
-	auto columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
-	
-	LLVMValueRef vector = lb_matrix_to_vector(p, m);
-	
 	unsigned mask_elems_index = 0;
 	unsigned mask_elems_index = 0;
 	auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count*column_count);
 	auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count*column_count);
 	for (unsigned j = 0; j < column_count; j++) {
 	for (unsigned j = 0; j < column_count; j++) {
@@ -539,6 +535,12 @@ LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
 	}
 	}
 	
 	
 	LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count);
 	LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count);
+	return mask;
+}
+
+LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
+	LLVMValueRef vector = lb_matrix_to_vector(p, m);
+	LLVMValueRef mask = lb_matrix_trimmed_vector_mask(p, m.type);
 	LLVMValueRef trimmed_vector = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, "");
 	LLVMValueRef trimmed_vector = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, "");
 	return trimmed_vector;
 	return trimmed_vector;
 }
 }