فهرست منبع

Disallow object-to-void* casting. Lowering fixes. Variant fixes.

Brian Fiete 5 سال پیش
والد
کامیت
16be83ceda

+ 1 - 1
BeefLibs/corlib/src/IO/OpenFileDialog.bf

@@ -421,7 +421,7 @@ namespace System.IO
 	        	ofn.mTitle = mTitle.ToScopedNativeWChar!::();
 	        ofn.mFlags = Options | (Windows.OFN_EXPLORER | Windows.OFN_ENABLEHOOK | Windows.OFN_ENABLESIZING);
 	        ofn.mHook = hookProcPtr;
-			ofn.mCustData = (int)(void*)this;
+			ofn.mCustData = (int)Internal.UnsafeCastToPtr(this);
 	        ofn.mFlagsEx = Windows.OFN_USESHELLITEM;
 	        if (mDefaultExt != null && AddExtension)
 	            ofn.mDefExt = mDefaultExt;

+ 2 - 2
BeefLibs/corlib/src/Object.bf

@@ -89,7 +89,7 @@ namespace System
 
         int IHashable.GetHashCode()
         {
-            return (int)(void*)this;
+            return (int)Internal.UnsafeCastToPtr(this);
         }
         
         public virtual void ToString(String strBuffer)
@@ -97,7 +97,7 @@ namespace System
             //strBuffer.Set(stack string(GetType().mName));
             RawGetType().GetName(strBuffer);
 			strBuffer.Append("@");
-			((int)(void*)this).ToString(strBuffer, "X", null);
+			((int)Internal.UnsafeCastToPtr(this)).ToString(strBuffer, "X", null);
         }
                 
         [SkipCall, NoShow]

+ 5 - 5
BeefLibs/corlib/src/Reflection/FieldInfo.bf

@@ -64,10 +64,10 @@ namespace System.Reflection
 	        }
 
 	        Type fieldType = Type.[Friend]GetType(mFieldData.mFieldTypeId);
-	        void* fieldDataAddr = ((uint8*)(void*)obj) + mFieldData.mDataOffset + dataOffsetAdjust;
+	        void* fieldDataAddr = ((uint8*)Internal.UnsafeCastToPtr(obj)) + mFieldData.mDataOffset + dataOffsetAdjust;
 
 			Type rawValueType = value.[Friend]RawGetType();
-			void* valueDataAddr = ((uint8*)(void*)value) + rawValueType.[Friend]mMemberDataOffset;
+			void* valueDataAddr = ((uint8*)Internal.UnsafeCastToPtr(value)) + rawValueType.[Friend]mMemberDataOffset;
 			
 			Type valueType = value.GetType();
 
@@ -126,7 +126,7 @@ namespace System.Reflection
 
 		    Type fieldType = Type.[Friend]GetType(mFieldData.mFieldTypeId);
 		    
-		    void* dataAddr = ((uint8*)(void*)obj) + mFieldData.mDataOffset + dataOffsetAdjust;
+		    void* dataAddr = ((uint8*)Internal.UnsafeCastToPtr(obj)) + mFieldData.mDataOffset + dataOffsetAdjust;
 
 			if (value.VariantType != fieldType)
 				return .Err;//("Invalid type");
@@ -154,7 +154,7 @@ namespace System.Reflection
 
 	        if (type.IsBoxed)
 	            return ((uint8*)(void*)value) + type.[Friend]mMemberDataOffset;
-	        return ((uint8*)(void*)value);
+	        return ((uint8*)Internal.UnsafeCastToPtr(value));
 	    }
 
 	    public Result<void> GetValue<TMember>(Object target, out TMember value)
@@ -237,7 +237,7 @@ namespace System.Reflection
 			void* targetDataAddr = (void*)(int)mFieldData.mConstValue;
 
 			Type fieldType = Type.[Friend]GetType(mFieldData.mFieldTypeId);
-			value.[Friend]mStructType = (int)(void*)fieldType;
+			value.[Friend]mStructType = (int)Internal.UnsafeCastToPtr(fieldType);
 
 			TypeCode typeCode = fieldType.[Friend]mTypeCode;
 			if (typeCode == TypeCode.Enum)

+ 2 - 2
BeefLibs/corlib/src/Reflection/MethodInfo.bf

@@ -238,7 +238,7 @@ namespace System.Reflection
 			if (args.Length != mMethodData.mParamCount)
 				return .Err(.ParamCountMismatch);
 
-			var (retVal, variantData) = Variant.Alloc(retType);
+			var variantData = Variant.Alloc(retType, var retVal);
 			void* retData = variantData;
 
 			// Struct return? Manually add it as an arg after 'this'.  Revisit this - this is architecture-dependent.
@@ -500,7 +500,7 @@ namespace System.Reflection
 			if (args.Count != mMethodData.mParamCount)
 				return .Err(.ParamCountMismatch);
 
-			var (retVal, variantData) = Variant.Alloc(retType);
+			var variantData = Variant.Alloc(retType, var retVal);
 			void* retData = variantData;
 
 			// Struct return? Manually add it as an arg after 'this'.  Revisit this - this is architecture-dependent.

+ 49 - 11
BeefLibs/corlib/src/Variant.bf

@@ -98,12 +98,12 @@ namespace System
 			if (val == null)
 			{
 				variant.mStructType = 2;
-				variant.mData = (int)(void*)typeof(T);
+				variant.mData = (int)Internal.UnsafeCastToPtr(typeof(T));
 			}
 			else
 			{
 				variant.mStructType = (int)(owns ? 1 : 0);
-				variant.mData = (int)(void*)val;
+				variant.mData = (int)Internal.UnsafeCastToPtr(val);
 			}
 			return variant;
 		}
@@ -112,7 +112,7 @@ namespace System
 		{
 			Variant variant;
 			Type type = typeof(T);
-			variant.mStructType = (int)(void*)type;
+			variant.mStructType = (int)Internal.UnsafeCastToPtr(type);
 			if (sizeof(T) <= sizeof(int))
 			{
 				variant.mData = 0;
@@ -131,7 +131,7 @@ namespace System
 		{
 			Variant variant;
 			Type type = typeof(T);
-			variant.mStructType = (int)(void*)type;
+			variant.mStructType = (int)Internal.UnsafeCastToPtr(type);
 			if (type.Size <= sizeof(int))
 			{
 				variant.mData = 0;
@@ -150,8 +150,7 @@ namespace System
 		{
 			Variant variant;
 			Debug.Assert(!type.IsObject);
-			//Debug.Assert((type.GetUnderlyingType() == typeof(T)) || (type == typeof(T)));
-			variant.mStructType = (int)(void*)type;
+			variant.mStructType = (int)Internal.UnsafeCastToPtr(type);
 			if (type.Size <= sizeof(int))
 			{
 				variant.mData = 0;
@@ -176,7 +175,7 @@ namespace System
 			}
 			else
 			{
-				variant.mStructType = (int)(void*)type;
+				variant.mStructType = (int)Internal.UnsafeCastToPtr(type);
 				if (type.Size <= sizeof(int))
 				{
 					variant.mData = 0;
@@ -304,18 +303,57 @@ namespace System
 			v1.Get<T>() == v2.Get<T>()
 		}
 
-		public static Result<Variant> CreateFromVariant(Variant varFrom, bool reference = true)
+		public static Result<Variant> CreateFromVariant(Variant varFrom)
 		{
 			Variant varTo = varFrom;
 			if (varTo.mStructType == 1)
 				varTo.mStructType = 0;
+			if (varTo.mStructType > 2)
+			{
+				let type = (Type)Internal.UnsafeCastToObject((void*)varFrom.mStructType);
+				if (type.[Friend]mSize > sizeof(int))
+				{
+					void* data = new uint8[type.[Friend]mSize]*;
+					Internal.MemCpy(data, (void*)varFrom.mData, type.[Friend]mSize);
+					varTo.mData = (int)data;
+				}
+			}
+
 			return varTo;
 		}
 
-		/*public static Result<Variant> CreateFromObject(Object objectFrom, bool reference = true)
+		public static Result<Variant> CreateFromBoxed(Object objectFrom)
 		{
+			if (objectFrom == null)
+				return default;
+
+			Variant variant = ?;
 			Type objType = objectFrom.[Friend]RawGetType();
-			
-		}*/
+			if (objType.IsBoxed)
+			{
+				void* srcDataPtr = (uint8*)Internal.UnsafeCastToPtr(objectFrom) + objType.[Friend]mMemberDataOffset;
+
+				var underlying = objType.UnderlyingType;
+				variant.mStructType = (int)Internal.UnsafeCastToPtr(underlying);
+				if (underlying.Size <= sizeof(int))
+				{
+					variant.mData = 0;
+					*(int*)&variant.mData = *(int*)srcDataPtr;
+				}
+				else
+				{
+					void* data = new uint8[underlying.[Friend]mSize]*;
+					Internal.MemCpy(data, srcDataPtr, underlying.[Friend]mSize);
+					variant.mData = (int)data;
+				}
+			}
+			else
+			{
+				variant.mStructType = 0;
+				variant.mData = (int)Internal.UnsafeCastToPtr(objectFrom);
+			}
+
+			return variant;
+		}
 	}
 }

+ 1 - 1
IDE/mintest/minlib/src/System/Reflection/FieldInfo.bf

@@ -246,7 +246,7 @@ namespace System.Reflection
 			void* targetDataAddr = (void*)(int)mFieldData.mConstValue;
 
 			Type fieldType = Type.[Friend]GetType(mFieldData.mFieldTypeId);
-			value.[Friend]mStructType = (int)(void*)fieldType;
+			value.[Friend]mStructType = (int)Internal.UnsafeCastToPtr(fieldType);
 
 			TypeCode typeCode = fieldType.[Friend]mTypeCode;
 			if (typeCode == TypeCode.Enum)

+ 87 - 14
IDE/mintest/minlib/src/System/Variant.bf

@@ -49,6 +49,26 @@ namespace System
 			}
 		}
 
+		public void* DataPtr
+		{
+			get mut
+			{
+				if (IsObject)
+				{
+					if (mStructType == 2)
+						return null;
+					Object obj = Internal.UnsafeCastToObject((void*)mData);
+					return (uint8*)Internal.UnsafeCastToPtr(obj) + obj.GetType().[Friend]mMemberDataOffset;
+				}
+
+				var type = VariantType;
+				if (type.Size <= sizeof(int))
+					return (void*)&mData;
+				else
+					return (void*)mData;
+			}
+		}
+
 		protected override void GCMarkMembers()
  		{
 			if ((mStructType == 1) || (mStructType == 0))
@@ -78,12 +98,12 @@ namespace System
 			if (val == null)
 			{
 				variant.mStructType = 2;
-				variant.mData = (int)(void*)typeof(T);
+				variant.mData = (int)Internal.UnsafeCastToPtr(typeof(T));
 			}
 			else
 			{
 				variant.mStructType = (int)(owns ? 1 : 0);
-				variant.mData = (int)(void*)val;
+				variant.mData = (int)Internal.UnsafeCastToPtr(val);
 			}
 			return variant;
 		}
@@ -92,7 +112,7 @@ namespace System
 		{
 			Variant variant;
 			Type type = typeof(T);
-			variant.mStructType = (int)(void*)type;
+			variant.mStructType = (int)Internal.UnsafeCastToPtr(type);
 			if (sizeof(T) <= sizeof(int))
 			{
 				variant.mData = 0;
@@ -111,7 +131,7 @@ namespace System
 		{
 			Variant variant;
 			Type type = typeof(T);
-			variant.mStructType = (int)(void*)type;
+			variant.mStructType = (int)Internal.UnsafeCastToPtr(type);
 			if (type.Size <= sizeof(int))
 			{
 				variant.mData = 0;
@@ -130,8 +150,7 @@ namespace System
 		{
 			Variant variant;
 			Debug.Assert(!type.IsObject);
-			//Debug.Assert((type.GetUnderlyingType() == typeof(T)) || (type == typeof(T)));
-			variant.mStructType = (int)(void*)type;
+			variant.mStructType = (int)Internal.UnsafeCastToPtr(type);
 			if (type.Size <= sizeof(int))
 			{
 				variant.mData = 0;
@@ -146,27 +165,27 @@ namespace System
 			return variant;
 		}
 
-		public static void* Alloc(Type type, out Variant variant)
+		public static (Variant, void*) Alloc(Type type)
 		{
-			variant = .();
+			Variant variant = .();
 
 			if (type.IsObject)
 			{
-				return &variant.mData;
+				return (variant, &variant.mData);
 			}
 			else
 			{
-				variant.mStructType = (int)(void*)type;
+				variant.mStructType = (int)Internal.UnsafeCastToPtr(type);
 				if (type.Size <= sizeof(int))
 				{
 					variant.mData = 0;
-					return &variant.mData;
+					return (variant, &variant.mData);
 				}
 				else
 				{
 					void* data = new uint8[type.[Friend]mSize]*;
 					variant.mData = (int)data;
-					return data;
+					return (variant, data);
 				}
 			}
 		}
@@ -222,8 +241,9 @@ namespace System
 			if (IsObject)
 			{
 				if (mStructType == 2)
-					*((Object*)dest) =null;
-				*((Object*)dest) = Internal.UnsafeCastToObject((void*)mData);
+					*((Object*)dest) = null;
+				else
+					*((Object*)dest) = Internal.UnsafeCastToObject((void*)mData);
 				return;
 			}
 			
@@ -282,5 +302,58 @@ namespace System
 		{
 			v1.Get<T>() == v2.Get<T>()
 		}
+
+		public static Result<Variant> CreateFromVariant(Variant varFrom)
+		{
+			Variant varTo = varFrom;
+			if (varTo.mStructType == 1)
+				varTo.mStructType = 0;
+			if (varTo.mStructType > 2)
+			{
+				let type = (Type)Internal.UnsafeCastToObject((void*)varFrom.mStructType);
+				if (type.[Friend]mSize > sizeof(int))
+				{
+					void* data = new uint8[type.[Friend]mSize]*;
+					Internal.MemCpy(data, (void*)varFrom.mData, type.[Friend]mSize);
+					varTo.mData = (int)data;
+				}
+			}
+
+			return varTo;
+		}
+
+		public static Result<Variant> CreateFromBoxed(Object objectFrom)
+		{
+			if (objectFrom == null)
+				return default;
+
+			Variant variant = ?;
+			Type objType = objectFrom.[Friend]RawGetType();
+			if (objType.IsBoxed)
+			{
+				void* srcDataPtr = (uint8*)Internal.UnsafeCastToPtr(objectFrom) + objType.[Friend]mMemberDataOffset;
+
+				var underlying = objType.UnderlyingType;
+				variant.mStructType = (int)Internal.UnsafeCastToPtr(underlying);
+				if (underlying.Size <= sizeof(int))
+				{
+					variant.mData = 0;
+					*(int*)&variant.mData = *(int*)srcDataPtr;
+				}
+				else
+				{
+					void* data = new uint8[underlying.[Friend]mSize]*;
+					Internal.MemCpy(data, srcDataPtr, underlying.[Friend]mSize);
+					variant.mData = (int)data;
+				}
+			}
+			else
+			{
+				variant.mStructType = 0;
+				variant.mData = (int)Internal.UnsafeCastToPtr(objectFrom);
+			}
+
+			return variant;
+		}
 	}
 }

+ 3 - 22
IDEHelper/Compiler/BfExprEvaluator.cpp

@@ -5162,15 +5162,11 @@ void BfExprEvaluator::PushArg(BfTypedValue argVal, SizedArrayImpl<BfIRValue>& ir
 		return;
 
 	bool wantSplat = false;
-	if (argVal.mType->IsSplattable())
+	if ((argVal.mType->IsSplattable()) && (!disableSplat))
 	{
+		disableLowering = true;
 		auto argTypeInstance = argVal.mType->ToTypeInstance();
-		if ((argTypeInstance != NULL) && (argTypeInstance->mIsCRepr))
-		{
-			// Always splat for crepr splattables
-			wantSplat = true;
-		}
-		else if ((!disableSplat) && (int)irArgs.size() + argVal.mType->GetSplatCount() <= mModule->mCompiler->mOptions.mMaxSplatRegs)
+		if ((!disableSplat) && (int)irArgs.size() + argVal.mType->GetSplatCount() <= mModule->mCompiler->mOptions.mMaxSplatRegs)
 			wantSplat = true;
 	}
 
@@ -5190,21 +5186,6 @@ void BfExprEvaluator::PushArg(BfTypedValue argVal, SizedArrayImpl<BfIRValue>& ir
 				BfTypeCode loweredTypeCode2 = BfTypeCode_None;				
 				if (argVal.mType->GetLoweredType(BfTypeUsage_Parameter, &loweredTypeCode, &loweredTypeCode2))
 				{
-// 					auto primType = mModule->GetPrimitiveType(loweredTypeCode);
-// 					auto ptrType = mModule->CreatePointerType(primType);
-// 					BfIRValue primPtrVal = mModule->mBfIRBuilder->CreateBitCast(argVal.mValue, mModule->mBfIRBuilder->MapType(ptrType));
-// 					auto primVal = mModule->mBfIRBuilder->CreateLoad(primPtrVal);
-// 					irArgs.push_back(primVal);
-// 
-// 					if (loweredTypeCode2 != BfTypeCode_None)
-// 					{
-// 						auto primType2 = mModule->GetPrimitiveType(loweredTypeCode2);
-// 						auto ptrType2 = mModule->CreatePointerType(primType2);
-// 						BfIRValue primPtrVal2 = mModule->mBfIRBuilder->CreateBitCast(mModule->mBfIRBuilder->CreateInBoundsGEP(primPtrVal, 1), mModule->mBfIRBuilder->MapType(ptrType2));
-// 						auto primVal2 = mModule->mBfIRBuilder->CreateLoad(primPtrVal2);
-// 						irArgs.push_back(primVal2);
-// 					}
-
 					auto primType = mModule->mBfIRBuilder->GetPrimitiveType(loweredTypeCode);
 					auto ptrType = mModule->mBfIRBuilder->GetPointerTo(primType);
 					BfIRValue primPtrVal = mModule->mBfIRBuilder->CreateBitCast(argVal.mValue, ptrType);

+ 15 - 13
IDEHelper/Compiler/BfModule.cpp

@@ -15579,13 +15579,12 @@ void BfModule::ProcessMethod_SetupParams(BfMethodInstance* methodInstance, BfTyp
 				BfTypeUtils::SplatIterate([&](BfType* checkType) { argIdx++; }, paramVar->mResolvedType);
 			}
 			else
-			{
+			{								
 				argIdx++;
+				if (loweredTypeCode2 != BfTypeCode_None)
+					argIdx++;
 			}
-		}
-
-		if (loweredTypeCode2 != BfTypeCode_None)
-			argIdx++;
+		}		
 	}
 
 	if (argIdx == methodInstance->GetStructRetIdx())
@@ -15630,13 +15629,7 @@ void BfModule::ProcessMethod_SetupParams(BfMethodInstance* methodInstance, BfTyp
 			}
 			else if (resolvedType->IsComposite() && resolvedType->IsSplattable())
 			{
-				auto resolvedTypeInst = resolvedType->ToTypeInstance();
-				if ((resolvedTypeInst != NULL) && (resolvedTypeInst->mIsCRepr))
-				{
-					// crepr splat is always splat
-					paramVar->mIsSplat = true;
-				}
-				else if (methodInstance->AllowsSplatting())
+				if (methodInstance->AllowsSplatting())
 				{
 					int splatCount = resolvedType->GetSplatCount();
 					if (argIdx + splatCount <= mCompiler->mOptions.mMaxSplatRegs)
@@ -20083,10 +20076,19 @@ void BfModule::DoMethodDeclaration(BfMethodDeclaration* methodDeclaration, bool
 	PopulateType(methodInstance->mReturnType, BfPopulateType_Data);	
 	if (!methodDef->mIsStatic)
     {
+		auto thisType = methodInstance->GetOwner();
 		if (methodInstance->GetParamIsSplat(-1))
 			argIdx += methodInstance->GetParamType(-1)->GetSplatCount();
-		else if (!methodInstance->GetOwner()->IsValuelessType())
+		else if (!thisType->IsValuelessType())
+		{
+			BfTypeCode loweredTypeCode = BfTypeCode_None;
+			BfTypeCode loweredTypeCode2 = BfTypeCode_None;
+			if (!methodDef->mIsMutating)
+				thisType->GetLoweredType(BfTypeUsage_Parameter, &loweredTypeCode, &loweredTypeCode2);
 			argIdx++;
+			if (loweredTypeCode2 != BfTypeCode_None)
+				argIdx++;
+		}
 	}
 
 	if (methodInstance->GetStructRetIdx() != -1)

+ 0 - 6
IDEHelper/Compiler/BfModuleTypeUtils.cpp

@@ -8881,12 +8881,6 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 
 	if (explicitCast)
 	{
-		// Object -> void*
-		if ((typedVal.mType->IsObject()) && (toType->IsVoidPtr()))
-		{
-			return mBfIRBuilder->CreateBitCast(typedVal.mValue, mBfIRBuilder->MapType(toType));
-		}
-
 		// Func -> void*
 		if ((typedVal.mType->IsFunction()) && (toType->IsVoidPtr()))
 		{

+ 8 - 4
IDEHelper/Compiler/BfReducer.cpp

@@ -2183,7 +2183,8 @@ BfExpression* BfReducer::CreateExpression(BfAstNode* node, CreateExprFlags creat
 
 					if (auto condExpr = BfNodeDynCast<BfConditionalExpression>(unaryOpExpr->mExpression))
 					{
-						exprLeft = condExpr;
+						if (exprLeft == unaryOpExpr)
+							exprLeft = condExpr;
 						ApplyToFirstExpression(unaryOpExpr, condExpr);
 					}
 
@@ -2197,7 +2198,8 @@ BfExpression* BfReducer::CreateExpression(BfAstNode* node, CreateExprFlags creat
 							MEMBER_SET(unaryOpExpr, mExpression, assignmentExpr->mLeft);
 							unaryOpExpr->SetSrcEnd(assignmentExpr->mLeft->GetSrcEnd());
 							MEMBER_SET(assignmentExpr, mLeft, unaryOpExpr);
-							exprLeft = assignmentExpr;
+							if (exprLeft == unaryOpExpr)
+								exprLeft = assignmentExpr;
 						}
 					}
 
@@ -2211,7 +2213,8 @@ BfExpression* BfReducer::CreateExpression(BfAstNode* node, CreateExprFlags creat
 							MEMBER_SET(unaryOpExpr, mExpression, dynCastExpr->mTarget);
 							unaryOpExpr->SetSrcEnd(dynCastExpr->mTarget->GetSrcEnd());
 							MEMBER_SET(dynCastExpr, mTarget, unaryOpExpr);
-							exprLeft = dynCastExpr;
+							if (exprLeft == unaryOpExpr)
+								exprLeft = dynCastExpr;
 						}
 					}
 
@@ -2225,7 +2228,8 @@ BfExpression* BfReducer::CreateExpression(BfAstNode* node, CreateExprFlags creat
 							MEMBER_SET(unaryOpExpr, mExpression, caseExpr->mValueExpression);
 							unaryOpExpr->SetSrcEnd(caseExpr->mValueExpression->GetSrcEnd());
 							MEMBER_SET(caseExpr, mValueExpression, unaryOpExpr);
-							exprLeft = caseExpr;
+							if (exprLeft == unaryOpExpr)
+								exprLeft = caseExpr;
 						}
 					}
 				}

+ 1 - 1
IDEHelper/Compiler/BfResolvedTypeUtils.cpp

@@ -924,7 +924,7 @@ int BfMethodInstance::DbgGetVirtualMethodNum()
 void BfMethodInstance::GetIRFunctionInfo(BfModule* module, BfIRType& returnType, SizedArrayImpl<BfIRType>& paramTypes, bool forceStatic)
 {
 	module->PopulateType(mReturnType);
-	
+
 	BfTypeCode loweredReturnTypeCode = BfTypeCode_None;
 	BfTypeCode loweredReturnTypeCode2 = BfTypeCode_None;	
 	if (GetLoweredReturnType(&loweredReturnTypeCode, &loweredReturnTypeCode2))