Browse Source

Try to support the matrix multiplication LLVM intrinsics

gingerBill 3 years ago
parent
commit
35111b39b8
1 changed files with 249 additions and 24 deletions
  1. 249 24
      src/llvm_backend_expr.cpp

+ 249 - 24
src/llvm_backend_expr.cpp

@@ -476,6 +476,254 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r
 	}
 }
 
+bool lb_matrix_elem_simple(Type *t) {
+	Type *mt = base_type(t);
+	GB_ASSERT(mt->kind == Type_Matrix);
+	
+	Type *elem = core_type(mt->Matrix.elem);
+	if (is_type_complex(elem)) {
+		return false;
+	}
+	
+	if (is_type_different_to_arch_endianness(elem)) {
+		return false;
+	}
+	
+	if (elem->kind == Type_Basic) {
+		switch (elem->Basic.kind) {
+		case Basic_f16:
+		case Basic_f16le:
+		case Basic_f16be:
+			// TODO(bill): determine when this is fine
+			return false;
+		}
+	}
+	
+	return true;
+}
+
+LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) {
+	lbModule *m = p->module;
+	
+	Type *mt = base_type(lhs.type);
+	GB_ASSERT(mt->kind == Type_Matrix);
+	GB_ASSERT(lb_matrix_elem_simple(mt));
+	
+	unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
+	
+	Type *elem = mt->Matrix.elem;
+	LLVMTypeRef elem_type = lb_type(m, elem);
+	
+	LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
+	LLVMTypeRef types[] = {vector_type};
+	
+	char const *name = "llvm.matrix.column.major.load";
+	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
+	
+	lbValue ptr = lb_address_from_load_or_generate_local(p, lhs);
+	ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
+	
+	LLVMValueRef values[5] = {};
+	values[0] = ptr.value;
+	values[1] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
+	values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
+	values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
+	values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
+	
+	return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+}
+LLVMValueRef llvm_matrix_column_major_load_from_ptr(lbProcedure *p, lbValue ptr) {
+	lbModule *m = p->module;
+
+	Type *mt = base_type(type_deref(ptr.type));
+	GB_ASSERT(mt->kind == Type_Matrix);
+	GB_ASSERT(lb_matrix_elem_simple(mt));
+	
+	unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
+	
+	Type *elem = mt->Matrix.elem;
+	LLVMTypeRef elem_type = lb_type(m, elem);
+	
+	LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
+	LLVMTypeRef types[] = {vector_type};
+	
+	char const *name = "llvm.matrix.column.major.load";
+	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
+	
+	LLVMValueRef values[5] = {};
+	values[0] = lb_emit_matrix_epi(p, ptr, 0, 0).value;
+	values[1] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
+	values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
+	values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
+	values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
+	
+	return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+}
+
+void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef vector_value) {
+	lbModule *m = p->module;
+	
+	Type *mt = base_type(lb_addr_type(addr));
+	GB_ASSERT(mt->kind == Type_Matrix);
+	GB_ASSERT(lb_matrix_elem_simple(mt));
+	
+	unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
+	
+	Type *elem = mt->Matrix.elem;
+	LLVMTypeRef elem_type = lb_type(m, elem);
+	
+	LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
+	LLVMTypeRef types[] = {vector_type};
+	
+	char const *name = "llvm.matrix.column.major.store";
+	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
+	
+	lbValue ptr = lb_addr_get_ptr(p, addr);
+	ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
+	
+	GB_ASSERT(LLVMTypeOf(vector_value) == vector_type);
+	unsigned vector_size = LLVMGetVectorSize(vector_type);
+	GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
+	
+	LLVMValueRef values[6] = {};
+	values[0] = vector_value;
+	values[1] = ptr.value;
+	values[2] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
+	values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
+	values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
+	values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
+	
+	LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+}
+
+void llvm_matrix_column_major_store_to_raw_ptr(lbProcedure *p, Type *mt, lbValue ptr, LLVMValueRef vector_value) {
+	lbModule *m = p->module;
+	
+	mt = base_type(mt);
+	GB_ASSERT(mt->kind == Type_Matrix);
+	GB_ASSERT(lb_matrix_elem_simple(mt));
+	
+	unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
+	
+	Type *elem = mt->Matrix.elem;
+	LLVMTypeRef elem_type = lb_type(m, elem);
+	
+	LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
+	LLVMTypeRef types[] = {vector_type};
+	
+	char const *name = "llvm.matrix.column.major.store";
+	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
+	
+	GB_ASSERT(LLVMTypeOf(vector_value) == vector_type);
+	unsigned vector_size = LLVMGetVectorSize(vector_type);
+	GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
+	
+	LLVMValueRef values[6] = {};
+	values[0] = vector_value;
+	values[1] = ptr.value;
+	values[2] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
+	values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
+	values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
+	values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
+	
+	LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+}
+
+LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, i64 outer_rows, i64 inner, i64 outer_columns) {
+	lbModule *m = p->module;
+	
+	LLVMTypeRef a_type = LLVMTypeOf(a);
+	LLVMTypeRef b_type = LLVMTypeOf(b);
+	
+	GB_ASSERT(LLVMGetElementType(a_type) == LLVMGetElementType(b_type));
+	
+	LLVMTypeRef elem_type = LLVMGetElementType(a_type);
+	
+	LLVMTypeRef res_vector_type = LLVMVectorType(elem_type, cast(unsigned)(outer_rows*outer_columns));
+	LLVMTypeRef types[] = {res_vector_type, a_type, b_type};
+	
+	char const *name = "llvm.matrix.multiply";
+	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
+
+	LLVMValueRef values[5] = {};
+	values[0] = a;
+	values[1] = b;
+	values[2] = lb_const_int(m, t_u32, outer_rows).value;
+	values[3] = lb_const_int(m, t_u32, inner).value;
+	values[4] = lb_const_int(m, t_u32, outer_columns).value;
+	
+	return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+}
+
+
+lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+	Type *xt = base_type(lhs.type);
+	Type *yt = base_type(rhs.type);
+	
+	GB_ASSERT(is_type_matrix(type));
+	GB_ASSERT(is_type_matrix(xt));
+	GB_ASSERT(is_type_matrix(yt));
+	GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
+	GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
+		
+	if (!lb_matrix_elem_simple(xt)) {
+		goto slow_form;
+	}
+	
+	if (false) {
+		// TODO(bill): LLVM ERROR: Do not know how to split the result of this operator!
+		lbAddr res = lb_add_local_generated(p, type, true);
+		
+		lbValue res_ptr = lb_addr_get_ptr(p, res);
+		res_ptr = lb_emit_matrix_epi(p, res_ptr, 0, 0);
+		
+		lbValue lhs_ptr = lb_address_from_load_or_generate_local(p, lhs);
+		lbValue rhs_ptr = lb_address_from_load_or_generate_local(p, rhs);
+		LLVMValueRef a = llvm_matrix_column_major_load_from_ptr(p, lhs_ptr);
+		LLVMValueRef b = llvm_matrix_column_major_load_from_ptr(p, rhs_ptr);
+		LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count);
+		
+		llvm_matrix_column_major_store_to_raw_ptr(p, type, res_ptr, c);
+		
+		return lb_addr_load(p, res);
+	}
+		
+slow_form:
+	{
+		Type *elem = xt->Matrix.elem;	
+		
+		lbAddr res = lb_add_local_generated(p, type, true);
+		
+		for (i64 i = 0; i < xt->Matrix.row_count; i++) {
+			for (i64 j = 0; j < yt->Matrix.column_count; j++) {
+				for (i64 k = 0; k < xt->Matrix.column_count; k++) {
+					lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
+					
+					lbValue a = lb_emit_matrix_ev(p, lhs, i, k);
+					lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
+					lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
+					lbValue d = lb_emit_load(p, dst);
+					lbValue e = lb_emit_arith(p, Token_Add, d, c, elem);
+					lb_emit_store(p, dst, e);
+					
+				}
+			}
+		}
+		
+		return lb_addr_load(p, res);
+	}
+}
+
 
 lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
 	GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
@@ -486,30 +734,7 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue
 	if (op == Token_Mul) {
 		if (xt->kind == Type_Matrix) {
 			if (yt->kind == Type_Matrix) {
-				GB_ASSERT(is_type_matrix(type));
-				GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
-				GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
-				
-				Type *elem = xt->Matrix.elem;
-				
-				lbAddr res = lb_add_local_generated(p, type, true);
-				for (i64 i = 0; i < xt->Matrix.row_count; i++) {
-					for (i64 j = 0; j < yt->Matrix.column_count; j++) {
-						for (i64 k = 0; k < xt->Matrix.column_count; k++) {
-							lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
-							
-							lbValue a = lb_emit_matrix_ev(p, lhs, i, k);
-							lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
-							lbValue c = lb_emit_arith(p, op, a, b, elem);
-							lbValue d = lb_emit_load(p, dst);
-							lbValue e = lb_emit_arith(p, Token_Add, d, c, elem);
-							lb_emit_store(p, dst, e);
-							
-						}
-					}
-				}
-				
-				return lb_addr_load(p, res);
+				return lb_emit_matrix_mul(p, lhs, rhs, type);
 			}
 		}