Browse Source

First step towards constant unions

gingerBill 2 weeks ago
parent
commit
a974c51d57
4 changed files with 81 additions and 32 deletions
  1. 1 1
      src/check_expr.cpp
  2. 1 1
      src/llvm_backend.hpp
  3. 60 27
      src/llvm_backend_const.cpp
  4. 19 3
      src/llvm_backend_expr.cpp

+ 1 - 1
src/check_expr.cpp

@@ -3520,7 +3520,7 @@ gb_internal bool is_type_union_constantable(Type *type) {
 			return false;
 		}
 	}
-	return false;
+	return true;
 }
 
 gb_internal bool check_cast_internal(CheckerContext *c, Operand *x, Type *type) {

+ 1 - 1
src/llvm_backend.hpp

@@ -455,7 +455,7 @@ static lbConstContext const LB_CONST_CONTEXT_DEFAULT_NO_LOCAL = {false, false, {
 
 gb_internal lbValue lb_const_nil(lbModule *m, Type *type);
 gb_internal lbValue lb_const_undef(lbModule *m, Type *type);
-gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lbConstContext cc = LB_CONST_CONTEXT_DEFAULT);
+gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lbConstContext cc = LB_CONST_CONTEXT_DEFAULT, Type *value_type=nullptr);
 gb_internal lbValue lb_const_bool(lbModule *m, Type *type, bool value);
 gb_internal lbValue lb_const_int(lbModule *m, Type *type, u64 value);
 

+ 60 - 27
src/llvm_backend_const.cpp

@@ -787,7 +787,7 @@ gb_internal bool lb_try_construct_const_union(lbModule *m, lbValue *value, Type
 }
 
 
-gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lbConstContext cc) {
+gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lbConstContext cc, Type *value_type) {
 	if (cc.allow_local) {
 		cc.is_rodata = false;
 	}
@@ -838,7 +838,6 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 	if (is_type_union(type) && is_type_union_constantable(type)) {
 		Type *bt = base_type(type);
 		GB_ASSERT(bt->kind == Type_Union);
-		GB_ASSERT(bt->Union.variants.count <= 1);
 		if (bt->Union.variants.count == 0) {
 			return lb_const_nil(m, original_type);
 		} else if (bt->Union.variants.count == 1) {
@@ -872,6 +871,37 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 				res.type = original_type;
 				return res;
 			}
+		} else {
+			GB_ASSERT(value_type != nullptr);
+
+			i64 block_size = bt->Union.variant_block_size;
+
+			lbValue cv = lb_const_value(m, value_type, value, cc, value_type);
+			Type *variant_type = cv.type;
+
+			LLVMValueRef values[4] = {};
+			unsigned value_count = 0;
+
+			values[value_count++] = cv.value;
+			if (type_size_of(variant_type) != block_size) {
+				LLVMTypeRef padding_type = lb_type_padding_filler(m, block_size - type_size_of(variant_type), 1);
+				values[value_count++] = LLVMConstNull(padding_type);
+			}
+
+			Type *tag_type = union_tag_type(bt);
+			LLVMTypeRef llvm_tag_type = lb_type(m, tag_type);
+			i64 tag_index = union_variant_index(bt, variant_type);
+			values[value_count++] = LLVMConstInt(llvm_tag_type, tag_index, false);
+			i64 used_size = block_size + type_size_of(tag_type);
+			i64 union_size = type_size_of(bt);
+			i64 padding = union_size - used_size;
+			if (padding > 0) {
+				LLVMTypeRef padding_type = lb_type_padding_filler(m, padding, 1);
+				values[value_count++] = LLVMConstNull(padding_type);
+			}
+
+			res.value = LLVMConstStructInContext(m->ctx, values, value_count, true);
+			return res;
 		}
 
 	}
@@ -909,7 +939,11 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 
 				array_data = llvm_alloca(p, llvm_type, 16);
 
-				LLVMBuildStore(p->builder, backing_array.value, array_data);
+				{
+					LLVMValueRef ptr = array_data;
+					ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(LLVMTypeOf(backing_array.value), 0), "");
+					LLVMBuildStore(p->builder, backing_array.value, ptr);
+				}
 
 				{
 					LLVMValueRef indices[2] = {llvm_zero(m), llvm_zero(m)};
@@ -931,7 +965,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 				String name = make_string(cast(u8 const *)str, gb_string_length(str));
 
 				Entity *e = alloc_entity_constant(nullptr, make_token_ident(name), t, value);
-				array_data = LLVMAddGlobal(m->mod, lb_type(m, t), str);
+				array_data = LLVMAddGlobal(m->mod, LLVMTypeOf(backing_array.value), str);
 				LLVMSetInitializer(array_data, backing_array.value);
 
 				if (cc.link_section.len > 0) {
@@ -942,15 +976,14 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 				}
 
 				lbValue g = {};
-				g.value = array_data;
+				g.value = LLVMConstPointerCast(array_data, LLVMPointerType(lb_type(m, t), 0));
 				g.type = t;
 
 				lb_add_entity(m, e, g);
 				lb_add_member(m, name, g);
 
 				{
-					LLVMValueRef indices[2] = {llvm_zero(m), llvm_zero(m)};
-					LLVMValueRef ptr = LLVMConstInBoundsGEP2(lb_type(m, t), array_data, indices, 2);
+					LLVMValueRef ptr = g.value;
 					LLVMValueRef len = LLVMConstInt(lb_type(m, t_int), count, true);
 					LLVMValueRef values[2] = {ptr, len};
 
@@ -1272,7 +1305,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 							}
 							if (lo == i) {
 								TypeAndValue tav = fv->value->tav;
-								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 								for (i64 k = lo; k < hi; k++) {
 									aos_values[value_index++] = val;
 								}
@@ -1287,7 +1320,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 							i64 index = exact_value_to_i64(index_tav.value);
 							if (index == i) {
 								TypeAndValue tav = fv->value->tav;
-								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 								aos_values[value_index++] = val;
 								found = true;
 								break;
@@ -1340,7 +1373,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 				for (isize i = 0; i < elem_count; i++) {
 					TypeAndValue tav = cl->elems[i]->tav;
 					GB_ASSERT(tav.mode != Addressing_Invalid);
-					aos_values[i] = lb_const_value(m, elem_type, tav.value, cc).value;
+					aos_values[i] = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 				}
 				for (isize i = elem_count; i < type->Struct.soa_count; i++) {
 					aos_values[i] = nullptr;
@@ -1407,7 +1440,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 							}
 							if (lo == i) {
 								TypeAndValue tav = fv->value->tav;
-								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 								for (i64 k = lo; k < hi; k++) {
 									values[value_index++] = val;
 								}
@@ -1422,7 +1455,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 							i64 index = exact_value_to_i64(index_tav.value);
 							if (index == i) {
 								TypeAndValue tav = fv->value->tav;
-								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 								values[value_index++] = val;
 								found = true;
 								break;
@@ -1437,12 +1470,12 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 
 				res.value = lb_build_constant_array_values(m, type, elem_type, cast(isize)type->Array.count, values, cc);
 				return res;
-			} else if (value.value_compound->tav.type == elem_type) {
+			} else if (are_types_identical(value.value_compound->tav.type, elem_type)) {
 				// Compound is of array item type; expand its value to all items in array.
 				LLVMValueRef* values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)type->Array.count);
 
 				for (isize i = 0; i < type->Array.count; i++) {
-					values[i] = lb_const_value(m, elem_type, value, cc).value;
+					values[i] = lb_const_value(m, elem_type, value, cc, elem_type).value;
 				}
 
 				res.value = lb_build_constant_array_values(m, type, elem_type, cast(isize)type->Array.count, values, cc);
@@ -1456,7 +1489,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 				for (isize i = 0; i < elem_count; i++) {
 					TypeAndValue tav = cl->elems[i]->tav;
 					GB_ASSERT(tav.mode != Addressing_Invalid);
-					values[i] = lb_const_value(m, elem_type, tav.value, cc).value;
+					values[i] = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 				}
 				for (isize i = elem_count; i < type->Array.count; i++) {
 					values[i] = LLVMConstNull(lb_type(m, elem_type));
@@ -1502,7 +1535,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 							}
 							if (lo == i) {
 								TypeAndValue tav = fv->value->tav;
-								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 								for (i64 k = lo; k < hi; k++) {
 									values[value_index++] = val;
 								}
@@ -1517,7 +1550,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 							i64 index = exact_value_to_i64(index_tav.value);
 							if (index == i) {
 								TypeAndValue tav = fv->value->tav;
-								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 								values[value_index++] = val;
 								found = true;
 								break;
@@ -1540,7 +1573,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 				for (isize i = 0; i < elem_count; i++) {
 					TypeAndValue tav = cl->elems[i]->tav;
 					GB_ASSERT(tav.mode != Addressing_Invalid);
-					values[i] = lb_const_value(m, elem_type, tav.value, cc).value;
+					values[i] = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 				}
 				for (isize i = elem_count; i < type->EnumeratedArray.count; i++) {
 					values[i] = LLVMConstNull(lb_type(m, elem_type));
@@ -1585,7 +1618,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 							}
 							if (lo == i) {
 								TypeAndValue tav = fv->value->tav;
-								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 								for (i64 k = lo; k < hi; k++) {
 									values[value_index++] = val;
 								}
@@ -1600,7 +1633,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 							i64 index = exact_value_to_i64(index_tav.value);
 							if (index == i) {
 								TypeAndValue tav = fv->value->tav;
-								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+								LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 								values[value_index++] = val;
 								found = true;
 								break;
@@ -1619,7 +1652,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 				for (isize i = 0; i < elem_count; i++) {
 					TypeAndValue tav = cl->elems[i]->tav;
 					GB_ASSERT(tav.mode != Addressing_Invalid);
-					values[i] = lb_const_value(m, elem_type, tav.value, cc).value;
+					values[i] = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 				}
 				LLVMTypeRef et = lb_type(m, elem_type);
 
@@ -1668,7 +1701,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 					i32 index = field_remapping[f->Variable.field_index];
 					if (elem_type_can_be_constant(f->type)) {
 						if (sel.index.count == 1) {
-							values[index]  = lb_const_value(m, f->type, tav.value, cc).value;
+							values[index]  = lb_const_value(m, f->type, tav.value, cc, tav.type).value;
 							visited[index] = true;
 						} else {
 							if (!visited[index]) {
@@ -1714,7 +1747,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 									}
 								}
 								if (is_constant) {
-									LLVMValueRef elem_value = lb_const_value(m, tav.type, tav.value, cc).value;
+									LLVMValueRef elem_value = lb_const_value(m, tav.type, tav.value, cc, tav.type).value;
 									if (LLVMIsConstant(elem_value) && LLVMIsConstant(values[index])) {
 										values[index] = llvm_const_insert_value(m, values[index], elem_value, idx_list, idx_list_len);
 									} else if (is_local) {
@@ -1768,7 +1801,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 
 					i32 index = field_remapping[f->Variable.field_index];
 					if (elem_type_can_be_constant(f->type)) {
-						values[index]  = lb_const_value(m, f->type, val, cc).value;
+						values[index]  = lb_const_value(m, f->type, val, cc, tav.type).value;
 						visited[index] = true;
 					}
 				}
@@ -1902,7 +1935,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 						
 						
 						TypeAndValue tav = fv->value->tav;
-						LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+						LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 						for (i64 k = lo; k < hi; k++) {
 							i64 offset = matrix_row_major_index_to_offset(type, k);
 							GB_ASSERT(values[offset] == nullptr);
@@ -1914,7 +1947,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 						i64 index = exact_value_to_i64(index_tav.value);
 						GB_ASSERT(index < max_count);
 						TypeAndValue tav = fv->value->tav;
-						LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc).value;
+						LLVMValueRef val = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 						i64 offset = matrix_row_major_index_to_offset(type, index);
 						GB_ASSERT(values[offset] == nullptr);
 						values[offset] = val;
@@ -1938,7 +1971,7 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, lb
 					GB_ASSERT(tav.mode != Addressing_Invalid);
 					i64 offset = 0;
 					offset = matrix_row_major_index_to_offset(type, i);
-					values[offset] = lb_const_value(m, elem_type, tav.value, cc).value;
+					values[offset] = lb_const_value(m, elem_type, tav.value, cc, tav.type).value;
 				}
 				for (isize i = 0; i < total_count; i++) {
 					if (values[i] == nullptr) {

+ 19 - 3
src/llvm_backend_expr.cpp

@@ -3947,6 +3947,20 @@ gb_internal lbValue lb_build_expr(lbProcedure *p, Ast *expr) {
 	return res;
 }
 
+gb_internal Type *lb_build_expr_original_const_type(Ast *expr) {
+	expr = unparen_expr(expr);
+	Type *type = type_of_expr(expr);
+	if (is_type_union(type)) {
+		if (expr->kind == Ast_CallExpr) {
+			if (expr->CallExpr.proc->tav.mode == Addressing_Type) {
+				Type *res = lb_build_expr_original_const_type(expr->CallExpr.args[0]);
+				return res;
+			}
+		}
+	}
+	return type_of_expr(expr);
+}
+
 gb_internal lbValue lb_build_expr_internal(lbProcedure *p, Ast *expr) {
 	lbModule *m = p->module;
 
@@ -3958,9 +3972,11 @@ gb_internal lbValue lb_build_expr_internal(lbProcedure *p, Ast *expr) {
 	GB_ASSERT_MSG(tv.mode != Addressing_Invalid, "invalid expression '%s' (tv.mode = %d, tv.type = %s) @ %s\n Current Proc: %.*s : %s", expr_to_string(expr), tv.mode, type_to_string(tv.type), token_pos_to_string(expr_pos), LIT(p->name), type_to_string(p->type));
 
 
+
 	if (tv.value.kind != ExactValue_Invalid) {
+		Type *original_type = lb_build_expr_original_const_type(expr);
 		// NOTE(bill): Short on constant values
-		return lb_const_value(p->module, type, tv.value, LB_CONST_CONTEXT_DEFAULT_ALLOW_LOCAL);
+		return lb_const_value(p->module, type, tv.value, LB_CONST_CONTEXT_DEFAULT_ALLOW_LOCAL, original_type);
 	} else if (tv.mode == Addressing_Type) {
 		// NOTE(bill, 2023-01-16): is this correct? I hope so at least
 		return lb_typeid(m, tv.type);
@@ -4041,7 +4057,7 @@ gb_internal lbValue lb_build_expr_internal(lbProcedure *p, Ast *expr) {
 		TypeAndValue tav = type_and_value_of_expr(expr);
 		GB_ASSERT(tav.mode == Addressing_Constant);
 
-		return lb_const_value(p->module, type, tv.value);
+		return lb_const_value(p->module, type, tv.value, LB_CONST_CONTEXT_DEFAULT_ALLOW_LOCAL, tv.type);
 	case_end;
 
 	case_ast_node(se, SelectorCallExpr, expr);
@@ -4322,7 +4338,7 @@ gb_internal lbAddr lb_build_addr_from_entity(lbProcedure *p, Entity *e, Ast *exp
 	GB_ASSERT(e != nullptr);
 	if (e->kind == Entity_Constant) {
 		Type *t = default_type(type_of_expr(expr));
-		lbValue v = lb_const_value(p->module, t, e->Constant.value);
+		lbValue v = lb_const_value(p->module, t, e->Constant.value, LB_CONST_CONTEXT_DEFAULT_NO_LOCAL, e->type);
 		if (LLVMIsConstant(v.value)) {
 			lbAddr g = lb_add_global_generated_from_procedure(p, t, v);
 			return g;