Browse Source

Support `for in` with `bit_set`

gingerBill 1 year ago
parent
commit
b862691d75
3 changed files with 152 additions and 57 deletions
  1. 13 0
      src/check_stmt.cpp
  2. 52 50
      src/llvm_backend_expr.cpp
  3. 87 7
      src/llvm_backend_stmt.cpp

+ 13 - 0
src/check_stmt.cpp

@@ -1554,6 +1554,19 @@ gb_internal void check_range_stmt(CheckerContext *ctx, Ast *node, u32 mod_flags)
 				}
 				break;
 
+			case Type_BitSet:
+				array_add(&vals, t->BitSet.elem);
+				if (rs->vals.count > 1) {
+					error(rs->vals[1], "Expected 1 name when iterating over a bit_set, got %td", rs->vals.count);
+				}
+				if (rs->vals.count == 1 &&
+				    rs->vals[0]->kind == Ast_UnaryExpr &&
+				    rs->vals[0]->UnaryExpr.op.kind == Token_And) {
+					error(rs->vals[0], "When iteraing across a bit_set, you cannot modify the value with '&' as that does not make much sense");
+				}
+				add_type_info_type(ctx, operand.type);
+				break;
+
 			case Type_EnumeratedArray:
 				if (is_ptr) use_by_reference_for_value = true;
 				array_add(&vals, t->EnumeratedArray.elem);

+ 52 - 50
src/llvm_backend_expr.cpp

@@ -1373,6 +1373,57 @@ gb_internal bool lb_is_empty_string_constant(Ast *expr) {
 	return false;
 }
 
+gb_internal lbValue lb_build_binary_in(lbProcedure *p, lbValue left, lbValue right, TokenKind op) {
+	Type *rt = base_type(right.type);
+	if (is_type_pointer(rt)) {
+		right = lb_emit_load(p, right);
+		rt = base_type(type_deref(rt));
+	}
+
+	switch (rt->kind) {
+	case Type_Map:
+		{
+			lbValue map_ptr = lb_address_from_load_or_generate_local(p, right);
+			lbValue key = left;
+			lbValue ptr = lb_internal_dynamic_map_get_ptr(p, map_ptr, key);
+			if (op == Token_in) {
+				return lb_emit_conv(p, lb_emit_comp_against_nil(p, Token_NotEq, ptr), t_bool);
+			} else {
+				return lb_emit_conv(p, lb_emit_comp_against_nil(p, Token_CmpEq, ptr), t_bool);
+			}
+		}
+		break;
+	case Type_BitSet:
+		{
+			Type *key_type = rt->BitSet.elem;
+			GB_ASSERT(are_types_identical(left.type, key_type));
+
+			Type *it = bit_set_to_int(rt);
+			left = lb_emit_conv(p, left, it);
+			if (is_type_different_to_arch_endianness(it)) {
+				left = lb_emit_byte_swap(p, left, integer_endian_type_to_platform_type(it));
+			}
+
+			lbValue lower = lb_const_value(p->module, left.type, exact_value_i64(rt->BitSet.lower));
+			lbValue key = lb_emit_arith(p, Token_Sub, left, lower, left.type);
+			lbValue bit = lb_emit_arith(p, Token_Shl, lb_const_int(p->module, left.type, 1), key, left.type);
+			bit = lb_emit_conv(p, bit, it);
+
+			lbValue old_value = lb_emit_transmute(p, right, it);
+			lbValue new_value = lb_emit_arith(p, Token_And, old_value, bit, it);
+
+			if (op == Token_in) {
+				return lb_emit_conv(p, lb_emit_comp(p, Token_NotEq, new_value, lb_const_int(p->module, new_value.type, 0)), t_bool);
+			} else {
+				return lb_emit_conv(p, lb_emit_comp(p, Token_CmpEq, new_value, lb_const_int(p->module, new_value.type, 0)), t_bool);
+			}
+		}
+		break;
+	}
+	GB_PANIC("Invalid 'in' type");
+	return {};
+}
+
 gb_internal lbValue lb_build_binary_expr(lbProcedure *p, Ast *expr) {
 	ast_node(be, BinaryExpr, expr);
 
@@ -1480,57 +1531,8 @@ gb_internal lbValue lb_build_binary_expr(lbProcedure *p, Ast *expr) {
 		{
 			lbValue left = lb_build_expr(p, be->left);
 			lbValue right = lb_build_expr(p, be->right);
-			Type *rt = base_type(right.type);
-			if (is_type_pointer(rt)) {
-				right = lb_emit_load(p, right);
-				rt = base_type(type_deref(rt));
-			}
-
-			switch (rt->kind) {
-			case Type_Map:
-				{
-					lbValue map_ptr = lb_address_from_load_or_generate_local(p, right);
-					lbValue key = left;
-					lbValue ptr = lb_internal_dynamic_map_get_ptr(p, map_ptr, key);
-					if (be->op.kind == Token_in) {
-						return lb_emit_conv(p, lb_emit_comp_against_nil(p, Token_NotEq, ptr), t_bool);
-					} else {
-						return lb_emit_conv(p, lb_emit_comp_against_nil(p, Token_CmpEq, ptr), t_bool);
-					}
-				}
-				break;
-			case Type_BitSet:
-				{
-					Type *key_type = rt->BitSet.elem;
-					GB_ASSERT(are_types_identical(left.type, key_type));
-
-					Type *it = bit_set_to_int(rt);
-					left = lb_emit_conv(p, left, it);
-					if (is_type_different_to_arch_endianness(it)) {
-						left = lb_emit_byte_swap(p, left, integer_endian_type_to_platform_type(it));
-					}
-
-					lbValue lower = lb_const_value(p->module, left.type, exact_value_i64(rt->BitSet.lower));
-					lbValue key = lb_emit_arith(p, Token_Sub, left, lower, left.type);
-					lbValue bit = lb_emit_arith(p, Token_Shl, lb_const_int(p->module, left.type, 1), key, left.type);
-					bit = lb_emit_conv(p, bit, it);
-
-					lbValue old_value = lb_emit_transmute(p, right, it);
-					lbValue new_value = lb_emit_arith(p, Token_And, old_value, bit, it);
-
-					if (be->op.kind == Token_in) {
-						return lb_emit_conv(p, lb_emit_comp(p, Token_NotEq, new_value, lb_const_int(p->module, new_value.type, 0)), t_bool);
-					} else {
-						return lb_emit_conv(p, lb_emit_comp(p, Token_CmpEq, new_value, lb_const_int(p->module, new_value.type, 0)), t_bool);
-					}
-				}
-				break;
-			default:
-				GB_PANIC("Invalid 'in' type");
-			}
-			break;
+			return lb_build_binary_in(p, left, right, be->op.kind);
 		}
-		break;
 	default:
 		GB_PANIC("Invalid binary expression");
 		break;

+ 87 - 7
src/llvm_backend_stmt.cpp

@@ -737,6 +737,22 @@ gb_internal void lb_build_range_interval(lbProcedure *p, AstBinaryExpr *node,
 	lb_start_block(p, done);
 }
 
+gb_internal lbValue lb_enum_values_slice(lbProcedure *p, Type *enum_type, i64 *enum_count_) {
+	Type *t = enum_type;
+	GB_ASSERT(is_type_enum(t));
+	t = base_type(t);
+	GB_ASSERT(t->kind == Type_Enum);
+	i64 enum_count = t->Enum.fields.count;
+
+	if (enum_count_) *enum_count_ = enum_count;
+
+	lbValue ti       = lb_type_info(p, t);
+	lbValue variant  = lb_emit_struct_ep(p, ti, 4);
+	lbValue eti_ptr  = lb_emit_conv(p, variant, t_type_info_enum_ptr);
+	lbValue values   = lb_emit_load(p, lb_emit_struct_ep(p, eti_ptr, 2));
+	return values;
+}
+
 gb_internal void lb_build_range_enum(lbProcedure *p, Type *enum_type, Type *val_type, lbValue *val_, lbValue *idx_, lbBlock **loop_, lbBlock **done_) {
 	lbModule *m = p->module;
 
@@ -744,15 +760,11 @@ gb_internal void lb_build_range_enum(lbProcedure *p, Type *enum_type, Type *val_
 	GB_ASSERT(is_type_enum(t));
 	t = base_type(t);
 	Type *core_elem = core_type(t);
-	GB_ASSERT(t->kind == Type_Enum);
-	i64 enum_count = t->Enum.fields.count;
-	lbValue max_count = lb_const_int(m, t_int, enum_count);
+	i64 enum_count = 0;
 
-	lbValue ti          = lb_type_info(p, t);
-	lbValue variant     = lb_emit_struct_ep(p, ti, 4);
-	lbValue eti_ptr     = lb_emit_conv(p, variant, t_type_info_enum_ptr);
-	lbValue values      = lb_emit_load(p, lb_emit_struct_ep(p, eti_ptr, 2));
+	lbValue values      = lb_enum_values_slice(p, enum_type, &enum_count);
 	lbValue values_data = lb_slice_elem(p, values);
+	lbValue max_count   = lb_const_int(m, t_int, enum_count);
 
 	lbAddr offset_ = lb_add_local_generated(p, t_int, false);
 	lb_addr_store(p, offset_, lb_const_int(m, t_int, 0));
@@ -1052,6 +1064,74 @@ gb_internal void lb_build_range_stmt(lbProcedure *p, AstRangeStmt *rs, Scope *sc
 		case Type_Tuple:
 			lb_build_range_tuple(p, expr, val0_type, val1_type, &val, &key, &loop, &done);
 			break;
+
+		case Type_BitSet: {
+			lbModule *m = p->module;
+
+			lbValue the_set = lb_build_expr(p, expr);
+			if (is_type_pointer(type_deref(the_set.type))) {
+				the_set = lb_emit_load(p, the_set);
+			}
+
+			Type *elem = et->BitSet.elem;
+			if (is_type_enum(elem)) {
+				i64 enum_count = 0;
+				lbValue values      = lb_enum_values_slice(p, elem, &enum_count);
+				lbValue values_data = lb_slice_elem(p, values);
+				lbValue max_count   = lb_const_int(m, t_int, enum_count);
+
+				lbAddr offset_ = lb_add_local_generated(p, t_int, false);
+				lb_addr_store(p, offset_, lb_const_int(m, t_int, 0));
+
+				loop = lb_create_block(p, "for.bit_set.enum.loop");
+				lb_emit_jump(p, loop);
+				lb_start_block(p, loop);
+
+				lbBlock *body_check = lb_create_block(p, "for.bit_set.enum.body-check");
+				lbBlock *body = lb_create_block(p, "for.bit_set.enum.body");
+				done = lb_create_block(p, "for.bit_set.enum.done");
+
+				lbValue offset = lb_addr_load(p, offset_);
+				lbValue cond = lb_emit_comp(p, Token_Lt, offset, max_count);
+				lb_emit_if(p, cond, body_check, done);
+				lb_start_block(p, body_check);
+
+				lbValue val_ptr = lb_emit_ptr_offset(p, values_data, offset);
+				lb_emit_increment(p, offset_.addr);
+				val = lb_emit_load(p, val_ptr);
+				val = lb_emit_conv(p, val, elem);
+
+				lbValue check = lb_build_binary_in(p, val, the_set, Token_in);
+				lb_emit_if(p, check, body, loop);
+				lb_start_block(p, body);
+			} else {
+				lbAddr offset_ = lb_add_local_generated(p, t_int, false);
+				lb_addr_store(p, offset_, lb_const_int(m, t_int, et->BitSet.lower));
+
+				lbValue max_count = lb_const_int(m, t_int, et->BitSet.upper);
+
+				loop = lb_create_block(p, "for.bit_set.range.loop");
+				lb_emit_jump(p, loop);
+				lb_start_block(p, loop);
+
+				lbBlock *body_check = lb_create_block(p, "for.bit_set.range.body-check");
+				lbBlock *body = lb_create_block(p, "for.bit_set.range.body");
+				done = lb_create_block(p, "for.bit_set.range.done");
+
+				lbValue offset = lb_addr_load(p, offset_);
+				lbValue cond = lb_emit_comp(p, Token_LtEq, offset, max_count);
+				lb_emit_if(p, cond, body_check, done);
+				lb_start_block(p, body_check);
+
+				val = lb_emit_conv(p, offset, elem);
+				lb_emit_increment(p, offset_.addr);
+
+				lbValue check = lb_build_binary_in(p, val, the_set, Token_in);
+				lb_emit_if(p, check, body, loop);
+				lb_start_block(p, body);
+			}
+			break;
+		}
 		default:
 			GB_PANIC("Cannot range over %s", type_to_string(expr_type));
 			break;