Przeglądaj źródła

Fixed comptime reflected static field accesses

Brian Fiete 1 rok temu
rodzic
commit
a356186514

+ 8 - 0
BeefLibs/corlib/src/Reflection/FieldInfo.bf

@@ -327,6 +327,14 @@ namespace System.Reflection
 				if (!mFieldData.mFlags.HasFlag(FieldFlags.Static))
 					return .Err(.InvalidTargetType);
 
+				if (Compiler.IsComptime)
+				{
+					void* dataPtr = Type.[Friend]Comptime_Field_GetStatic((int32)mTypeInstance.TypeId, (int32)mFieldData.mData);
+					if (dataPtr != null)
+						value = *(TMember*)dataPtr;
+					return .Ok;
+				}
+
 				targetDataAddr = null;
 			}
 			else

+ 1 - 0
BeefLibs/corlib/src/Type.bf

@@ -567,6 +567,7 @@ namespace System
 		static extern Type Comptime_Method_GetGenericArg(int64 methodHandle, int32 genericArgIdx);
 		static extern String Comptime_Field_GetName(int64 fieldHandle);
 		static extern ComptimeFieldInfo Comptime_Field_GetInfo(int64 fieldHandle);
+		static extern void* Comptime_Field_GetStatic(int32 typeId, int32 fieldIdx);
 
         protected static Type GetType(TypeId typeId)
         {

+ 8 - 1
IDEHelper/Compiler/BfModule.cpp

@@ -5859,8 +5859,15 @@ BfIRValue BfModule::CreateFieldData(BfFieldInstance* fieldInstance, int customAt
 	else if (fieldInstance->GetFieldDef()->mIsStatic)
 	{
 		BfTypedValue refVal;
-		if (!mIsComptimeModule) // This can create circular reference issues for a `Self` static
+		if (mIsComptimeModule)
+		{
+			constValue = mBfIRBuilder->CreateConst(BfTypeCode_IntPtr, fieldInstance->mFieldIdx);
+		}
+		else
+		{
 			refVal = ReferenceStaticField(fieldInstance);
+		}
+
 		if (refVal.mValue.IsConst())
 		{
 			auto constant = mBfIRBuilder->GetConstant(refVal.mValue);

+ 12 - 0
IDEHelper/Compiler/BfModuleTypeUtils.cpp

@@ -2140,6 +2140,15 @@ BfCEParseContext BfModule::CEEmitParse(BfTypeInstance* typeInstance, BfTypeDef*
 	ceParseContext.mFailIdx = mCompiler->mPassInstance->mFailedIdx;
 	ceParseContext.mWarnIdx = mCompiler->mPassInstance->mWarnIdx;
 
+	if (typeInstance->mTypeDef->mEmitParent == NULL)
+	{
+		if (typeInstance->mTypeDef->mNextRevision != NULL)
+		{
+			InternalError("CEEmitParse preconditions failed");
+			return ceParseContext;
+		}
+	}
+
 	bool createdParser = false;
 	int startSrcIdx = 0;
 
@@ -5203,6 +5212,9 @@ void BfModule::DoPopulateType(BfType* resolvedTypeRef, BfPopulateType populateTy
 
 			if (hadNewMembers)
 			{
+				// Avoid getting stale cached comptime reflection info
+				mCompiler->mCeMachine->mCeModule->mTypeDataRefs.Remove(resolvedTypeRef);
+
 				// We need to avoid passing in BfPopulateType_Interfaces_All because it could cause us to miss out on new member processing,
 				//  including resizing the method group table
 				DoPopulateType(resolvedTypeRef, BF_MAX(populateType, BfPopulateType_Data));

+ 127 - 15
IDEHelper/Compiler/CeMachine.cpp

@@ -7,6 +7,7 @@
 #include "BfReducer.h"
 #include "BfExprEvaluator.h"
 #include "BfResolvePass.h"
+#include "BfMangler.h"
 #include "../Backend/BeIRCodeGen.h"
 #include "BeefySysLib/platform/PlatformHelper.h"
 #include "../DebugManager.h"
@@ -5883,7 +5884,7 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8*
 		return false;
 	};
 
-	auto _CheckFunction = [&](CeFunction* checkFunction, bool& handled)
+	std::function<bool(CeFunction* checkFunction, bool& handled)> _CheckFunction = [&](CeFunction* checkFunction, bool& handled)
 	{
 		if (checkFunction == NULL)
 		{
@@ -6326,6 +6327,106 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8*
 				_FixVariables();
 				CeSetAddrVal(stackPtr + 0, reflectType, ptrSize);
 			}
+			else if (checkFunction->mFunctionKind == CeFunctionKind_Field_GetStatic)
+			{
+				int32 typeId = *(int32*)((uint8*)stackPtr + ptrSize);
+				int32 fieldIdx = *(int32*)((uint8*)stackPtr + ptrSize + 4);
+
+				CeFunction* ctorCallFunction = NULL;
+
+				BfType* bfType = GetBfType(typeId);
+				bool success = false;
+				if (bfType != NULL)
+				{
+					auto typeInst = bfType->ToTypeInstance();
+					if (typeInst != NULL)
+					{
+						if (typeInst->mDefineState < BfTypeDefineState_CETypeInit)
+							mCurModule->PopulateType(typeInst);
+						if ((fieldIdx >= 0) && (fieldIdx < typeInst->mFieldInstances.mSize))
+						{
+							auto& fieldInstance = typeInst->mFieldInstances[fieldIdx];
+
+							auto fieldType = fieldInstance.mResolvedType;
+							ceModule->PopulateType(fieldType, BfPopulateType_Full_Force);
+
+							int64 fieldId = ((int64)typeId << 32) | fieldIdx;
+
+							CeStaticFieldInfo* staticFieldInfo = NULL;
+							if (mStaticFieldIdMap.TryAdd(fieldId, NULL, &staticFieldInfo))
+							{
+								if (mStaticCtorExecSet.TryAdd(typeId, NULL))
+								{
+									BfTypeInstance* bfTypeInstance = NULL;
+									if (bfType != NULL)
+										bfTypeInstance = bfType->ToTypeInstance();
+									if (bfTypeInstance == NULL)
+									{
+										_Fail("Invalid type");
+										return false;
+									}
+
+									auto methodDef = bfTypeInstance->mTypeDef->GetMethodByName("__BfStaticCtor");
+									if (methodDef == NULL)
+									{
+										_Fail("No static ctor found");
+										return false;
+									}
+
+									auto moduleMethodInstance = ceModule->GetMethodInstance(bfTypeInstance, methodDef, BfTypeVector());
+									if (!moduleMethodInstance)
+									{
+										_Fail("No static ctor instance found");
+										return false;
+									}
+
+									bool added = false;
+									ctorCallFunction = mCeMachine->GetFunction(moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, added);
+									if (ctorCallFunction->mInitializeState < CeFunction::InitializeState_Initialized)
+										mCeMachine->PrepareFunction(ctorCallFunction, NULL);
+								}
+
+								_FixVariables();
+
+								StringT<4096> staticVarName;
+								BfMangler::Mangle(staticVarName, ceModule->mCompiler->GetMangleKind(), &fieldInstance);
+
+								CeStaticFieldInfo* nameStaticFieldInfo = NULL;
+								mStaticFieldMap.TryAdd(staticVarName, NULL, &nameStaticFieldInfo);
+
+								if (nameStaticFieldInfo->mAddr == 0)
+								{
+									int fieldSize = fieldInstance.mResolvedType->mSize;
+									CE_CHECKALLOC(fieldSize);
+									uint8* ptr = CeMalloc(fieldSize);
+									_FixVariables();
+									if (fieldSize > 0)
+										memset(ptr, 0, fieldSize);
+									nameStaticFieldInfo->mAddr = (addr_ce)(ptr - memStart);
+								}
+
+								staticFieldInfo->mAddr = nameStaticFieldInfo->mAddr;
+							}
+
+							CeSetAddrVal(stackPtr + 0, staticFieldInfo->mAddr, ptrSize);
+						}
+						else if (fieldIdx != -1)
+						{
+							_Fail("Invalid field");
+							return false;
+						}
+					}
+				}
+
+				if (ctorCallFunction != NULL)
+				{
+					bool handled = false;
+					if (!_CheckFunction(ctorCallFunction, handled))
+						return false;
+					if (!handled)
+						CE_CALL(ctorCallFunction);
+				}
+			}
 			else if (checkFunction->mFunctionKind == CeFunctionKind_SetReturnType)
 			{
 				int32 typeId = *(int32*)((uint8*)stackPtr);
@@ -7920,24 +8021,30 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8*
 						return false;
 					}
 
-					auto methodDef = bfTypeInstance->mTypeDef->GetMethodByName("__BfStaticCtor");
-					if (methodDef == NULL)
+					if (bfType->mDefineState == BfTypeDefineState_CETypeInit)
 					{
-						_Fail("No static ctor found");
-						return false;
+						// Don't create circular references
 					}
-
-					auto moduleMethodInstance = ceModule->GetMethodInstance(bfTypeInstance, methodDef, BfTypeVector());
-					if (!moduleMethodInstance)
+					else
 					{
-						_Fail("No static ctor instance found");
-						return false;
-					}
+						auto methodDef = bfTypeInstance->mTypeDef->GetMethodByName("__BfStaticCtor");
+						if (methodDef != NULL)
+						{
+							auto moduleMethodInstance = ceModule->GetMethodInstance(bfTypeInstance, methodDef, BfTypeVector());
+							if (!moduleMethodInstance)
+							{
+								_Fail("No static ctor instance found");
+								return false;
+							}
+
+							ceModule->PopulateType(bfTypeInstance, BfPopulateType_DataAndMethods);
 
-					bool added = false;
-					ctorCallFunction = mCeMachine->GetFunction(moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, added);
-					if (ctorCallFunction->mInitializeState < CeFunction::InitializeState_Initialized)
-						mCeMachine->PrepareFunction(ctorCallFunction, NULL);
+							bool added = false;
+							ctorCallFunction = mCeMachine->GetFunction(moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, added);
+							if (ctorCallFunction->mInitializeState < CeFunction::InitializeState_Initialized)
+								mCeMachine->PrepareFunction(ctorCallFunction, NULL);
+						}
+					}
 				}
 
 				CeStaticFieldInfo* staticFieldInfo = NULL;
@@ -9549,6 +9656,10 @@ void CeMachine::CheckFunctionKind(CeFunction* ceFunction)
 				{
 					ceFunction->mFunctionKind = CeFunctionKind_Method_GetGenericArg;
 				}
+				else if (methodDef->mName == "Comptime_Field_GetStatic")
+				{
+					ceFunction->mFunctionKind = CeFunctionKind_Field_GetStatic;
+				}
 			}
 			else if (owner->IsInstanceOf(mCeModule->mCompiler->mCompilerTypeDef))
 			{
@@ -10089,6 +10200,7 @@ void CeMachine::ReleaseContext(CeContext* ceContext)
 		ceContext->mMemory.Dispose();
 	ceContext->mStaticCtorExecSet.Clear();
 	ceContext->mStaticFieldMap.Clear();
+	ceContext->mStaticFieldIdMap.Clear();
 	ceContext->mHeap->Clear(BF_CE_MAX_CARRYOVER_HEAP);
 	ceContext->mReflectTypeIdOffset = -1;
 	mCurEmitContext = ceContext->mCurEmitContext;

+ 2 - 0
IDEHelper/Compiler/CeMachine.h

@@ -447,6 +447,7 @@ enum CeFunctionKind
 	CeFunctionKind_Method_GetInfo,
 	CeFunctionKind_Method_GetParamInfo,
 	CeFunctionKind_Method_GetGenericArg,
+	CeFunctionKind_Field_GetStatic,
 
 	CeFunctionKind_SetReturnType,
 	CeFunctionKind_Align,
@@ -1106,6 +1107,7 @@ public:
 	Dictionary<Val128, addr_ce> mConstDataMap;
 	HashSet<int> mStaticCtorExecSet;
 	Dictionary<String, CeStaticFieldInfo> mStaticFieldMap;
+	Dictionary<int64, CeStaticFieldInfo> mStaticFieldIdMap;
 	Dictionary<int, CeInternalData*> mInternalDataMap;
 	int mCurHandleId;
 

+ 19 - 0
IDEHelper/Tests/src/Comptime.bf

@@ -481,6 +481,21 @@ namespace Tests
 			}
 		}
 
+		class ClassB
+		{
+			public static int mA = 123;
+		}
+
+		class ClassC
+		{
+			[OnCompile(.TypeInit), Comptime]
+			static void Init()
+			{
+				typeof(ClassB).GetField("mA").Value.GetValue<int>(null, var value);
+				Compiler.EmitTypeBody(typeof(Self), scope $"public static int sA = {1000 + value};");
+			}
+		}
+
 		[Test]
 		public static void TestBasics()
 		{
@@ -499,6 +514,8 @@ namespace Tests
 			Test.Assert(sa.mC == 345);
 			Test.Assert(sa.GetValC() == 345);
 
+			Test.Assert(ClassC.sA == 1123);
+
 			Compiler.Mixin("int val = 99;");
 			Test.Assert(val == 99);
 
@@ -532,6 +549,8 @@ namespace Tests
 			Test.Assert(typeof(decltype(f)) == typeof(float));
 			Test.Assert(ClassB<const 3>.cTimesTen == 30);
 
+
+
 			DictWrapper<Dictionary<int, float>> dictWrap = scope .();
 			dictWrap.[Friend]mValue.Add(1, 2.3f);
 			dictWrap.[Friend]mValue.Add(2, 3.4f);