Browse Source

Support indexing matrices

gingerBill 3 years ago
parent
commit
662cbaf425

+ 23 - 0
core/runtime/error_checks.odin

@@ -96,6 +96,29 @@ dynamic_array_expr_error :: proc "contextless" (file: string, line, column: i32,
 }
 }
 
 
 
 
+matrix_bounds_check_error :: proc "contextless" (file: string, line, column: i32, row_index, column_index, row_count, column_count: int) {
+	if 0 <= row_index && row_index < row_count && 
+	   0 <= column_index && column_index < column_count {
+		return
+	}
+	handle_error :: proc "contextless" (file: string, line, column: i32, row_index, column_index, row_count, column_count: int) {
+		print_caller_location(Source_Code_Location{file, line, column, ""})
+		print_string(" Matrix indices [")
+		print_i64(i64(row_index))
+		print_string(", ")
+		print_i64(i64(column_index))
+		print_string(" is out of bounds range [0..<")
+		print_i64(i64(row_count))
+		print_string(", 0..<")
+		print_i64(i64(column_count))
+		print_string("]")
+		print_byte('\n')
+		bounds_trap()
+	}
+	handle_error(file, line, column, row_index, column_index, row_count, column_count)
+}
+
+
 type_assertion_check :: proc "contextless" (ok: bool, file: string, line, column: i32, from, to: typeid) {
 type_assertion_check :: proc "contextless" (ok: bool, file: string, line, column: i32, from, to: typeid) {
 	if ok {
 	if ok {
 		return
 		return

+ 63 - 3
src/check_expr.cpp

@@ -6367,8 +6367,7 @@ bool check_set_index_data(Operand *o, Type *t, bool indirection, i64 *max_count,
 		*max_count = t->Matrix.column_count;
 		*max_count = t->Matrix.column_count;
 		if (indirection) {
 		if (indirection) {
 			o->mode = Addressing_Variable;
 			o->mode = Addressing_Variable;
-		} else if (o->mode != Addressing_Variable &&
-		           o->mode != Addressing_Constant) {
+		} else if (o->mode != Addressing_Variable) {
 			o->mode = Addressing_Value;
 			o->mode = Addressing_Value;
 		}
 		}
 		o->type = alloc_type_array(t->Matrix.elem, t->Matrix.row_count);
 		o->type = alloc_type_array(t->Matrix.elem, t->Matrix.row_count);
@@ -6672,7 +6671,68 @@ void check_promote_optional_ok(CheckerContext *c, Operand *x, Type **val_type_,
 
 
 
 
 void check_matrix_index_expr(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) {
 void check_matrix_index_expr(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) {
-	error(node, "TODO: matrix index expressions");
+	ast_node(ie, MatrixIndexExpr, node);
+	
+	check_expr(c, o, ie->expr);
+	node->viral_state_flags |= ie->expr->viral_state_flags;
+	if (o->mode == Addressing_Invalid) {
+		o->expr = node;
+		return;
+	}
+	
+	Type *t = base_type(type_deref(o->type));
+	bool is_ptr = is_type_pointer(o->type);
+	bool is_const = o->mode == Addressing_Constant;
+	
+	if (t->kind != Type_Matrix) {
+		gbString str = expr_to_string(o->expr);
+		gbString type_str = type_to_string(o->type);
+		defer (gb_string_free(str));
+		defer (gb_string_free(type_str));
+		if (is_const) {
+			error(o->expr, "Cannot use matrix indexing on constant '%s' of type '%s'", str, type_str);
+		} else {
+			error(o->expr, "Cannot use matrix indexing on '%s' of type '%s'", str, type_str);
+		}
+		o->mode = Addressing_Invalid;
+		o->expr = node;
+		return;
+	}
+	o->type = t->Matrix.elem;
+	if (is_ptr) {
+		o->mode = Addressing_Variable;
+	} else if (o->mode != Addressing_Variable) {
+		o->mode = Addressing_Value;
+	}
+	
+	if (ie->row_index == nullptr) {
+		gbString str = expr_to_string(o->expr);
+		error(o->expr, "Missing row index for '%s'", str);
+		gb_string_free(str);
+		o->mode = Addressing_Invalid;
+		o->expr = node;
+		return;
+	}
+	if (ie->column_index == nullptr) {
+		gbString str = expr_to_string(o->expr);
+		error(o->expr, "Missing column index for '%s'", str);
+		gb_string_free(str);
+		o->mode = Addressing_Invalid;
+		o->expr = node;
+		return;
+	}
+	
+	i64 row_count = t->Matrix.row_count;
+	i64 column_count = t->Matrix.column_count;
+	
+	i64 row_index = 0;
+	i64 column_index = 0;
+	bool row_ok = check_index_value(c, t, false, ie->row_index, row_count, &row_index, nullptr);
+	bool column_ok = check_index_value(c, t, false, ie->column_index, column_count, &column_index, nullptr);
+	
+	
+	gb_unused(row_ok);
+	gb_unused(column_ok);
 }
 }
 
 
 
 

+ 1 - 0
src/checker.cpp

@@ -2022,6 +2022,7 @@ void generate_minimum_dependency_set(Checker *c, Entity *start) {
 		String bounds_check_entities[] = {
 		String bounds_check_entities[] = {
 			// Bounds checking related procedures
 			// Bounds checking related procedures
 			str_lit("bounds_check_error"),
 			str_lit("bounds_check_error"),
+			str_lit("matrix_bounds_check_error"),
 			str_lit("slice_expr_error_hi"),
 			str_lit("slice_expr_error_hi"),
 			str_lit("slice_expr_error_lo_hi"),
 			str_lit("slice_expr_error_lo_hi"),
 			str_lit("multi_pointer_slice_expr_error"),
 			str_lit("multi_pointer_slice_expr_error"),

+ 1 - 0
src/llvm_backend.hpp

@@ -333,6 +333,7 @@ lbValue lb_emit_array_ep(lbProcedure *p, lbValue s, lbValue index);
 lbValue lb_emit_deep_field_gep(lbProcedure *p, lbValue e, Selection sel);
 lbValue lb_emit_deep_field_gep(lbProcedure *p, lbValue e, Selection sel);
 lbValue lb_emit_deep_field_ev(lbProcedure *p, lbValue e, Selection sel);
 lbValue lb_emit_deep_field_ev(lbProcedure *p, lbValue e, Selection sel);
 
 
+lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lbValue column);
 lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column);
 lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column);
 lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column);
 lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column);
 
 

+ 53 - 1
src/llvm_backend_expr.cpp

@@ -1727,7 +1727,7 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) {
 	}
 	}
 	
 	
 	if (is_type_matrix(dst) && !is_type_matrix(src)) {
 	if (is_type_matrix(dst) && !is_type_matrix(src)) {
-		GB_ASSERT(dst->Matrix.row_count == dst->Matrix.column_count);
+		GB_ASSERT_MSG(dst->Matrix.row_count == dst->Matrix.column_count, "%s <- %s", type_to_string(dst), type_to_string(src));
 		
 		
 		Type *elem = base_array_type(dst);
 		Type *elem = base_array_type(dst);
 		lbValue e = lb_emit_conv(p, value, elem);
 		lbValue e = lb_emit_conv(p, value, elem);
@@ -2805,6 +2805,10 @@ lbValue lb_build_expr(lbProcedure *p, Ast *expr) {
 	case_ast_node(ie, IndexExpr, expr);
 	case_ast_node(ie, IndexExpr, expr);
 		return lb_addr_load(p, lb_build_addr(p, expr));
 		return lb_addr_load(p, lb_build_addr(p, expr));
 	case_end;
 	case_end;
+	
+	case_ast_node(ie, MatrixIndexExpr, expr);
+		return lb_addr_load(p, lb_build_addr(p, expr));
+	case_end;
 
 
 	case_ast_node(ia, InlineAsmExpr, expr);
 	case_ast_node(ia, InlineAsmExpr, expr);
 		Type *t = type_of_expr(expr);
 		Type *t = type_of_expr(expr);
@@ -3304,6 +3308,25 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) {
 			lbValue v = lb_emit_ptr_offset(p, elem, index);
 			lbValue v = lb_emit_ptr_offset(p, elem, index);
 			return lb_addr(v);
 			return lb_addr(v);
 		}
 		}
+		
+		case Type_Matrix: {
+			lbValue matrix = {};
+			matrix = lb_build_addr_ptr(p, ie->expr);
+			if (deref) {
+				matrix = lb_emit_load(p, matrix);
+			}
+			lbValue index = lb_build_expr(p, ie->index);
+			index = lb_emit_conv(p, index, t_int);
+			lbValue elem = lb_emit_matrix_ep(p, matrix, lb_const_int(p->module, t_int, 0), index);
+			elem = lb_emit_conv(p, elem, alloc_type_pointer(type_of_expr(expr)));
+
+			auto index_tv = type_and_value_of_expr(ie->index);
+			if (index_tv.mode != Addressing_Constant) {
+				lbValue len = lb_const_int(p->module, t_int, t->Matrix.column_count);
+				lb_emit_bounds_check(p, ast_token(ie->index), index, len);
+			}
+			return lb_addr(elem);
+		}
 
 
 
 
 		case Type_Basic: { // Basic_string
 		case Type_Basic: { // Basic_string
@@ -3326,6 +3349,35 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) {
 		}
 		}
 		}
 		}
 	case_end;
 	case_end;
+	
+	case_ast_node(ie, MatrixIndexExpr, expr);
+		Type *t = base_type(type_of_expr(ie->expr));
+
+		bool deref = is_type_pointer(t);
+		t = base_type(type_deref(t));
+		
+		lbValue m = {};
+		m = lb_build_addr_ptr(p, ie->expr);
+		if (deref) {
+			m = lb_emit_load(p, m);
+		}
+		lbValue row_index = lb_build_expr(p, ie->row_index);
+		lbValue column_index = lb_build_expr(p, ie->column_index);
+		row_index = lb_emit_conv(p, row_index, t_int);
+		column_index = lb_emit_conv(p, column_index, t_int);
+		lbValue elem = lb_emit_matrix_ep(p, m, row_index, column_index);
+
+		auto row_index_tv = type_and_value_of_expr(ie->row_index);
+		auto column_index_tv = type_and_value_of_expr(ie->column_index);
+		if (row_index_tv.mode != Addressing_Constant || column_index_tv.mode != Addressing_Constant) {
+			lbValue row_count = lb_const_int(p->module, t_int, t->Matrix.row_count);
+			lbValue column_count = lb_const_int(p->module, t_int, t->Matrix.column_count);
+			lb_emit_matrix_bounds_check(p, ast_token(ie->row_index), row_index, column_index, row_count, column_count);
+		}
+		return lb_addr(elem);
+		
+		
+	case_end;
 
 
 	case_ast_node(se, SliceExpr, expr);
 	case_ast_node(se, SliceExpr, expr);
 
 

+ 30 - 0
src/llvm_backend_general.cpp

@@ -419,6 +419,36 @@ void lb_emit_bounds_check(lbProcedure *p, Token token, lbValue index, lbValue le
 	lb_emit_runtime_call(p, "bounds_check_error", args);
 	lb_emit_runtime_call(p, "bounds_check_error", args);
 }
 }
 
 
+void lb_emit_matrix_bounds_check(lbProcedure *p, Token token, lbValue row_index, lbValue column_index, lbValue row_count, lbValue column_count) {
+	if (build_context.no_bounds_check) {
+		return;
+	}
+	if ((p->state_flags & StateFlag_no_bounds_check) != 0) {
+		return;
+	}
+
+	row_index = lb_emit_conv(p, row_index, t_int);
+	column_index = lb_emit_conv(p, column_index, t_int);
+	row_count = lb_emit_conv(p, row_count, t_int);
+	column_count = lb_emit_conv(p, column_count, t_int);
+
+	lbValue file = lb_find_or_add_entity_string(p->module, get_file_path_string(token.pos.file_id));
+	lbValue line = lb_const_int(p->module, t_i32, token.pos.line);
+	lbValue column = lb_const_int(p->module, t_i32, token.pos.column);
+
+	auto args = array_make<lbValue>(permanent_allocator(), 7);
+	args[0] = file;
+	args[1] = line;
+	args[2] = column;
+	args[3] = row_index;
+	args[4] = column_index;
+	args[5] = row_count;
+	args[6] = column_count;
+
+	lb_emit_runtime_call(p, "matrix_bounds_check_error", args);
+}
+
+
 void lb_emit_multi_pointer_slice_bounds_check(lbProcedure *p, Token token, lbValue low, lbValue high) {
 void lb_emit_multi_pointer_slice_bounds_check(lbProcedure *p, Token token, lbValue low, lbValue high) {
 	if (build_context.no_bounds_check) {
 	if (build_context.no_bounds_check) {
 		return;
 		return;

+ 31 - 0
src/llvm_backend_utility.cpp

@@ -1249,6 +1249,37 @@ lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) {
 	return res;
 	return res;
 }
 }
 
 
+lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lbValue column) {
+	Type *t = s.type;
+	GB_ASSERT(is_type_pointer(t));
+	Type *mt = base_type(type_deref(t));
+	GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt));
+
+	Type *ptr = base_array_type(mt);
+	
+	LLVMValueRef stride_elems = lb_const_int(p->module, t_int, matrix_type_stride_in_elems(mt)).value;
+	
+	row = lb_emit_conv(p, row, t_int);
+	column = lb_emit_conv(p, column, t_int);
+	
+	LLVMValueRef index = LLVMBuildAdd(p->builder, row.value, LLVMBuildMul(p->builder, column.value, stride_elems, ""), "");
+
+	LLVMValueRef indices[2] = {
+		LLVMConstInt(lb_type(p->module, t_int), 0, false),
+		index,
+	};
+
+	lbValue res = {};
+	if (lb_is_const(s)) {
+		res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices));
+	} else {
+		res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), "");
+	}
+	res.type = alloc_type_pointer(ptr);
+	return res;
+}
+
+
 lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) {
 lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) {
 	Type *st = base_type(s.type);
 	Type *st = base_type(s.type);
 	GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st));
 	GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st));

+ 4 - 0
src/types.cpp

@@ -1726,6 +1726,8 @@ bool is_type_indexable(Type *t) {
 		return true;
 		return true;
 	case Type_RelativeSlice:
 	case Type_RelativeSlice:
 		return true;
 		return true;
+	case Type_Matrix:
+		return true;
 	}
 	}
 	return false;
 	return false;
 }
 }
@@ -1743,6 +1745,8 @@ bool is_type_sliceable(Type *t) {
 		return false;
 		return false;
 	case Type_RelativeSlice:
 	case Type_RelativeSlice:
 		return true;
 		return true;
+	case Type_Matrix:
+		return false;
 	}
 	}
 	return false;
 	return false;
 }
 }