Browse Source

Improve matrix code generation for all supported platforms

Through assembly optimization
gingerBill 3 years ago
parent
commit
d62c701a43
2 changed files with 47 additions and 3 deletions
  1. 27 2
      src/llvm_backend_expr.cpp
  2. 20 1
      src/llvm_backend_utility.cpp

+ 27 - 2
src/llvm_backend_expr.cpp

@@ -489,13 +489,32 @@ bool lb_is_matrix_simdable(Type *t) {
 		return false;
 	}
 	
+	switch (build_context.metrics.arch) {
+	case TargetArch_amd64:
+	case TargetArch_arm64:
+		// possible
+		break;
+	case TargetArch_386:
+	case TargetArch_wasm32:
+		// nope
+		return false;
+	}
+	
 	if (elem->kind == Type_Basic) {
 		switch (elem->Basic.kind) {
 		case Basic_f16:
 		case Basic_f16le:
 		case Basic_f16be:
-			// TODO(bill): determine when this is fine
-			return true;
+			switch (build_context.metrics.arch) {
+			case TargetArch_amd64:
+				return false;
+			case TargetArch_arm64:
+				// TODO(bill): determine when this is fine
+				return true;
+			case TargetArch_386:
+			case TargetArch_wasm32:
+				return false;
+			}
 		}
 	}
 	
@@ -690,6 +709,8 @@ lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type)
 }
 
 lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+	// TODO(bill): Handle edge case for f16 types on x86(-64) platforms
+	
 	Type *xt = base_type(lhs.type);
 	Type *yt = base_type(rhs.type);
 	
@@ -775,6 +796,8 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
 }
 
 lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+	// TODO(bill): Handle edge case for f16 types on x86(-64) platforms
+	
 	Type *mt = base_type(lhs.type);
 	Type *vt = base_type(rhs.type);
 	
@@ -843,6 +866,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 }
 
 lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+	// TODO(bill): Handle edge case for f16 types on x86(-64) platforms
+	
 	Type *mt = base_type(rhs.type);
 	Type *vt = base_type(lhs.type);
 	

+ 20 - 1
src/llvm_backend_utility.cpp

@@ -1492,7 +1492,26 @@ lbValue lb_emit_mul_add(lbProcedure *p, lbValue a, lbValue b, lbValue c, Type *t
 	b = lb_emit_conv(p, b, t);
 	c = lb_emit_conv(p, c, t);
 	
-	if (!is_type_different_to_arch_endianness(t) && is_type_float(t)) {
+	bool is_possible = !is_type_different_to_arch_endianness(t) && is_type_float(t);
+	
+	if (is_possible) {
+		switch (build_context.metrics.arch) {
+		case TargetArch_amd64:
+			if (type_size_of(t) == 2) {
+				is_possible = false;
+			}
+			break;
+		case TargetArch_arm64:
+			// possible
+			break;
+		case TargetArch_386:
+		case TargetArch_wasm32:
+			is_possible = false;
+			break;
+		}
+	}
+
+	if (is_possible) {
 		char const *name = "llvm.fma";
 		unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
 		GB_ASSERT_MSG(id != 0, "Unable to find %s", name);