Jelajahi Sumber

Fixed deferred function call

Brian Fiete 5 bulan lalu
induk
melakukan
be0733d37c

+ 20 - 11
IDEHelper/Compiler/BfExprEvaluator.cpp

@@ -3356,8 +3356,7 @@ BfExprEvaluator::BfExprEvaluator(BfModule* module)
 	mExpectingType = NULL;
 	mFunctionBindResult = NULL;
 	mExplicitCast = false;
-	mDeferCallRef = NULL;
-	mDeferScopeAlloc = NULL;
+	mDeferCallData = NULL;	
 	mPrefixedAttributeState = NULL;
 	mResolveGenericParam = true;
 	mNoBind = false;
@@ -6938,7 +6937,7 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, BfMethodInstance*
 
 	if (methodInstance->mVirtualTableIdx != -1)
 	{
-		if ((!bypassVirtual) && (mDeferCallRef == NULL))
+		if ((!bypassVirtual) && (mDeferCallData == NULL))
 		{
 			if ((methodDef->mIsOverride) && (mModule->mCurMethodInstance->mIsReified))
 			{
@@ -7107,9 +7106,12 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, BfMethodInstance*
 		return _GetDefaultReturnValue();
 	}
 
-	if (mDeferCallRef != NULL)
+	if (mDeferCallData != NULL)
 	{
-		mModule->AddDeferredCall(BfModuleMethodInstance(methodInstance, func), irArgs, mDeferScopeAlloc, mDeferCallRef, bypassVirtual);
+		if (mDeferCallData->mFuncAlloca_Orig == func)
+			mModule->AddDeferredCall(BfModuleMethodInstance(methodInstance, mDeferCallData->mFuncAlloca), irArgs, mDeferCallData->mScopeAlloc, mDeferCallData->mRefNode, bypassVirtual, false, true);			
+		else
+			mModule->AddDeferredCall(BfModuleMethodInstance(methodInstance, func), irArgs, mDeferCallData->mScopeAlloc, mDeferCallData->mRefNode, bypassVirtual);
 		return mModule->GetFakeTypedValue(returnType);
 	}
 
@@ -7861,6 +7863,13 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu
 			auto funcType = mModule->mBfIRBuilder->MapMethod(moduleMethodInstance.mMethodInstance);
 			auto funcPtrType = mModule->mBfIRBuilder->GetPointerTo(funcType);
 			moduleMethodInstance.mFunc = mModule->mBfIRBuilder->CreateIntToPtr(target.mValue, funcPtrType);
+
+			if (mDeferCallData != NULL)
+			{
+				mDeferCallData->mFuncAlloca_Orig = moduleMethodInstance.mFunc;
+				mDeferCallData->mFuncAlloca = mModule->CreateAlloca(funcPtrType, target.mType->mAlign, false, "FuncAlloca");
+				mModule->mBfIRBuilder->CreateStore(mDeferCallData->mFuncAlloca_Orig, mDeferCallData->mFuncAlloca);
+			}
 		}
 		else if (!methodDef->mIsStatic)
 		{
@@ -8070,7 +8079,7 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu
 				autoComplete->mIsCapturingMethodMatchInfo = wasCapturingMatchInfo;
 		});
 
-	BfScopeData* boxScopeData = mDeferScopeAlloc;
+	BfScopeData* boxScopeData = (mDeferCallData != NULL) ? mDeferCallData->mScopeAlloc : NULL;
 	if ((boxScopeData == NULL) && (mModule->mCurMethodState != NULL))
 		boxScopeData = mModule->mCurMethodState->mCurScope;
 
@@ -9254,9 +9263,9 @@ BfTypedValue BfExprEvaluator::ResolveArgValue(BfResolvedArg& resolvedArg, BfType
 
 				if ((argValue) && (argValue.mType != wantType) && (wantType != NULL))
 				{
-					if ((mDeferScopeAlloc != NULL) && (wantType == mModule->mContext->mBfObjectType))
+					if ((mDeferCallData != NULL) && (wantType == mModule->mContext->mBfObjectType))
 					{
-						BfAllocTarget allocTarget(mDeferScopeAlloc);
+						BfAllocTarget allocTarget(mDeferCallData->mScopeAlloc);
 						argValue = mModule->BoxValue(expr, argValue, wantType, allocTarget, ((mBfEvalExprFlags & BfEvalExprFlags_Comptime) != 0) ? BfCastFlags_WantsConst : BfCastFlags_None);
 					}
 					else
@@ -17731,7 +17740,7 @@ void BfExprEvaluator::InjectMixin(BfAstNode* targetSrc, BfTypedValue target, boo
 	if (mModule->mCurMethodState == NULL)
 		return;
 
-	if (mDeferCallRef != NULL)
+	if (mDeferCallData != NULL)
 	{
 		mModule->Fail("Mixins cannot be directly deferred. Consider wrapping in a block.", targetSrc);
 	}
@@ -20013,7 +20022,7 @@ BfTypedValue BfExprEvaluator::GetResult(bool clearResult, bool resolveGenericTyp
 		if (!handled)
 		{
 			SetAndRestoreValue<BfFunctionBindResult*> prevFunctionBindResult(mFunctionBindResult, NULL);
-			SetAndRestoreValue<BfAstNode*> prevDeferCallRef(mDeferCallRef, NULL);
+			SetAndRestoreValue<BfDeferCallData*> prevDeferCallRef(mDeferCallData, NULL);
 
 			BfMethodDef* matchedMethod = GetPropertyMethodDef(mPropDef, BfMethodType_PropertyGetter, mPropCheckedKind, mPropTarget);
 			if (matchedMethod == NULL)
@@ -23198,7 +23207,7 @@ BfTypedValue BfExprEvaluator::PerformUnaryOperation_TryOperator(const BfTypedVal
 	else
 	{
 		SetAndRestoreValue<BfEvalExprFlags> prevFlags(mBfEvalExprFlags, (BfEvalExprFlags)(mBfEvalExprFlags | BfEvalExprFlags_NoAutoComplete));
-		SetAndRestoreValue<BfAstNode*> prevDeferCallRef(mDeferCallRef, NULL);
+		SetAndRestoreValue<BfDeferCallData*> prevDeferCallRef(mDeferCallData, NULL);
 		result = CreateCall(&methodMatcher, callTarget);
 	}
 

+ 10 - 3
IDEHelper/Compiler/BfExprEvaluator.h

@@ -401,6 +401,14 @@ enum BfBinOpFlags
 	BfBinOpFlag_DeferRight = 0x20
 };
 
+struct BfDeferCallData
+{
+	BfAstNode* mRefNode;
+	BfScopeData* mScopeAlloc;
+	BfIRValue mFuncAlloca; // When we need to load
+	BfIRValue mFuncAlloca_Orig;
+};
+
 class BfExprEvaluator : public BfStructuralVisitor
 {
 public:
@@ -422,9 +430,8 @@ public:
 	BfAttributeState* mPrefixedAttributeState;
 	BfTypedValue* mReceivingValue;
 	BfFunctionBindResult* mFunctionBindResult;
-	SizedArray<BfResolvedArg, 2> mIndexerValues;
-	BfAstNode* mDeferCallRef;
-	BfScopeData* mDeferScopeAlloc;
+	SizedArray<BfResolvedArg, 2> mIndexerValues;	
+	BfDeferCallData* mDeferCallData;	
 	bool mUsedAsStatement;
 	bool mPropDefBypassVirtual;
 	bool mExplicitCast;

+ 16 - 6
IDEHelper/Compiler/BfModule.cpp

@@ -2080,23 +2080,23 @@ void BfModule::RestoreScopeState()
 	mCurMethodState->mTailScope = mCurMethodState->mCurScope;
 }
 
-BfIRValue BfModule::CreateAlloca(BfType* type, bool addLifetime, const char* name, BfIRValue arraySize)
+BfIRValue BfModule::CreateAlloca(BfIRType irType, int align, bool addLifetime, const char* name, BfIRValue arraySize)
 {
 	if (mBfIRBuilder->mIgnoreWrites)
 		return mBfIRBuilder->GetFakeVal();
 
-	BF_ASSERT((*(int8*)&addLifetime == 1) || (*(int8*)&addLifetime == 0));
-	mBfIRBuilder->PopulateType(type);
+	BF_ASSERT((*(int8*)&addLifetime == 1) || (*(int8*)&addLifetime == 0));	
 	auto prevInsertBlock = mBfIRBuilder->GetInsertBlock();
 	if (!mBfIRBuilder->mIgnoreWrites)
 		BF_ASSERT(!prevInsertBlock.IsFake());
 	mBfIRBuilder->SetInsertPoint(mCurMethodState->mIRHeadBlock);
 	BfIRValue allocaInst;
 	if (arraySize)
-		allocaInst = mBfIRBuilder->CreateAlloca(mBfIRBuilder->MapType(type), arraySize);
+		allocaInst = mBfIRBuilder->CreateAlloca(irType, arraySize);
 	else
-		allocaInst = mBfIRBuilder->CreateAlloca(mBfIRBuilder->MapType(type));
-	mBfIRBuilder->SetAllocaAlignment(allocaInst, type->mAlign);
+		allocaInst = mBfIRBuilder->CreateAlloca(irType);
+	if (align > 0)
+		mBfIRBuilder->SetAllocaAlignment(allocaInst, align);
 	mBfIRBuilder->ClearDebugLocation(allocaInst);
 	if (name != NULL)
 		mBfIRBuilder->SetName(allocaInst, name);
@@ -2110,6 +2110,16 @@ BfIRValue BfModule::CreateAlloca(BfType* type, bool addLifetime, const char* nam
 	return allocaInst;
 }
 
+BfIRValue BfModule::CreateAlloca(BfType* type, bool addLifetime, const char* name, BfIRValue arraySize)
+{
+	if (mBfIRBuilder->mIgnoreWrites)
+		return mBfIRBuilder->GetFakeVal();
+
+	BF_ASSERT((*(int8*)&addLifetime == 1) || (*(int8*)&addLifetime == 0));
+	mBfIRBuilder->PopulateType(type);
+	return CreateAlloca(mBfIRBuilder->MapType(type), type->mAlign, addLifetime, name, arraySize);
+}
+
 BfIRValue BfModule::CreateAllocaInst(BfTypeInstance* typeInst, bool addLifetime, const char* name)
 {
 	if (mBfIRBuilder->mIgnoreWrites)

+ 4 - 1
IDEHelper/Compiler/BfModule.h

@@ -323,6 +323,7 @@ public:
 	bool mCastThis;
 	bool mArgsNeedLoad;
 	bool mIgnored;
+	bool mIsAllocaFunc;
 
 	SLIList<BfDeferredCallEntry*> mDynList;
 	BfIRValue mDynCallTail;
@@ -1481,6 +1482,7 @@ enum BfDeferredBlockFlags
 	BfDeferredBlockFlag_DoNullChecks = 2,
 	BfDeferredBlockFlag_SkipObjectAccessCheck = 4,
 	BfDeferredBlockFlag_MoveNewBlocksToEnd = 8,
+	BfDeferredBlockFlag_IsAllocaFunc = 0x10
 };
 
 enum BfGetCustomAttributesFlags
@@ -1696,6 +1698,7 @@ public:
 
 	BfTypedValue FlushNullConditional(BfTypedValue result, bool ignoreNullable = false);
 	void NewScopeState(bool createLexicalBlock = true, bool flushValueScope = true); // returns prev scope data
+	BfIRValue CreateAlloca(BfIRType irType, int align, bool addLifetime = true, const char* name = NULL, BfIRValue arraySize = BfIRValue());
 	BfIRValue CreateAlloca(BfType* type, bool addLifetime = true, const char* name = NULL, BfIRValue arraySize = BfIRValue());
 	BfIRValue CreateAllocaInst(BfTypeInstance* typeInst, bool addLifetime = true, const char* name = NULL);
 	BfDeferredCallEntry* AddStackAlloc(BfTypedValue val, BfIRValue arraySize, BfAstNode* refNode, BfScopeData* scope, bool condAlloca = false, bool mayEscape = false, BfIRBlock valBlock = BfIRBlock());
@@ -1721,7 +1724,7 @@ public:
 	void EmitDeferredCall(BfModuleMethodInstance moduleMethodInstance, SizedArrayImpl<BfIRValue>& llvmArgs, BfDeferredBlockFlags flags = BfDeferredBlockFlag_None);
 	bool AddDeferredCallEntry(BfDeferredCallEntry* deferredCallEntry, BfScopeData* scope);
 	BfDeferredCallEntry* AddDeferredBlock(BfBlock* block, BfScopeData* scope, Array<BfDeferredCapture>* captures = NULL);
-	BfDeferredCallEntry* AddDeferredCall(const BfModuleMethodInstance& moduleMethodInstance, SizedArrayImpl<BfIRValue>& llvmArgs, BfScopeData* scope, BfAstNode* srcNode = NULL, bool bypassVirtual = false, bool doNullCheck = false);
+	BfDeferredCallEntry* AddDeferredCall(const BfModuleMethodInstance& moduleMethodInstance, SizedArrayImpl<BfIRValue>& llvmArgs, BfScopeData* scope, BfAstNode* srcNode = NULL, bool bypassVirtual = false, bool doNullCheck = false, bool isAllocaFunc = false);
 	void EmitDeferredCall(BfScopeData* scopeData, BfDeferredCallEntry& deferredCallEntry, bool moveBlocks);
 	void EmitDeferredCallProcessor(BfScopeData* scopeData, SLIList<BfDeferredCallEntry*>& callEntries, BfIRValue callTail);
 	void EmitDeferredCallProcessorInstances(BfScopeData* scopeData);

+ 20 - 4
IDEHelper/Compiler/BfStmtEvaluator.cpp

@@ -543,11 +543,12 @@ BfDeferredCallEntry* BfModule::AddDeferredBlock(BfBlock* block, BfScopeData* sco
 	return deferredCallEntry;
 }
 
-BfDeferredCallEntry* BfModule::AddDeferredCall(const BfModuleMethodInstance& moduleMethodInstance, SizedArrayImpl<BfIRValue>& llvmArgs, BfScopeData* scopeData, BfAstNode* srcNode, bool bypassVirtual, bool doNullCheck)
+BfDeferredCallEntry* BfModule::AddDeferredCall(const BfModuleMethodInstance& moduleMethodInstance, SizedArrayImpl<BfIRValue>& llvmArgs, BfScopeData* scopeData, BfAstNode* srcNode, bool bypassVirtual, bool doNullCheck, bool isAllocaFunc)
 {
 	BfDeferredCallEntry* deferredCallEntry = new BfDeferredCallEntry();
 	BF_ASSERT(moduleMethodInstance);
 	deferredCallEntry->mModuleMethodInstance = moduleMethodInstance;
+	deferredCallEntry->mIsAllocaFunc = isAllocaFunc;
 
 	for (auto arg : llvmArgs)
 	{
@@ -783,7 +784,12 @@ void BfModule::EmitDeferredCall(BfModuleMethodInstance moduleMethodInstance, Siz
 	}
 
 	BfExprEvaluator expressionEvaluator(this);
-	expressionEvaluator.CreateCall(NULL, moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, ((flags & BfDeferredBlockFlag_BypassVirtual) != 0), llvmArgs);
+
+	auto func = moduleMethodInstance.mFunc;
+	if ((flags & BfDeferredBlockFlag_IsAllocaFunc) != 0)
+		func = mBfIRBuilder->CreateLoad(func);
+
+	expressionEvaluator.CreateCall(NULL, moduleMethodInstance.mMethodInstance, func, ((flags & BfDeferredBlockFlag_BypassVirtual) != 0), llvmArgs);
 
 	if ((flags & BfDeferredBlockFlag_DoNullChecks) != 0)
 	{
@@ -914,6 +920,8 @@ void BfModule::EmitDeferredCall(BfScopeData* scopeData, BfDeferredCallEntry& def
 		flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_DoNullChecks | BfDeferredBlockFlag_SkipObjectAccessCheck | BfDeferredBlockFlag_MoveNewBlocksToEnd);
 	if (moveBlocks)
 		flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_MoveNewBlocksToEnd);
+	if (deferredCallEntry.mIsAllocaFunc)
+		flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_IsAllocaFunc);
 
 	EmitDeferredCall(deferredCallEntry.mModuleMethodInstance, args, flags);
 }
@@ -926,6 +934,7 @@ void BfModule::EmitDeferredCallProcessor(BfScopeData* scopeData, SLIList<BfDefer
 	{
 		BfModuleMethodInstance mModuleMethodInstance;
 		bool mBypassVirtual;
+		bool mIsAllocaFunc;
 	};
 
 	//typedef std::map<int64, _CallInfo> MapType;
@@ -951,6 +960,7 @@ void BfModule::EmitDeferredCallProcessor(BfScopeData* scopeData, SLIList<BfDefer
 			{
 				callInfo->mModuleMethodInstance = moduleMethodInstance;
 				callInfo->mBypassVirtual = deferredCallEntry->mBypassVirtual;
+				callInfo->mIsAllocaFunc = deferredCallEntry->mIsAllocaFunc;
 			}
 			else
 			{
@@ -1118,6 +1128,7 @@ void BfModule::EmitDeferredCallProcessor(BfScopeData* scopeData, SLIList<BfDefer
 	{
 		auto moduleMethodInstance = callInfoKV.mValue.mModuleMethodInstance;
 		bool bypassVirtual = callInfoKV.mValue.mBypassVirtual;
+		bool isAllocaFunc = callInfoKV.mValue.mIsAllocaFunc;
 		auto methodInstance = moduleMethodInstance.mMethodInstance;
 		auto methodDef = methodInstance->mMethodDef;
 		auto methodOwner = methodInstance->mMethodInstanceGroup->mOwner;
@@ -1204,6 +1215,8 @@ void BfModule::EmitDeferredCallProcessor(BfScopeData* scopeData, SLIList<BfDefer
 			flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_DoNullChecks | BfDeferredBlockFlag_SkipObjectAccessCheck);
 		if (bypassVirtual)
 			flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_BypassVirtual);
+		if (isAllocaFunc)
+			flags = (BfDeferredBlockFlags)(flags | BfDeferredBlockFlag_IsAllocaFunc);
 		EmitDeferredCall(moduleMethodInstance, llvmArgs, flags);
 		ValueScopeEnd(valueScopeStart);
 		mBfIRBuilder->CreateBr(condBB);
@@ -7336,9 +7349,12 @@ void BfModule::Visit(BfDeferStatement* deferStmt)
 	}
 	else if (auto exprStmt = BfNodeDynCast<BfExpressionStatement>(deferStmt->mTargetNode))
 	{
+		BfDeferCallData deferCallData;
+		deferCallData.mRefNode = exprStmt->mExpression;
+		deferCallData.mScopeAlloc = scope;
+
 		BfExprEvaluator expressionEvaluator(this);
-		expressionEvaluator.mDeferCallRef = exprStmt->mExpression;
-		expressionEvaluator.mDeferScopeAlloc = scope;
+		expressionEvaluator.mDeferCallData = &deferCallData;		
 		expressionEvaluator.VisitChild(exprStmt->mExpression);
 		if (mCurMethodState->mPendingNullConditional != NULL)
 			FlushNullConditional(expressionEvaluator.mResult, true);

+ 11 - 0
IDEHelper/Tests/src/Functions.bf

@@ -182,6 +182,13 @@ namespace Tests
 			{
 				sVal = 123;
 			}
+
+			public static void TestDefer()
+			{
+				function void() func = => Func;
+				if (func != null)
+					defer:: func.Invoke();
+			}
 		}
 
 		public static int UseFunc0<T>(function int (T this, float f) func, T a, float b)
@@ -254,6 +261,10 @@ namespace Tests
 
 			ClassC<Zoop>.Test();
 			Test.Assert(Zoop.sVal == 123);
+
+			Zoop.sVal = 0;
+			Zoop.TestDefer();
+			Test.Assert(Zoop.sVal == 123);
 		}
 	}
 }