Parcourir la source

Improve use of vector muladd operations

gingerBill il y a 4 ans
Parent
commit
a440d8d812
5 fichiers modifiés avec 64 ajouts et 18 suppressions
  1. 12 1
      src/llvm_abi.cpp
  2. 3 1
      src/llvm_backend.hpp
  3. 4 6
      src/llvm_backend_expr.cpp
  4. 1 10
      src/llvm_backend_proc.cpp
  5. 44 0
      src/llvm_backend_utility.cpp

+ 12 - 1
src/llvm_abi.cpp

@@ -153,7 +153,18 @@ void lb_add_function_type_attributes(LLVMValueRef fn, lbFunctionType *ft, ProcCa
 	// TODO(bill): Clean up this logic
 	if (!is_arch_wasm()) {
 		cc_kind = lb_calling_convention_map[calling_convention];
-	}
+	} 
+	// if (build_context.metrics.arch == TargetArch_amd64) {
+	// 	if (build_context.metrics.os == TargetOs_windows) {
+	// 		if (cc_kind == lbCallingConvention_C) {
+	// 			cc_kind = lbCallingConvention_Win64;
+	// 		}
+	// 	} else {
+	// 		if (cc_kind == lbCallingConvention_C) {
+	// 			cc_kind = lbCallingConvention_X86_64_SysV;
+	// 		}
+	// 	}
+	// } 
 	LLVMSetFunctionCallConv(fn, cc_kind);
 	if (calling_convention == ProcCC_Odin) {
 		unsigned context_index = offset+arg_count;

+ 3 - 1
src/llvm_backend.hpp

@@ -472,7 +472,7 @@ LLVMTypeRef lb_type_padding_filler(lbModule *m, i64 padding, i64 padding_align);
 
 
 
-enum lbCallingConventionKind {
+enum lbCallingConventionKind : unsigned {
 	lbCallingConvention_C = 0,
 	lbCallingConvention_Fast = 8,
 	lbCallingConvention_Cold = 9,
@@ -517,6 +517,8 @@ enum lbCallingConventionKind {
 	lbCallingConvention_AMDGPU_LS = 95,
 	lbCallingConvention_AMDGPU_ES = 96,
 	lbCallingConvention_AArch64_VectorCall = 97,
+	lbCallingConvention_AArch64_SVE_VectorCall = 98,
+	lbCallingConvention_WASM_EmscriptenInvoke = 99,
 	lbCallingConvention_MaxID = 1023,
 };
 

+ 4 - 6
src/llvm_backend_expr.cpp

@@ -837,11 +837,10 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 		
 		LLVMValueRef vector = nullptr;
 		for (i64 i = 0; i < column_count; i++) {
-			LLVMValueRef product = llvm_vector_mul(p, m_columns[i], v_rows[i]);
 			if (i == 0) {
-				vector = product;
+				vector = llvm_vector_mul(p, m_columns[i], v_rows[i]);
 			} else {
-				vector = llvm_vector_add(p, vector, product);
+				vector = llvm_vector_mul_add(p, m_columns[i], v_rows[i], vector);
 			}
 		}
 		
@@ -914,11 +913,10 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 		
 		LLVMValueRef vector = nullptr;
 		for (i64 i = 0; i < row_count; i++) {
-			LLVMValueRef product = llvm_vector_mul(p, v_rows[i], m_columns[i]);
 			if (i == 0) {
-				vector = product;
+				vector = llvm_vector_mul(p, v_rows[i], m_columns[i]);
 			} else {
-				vector = llvm_vector_add(p, vector, product);
+				vector = llvm_vector_mul_add(p, v_rows[i], m_columns[i], vector);
 			}
 		}
 

+ 1 - 10
src/llvm_backend_proc.cpp

@@ -127,16 +127,7 @@ lbProcedure *lb_create_procedure(lbModule *m, Entity *entity, bool ignore_body)
 
 	lb_ensure_abi_function_type(m, p);
 	lb_add_function_type_attributes(p->value, p->abi_function_type, p->abi_function_type->calling_convention);
-	if (false) {
-		lbCallingConventionKind cc_kind = lbCallingConvention_C;
-		// TODO(bill): Clean up this logic
-		if (!is_arch_wasm()) {
-			cc_kind = lb_calling_convention_map[pt->Proc.calling_convention];
-		}
-		LLVMSetFunctionCallConv(p->value, cc_kind);
-	}
-
-
+	
 	if (pt->Proc.diverging) {
 		lb_add_attribute_to_proc(m, p->value, "noreturn");
 	}

+ 44 - 0
src/llvm_backend_utility.cpp

@@ -1631,4 +1631,48 @@ LLVMValueRef llvm_vector_mul(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
 
 LLVMValueRef llvm_vector_dot(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
 	return llvm_vector_reduce_add(p, llvm_vector_mul(p, a, b));
+}
+
+LLVMValueRef llvm_vector_mul_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, LLVMValueRef c) {
+	lbModule *m = p->module;
+	
+	LLVMTypeRef t = LLVMTypeOf(a);
+	GB_ASSERT(t == LLVMTypeOf(b));
+	GB_ASSERT(t == LLVMTypeOf(c));
+	GB_ASSERT(LLVMGetTypeKind(t) == LLVMVectorTypeKind);
+	
+	LLVMTypeRef elem = LLVMGetElementType(t);
+	
+	bool is_possible = false;
+	
+	switch (LLVMGetTypeKind(elem)) {
+	case LLVMHalfTypeKind:
+		is_possible = true;
+		break;
+	case LLVMFloatTypeKind:
+	case LLVMDoubleTypeKind:
+		is_possible = true;
+		break;
+	}
+
+	if (is_possible) {
+		char const *name = "llvm.fmuladd";
+		unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+		GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+		
+		LLVMTypeRef types[1] = {};
+		types[0] = t;
+		
+		LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
+		LLVMValueRef values[3] = {};
+		values[0] = a;
+		values[1] = b;
+		values[2] = c;
+		LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+		return call;
+	} else {
+		LLVMValueRef x = llvm_vector_mul(p, a, b);
+		LLVMValueRef y = llvm_vector_add(p, x, c);
+		return y;
+	}
 }