Browse Source

Improve vector arithmetic generation for array programming operations

gingerBill 4 years ago
parent
commit
582f423b67
3 changed files with 218 additions and 73 deletions
  1. 1 0
      src/llvm_abi.cpp
  2. 215 72
      src/llvm_backend.cpp
  3. 2 1
      src/llvm_backend.hpp

+ 1 - 0
src/llvm_abi.cpp

@@ -272,6 +272,7 @@ i64 lb_alignof(LLVMTypeRef type) {
 		return 8;
 	case LLVMVectorTypeKind:
 		{
+			// TODO(bill): This appears to be correct but LLVM isn't necessarily "great" with regards to documentation
 			LLVMTypeRef elem = LLVMGetElementType(type);
 			i64 elem_size = lb_sizeof(elem);
 			i64 count = LLVMGetVectorSize(type);

+ 215 - 72
src/llvm_backend.cpp

@@ -291,16 +291,27 @@ void lb_emit_slice_bounds_check(lbProcedure *p, Token token, lbValue low, lbValu
 	}
 }
 
-bool lb_try_vector_cast(lbProcedure *p, lbValue ptr, LLVMTypeRef *vector_type_) {
+bool lb_try_update_alignment(lbValue ptr, unsigned alignment)  {
+	LLVMValueRef addr_ptr = ptr.value;
+	if (LLVMGetAlignment(addr_ptr) < alignment) {
+		if (LLVMIsAAllocaInst(addr_ptr) || LLVMIsAGlobalValue(addr_ptr)) {
+			LLVMSetAlignment(addr_ptr, alignment);
+		}
+	}
+	return LLVMGetAlignment(addr_ptr) >= alignment;
+}
+
+bool lb_try_vector_cast(lbModule *m, lbValue ptr, LLVMTypeRef *vector_type_) {
 	Type *array_type = base_type(type_deref(ptr.type));
 	GB_ASSERT(array_type->kind == Type_Array);
 	Type *elem_type = base_type(array_type->Array.elem);
 
+	// TODO(bill): Determine what is the correct limit for doing vector arithmetic
 	if (type_size_of(array_type) <= build_context.max_align &&
 	    is_type_valid_vector_elem(elem_type)) {
 		// Try to treat it like a vector if possible
 		bool possible = false;
-		LLVMTypeRef vector_type = LLVMVectorType(lb_type(p->module, elem_type), cast(unsigned)array_type->Array.count);
+		LLVMTypeRef vector_type = LLVMVectorType(lb_type(m, elem_type), cast(unsigned)array_type->Array.count);
 		unsigned vector_alignment = cast(unsigned)lb_alignof(vector_type);
 
 		LLVMValueRef addr_ptr = ptr.value;
@@ -408,7 +419,7 @@ void lb_addr_store(lbProcedure *p, lbAddr addr, lbValue value) {
 		lb_insert_dynamic_map_key_and_value(p, addr, addr.map.type, addr.map.key, value, p->curr_stmt);
 		return;
 	} else if (addr.kind == lbAddr_Context) {
-		lbAddr old_addr = lb_find_or_generate_context_ptr(p);
+		lbAddr old_addr = lb_find_context_ptr(p);
 
 
 		// IMPORTANT NOTE(bill, 2021-04-22): reuse unused 'context' variables to minimize stack usage
@@ -801,14 +812,7 @@ lbValue lb_addr_load(lbProcedure *p, lbAddr const &addr) {
 
 		static u8 const ordered_indices[4] = {0, 1, 2, 3};
 		if (gb_memcompare(ordered_indices, addr.swizzle.indices, addr.swizzle.count) == 0) {
-			LLVMValueRef addr_ptr = addr.addr.value;
-			if (LLVMGetAlignment(addr.addr.value) < res_align) {
-				if (LLVMIsAAllocaInst(addr_ptr) || LLVMIsAGlobalValue(addr_ptr)) {
-					LLVMSetAlignment(addr_ptr, res_align);
-				}
-			}
-
-			if (LLVMGetAlignment(addr.addr.value) >= res_align) {
+			if (lb_try_update_alignment(addr.addr, res_align)) {
 				Type *pt = alloc_type_pointer(addr.swizzle.type);
 				lbValue res = {};
 				res.value = LLVMBuildPointerCast(p->builder, addr.addr.value, lb_type(p->module, pt), "");
@@ -823,7 +827,7 @@ lbValue lb_addr_load(lbProcedure *p, lbAddr const &addr) {
 		GB_ASSERT(is_type_pointer(ptr.type));
 
 		LLVMTypeRef vector_type = nullptr;
-		if (lb_try_vector_cast(p, addr.addr, &vector_type)) {
+		if (lb_try_vector_cast(p->module, addr.addr, &vector_type)) {
 			LLVMSetAlignment(res.addr.value, cast(unsigned)lb_alignof(vector_type));
 
 			LLVMValueRef vp = LLVMBuildPointerCast(p->builder, addr.addr.value, LLVMPointerType(vector_type, 0), "");
@@ -3626,7 +3630,7 @@ void lb_mem_zero_ptr(lbProcedure *p, LLVMValueRef ptr, Type *type, unsigned alig
 	}
 }
 
-lbAddr lb_add_local(lbProcedure *p, Type *type, Entity *e, bool zero_init, i32 param_index) {
+lbAddr lb_add_local(lbProcedure *p, Type *type, Entity *e, bool zero_init, i32 param_index, bool force_no_init) {
 	GB_ASSERT(p->decl_block != p->curr_block);
 	LLVMPositionBuilderAtEnd(p->builder, p->decl_block->block);
 
@@ -3645,7 +3649,7 @@ lbAddr lb_add_local(lbProcedure *p, Type *type, Entity *e, bool zero_init, i32 p
 	LLVMPositionBuilderAtEnd(p->builder, p->curr_block->block);
 
 
-	if (!zero_init) {
+	if (!zero_init && !force_no_init) {
 		// If there is any padding of any kind, just zero init regardless of zero_init parameter
 		LLVMTypeKind kind = LLVMGetTypeKind(llvm_type);
 		if (kind == LLVMStructTypeKind) {
@@ -3678,6 +3682,12 @@ lbAddr lb_add_local_generated(lbProcedure *p, Type *type, bool zero_init) {
 	return lb_add_local(p, type, nullptr, zero_init);
 }
 
+lbAddr lb_add_local_generated_temp(lbProcedure *p, Type *type, i64 min_alignment) {
+	lbAddr res = lb_add_local(p, type, nullptr, false, 0, true);
+	lb_try_update_alignment(res.addr, cast(unsigned)min_alignment);
+	return res;
+}
+
 
 void lb_build_nested_proc(lbProcedure *p, AstProcLit *pd, Entity *e) {
 	GB_ASSERT(pd->body != nullptr);
@@ -7023,13 +7033,40 @@ lbValue lb_emit_unary_arith(lbProcedure *p, TokenKind op, lbValue x, Type *type)
 		Type *elem_type = base_array_type(type);
 
 		// NOTE(bill): Doesn't need to be zero because it will be initialized in the loops
-		lbAddr res_addr = lb_add_local_generated(p, type, false);
+		lbAddr res_addr = lb_add_local(p, type, nullptr, false, 0, true);
 		lbValue res = lb_addr_get_ptr(p, res_addr);
 
 		bool inline_array_arith = type_size_of(type) <= build_context.max_align;
 
 		i32 count = cast(i32)tl->Array.count;
 
+		LLVMTypeRef vector_type = nullptr;
+		if (op != Token_Not && lb_try_vector_cast(p->module, val, &vector_type)) {
+			LLVMValueRef vp = LLVMBuildPointerCast(p->builder, val.value, LLVMPointerType(vector_type, 0), "");
+			LLVMValueRef v = LLVMBuildLoad2(p->builder, vector_type, vp, "");
+
+			LLVMValueRef opv = nullptr;
+			switch (op) {
+			case Token_Xor:
+				opv = LLVMBuildNot(p->builder, v, "");
+				break;
+			case Token_Sub:
+				if (is_type_float(elem_type)) {
+					opv = LLVMBuildFNeg(p->builder, v, "");
+				} else {
+					opv = LLVMBuildNeg(p->builder, v, "");
+				}
+				break;
+			}
+
+			if (opv != nullptr) {
+				LLVMSetAlignment(res.value, cast(unsigned)lb_alignof(vector_type));
+				LLVMValueRef res_ptr = LLVMBuildPointerCast(p->builder, res.value, LLVMPointerType(vector_type, 0), "");
+				LLVMBuildStore(p->builder, opv, res_ptr);
+				return lb_emit_conv(p, lb_emit_load(p, res), type);
+			}
+		}
+
 		if (inline_array_arith) {
 			// inline
 			for (i32 i = 0; i < count; i++) {
@@ -7132,6 +7169,140 @@ lbValue lb_emit_unary_arith(lbProcedure *p, TokenKind op, lbValue x, Type *type)
 	return res;
 }
 
+bool lb_try_direct_vector_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, lbValue *res_) {
+	GB_ASSERT(is_type_array(type));
+	Type *elem_type = base_array_type(type);
+
+	// NOTE(bill): Shift operations cannot be easily dealt with due to Odin's semantics
+	if (op == Token_Shl || op == Token_Shr) {
+		return false;
+	}
+
+	if (!LLVMIsALoadInst(lhs.value) || !LLVMIsALoadInst(rhs.value)) {
+		return false;
+	}
+
+	lbValue lhs_ptr = {};
+	lbValue rhs_ptr = {};
+	lhs_ptr.value = LLVMGetOperand(lhs.value, 0);
+	lhs_ptr.type = alloc_type_pointer(lhs.type);
+	rhs_ptr.value = LLVMGetOperand(rhs.value, 0);
+	rhs_ptr.type = alloc_type_pointer(rhs.type);
+
+	LLVMTypeRef vector_type0 = nullptr;
+	LLVMTypeRef vector_type1 = nullptr;
+	if (lb_try_vector_cast(p->module, lhs_ptr, &vector_type0) &&
+	    lb_try_vector_cast(p->module, rhs_ptr, &vector_type1)) {
+		GB_ASSERT(vector_type0 == vector_type1);
+		LLVMTypeRef vector_type = vector_type0;
+
+		LLVMValueRef lhs_vp = LLVMBuildPointerCast(p->builder, lhs_ptr.value, LLVMPointerType(vector_type, 0), "");
+		LLVMValueRef rhs_vp = LLVMBuildPointerCast(p->builder, rhs_ptr.value, LLVMPointerType(vector_type, 0), "");
+		LLVMValueRef x = LLVMBuildLoad2(p->builder, vector_type, lhs_vp, "");
+		LLVMValueRef y = LLVMBuildLoad2(p->builder, vector_type, rhs_vp, "");
+		LLVMValueRef z = nullptr;
+
+		Type *integral_type = base_type(elem_type);
+		if (is_type_simd_vector(integral_type)) {
+			integral_type = core_array_type(integral_type);
+		}
+		if (is_type_bit_set(integral_type)) {
+			switch (op) {
+			case Token_Add: op = Token_Or;     break;
+			case Token_Sub: op = Token_AndNot; break;
+			}
+		}
+
+		if (is_type_float(integral_type)) {
+			switch (op) {
+			case Token_Add:
+				z = LLVMBuildFAdd(p->builder, x, y, "");
+				break;
+			case Token_Sub:
+				z = LLVMBuildFSub(p->builder, x, y, "");
+				break;
+			case Token_Mul:
+				z = LLVMBuildFMul(p->builder, x, y, "");
+				break;
+			case Token_Quo:
+				z = LLVMBuildFDiv(p->builder, x, y, "");
+				break;
+			case Token_Mod:
+				z = LLVMBuildFRem(p->builder, x, y, "");
+				break;
+			default:
+				GB_PANIC("Unsupported vector operation");
+				break;
+			}
+
+		} else {
+
+			switch (op) {
+			case Token_Add:
+				z = LLVMBuildAdd(p->builder, x, y, "");
+				break;
+			case Token_Sub:
+				z = LLVMBuildSub(p->builder, x, y, "");
+				break;
+			case Token_Mul:
+				z = LLVMBuildMul(p->builder, x, y, "");
+				break;
+			case Token_Quo:
+				if (is_type_unsigned(integral_type)) {
+					z = LLVMBuildUDiv(p->builder, x, y, "");
+				} else {
+					z = LLVMBuildSDiv(p->builder, x, y, "");
+				}
+				break;
+			case Token_Mod:
+				if (is_type_unsigned(integral_type)) {
+					z = LLVMBuildURem(p->builder, x, y, "");
+				} else {
+					z = LLVMBuildSRem(p->builder, x, y, "");
+				}
+				break;
+			case Token_ModMod:
+				if (is_type_unsigned(integral_type)) {
+					z = LLVMBuildURem(p->builder, x, y, "");
+				} else {
+					LLVMValueRef a = LLVMBuildSRem(p->builder, x, y, "");
+					LLVMValueRef b = LLVMBuildAdd(p->builder, a, y, "");
+					z = LLVMBuildSRem(p->builder, b, y, "");
+				}
+				break;
+			case Token_And:
+				z = LLVMBuildAnd(p->builder, x, y, "");
+				break;
+			case Token_AndNot:
+				z = LLVMBuildAnd(p->builder, x, LLVMBuildNot(p->builder, y, ""), "");
+				break;
+			case Token_Or:
+				z = LLVMBuildOr(p->builder, x, y, "");
+				break;
+			case Token_Xor:
+				z = LLVMBuildXor(p->builder, x, y, "");
+				break;
+			default:
+				GB_PANIC("Unsupported vector operation");
+				break;
+			}
+		}
+
+
+		if (z != nullptr) {
+			lbAddr res = lb_add_local_generated_temp(p, type, lb_alignof(vector_type));
+
+			LLVMValueRef vp = LLVMBuildPointerCast(p->builder, res.addr.value, LLVMPointerType(vector_type, 0), "");
+			LLVMBuildStore(p->builder, z, vp);
+			lbValue v = lb_addr_load(p, res);
+			if (res_) *res_ = v;
+			return true;
+		}
+	}
+
+	return false;
+}
+
 
 lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
 	GB_ASSERT(is_type_array(lhs.type) || is_type_array(rhs.type));
@@ -7143,18 +7314,22 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r
 	Type *elem_type = base_array_type(type);
 
 	i64 count = base_type(type)->Array.count;
+	unsigned n = cast(unsigned)count;
 
-	bool inline_array_arith = type_size_of(type) <= build_context.max_align;
+	// NOTE(bill, 2021-06-12): Try to do a direct operation as a vector, if possible
+	lbValue direct_vector_res = {};
+	if (lb_try_direct_vector_arith(p, op, lhs, rhs, type, &direct_vector_res)) {
+		return direct_vector_res;
+	}
 
+	bool inline_array_arith = type_size_of(type) <= build_context.max_align;
 	if (inline_array_arith) {
-	#if 1
-		#if 1
-		unsigned n = cast(unsigned)count;
-		auto dst_ptrs = slice_make<lbValue>(temporary_allocator(), count);
 
-		auto a_loads = slice_make<lbValue>(temporary_allocator(), count);
-		auto b_loads = slice_make<lbValue>(temporary_allocator(), count);
-		auto c_ops = slice_make<lbValue>(temporary_allocator(), count);
+		auto dst_ptrs = slice_make<lbValue>(temporary_allocator(), n);
+
+		auto a_loads = slice_make<lbValue>(temporary_allocator(), n);
+		auto b_loads = slice_make<lbValue>(temporary_allocator(), n);
+		auto c_ops = slice_make<lbValue>(temporary_allocator(), n);
 
 		for (unsigned i = 0; i < n; i++) {
 			a_loads[i].value = LLVMBuildExtractValue(p->builder, lhs.value, i, "");
@@ -7175,54 +7350,7 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r
 		for (unsigned i = 0; i < n; i++) {
 			lb_emit_store(p, dst_ptrs[i], c_ops[i]);
 		}
-		#else
-		lbValue x = lb_address_from_load_or_generate_local(p, lhs);
-		lbValue y = lb_address_from_load_or_generate_local(p, rhs);
-
-		auto a_ptrs = slice_make<lbValue>(temporary_allocator(), count);
-		auto b_ptrs = slice_make<lbValue>(temporary_allocator(), count);
-		auto dst_ptrs = slice_make<lbValue>(temporary_allocator(), count);
-
-		auto a_loads = slice_make<lbValue>(temporary_allocator(), count);
-		auto b_loads = slice_make<lbValue>(temporary_allocator(), count);
-		auto c_ops = slice_make<lbValue>(temporary_allocator(), count);
-
-		for (i64 i = 0; i < count; i++) {
-			a_ptrs[i] = lb_emit_array_epi(p, x, i);
-		}
-		for (i64 i = 0; i < count; i++) {
-			b_ptrs[i] = lb_emit_array_epi(p, y, i);
-		}
-		for (i64 i = 0; i < count; i++) {
-			a_loads[i] = lb_emit_load(p, a_ptrs[i]);
-		}
-		for (i64 i = 0; i < count; i++) {
-			b_loads[i] = lb_emit_load(p, b_ptrs[i]);
-		}
-		for (i64 i = 0; i < count; i++) {
-			c_ops[i] = lb_emit_arith(p, op, a_loads[i], b_loads[i], elem_type);
-		}
-
-		lbAddr res = lb_add_local_generated(p, type, false);
-		for (i64 i = 0; i < count; i++) {
-			dst_ptrs[i] = lb_emit_array_epi(p, res.addr, i);
-		}
-		for (i64 i = 0; i < count; i++) {
-			lb_emit_store(p, dst_ptrs[i], c_ops[i]);
-		}
-		#endif
-	#else
-		for (i64 i = 0; i < count; i++) {
-			lbValue a_ptr = lb_emit_array_epi(p, x, i);
-			lbValue b_ptr = lb_emit_array_epi(p, y, i);
-			lbValue dst_ptr = lb_emit_array_epi(p, res.addr, i);
 
-			lbValue a = lb_emit_load(p, a_ptr);
-			lbValue b = lb_emit_load(p, b_ptr);
-			lbValue c = lb_emit_arith(p, op, a, b, elem_type);
-			lb_emit_store(p, dst_ptr, c);
-		}
-	#endif
 
 		return lb_addr_load(p, res);
 	} else {
@@ -8390,6 +8518,15 @@ lbContextData *lb_push_context_onto_stack(lbProcedure *p, lbAddr ctx) {
 	return cd;
 }
 
+lbAddr lb_find_context_ptr(lbProcedure *p) {
+	if (p->context_stack.count > 0) {
+		return p->context_stack[p->context_stack.count-1].ctx;
+	}
+
+	GB_PANIC("Unable to get `context` ptr");
+	return {};
+}
+
 
 lbAddr lb_find_or_generate_context_ptr(lbProcedure *p) {
 	if (p->context_stack.count > 0) {
@@ -8957,7 +9094,7 @@ lbValue lb_emit_call(lbProcedure *p, lbValue value, Array<lbValue> const &args,
 
 	lbAddr context_ptr = {};
 	if (pt->Proc.calling_convention == ProcCC_Odin) {
-		context_ptr = lb_find_or_generate_context_ptr(p);
+		context_ptr = lb_find_context_ptr(p);
 	}
 
 	defer (if (pt->Proc.diverging) {
@@ -12739,7 +12876,7 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) {
 		lbAddr v = {};
 		switch (i->kind) {
 		case Token_context:
-			v = lb_find_or_generate_context_ptr(p);
+			v = lb_find_context_ptr(p);
 			break;
 		}
 
@@ -15054,6 +15191,12 @@ lbProcedure *lb_create_main_procedure(lbModule *m, lbProcedure *startup_runtime)
 	} else {
 		if (m->info->entry_point != nullptr) {
 			lbValue entry_point = lb_find_procedure_value_from_entity(m, m->info->entry_point);
+			Type *pt = base_type(entry_point.type);
+			GB_ASSERT(pt->kind == Type_Proc);
+			if (pt->kind == Type_Proc) {
+				lbAddr ctx = lb_find_or_generate_context_ptr(p);
+			}
+
 			lb_emit_call(p, entry_point, {});
 		}
 	}

+ 2 - 1
src/llvm_backend.hpp

@@ -342,13 +342,14 @@ void lb_start_block(lbProcedure *p, lbBlock *b);
 lbValue lb_build_call_expr(lbProcedure *p, Ast *expr);
 
 
+lbAddr lb_find_context_ptr(lbProcedure *p);
 lbAddr lb_find_or_generate_context_ptr(lbProcedure *p);
 lbContextData *lb_push_context_onto_stack(lbProcedure *p, lbAddr ctx);
 lbContextData *lb_push_context_onto_stack_from_implicit_parameter(lbProcedure *p);
 
 
 lbAddr lb_add_global_generated(lbModule *m, Type *type, lbValue value={});
-lbAddr lb_add_local(lbProcedure *p, Type *type, Entity *e=nullptr, bool zero_init=true, i32 param_index=0);
+lbAddr lb_add_local(lbProcedure *p, Type *type, Entity *e=nullptr, bool zero_init=true, i32 param_index=0, bool force_no_init=false);
 
 void lb_add_foreign_library_path(lbModule *m, Entity *e);