2
0
Эх сурвалжийг харах

Allow `=> funcPtr` function binding

Brian Fiete 4 жил өмнө
parent
commit
eaeb5ab6f8

+ 18 - 10
IDEHelper/Compiler/BfExprEvaluator.cpp

@@ -1079,14 +1079,21 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
 		{
 		{
 			if (delegateBindExpr->mNewToken == NULL)
 			if (delegateBindExpr->mNewToken == NULL)
 			{
 			{
-				resolvedArg.mExpectedType = checkType;
-				auto methodRefType = mModule->CreateMethodRefType(boundMethodInstance);
-				mModule->AddDependency(methodRefType, mModule->mCurTypeInstance, BfDependencyMap::DependencyFlag_Calls);				
-				mModule->AddCallDependency(boundMethodInstance);
-				argTypedValue = BfTypedValue(mModule->mBfIRBuilder->GetFakeVal(), methodRefType);
+				if (boundMethodInstance->GetOwner()->IsFunction())
+				{
+					return BfTypedValue(mModule->mBfIRBuilder->GetFakeVal(), boundMethodInstance->GetOwner());
+				}
+				else
+				{
+					resolvedArg.mExpectedType = checkType;
+					auto methodRefType = mModule->CreateMethodRefType(boundMethodInstance);
+					mModule->AddDependency(methodRefType, mModule->mCurTypeInstance, BfDependencyMap::DependencyFlag_Calls);
+					mModule->AddCallDependency(boundMethodInstance);
+					argTypedValue = BfTypedValue(mModule->mBfIRBuilder->GetFakeVal(), methodRefType);
+				}
 			}
 			}
 			else
 			else
-				argTypedValue = BfTypedValue(BfTypedValueKind_UntypedValue);				
+				argTypedValue = BfTypedValue(BfTypedValueKind_UntypedValue);
 		}
 		}
 	}
 	}
 	else if ((resolvedArg.mArgFlags & BfArgFlag_LambdaBindAttempt) != 0)
 	else if ((resolvedArg.mArgFlags & BfArgFlag_LambdaBindAttempt) != 0)
@@ -5728,7 +5735,7 @@ BfTypedValue BfExprEvaluator::CreateCall(BfAstNode* targetSrc, const BfTypedValu
 
 
 				BfError* error;
 				BfError* error;
 				if ((prevBindResult.mPrevVal != NULL) && (prevBindResult.mPrevVal->mBindType != NULL))
 				if ((prevBindResult.mPrevVal != NULL) && (prevBindResult.mPrevVal->mBindType != NULL))
- 					error = mModule->Fail(StrFormat("Method '%s' has too many parameters to bind to '%s'.", mModule->MethodToString(methodInstance).c_str(), mModule->TypeToString(prevBindResult.mPrevVal->mBindType).c_str()), errorRef);
+ 					error = mModule->Fail(StrFormat("Method '%s' has too few parameters to bind to '%s'.", mModule->MethodToString(methodInstance).c_str(), mModule->TypeToString(prevBindResult.mPrevVal->mBindType).c_str()), errorRef);
 				else
 				else
 					error = mModule->Fail(StrFormat("Too many arguments, expected %d fewer.", (int)argValues.size() - argExprIdx), errorRef);
 					error = mModule->Fail(StrFormat("Too many arguments, expected %d fewer.", (int)argValues.size() - argExprIdx), errorRef);
 				if ((error != NULL) && (methodInstance->mMethodDef->mMethodDeclaration != NULL))
 				if ((error != NULL) && (methodInstance->mMethodDef->mMethodDeclaration != NULL))
@@ -10295,7 +10302,8 @@ void BfExprEvaluator::Visit(BfDelegateBindExpression* delegateBindExpr)
 
 
 		if (isDirectFunction)
 		if (isDirectFunction)
 		{
 		{
-			if (delegateTypeInstance->IsFunction())
+			//if ((delegateTypeInstance != NULL) && (delegateTypeInstance->IsFunction()))
+			if (mExpectingType->IsFunction())
 			{
 			{
 				auto intPtrVal = mModule->mBfIRBuilder->CreatePtrToInt(bindResult.mFunc, BfTypeCode_IntPtr);
 				auto intPtrVal = mModule->mBfIRBuilder->CreatePtrToInt(bindResult.mFunc, BfTypeCode_IntPtr);
 				mResult = BfTypedValue(intPtrVal, mExpectingType);
 				mResult = BfTypedValue(intPtrVal, mExpectingType);
@@ -10311,7 +10319,7 @@ void BfExprEvaluator::Visit(BfDelegateBindExpression* delegateBindExpr)
 				//
 				//
 				{
 				{
 					SetAndRestoreValue<bool> prevIgnore(mModule->mBfIRBuilder->mIgnoreWrites, true);
 					SetAndRestoreValue<bool> prevIgnore(mModule->mBfIRBuilder->mIgnoreWrites, true);
-					result = mModule->CastToFunction(delegateBindExpr->mTarget, bindResult.mMethodInstance, mExpectingType);
+					result = mModule->CastToFunction(delegateBindExpr->mTarget, bindResult.mOrigTarget, bindResult.mMethodInstance, mExpectingType);
 				}
 				}
 
 
 				if (result)
 				if (result)
@@ -10363,7 +10371,7 @@ void BfExprEvaluator::Visit(BfDelegateBindExpression* delegateBindExpr)
 			}
 			}
 			else
 			else
 			{
 			{
-				result = mModule->CastToFunction(delegateBindExpr->mTarget, bindResult.mMethodInstance, mExpectingType);
+				result = mModule->CastToFunction(delegateBindExpr->mTarget, bindResult.mOrigTarget, bindResult.mMethodInstance, mExpectingType);
 			}			
 			}			
 			if (result)
 			if (result)
 				mResult = BfTypedValue(result, mExpectingType);
 				mResult = BfTypedValue(result, mExpectingType);

+ 9 - 2
IDEHelper/Compiler/BfIRBuilder.cpp

@@ -2239,8 +2239,15 @@ void BfIRBuilder::CreateTypeDeclaration(BfType* type, bool forceDbgDefine)
 		name += BfTypeUtils::HashEncode64(methodInstance->mIdHash).c_str();
 		name += BfTypeUtils::HashEncode64(methodInstance->mIdHash).c_str();
 
 
 		if (wantDIData)
 		if (wantDIData)
-		{												
-			auto bfFileInstance = mModule->GetFileFromNode(methodInstance->GetOwner()->mTypeDef->mTypeDeclaration);
+		{	
+			auto typeDeclaration = methodInstance->GetOwner()->mTypeDef->mTypeDeclaration;			
+
+			BfFileInstance* bfFileInstance;
+			if (typeDeclaration != NULL)
+				bfFileInstance = mModule->GetFileFromNode(typeDeclaration);
+			else
+				bfFileInstance = mModule->GetFileFromNode(mModule->mContext->mBfObjectType->mTypeDef->mTypeDeclaration);
+
 			auto namespaceScope = DbgCreateNameSpace(bfFileInstance->mDIFile, "_bf", bfFileInstance->mDIFile, 0);				
 			auto namespaceScope = DbgCreateNameSpace(bfFileInstance->mDIFile, "_bf", bfFileInstance->mDIFile, 0);				
 	
 	
 			StringT<128> mangledName;
 			StringT<128> mangledName;

+ 14 - 4
IDEHelper/Compiler/BfModule.cpp

@@ -2683,7 +2683,7 @@ BfError* BfModule::Fail(const StringImpl& error, BfAstNode* refNode, bool isPers
 
 
 	//BF_ASSERT(refNode != NULL);
 	//BF_ASSERT(refNode != NULL);
 
 
-	if (mCurMethodInstance != NULL)
+ 	if (mCurMethodInstance != NULL)
 		mCurMethodInstance->mHasFailed = true;
 		mCurMethodInstance->mHasFailed = true;
 
 
 	if ((mCurTypeInstance != NULL) && (mCurTypeInstance->IsUnspecializedTypeVariation()))
 	if ((mCurTypeInstance != NULL) && (mCurTypeInstance->IsUnspecializedTypeVariation()))
@@ -7397,13 +7397,23 @@ bool BfModule::CheckGenericConstraints(const BfGenericParamSource& genericParamS
 				convCheckConstraint = ResolveGenericType(convCheckConstraint, NULL, methodGenericArgs);
 				convCheckConstraint = ResolveGenericType(convCheckConstraint, NULL, methodGenericArgs);
 			if (convCheckConstraint == NULL)
 			if (convCheckConstraint == NULL)
 				return false;
 				return false;
-			if ((checkArgType->IsMethodRef()) && (convCheckConstraint->IsDelegate()))
+			if (((checkArgType->IsMethodRef()) || (checkArgType->IsFunction())) && (convCheckConstraint->IsDelegate()))
 			{
 			{
-				auto methodRefType = (BfMethodRefType*)checkArgType;
+				BfMethodInstance* checkMethodInstance;
+				if (checkArgType->IsMethodRef())
+				{
+					auto methodRefType = (BfMethodRefType*)checkArgType;
+					checkMethodInstance = methodRefType->mMethodRef;
+				}
+				else
+				{
+					checkMethodInstance = GetRawMethodInstanceAtIdx(checkArgType->ToTypeInstance(), 0, "Invoke");
+				}
+				
 				auto invokeMethod = GetRawMethodInstanceAtIdx(convCheckConstraint->ToTypeInstance(), 0, "Invoke");
 				auto invokeMethod = GetRawMethodInstanceAtIdx(convCheckConstraint->ToTypeInstance(), 0, "Invoke");
 
 
 				BfExprEvaluator exprEvaluator(this);
 				BfExprEvaluator exprEvaluator(this);
-				if (exprEvaluator.IsExactMethodMatch(methodRefType->mMethodRef, invokeMethod))
+				if (exprEvaluator.IsExactMethodMatch(checkMethodInstance, invokeMethod))
 					constraintMatched = true;
 					constraintMatched = true;
 
 
 			}
 			}

+ 1 - 1
IDEHelper/Compiler/BfModule.h

@@ -1547,7 +1547,7 @@ public:
 	bool CanCast(BfTypedValue typedVal, BfType* toType, BfCastFlags castFlags = BfCastFlags_None);
 	bool CanCast(BfTypedValue typedVal, BfType* toType, BfCastFlags castFlags = BfCastFlags_None);
 	bool AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNeedsMemberCasting);
 	bool AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNeedsMemberCasting);
 	BfTypedValue BoxValue(BfAstNode* srcNode, BfTypedValue typedVal, BfType* toType /*Can be System.Object or interface*/, const BfAllocTarget& allocTarget, bool callDtor = true);
 	BfTypedValue BoxValue(BfAstNode* srcNode, BfTypedValue typedVal, BfType* toType /*Can be System.Object or interface*/, const BfAllocTarget& allocTarget, bool callDtor = true);
-	BfIRValue CastToFunction(BfAstNode* srcNode, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags = BfCastFlags_None);
+	BfIRValue CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags = BfCastFlags_None);
 	BfIRValue CastToValue(BfAstNode* srcNode, BfTypedValue val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfCastResultFlags* resultFlags = NULL);
 	BfIRValue CastToValue(BfAstNode* srcNode, BfTypedValue val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None, BfCastResultFlags* resultFlags = NULL);
 	BfTypedValue Cast(BfAstNode* srcNode, const BfTypedValue& val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None);
 	BfTypedValue Cast(BfAstNode* srcNode, const BfTypedValue& val, BfType* toType, BfCastFlags castFlags = BfCastFlags_None);
 	BfPrimitiveType* GetIntCoercibleType(BfType* type);
 	BfPrimitiveType* GetIntCoercibleType(BfType* type);

+ 8 - 2
IDEHelper/Compiler/BfModuleTypeUtils.cpp

@@ -9650,12 +9650,18 @@ bool BfModule::AreSplatsCompatible(BfType* fromType, BfType* toType, bool* outNe
 	return true;
 	return true;
 }
 }
 
 
-BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags)
+BfIRValue BfModule::CastToFunction(BfAstNode* srcNode, const BfTypedValue& targetValue, BfMethodInstance* methodInstance, BfType* toType, BfCastFlags castFlags)
 {	
 {	
 	auto invokeMethodInstance = GetDelegateInvokeMethod(toType->ToTypeInstance());
 	auto invokeMethodInstance = GetDelegateInvokeMethod(toType->ToTypeInstance());
 
 
 	if (invokeMethodInstance->IsExactMatch(methodInstance, false, true))
 	if (invokeMethodInstance->IsExactMatch(methodInstance, false, true))
 	{
 	{
+		if (methodInstance->GetOwner()->IsFunction())
+		{
+			BF_ASSERT(targetValue);
+			return targetValue.mValue;
+		}
+
 		BfModuleMethodInstance methodRefMethod;
 		BfModuleMethodInstance methodRefMethod;
 		if (methodInstance->mDeclModule == this)
 		if (methodInstance->mDeclModule == this)
 			methodRefMethod = methodInstance;
 			methodRefMethod = methodInstance;
@@ -10087,7 +10093,7 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 	if ((typedVal.mType->IsMethodRef()) && (toType->IsFunction()))
 	if ((typedVal.mType->IsMethodRef()) && (toType->IsFunction()))
 	{
 	{
 		BfMethodInstance* methodInstance = ((BfMethodRefType*)typedVal.mType)->mMethodRef;
 		BfMethodInstance* methodInstance = ((BfMethodRefType*)typedVal.mType)->mMethodRef;
-		auto result = CastToFunction(srcNode, methodInstance, toType, castFlags);
+		auto result = CastToFunction(srcNode, BfTypedValue(), methodInstance, toType, castFlags);
 		if (result)
 		if (result)
 			return result;
 			return result;
 	}
 	}

+ 12 - 0
IDEHelper/Tests/src/FuncRefs.bf

@@ -279,6 +279,11 @@ namespace Tests
 				bind.Dispose();
 				bind.Dispose();
 			}
 			}
 
 
+			public static int StaticMethod(int a)
+			{
+				return a+1000;
+			}
+
 			public void TestDlg() mut
 			public void TestDlg() mut
 			{
 			{
 				int a = 0;
 				int a = 0;
@@ -306,6 +311,13 @@ namespace Tests
 				Test.Assert(mA == 100+300+300 + 300);
 				Test.Assert(mA == 100+300+300 + 300);
 
 
 				bind.Dispose();
 				bind.Dispose();
+
+				Test.Assert(Use(scope => dlg, 10) == 400);
+
+				function int(int num) func = => StaticMethod;
+				Test.Assert(Use(=> StaticMethod, 123) == 1123);
+				Test.Assert(Use(func, 123) == 1123);
+				Test.Assert(Use(=> func, 123) == 1123);
 			}
 			}
 		}
 		}