Browse Source

Basic support for matrix*vector, vector*matrix operations

gingerBill 4 năm trước cách đây
mục cha
commit
243e2e2b8a

+ 8 - 8
core/fmt/fmt.odin

@@ -1960,13 +1960,13 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) {
 		
 		fi.indent += 1;  defer fi.indent -= 1
 		
-		if fi.hash	{ 
+		if fi.hash { 
 			io.write_byte(fi.writer, '\n')
 			// TODO(bill): Should this render it like in written form? e.g. tranposed
-			for col in 0..<info.column_count {
+			for row in 0..<info.row_count {
 				fmt_write_indent(fi)
-				for row in 0..<info.row_count {
-					if row > 0 { io.write_string(fi.writer, ", ") }
+				for col in 0..<info.column_count {
+					if col > 0 { io.write_string(fi.writer, ", ") }
 					
 					offset := row*info.elem_size + col*info.stride
 					
@@ -1976,10 +1976,10 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) {
 				io.write_string(fi.writer, ";\n")
 			}
 		} else {
-			for col in 0..<info.column_count {
-				if col > 0 { io.write_string(fi.writer, "; ") }
-				for row in 0..<info.row_count {
-					if row > 0 { io.write_string(fi.writer, ", ") }
+			for row in 0..<info.row_count {
+				if row > 0 { io.write_string(fi.writer, ", ") }
+				for col in 0..<info.column_count {
+					if col > 0 { io.write_string(fi.writer, "; ") }
 					
 					offset := row*info.elem_size + col*info.stride
 					

+ 19 - 11
src/check_expr.cpp

@@ -2686,10 +2686,11 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand
 		x->mode = Addressing_Invalid;
 		return;
 	}
+		
+	Type *xt = base_type(x->type);
+	Type *yt = base_type(y->type);
 	
 	if (is_type_matrix(x->type)) {
-		Type *xt = base_type(x->type);
-		Type *yt = base_type(y->type);
 		GB_ASSERT(xt->kind == Type_Matrix);
 		if (op.kind == Token_Mul) {
 			if (yt->kind == Type_Matrix) {
@@ -2714,7 +2715,11 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand
 				
 				// Treat arrays as column vectors
 				x->mode = Addressing_Value;
-				x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1);
+				if (type_hint == nullptr && xt->Matrix.row_count == yt->Array.count) {
+					x->type = y->type;
+				} else {
+					x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1);
+				}
 				goto matrix_success;
 			}
 		}
@@ -2725,8 +2730,6 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand
 		x->type = xt;
 		goto matrix_success;
 	} else {
-		Type *xt = base_type(x->type);
-		Type *yt = base_type(y->type);
 		GB_ASSERT(is_type_matrix(yt));
 		GB_ASSERT(!is_type_matrix(xt));
 		
@@ -2743,7 +2746,11 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand
 				
 				// Treat arrays as row vectors
 				x->mode = Addressing_Value;
-				x->type = alloc_type_matrix(xt->Matrix.elem, 1, xt->Matrix.column_count);
+				if (type_hint == nullptr && yt->Matrix.column_count == xt->Array.count) {
+					x->type = x->type;
+				} else {
+					x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count);
+				}
 				goto matrix_success;
 			}
 		}
@@ -2775,13 +2782,13 @@ matrix_success:
 	
 	
 matrix_error:
-	gbString xt = type_to_string(x->type);
-	gbString yt = type_to_string(y->type);
+	gbString xts = type_to_string(x->type);
+	gbString yts = type_to_string(y->type);
 	gbString expr_str = expr_to_string(x->expr);
-	error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xt, yt);
+	error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xts, yts);
 	gb_string_free(expr_str);
-	gb_string_free(yt);
-	gb_string_free(xt);
+	gb_string_free(yts);
+	gb_string_free(xts);
 	x->type = t_invalid;
 	x->mode = Addressing_Invalid;
 	return;
@@ -2994,6 +3001,7 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint
 	}
 	if (is_type_matrix(x->type) || is_type_matrix(y->type)) {
 		check_binary_matrix(c, op, x, y, type_hint, use_lhs_as_type_hint);
+		x->expr = node;
 		return;
 	}
 

+ 49 - 8
src/llvm_backend.cpp

@@ -1135,13 +1135,46 @@ void lb_generate_code(lbGenerator *gen) {
 
 	auto *min_dep_set = &info->minimum_dependency_set;
 
-	LLVMInitializeAllTargetInfos();
-	LLVMInitializeAllTargets();
-	LLVMInitializeAllTargetMCs();
-	LLVMInitializeAllAsmPrinters();
-	LLVMInitializeAllAsmParsers();
-	LLVMInitializeAllDisassemblers();
-	LLVMInitializeNativeTarget();
+	switch (build_context.metrics.arch) {
+	case TargetArch_amd64: 
+	case TargetArch_386:
+		LLVMInitializeX86TargetInfo();
+		LLVMInitializeX86Target();
+		LLVMInitializeX86TargetMC();
+		LLVMInitializeX86AsmPrinter();
+		LLVMInitializeX86AsmParser();
+		LLVMInitializeX86Disassembler();
+		break;
+	case TargetArch_arm64:
+		LLVMInitializeAArch64TargetInfo();
+		LLVMInitializeAArch64Target();
+		LLVMInitializeAArch64TargetMC();
+		LLVMInitializeAArch64AsmPrinter();
+		LLVMInitializeAArch64AsmParser();
+		LLVMInitializeAArch64Disassembler();
+		break;
+	case TargetArch_wasm32:
+		LLVMInitializeWebAssemblyTargetInfo();
+		LLVMInitializeWebAssemblyTarget();
+		LLVMInitializeWebAssemblyTargetMC();
+		LLVMInitializeWebAssemblyAsmPrinter();
+		LLVMInitializeWebAssemblyAsmParser();
+		LLVMInitializeWebAssemblyDisassembler();
+		break;
+	default:
+		LLVMInitializeAllTargetInfos();
+		LLVMInitializeAllTargets();
+		LLVMInitializeAllTargetMCs();
+		LLVMInitializeAllAsmPrinters();
+		LLVMInitializeAllAsmParsers();
+		LLVMInitializeAllDisassemblers();
+		break;
+	}
+
+	
+	if (build_context.microarch == "native") {
+		LLVMInitializeNativeTarget();
+	}
 
 	char const *target_triple = alloc_cstring(permanent_allocator(), build_context.metrics.target_triplet);
 	for_array(i, gen->modules.entries) {
@@ -1174,6 +1207,14 @@ void lb_generate_code(lbGenerator *gen) {
 		if (gb_strcmp(llvm_cpu, host_cpu_name) == 0) {
 			llvm_features = LLVMGetHostCPUFeatures();
 		}
+	} else if (build_context.metrics.arch == TargetArch_amd64) {
+		// NOTE(bill): x86-64-v2 is more than enough for everyone
+		//
+		// x86-64: CMOV, CMPXCHG8B, FPU, FXSR, MMX, FXSR, SCE, SSE, SSE2
+		// x86-64-v2: (close to Nehalem) CMPXCHG16B, LAHF-SAHF, POPCNT, SSE3, SSE4.1, SSE4.2, SSSE3
+		// x86-64-v3: (close to Haswell) AVX, AVX2, BMI1, BMI2, F16C, FMA, LZCNT, MOVBE, XSAVE
+		// x86-64-v4: AVX512F, AVX512BW, AVX512CD, AVX512DQ, AVX512VL
+		llvm_cpu = "x86-64-v2";
 	}
 
 	// GB_ASSERT_MSG(LLVMTargetHasAsmBackend(target));
@@ -1640,6 +1681,7 @@ void lb_generate_code(lbGenerator *gen) {
 		code_gen_file_type = LLVMAssemblyFile;
 	}
 
+
 	for_array(j, gen->modules.entries) {
 		lbModule *m = gen->modules.entries[j].value;
 		if (LLVMVerifyModule(m->mod, LLVMReturnStatusAction, &llvm_error)) {
@@ -1684,7 +1726,6 @@ void lb_generate_code(lbGenerator *gen) {
 		}
 	}
 
-
 	TIME_SECTION("LLVM Add Foreign Library Paths");
 
 	for_array(j, gen->modules.entries) {

+ 117 - 96
src/llvm_backend_expr.cpp

@@ -509,12 +509,16 @@ LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) {
 	GB_ASSERT(mt->kind == Type_Matrix);
 	GB_ASSERT(lb_matrix_elem_simple(mt));
 	
-	unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
+	
+	i64 stride = matrix_type_stride_in_elems(mt);
+	i64 rows = mt->Matrix.row_count;
+	i64 columns = mt->Matrix.column_count;
+	unsigned elem_count = cast(unsigned)(rows*columns);
 	
 	Type *elem = mt->Matrix.elem;
 	LLVMTypeRef elem_type = lb_type(m, elem);
 	
-	LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
+	LLVMTypeRef vector_type = LLVMVectorType(elem_type, elem_count);
 	LLVMTypeRef types[] = {vector_type};
 	
 	char const *name = "llvm.matrix.column.major.load";
@@ -524,44 +528,18 @@ LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) {
 	
 	lbValue ptr = lb_address_from_load_or_generate_local(p, lhs);
 	ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
-	
+		
 	LLVMValueRef values[5] = {};
 	values[0] = ptr.value;
-	values[1] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
-	values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
-	values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
-	values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
-	
-	return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
-}
-LLVMValueRef llvm_matrix_column_major_load_from_ptr(lbProcedure *p, lbValue ptr) {
-	lbModule *m = p->module;
-
-	Type *mt = base_type(type_deref(ptr.type));
-	GB_ASSERT(mt->kind == Type_Matrix);
-	GB_ASSERT(lb_matrix_elem_simple(mt));
-	
-	unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
-	
-	Type *elem = mt->Matrix.elem;
-	LLVMTypeRef elem_type = lb_type(m, elem);
-	
-	LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
-	LLVMTypeRef types[] = {vector_type};
-	
-	char const *name = "llvm.matrix.column.major.load";
-	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
-	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
-	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
-	
-	LLVMValueRef values[5] = {};
-	values[0] = lb_emit_matrix_epi(p, ptr, 0, 0).value;
-	values[1] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
+	values[1] = lb_const_int(m, t_u64, stride).value; 
 	values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
 	values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
 	values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
 	
-	return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+	gb_printf_err("%s\n", LLVMPrintValueToString(call));
+	// LLVMAddAttributeAtIndex(call, 0, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt)));
+	return call;
 }
 
 void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef vector_value) {
@@ -571,12 +549,7 @@ void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef ve
 	GB_ASSERT(mt->kind == Type_Matrix);
 	GB_ASSERT(lb_matrix_elem_simple(mt));
 	
-	unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
-	
-	Type *elem = mt->Matrix.elem;
-	LLVMTypeRef elem_type = lb_type(m, elem);
-	
-	LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
+	LLVMTypeRef vector_type = LLVMTypeOf(vector_value);
 	LLVMTypeRef types[] = {vector_type};
 	
 	char const *name = "llvm.matrix.column.major.store";
@@ -587,56 +560,26 @@ void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef ve
 	lbValue ptr = lb_addr_get_ptr(p, addr);
 	ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
 	
-	GB_ASSERT(LLVMTypeOf(vector_value) == vector_type);
 	unsigned vector_size = LLVMGetVectorSize(vector_type);
 	GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
 	
-	LLVMValueRef values[6] = {};
-	values[0] = vector_value;
-	values[1] = ptr.value;
-	values[2] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
-	values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
-	values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
-	values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
-	
-	LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
-}
-
-void llvm_matrix_column_major_store_to_raw_ptr(lbProcedure *p, Type *mt, lbValue ptr, LLVMValueRef vector_value) {
-	lbModule *m = p->module;
-	
-	mt = base_type(mt);
-	GB_ASSERT(mt->kind == Type_Matrix);
-	GB_ASSERT(lb_matrix_elem_simple(mt));
-	
-	unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
-	
-	Type *elem = mt->Matrix.elem;
-	LLVMTypeRef elem_type = lb_type(m, elem);
-	
-	LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
-	LLVMTypeRef types[] = {vector_type};
-	
-	char const *name = "llvm.matrix.column.major.store";
-	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
-	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
-	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
-	
-	GB_ASSERT(LLVMTypeOf(vector_value) == vector_type);
-	unsigned vector_size = LLVMGetVectorSize(vector_type);
-	GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
+	i64 stride = matrix_type_stride_in_elems(mt);
 	
 	LLVMValueRef values[6] = {};
 	values[0] = vector_value;
 	values[1] = ptr.value;
-	values[2] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
+	values[2] = lb_const_int(m, t_u64, stride).value;
 	values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
 	values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
 	values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
 	
-	LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+	gb_printf_err("%s\n", LLVMPrintValueToString(call));
+	// LLVMAddAttributeAtIndex(call, 1, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt)));
+	gb_unused(call);
 }
 
+
 LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, i64 outer_rows, i64 inner, i64 outer_columns) {
 	lbModule *m = p->module;
 	
@@ -648,6 +591,7 @@ LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b
 	LLVMTypeRef elem_type = LLVMGetElementType(a_type);
 	
 	LLVMTypeRef res_vector_type = LLVMVectorType(elem_type, cast(unsigned)(outer_rows*outer_columns));
+	
 	LLVMTypeRef types[] = {res_vector_type, a_type, b_type};
 	
 	char const *name = "llvm.matrix.multiply";
@@ -662,7 +606,9 @@ LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b
 	values[3] = lb_const_int(m, t_u32, inner).value;
 	values[4] = lb_const_int(m, t_u32, outer_columns).value;
 	
-	return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+	gb_printf_err("%s\n", LLVMPrintValueToString(call));
+	return call;
 }
 
 
@@ -684,19 +630,13 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
 		// TODO(bill): LLVM ERROR: Do not know how to split the result of this operator!
 		lbAddr res = lb_add_local_generated(p, type, true);
 		
-		lbValue res_ptr = lb_addr_get_ptr(p, res);
-		res_ptr = lb_emit_matrix_epi(p, res_ptr, 0, 0);
-		
-		lbValue lhs_ptr = lb_address_from_load_or_generate_local(p, lhs);
-		lbValue rhs_ptr = lb_address_from_load_or_generate_local(p, rhs);
-		LLVMValueRef a = llvm_matrix_column_major_load_from_ptr(p, lhs_ptr);
-		LLVMValueRef b = llvm_matrix_column_major_load_from_ptr(p, rhs_ptr);
-		LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count);
-		
-		llvm_matrix_column_major_store_to_raw_ptr(p, type, res_ptr, c);
+		LLVMValueRef a = llvm_matrix_column_major_load(p, lhs); gb_unused(a);
+		LLVMValueRef b = llvm_matrix_column_major_load(p, rhs); gb_unused(b);
+		LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count); gb_unused(c);
+		llvm_matrix_column_major_store(p, res, c);
 		
 		return lb_addr_load(p, res);
-	}
+	} 
 		
 slow_form:
 	{
@@ -704,18 +644,21 @@ slow_form:
 		
 		lbAddr res = lb_add_local_generated(p, type, true);
 		
-		for (i64 i = 0; i < xt->Matrix.row_count; i++) {
-			for (i64 j = 0; j < yt->Matrix.column_count; j++) {
-				for (i64 k = 0; k < xt->Matrix.column_count; k++) {
+		i64 outer_rows    = xt->Matrix.row_count;
+		i64 inner         = xt->Matrix.column_count;
+		i64 outer_columns = yt->Matrix.column_count;
+		
+		for (i64 j = 0; j < outer_columns; j++) {
+			for (i64 i = 0; i < outer_rows; i++) {
+				for (i64 k = 0; k < inner; k++) {
 					lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
+					lbValue d0 = lb_emit_load(p, dst);
 					
 					lbValue a = lb_emit_matrix_ev(p, lhs, i, k);
 					lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
 					lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
-					lbValue d = lb_emit_load(p, dst);
-					lbValue e = lb_emit_arith(p, Token_Add, d, c, elem);
-					lb_emit_store(p, dst, e);
-					
+					lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem);
+					lb_emit_store(p, dst, d);
 				}
 			}
 		}
@@ -724,6 +667,72 @@ slow_form:
 	}
 }
 
+lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+	Type *mt = base_type(lhs.type);
+	Type *vt = base_type(rhs.type);
+	
+	GB_ASSERT(is_type_matrix(mt));
+	GB_ASSERT(is_type_array_like(vt));
+	
+	i64 vector_count = get_array_type_count(vt);
+	
+	GB_ASSERT(mt->Matrix.column_count == vector_count);
+	GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
+	
+	Type *elem = mt->Matrix.elem;
+	
+	lbAddr res = lb_add_local_generated(p, type, true);
+	
+	for (i64 i = 0; i < mt->Matrix.row_count; i++) {
+		for (i64 j = 0; j < mt->Matrix.column_count; j++) {
+			lbValue dst = lb_emit_matrix_epi(p, res.addr, i, 0);
+			lbValue d0 = lb_emit_load(p, dst);
+			
+			lbValue a = lb_emit_matrix_ev(p, lhs, i, j);
+			lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j);
+			lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
+			lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem);
+			lb_emit_store(p, dst, d);
+		}
+	}
+	
+	return lb_addr_load(p, res);
+}
+
+lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+	Type *mt = base_type(rhs.type);
+	Type *vt = base_type(lhs.type);
+	
+	GB_ASSERT(is_type_matrix(mt));
+	GB_ASSERT(is_type_array_like(vt));
+	
+	i64 vector_count = get_array_type_count(vt);
+	
+	GB_ASSERT(mt->Matrix.row_count == vector_count);
+	GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
+	
+	Type *elem = mt->Matrix.elem;
+	
+	lbAddr res = lb_add_local_generated(p, type, true);
+		
+	for (i64 j = 0; j < mt->Matrix.column_count; j++) {
+		for (i64 k = 0; k < mt->Matrix.row_count; k++) {
+			lbValue dst = lb_emit_matrix_epi(p, res.addr, 0, j);
+			lbValue d0 = lb_emit_load(p, dst);
+			
+			lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k);
+			lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
+			lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
+			lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem);
+			lb_emit_store(p, dst, d);
+		}
+	}
+	
+	return lb_addr_load(p, res);
+}
+
+
+
 
 lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
 	GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
@@ -735,7 +744,12 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue
 		if (xt->kind == Type_Matrix) {
 			if (yt->kind == Type_Matrix) {
 				return lb_emit_matrix_mul(p, lhs, rhs, type);
+			} else if (is_type_array_like(yt)) {
+				return lb_emit_matrix_mul_vector(p, lhs, rhs, type);
 			}
+		} else if (is_type_array_like(xt)) {
+			GB_ASSERT(yt->kind == Type_Matrix);
+			return lb_emit_vector_mul_matrix(p, lhs, rhs, type);
 		}
 		
 	} else {
@@ -1036,6 +1050,13 @@ lbValue lb_build_binary_expr(lbProcedure *p, Ast *expr) {
 	ast_node(be, BinaryExpr, expr);
 
 	TypeAndValue tv = type_and_value_of_expr(expr);
+	
+	if (is_type_matrix(be->left->tav.type) || is_type_matrix(be->right->tav.type)) {
+		lbValue left = lb_build_expr(p, be->left);
+		lbValue right = lb_build_expr(p, be->right);
+		return lb_emit_arith_matrix(p, be->op.kind, left, right, default_type(tv.type));
+	}
+	
 
 	switch (be->op.kind) {
 	case Token_Add:

+ 5 - 3
src/llvm_backend_general.cpp

@@ -1937,7 +1937,7 @@ LLVMTypeRef lb_type_internal(lbModule *m, Type *type) {
 			i64 elem_size = type_size_of(type->Matrix.elem);
 			GB_ASSERT(elem_size > 0);
 			i64 elem_count = size/elem_size;
-			GB_ASSERT(elem_count > 0);
+			GB_ASSERT_MSG(elem_count > 0, "%s", type_to_string(type));
 			
 			m->internal_type_level -= 1;
 			
@@ -2611,8 +2611,10 @@ lbAddr lb_add_local(lbProcedure *p, Type *type, Entity *e, bool zero_init, i32 p
 	LLVMTypeRef llvm_type = lb_type(p->module, type);
 	LLVMValueRef ptr = LLVMBuildAlloca(p->builder, llvm_type, name);
 
-	// unsigned alignment = 16; // TODO(bill): Make this configurable
-	unsigned alignment = cast(unsigned)lb_alignof(llvm_type);
+	unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(llvm_type));
+	if (is_type_matrix(type)) {
+		alignment *= 2; // NOTE(bill): Just in case
+	}
 	LLVMSetAlignment(ptr, alignment);
 
 	LLVMPositionBuilderAtEnd(p->builder, p->curr_block->block);

+ 6 - 4
src/llvm_backend_utility.cpp

@@ -1224,12 +1224,14 @@ lbValue lb_emit_ptr_offset(lbProcedure *p, lbValue ptr, lbValue index) {
 lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) {
 	Type *t = s.type;
 	GB_ASSERT(is_type_pointer(t));
-	Type *st = base_type(type_deref(t));
-	GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st));
+	Type *mt = base_type(type_deref(t));
+	GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt));
 
-	Type *ptr = base_array_type(st);
+	Type *ptr = base_array_type(mt);
+	
+	i64 stride_elems = matrix_type_stride_in_elems(mt);
 	
-	isize index = row*column;
+	isize index = row + column*stride_elems;
 	GB_ASSERT(0 <= index);
 
 	LLVMValueRef indices[2] = {

+ 3 - 0
src/types.cpp

@@ -1249,6 +1249,7 @@ bool is_type_matrix(Type *t) {
 }
 
 i64 matrix_type_stride(Type *t) {
+	// TODO(bill): precompute matrix stride
 	t = base_type(t);
 	GB_ASSERT(t->kind == Type_Matrix);
 	i64 align = type_align_of(t);
@@ -1258,6 +1259,7 @@ i64 matrix_type_stride(Type *t) {
 }
 
 i64 matrix_type_stride_in_elems(Type *t) {
+	// TODO(bill): precompute matrix stride
 	t = base_type(t);
 	GB_ASSERT(t->kind == Type_Matrix);
 	i64 stride = matrix_type_stride(t);
@@ -1266,6 +1268,7 @@ i64 matrix_type_stride_in_elems(Type *t) {
 
 
 i64 matrix_type_total_elems(Type *t) {
+	// TODO(bill): precompute matrix total elems
 	t = base_type(t);
 	GB_ASSERT(t->kind == Type_Matrix);
 	i64 size = type_size_of(t);