Browse Source

Merge pull request #5378 from laytan/fix-wasm-c-abi-raw-unions

Fix WASM C ABI for raw unions
gingerBill 1 month ago
parent
commit
8d37f9de09
3 changed files with 134 additions and 15 deletions
  1. 56 14
      src/llvm_abi.cpp
  2. 1 1
      src/llvm_backend_general.cpp
  3. 77 0
      src/types.cpp

+ 56 - 14
src/llvm_abi.cpp

@@ -1313,7 +1313,7 @@ namespace lbAbiWasm {
 		            registers/arguments if possible rather than by pointer.
 		            registers/arguments if possible rather than by pointer.
 	*/
 	*/
 	gb_internal Array<lbArgType> compute_arg_types(LLVMContextRef c, LLVMTypeRef *arg_types, unsigned arg_count, ProcCallingConvention calling_convention, Type *original_type);
 	gb_internal Array<lbArgType> compute_arg_types(LLVMContextRef c, LLVMTypeRef *arg_types, unsigned arg_count, ProcCallingConvention calling_convention, Type *original_type);
-	gb_internal LB_ABI_COMPUTE_RETURN_TYPE(compute_return_type);
+	gb_internal lbArgType compute_return_type(lbFunctionType *ft, LLVMContextRef c, LLVMTypeRef return_type, bool return_is_defined, bool return_is_tuple, Type* original_type);
 
 
 	enum {MAX_DIRECT_STRUCT_SIZE = 32};
 	enum {MAX_DIRECT_STRUCT_SIZE = 32};
 
 
@@ -1323,7 +1323,9 @@ namespace lbAbiWasm {
 		ft->ctx = c;
 		ft->ctx = c;
 		ft->calling_convention = calling_convention;
 		ft->calling_convention = calling_convention;
 		ft->args = compute_arg_types(c, arg_types, arg_count, calling_convention, original_type);
 		ft->args = compute_arg_types(c, arg_types, arg_count, calling_convention, original_type);
-		ft->ret = compute_return_type(ft, c, return_type, return_is_defined, return_is_tuple);
+
+		GB_ASSERT(original_type->kind == Type_Proc);
+		ft->ret = compute_return_type(ft, c, return_type, return_is_defined, return_is_tuple, original_type->Proc.results);
 		return ft;
 		return ft;
 	}
 	}
 
 
@@ -1359,7 +1361,7 @@ namespace lbAbiWasm {
 		return false;
 		return false;
 	}
 	}
 
 
-	gb_internal bool type_can_be_direct(LLVMTypeRef type, ProcCallingConvention calling_convention) {
+	gb_internal bool type_can_be_direct(LLVMTypeRef type, Type *original_type, ProcCallingConvention calling_convention) {
 		LLVMTypeKind kind = LLVMGetTypeKind(type);
 		LLVMTypeKind kind = LLVMGetTypeKind(type);
 		i64 sz = lb_sizeof(type);
 		i64 sz = lb_sizeof(type);
 		if (sz == 0) {
 		if (sz == 0) {
@@ -1372,9 +1374,21 @@ namespace lbAbiWasm {
 				return false;
 				return false;
 			} else if (kind == LLVMStructTypeKind) {
 			} else if (kind == LLVMStructTypeKind) {
 				unsigned count = LLVMCountStructElementTypes(type);
 				unsigned count = LLVMCountStructElementTypes(type);
+
+				// NOTE(laytan): raw unions are always structs with 1 field in LLVM, need to check our own def.
+				Type *bt = base_type(original_type);
+				if (bt->kind == Type_Struct && bt->Struct.is_raw_union) {
+					count = cast(unsigned)bt->Struct.fields.count;
+				}
+
 				if (count == 1) {
 				if (count == 1) {
-					return type_can_be_direct(LLVMStructGetTypeAtIndex(type, 0), calling_convention);
+					return type_can_be_direct(
+						LLVMStructGetTypeAtIndex(type, 0),
+						type_internal_index(original_type, 0),
+						calling_convention
+					);
 				}
 				}
+
 			} else if (is_basic_register_type(type)) {
 			} else if (is_basic_register_type(type)) {
 				return true;
 				return true;
 			}
 			}
@@ -1398,7 +1412,7 @@ namespace lbAbiWasm {
 		return false;
 		return false;
 	}
 	}
 
 
-	gb_internal lbArgType is_struct(LLVMContextRef c, LLVMTypeRef type, ProcCallingConvention calling_convention) {
+	gb_internal lbArgType is_struct(LLVMContextRef c, LLVMTypeRef type, Type *original_type, ProcCallingConvention calling_convention) {
 		LLVMTypeKind kind = LLVMGetTypeKind(type);
 		LLVMTypeKind kind = LLVMGetTypeKind(type);
 		GB_ASSERT(kind == LLVMArrayTypeKind || kind == LLVMStructTypeKind);
 		GB_ASSERT(kind == LLVMArrayTypeKind || kind == LLVMStructTypeKind);
 		
 		
@@ -1406,15 +1420,15 @@ namespace lbAbiWasm {
 		if (sz == 0) {
 		if (sz == 0) {
 			return lb_arg_type_ignore(type);
 			return lb_arg_type_ignore(type);
 		}
 		}
-		if (type_can_be_direct(type, calling_convention)) {
+		if (type_can_be_direct(type, original_type, calling_convention)) {
 			return lb_arg_type_direct(type);
 			return lb_arg_type_direct(type);
 		}
 		}
 		return lb_arg_type_indirect(type, nullptr);
 		return lb_arg_type_indirect(type, nullptr);
 	}
 	}
 	
 	
-	gb_internal lbArgType pseudo_slice(LLVMContextRef c, LLVMTypeRef type, ProcCallingConvention calling_convention) {
+	gb_internal lbArgType pseudo_slice(LLVMContextRef c, LLVMTypeRef type, Type *original_type, ProcCallingConvention calling_convention) {
 		if (build_context.metrics.ptr_size < build_context.metrics.int_size &&
 		if (build_context.metrics.ptr_size < build_context.metrics.int_size &&
-		    type_can_be_direct(type, calling_convention)) {
+		    type_can_be_direct(type, original_type, calling_convention)) {
 			LLVMTypeRef types[2] = {
 			LLVMTypeRef types[2] = {
 				LLVMStructGetTypeAtIndex(type, 0),
 				LLVMStructGetTypeAtIndex(type, 0),
 				// ignore padding
 				// ignore padding
@@ -1423,7 +1437,7 @@ namespace lbAbiWasm {
 			LLVMTypeRef new_type = LLVMStructTypeInContext(c, types, gb_count_of(types), false);
 			LLVMTypeRef new_type = LLVMStructTypeInContext(c, types, gb_count_of(types), false);
 			return lb_arg_type_direct(type, new_type, nullptr, nullptr);
 			return lb_arg_type_direct(type, new_type, nullptr, nullptr);
 		} else {
 		} else {
-			return is_struct(c, type, calling_convention);
+			return is_struct(c, type, original_type, calling_convention);
 		}
 		}
 	}
 	}
 
 
@@ -1444,9 +1458,9 @@ namespace lbAbiWasm {
 			LLVMTypeKind kind = LLVMGetTypeKind(t);
 			LLVMTypeKind kind = LLVMGetTypeKind(t);
 			if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind) {
 			if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind) {
 				if (is_type_slice(ptype) || is_type_string(ptype)) {
 				if (is_type_slice(ptype) || is_type_string(ptype)) {
-					args[i] = pseudo_slice(c, t, calling_convention);
+					args[i] = pseudo_slice(c, t, ptype, calling_convention);
 				} else {
 				} else {
-					args[i] = is_struct(c, t, calling_convention);
+					args[i] = is_struct(c, t, ptype, calling_convention);
 				}
 				}
 			} else {
 			} else {
 				args[i] = non_struct(c, t, false);
 				args[i] = non_struct(c, t, false);
@@ -1455,11 +1469,11 @@ namespace lbAbiWasm {
 		return args;
 		return args;
 	}
 	}
 
 
-	gb_internal LB_ABI_COMPUTE_RETURN_TYPE(compute_return_type) {
+	gb_internal lbArgType compute_return_type(lbFunctionType *ft, LLVMContextRef c, LLVMTypeRef return_type, bool return_is_defined, bool return_is_tuple, Type* original_type) {
 		if (!return_is_defined) {
 		if (!return_is_defined) {
 			return lb_arg_type_direct(LLVMVoidTypeInContext(c));
 			return lb_arg_type_direct(LLVMVoidTypeInContext(c));
 		} else if (lb_is_type_kind(return_type, LLVMStructTypeKind) || lb_is_type_kind(return_type, LLVMArrayTypeKind)) {
 		} else if (lb_is_type_kind(return_type, LLVMStructTypeKind) || lb_is_type_kind(return_type, LLVMArrayTypeKind)) {
-			if (type_can_be_direct(return_type, ft->calling_convention)) {
+			if (type_can_be_direct(return_type, original_type, ft->calling_convention)) {
 				return lb_arg_type_direct(return_type);
 				return lb_arg_type_direct(return_type);
 			} else if (ft->calling_convention != ProcCC_CDecl) {
 			} else if (ft->calling_convention != ProcCC_CDecl) {
 				i64 sz = lb_sizeof(return_type);
 				i64 sz = lb_sizeof(return_type);
@@ -1471,7 +1485,35 @@ namespace lbAbiWasm {
 				}
 				}
 			}
 			}
 
 
-			LB_ABI_MODIFY_RETURN_IF_TUPLE_MACRO();
+			// Multiple returns.
+			if (return_is_tuple) {                                                                                      \
+				lbArgType return_arg = {};
+				if (lb_is_type_kind(return_type, LLVMStructTypeKind)) {
+					unsigned field_count = LLVMCountStructElementTypes(return_type);
+					if (field_count > 1) {
+						ft->original_arg_count = ft->args.count;
+						ft->multiple_return_original_type = return_type;
+
+						for (unsigned i = 0; i < field_count-1; i++) {
+							LLVMTypeRef field_type = LLVMStructGetTypeAtIndex(return_type, i);
+							LLVMTypeRef field_pointer_type = LLVMPointerType(field_type, 0);
+							lbArgType ret_partial = lb_arg_type_direct(field_pointer_type);
+							array_add(&ft->args, ret_partial);
+						}
+
+						return_arg = compute_return_type(
+							ft,
+							c,
+							LLVMStructGetTypeAtIndex(return_type, field_count-1),
+							true, false,
+							type_internal_index(original_type, field_count-1)
+						);
+					}
+				}
+				if (return_arg.type != nullptr) {
+					return return_arg;
+				}
+			}
 
 
 			LLVMAttributeRef attr = lb_create_enum_attribute_with_type(c, "sret", return_type);
 			LLVMAttributeRef attr = lb_create_enum_attribute_with_type(c, "sret", return_type);
 			return lb_arg_type_indirect(return_type, attr);
 			return lb_arg_type_indirect(return_type, attr);

+ 1 - 1
src/llvm_backend_general.cpp

@@ -2206,7 +2206,7 @@ gb_internal LLVMTypeRef lb_type_internal(lbModule *m, Type *type) {
 				field_count = 3;
 				field_count = 3;
 			}
 			}
 			LLVMTypeRef *fields = gb_alloc_array(permanent_allocator(), LLVMTypeRef, field_count);
 			LLVMTypeRef *fields = gb_alloc_array(permanent_allocator(), LLVMTypeRef, field_count);
-			fields[0] = LLVMPointerType(lb_type(m, type->Pointer.elem), 0);
+			fields[0] = LLVMPointerType(lb_type(m, type->SoaPointer.elem), 0);
 			if (bigger_int) {
 			if (bigger_int) {
 				fields[1] = lb_type_padding_filler(m, build_context.ptr_size, build_context.ptr_size);
 				fields[1] = lb_type_padding_filler(m, build_context.ptr_size, build_context.ptr_size);
 				fields[2] = LLVMIntTypeInContext(ctx, 8*cast(unsigned)build_context.int_size);
 				fields[2] = LLVMIntTypeInContext(ctx, 8*cast(unsigned)build_context.int_size);

+ 77 - 0
src/types.cpp

@@ -4618,6 +4618,83 @@ gb_internal Type *alloc_type_proc_from_types(Type **param_types, unsigned param_
 // 	return type;
 // 	return type;
 // }
 // }
 
 
+// Index a type that is internally a struct or array.
+gb_internal Type *type_internal_index(Type *t, isize index) {
+	Type *bt = base_type(t);
+	if (bt == nullptr) {
+		return nullptr;
+	}
+
+	switch (bt->kind) {
+	case Type_Basic:
+		{
+			switch (bt->Basic.kind) {
+			case Basic_complex32:     return t_f16;
+			case Basic_complex64:     return t_f32;
+			case Basic_complex128:    return t_f64;
+			case Basic_quaternion64:  return t_f16;
+			case Basic_quaternion128: return t_f32;
+			case Basic_quaternion256: return t_f64;
+			case Basic_string:
+				{
+					GB_ASSERT(index == 0 || index == 1);
+					return index == 0 ? t_u8_ptr : t_int;
+				}
+			case Basic_any:
+				{
+					GB_ASSERT(index == 0 || index == 1);
+					return index == 0 ? t_rawptr : t_typeid;
+				}
+			}
+		}
+		break;
+
+	case Type_Array:           return bt->Array.elem;
+	case Type_EnumeratedArray: return bt->EnumeratedArray.elem;
+	case Type_SimdVector:      return bt->SimdVector.elem;
+	case Type_Slice:
+		{
+			GB_ASSERT(index == 0 || index == 1);
+			return index == 0 ? t_rawptr : t_typeid;
+		}
+	case Type_DynamicArray:
+		{
+			switch (index) {
+			case 0:  return t_rawptr;
+			case 1:  return t_int;
+			case 2:  return t_int;
+			case 3:  return t_allocator;
+			default: GB_PANIC("invalid raw dynamic array index");
+			};
+		}
+	case Type_Struct:
+		return get_struct_field_type(bt, index);
+	case Type_Union:
+		if (index < bt->Union.variants.count) {
+			return bt->Union.variants[index];
+		}
+		return union_tag_type(bt);
+	case Type_Tuple:
+		return bt->Tuple.variables[index]->type;
+	case Type_Matrix:
+		return bt->Matrix.elem;
+	case Type_SoaPointer:
+		{
+			GB_ASSERT(index == 0 || index == 1);
+			return index == 0 ? t_rawptr : t_int;
+		}
+	case Type_Map:
+		return type_internal_index(bt->Map.debug_metadata_type, index);
+	case Type_BitField:
+		return type_internal_index(bt->BitField.backing_type, index);
+	case Type_Generic:
+		return type_internal_index(bt->Generic.specialized, index);
+	};
+
+	GB_PANIC("Unhandled type %s", type_to_string(bt));
+	return nullptr;
+};
+
 gb_internal gbString write_type_to_string(gbString str, Type *type, bool shorthand=false, bool allow_polymorphic=false) {
 gb_internal gbString write_type_to_string(gbString str, Type *type, bool shorthand=false, bool allow_polymorphic=false) {
 	if (type == nullptr) {
 	if (type == nullptr) {
 		return gb_string_appendc(str, "<no type>");
 		return gb_string_appendc(str, "<no type>");