Explorar el Código

Lambda return type inference

Brian Fiete hace 4 años
padre
commit
bb12a4ec20

+ 154 - 38
IDEHelper/Compiler/BfExprEvaluator.cpp

@@ -588,6 +588,32 @@ bool BfGenericInferContext::InferGenericArgument(BfMethodInstance* methodInstanc
 	return true;
 }
 
+void BfGenericInferContext::InferGenericArguments(BfMethodInstance* methodInstance)
+{
+	// Attempt to infer from other generic args
+	for (int srcGenericIdx = 0; srcGenericIdx < (int)mCheckMethodGenericArguments->size(); srcGenericIdx++)
+	{
+		auto& srcGenericArg = (*mCheckMethodGenericArguments)[srcGenericIdx];
+		if (srcGenericArg == NULL)
+			continue;
+
+		auto srcGenericParam = methodInstance->mMethodInfoEx->mGenericParams[srcGenericIdx];
+		for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints)
+		{
+			if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance()))
+			{
+				InferGenericArgument(methodInstance, srcGenericArg, ifaceConstraint, BfIRValue());
+				auto typeInstance = srcGenericArg->ToTypeInstance();
+				if (typeInstance != NULL)
+				{
+					for (auto ifaceEntry : typeInstance->mInterfaces)
+						InferGenericArgument(methodInstance, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue());
+				}
+			}
+		}
+	}
+}
+
 // void BfGenericInferContext::PropogateInference(BfType* resolvedType, BfType* unresovledType)
 // {
 // 	if (!unresovledType->IsUnspecializedTypeVariation())
@@ -1207,7 +1233,7 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
 						if (!resolvedArg.mTypedValue)							
 						{
 							// Resolve for real
-							resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, BfEvalExprFlags_NoCast);
+							resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, checkType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete));
 						}
 						argTypedValue = resolvedArg.mTypedValue;
 					}
@@ -1217,6 +1243,73 @@ BfTypedValue BfMethodMatcher::ResolveArgTypedValue(BfResolvedArg& resolvedArg, B
 				}
 			}
 		}
+		else if ((checkType == NULL) && (origCheckType != NULL) && (origCheckType->IsUnspecializedTypeVariation()) && (genericArgumentsSubstitute != NULL))
+		{
+			BfMethodInstance* methodInstance = mModule->GetRawMethodInstanceAtIdx(origCheckType->ToTypeInstance(), 0, "Invoke");
+			if (methodInstance != NULL)
+			{				
+				if ((methodInstance->mReturnType->IsGenericParam()) && (((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamKind == BfGenericParamKind_Method))
+				{
+					bool isValid = true;
+
+					int returnMethodGenericArgIdx = ((BfGenericParamType*)methodInstance->mReturnType)->mGenericParamIdx;
+					if ((*genericArgumentsSubstitute)[returnMethodGenericArgIdx] != NULL)
+					{
+						isValid = false;
+					}
+
+					if (methodInstance->mParams.size() != (int)lambdaBindExpr->mParams.size())
+						isValid = false;
+
+					for (auto& param : methodInstance->mParams)
+					{
+						if (param.mResolvedType->IsGenericParam())
+						{
+							auto genericParamType = (BfGenericParamType*)param.mResolvedType;
+							if ((genericParamType->mGenericParamKind == BfGenericParamKind_Method) && ((*genericArgumentsSubstitute)[genericParamType->mGenericParamIdx] == NULL))
+							{
+								isValid = false;
+							}
+						}
+					}
+
+					if (isValid)
+					{
+						bool success = false;
+
+						(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = mModule->GetPrimitiveType(BfTypeCode_None);
+						auto tryType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute);
+						if (tryType != NULL)
+						{
+							auto inferredReturnType = mModule->CreateValueFromExpression(lambdaBindExpr, tryType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_InferReturnType | BfEvalExprFlags_NoAutoComplete));
+							if (inferredReturnType.mType != NULL)
+							{								
+								(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = inferredReturnType.mType;
+								
+								if (((flags & BfResolveArgFlag_FromGenericParam) != 0) && (lambdaBindExpr->mNewToken == NULL))
+								{
+									auto resolvedType = mModule->ResolveGenericType(origCheckType, NULL, genericArgumentsSubstitute);
+									if (resolvedType != NULL)
+									{										
+										// Resolve for real
+										resolvedArg.mTypedValue = mModule->CreateValueFromExpression(lambdaBindExpr, resolvedType, (BfEvalExprFlags)(BfEvalExprFlags_NoCast | BfEvalExprFlags_NoAutoComplete));
+										argTypedValue = resolvedArg.mTypedValue;
+									}
+								}
+
+								success = true;
+							}
+						}
+						
+						if (!success)
+						{
+							// Put back
+							(*genericArgumentsSubstitute)[returnMethodGenericArgIdx] = NULL;
+						}
+					}
+				}
+			}
+		}
 	}
 	else if ((resolvedArg.mArgFlags & BfArgFlag_UnqualifiedDotAttempt) != 0)
 	{
@@ -1672,7 +1765,12 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
 						argTypedValue = mTarget;
 				}
 				else
-					argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, BfResolveArgFlag_FromGeneric);
+				{
+					BfResolveArgFlags flags = BfResolveArgFlag_FromGeneric;
+					if (wantType->IsGenericParam())
+						flags = (BfResolveArgFlags)(flags | BfResolveArgFlag_FromGenericParam);
+					argTypedValue = ResolveArgTypedValue(mArguments[argIdx], checkType, genericArgumentsSubstitute, origCheckType, flags);
+				}
 				if (!argTypedValue.IsUntypedValue())
 				{
 					auto type = argTypedValue.mType;
@@ -1709,6 +1807,11 @@ bool BfMethodMatcher::CheckMethod(BfTypeInstance* targetTypeInstance, BfTypeInst
 				goto NoMatch;
 		}
 
+		if (!deferredArgs.IsEmpty())
+		{
+			genericInferContext.InferGenericArguments(methodInstance);
+		}
+
 		while (!deferredArgs.IsEmpty())
 		{						
 			int prevDeferredSize = (int)deferredArgs.size();
@@ -11504,7 +11607,12 @@ void BfExprEvaluator::VisitLambdaBodies(BfAstNode* body, BfFieldDtorDeclaration*
 	if (auto blockBody = BfNodeDynCast<BfBlock>(body))
 		mModule->VisitChild(blockBody);
 	else if (auto bodyExpr = BfNodeDynCast<BfExpression>(body))
-		mModule->CreateValueFromExpression(bodyExpr);
+	{
+		auto result = mModule->CreateValueFromExpression(bodyExpr);
+		if ((result) && (mModule->mCurMethodState->mClosureState != NULL) &&
+			(mModule->mCurMethodState->mClosureState->mReturnTypeInferState == BfReturnTypeInferState_Inferring))
+			mModule->mCurMethodState->mClosureState->mReturnType = result.mType;
+	}
 	
 	while (fieldDtor != NULL)
 	{
@@ -11530,8 +11638,10 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
 		}
 	}
 
+	bool isInferReturnType = (mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0;
+
 	BfLambdaInstance* lambdaInstance = NULL;
-	if (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance))
+	if ((!isInferReturnType) && (rootMethodState->mLambdaCache.TryGetValue(cacheNodeList, &lambdaInstance)))
 		return lambdaInstance;	
 	
 	static int sBindCount = 0;
@@ -11683,7 +11793,11 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
 		return NULL;
 	}
 
-	if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind))
+	if ((lambdaBindExpr->mNewToken == NULL) && (isInferReturnType))
+	{
+		// Method ref, but let this follow infer route
+	}
+	else if ((lambdaBindExpr->mNewToken == NULL) || (isFunctionBind))
 	{
 		if ((lambdaBindExpr->mNewToken != NULL) && (isFunctionBind))
 			mModule->Fail("Binds to functions should do not require allocations.", lambdaBindExpr->mNewToken);
@@ -11951,8 +12065,14 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
 	closureState.mCaptureVisitingBody = true;
 	closureState.mClosureInstanceInfo = closureInstanceInfo;
 	
-	VisitLambdaBodies(lambdaBindExpr->mBody, lambdaBindExpr->mDtor);
+	if ((mBfEvalExprFlags & BfEvalExprFlags_InferReturnType) != 0)
+	{
+		closureState.mReturnType = NULL;
+		closureState.mReturnTypeInferState = BfReturnTypeInferState_Inferring;
+	}
 
+	VisitLambdaBodies(lambdaBindExpr->mBody, lambdaBindExpr->mDtor);
+	
 	if (hasExplicitCaptureNames)	
 		_SetNotCapturedFlag(false);
 
@@ -11964,12 +12084,32 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
 			prevClosureState->mCaptureStartAccessId = closureState.mCaptureStartAccessId;
 	}
 
-	if (mModule->mCurMethodInstance->mIsUnspecialized)
+	bool earlyExit = false;
+	if (isInferReturnType)
+	{
+		if ((closureState.mReturnTypeInferState == BfReturnTypeInferState_Fail) ||
+			(closureState.mReturnType == NULL))
+		{
+			mResult = BfTypedValue();
+		}
+		else
+		{
+			mResult = BfTypedValue(closureState.mReturnType);
+		}
+		
+		earlyExit = true;
+	}
+	else if (mModule->mCurMethodInstance->mIsUnspecialized)
+	{
+		earlyExit = true;
+		mResult = mModule->GetDefaultTypedValue(delegateTypeInstance);
+	}
+
+	if (earlyExit)
 	{
 		prevIgnoreWrites.Restore();
 		mModule->mBfIRBuilder->RestoreDebugLocation();
-
-		mResult = mModule->GetDefaultTypedValue(delegateTypeInstance);
+		
 		mModule->mBfIRBuilder->SetActiveFunction(prevActiveFunction);
 		if (!prevInsertBlock.IsFake())
 			mModule->mBfIRBuilder->SetInsertPoint(prevInsertBlock);
@@ -11981,7 +12121,7 @@ BfLambdaInstance* BfExprEvaluator::GetLambdaInstance(BfLambdaBindExpression* lam
 	closureState.mCaptureVisitingBody = false;
 	
 	prevIgnoreWrites.Restore();
-	mModule->mBfIRBuilder->RestoreDebugLocation();
+	mModule->mBfIRBuilder->RestoreDebugLocation();	
 
 	auto _GetCaptureType = [&](const StringImpl& str)
 	{
@@ -14189,34 +14329,10 @@ BfModuleMethodInstance BfExprEvaluator::GetSelectedMethod(BfAstNode* targetSrc,
 
 			if (genericArg == NULL)
 			{
-				// Attempt to infer from other generic args
-				for (int srcGenericIdx = 0; srcGenericIdx < (int)methodMatcher.mBestMethodGenericArguments.size(); srcGenericIdx++)
-				{
-					auto& srcGenericArg = methodMatcher.mBestMethodGenericArguments[srcGenericIdx];
-					if (srcGenericArg == NULL)
-						continue;
-
-					auto srcGenericParam = unspecializedMethod->mMethodInfoEx->mGenericParams[srcGenericIdx];
-					
-					BfGenericInferContext genericInferContext;
-					genericInferContext.mModule = mModule;
-					genericInferContext.mCheckMethodGenericArguments = &methodMatcher.mBestMethodGenericArguments;
-					
-					for (auto ifaceConstraint : srcGenericParam->mInterfaceConstraints)
-					{
-						if ((ifaceConstraint->IsUnspecializedTypeVariation()) && (ifaceConstraint->IsGenericTypeInstance()))
-						{	
-							genericInferContext.InferGenericArgument(unspecializedMethod, srcGenericArg, ifaceConstraint, BfIRValue());
-
-							auto typeInstance = srcGenericArg->ToTypeInstance();
-							if (typeInstance != NULL)
-							{
-								for (auto ifaceEntry : typeInstance->mInterfaces)
-									genericInferContext.InferGenericArgument(unspecializedMethod, ifaceEntry.mInterfaceType, ifaceConstraint, BfIRValue());
-							}							
-						}						
-					}					
-				}
+				BfGenericInferContext genericInferContext;
+				genericInferContext.mModule = mModule;
+				genericInferContext.mCheckMethodGenericArguments = &methodMatcher.mBestMethodGenericArguments;
+				genericInferContext.InferGenericArguments(unspecializedMethod);				
 			}
 
 			if (genericArg == NULL)

+ 3 - 1
IDEHelper/Compiler/BfExprEvaluator.h

@@ -37,7 +37,8 @@ enum BfResolveArgsFlags
 enum BfResolveArgFlags
 {
 	BfResolveArgFlag_None = 0,
-	BfResolveArgFlag_FromGeneric = 1
+	BfResolveArgFlag_FromGeneric = 1,
+	BfResolveArgFlag_FromGenericParam = 2
 };
 
 class BfResolvedArg
@@ -141,6 +142,7 @@ public:
 	{
 		return (int)mCheckMethodGenericArguments->size() - mInferredCount;
 	}
+	void InferGenericArguments(BfMethodInstance* methodInstance);
 };
 
 class BfMethodMatcher

+ 2 - 0
IDEHelper/Compiler/BfModule.cpp

@@ -8037,6 +8037,8 @@ BfTypedValue BfModule::CreateValueFromExpression(BfExprEvaluator& exprEvaluator,
 
 	if (!exprEvaluator.mResult)
 	{
+		if ((flags & BfEvalExprFlags_InferReturnType) != 0)
+			return exprEvaluator.mResult;
 		if (!mCompiler->mPassInstance->HasFailed())
 			Fail("INTERNAL ERROR: No expression result returned but no error caught in expression evaluator", expr);
 		return BfTypedValue();

+ 11 - 0
IDEHelper/Compiler/BfModule.h

@@ -74,6 +74,8 @@ enum BfEvalExprFlags
 	BfEvalExprFlags_NoLookupError = 0x40000,
 	BfEvalExprFlags_Comptime = 0x80000,
 	BfEvalExprFlags_InCascade = 0x100000,
+	BfEvalExprFlags_InferReturnType = 0x200000,
+	BfEvalExprFlags_WasMethodRef = 0x400000
 };
 
 enum BfCastFlags
@@ -652,6 +654,13 @@ public:
 	Array<BfMixinRecord> mMixinStateRecords;
 };
 
+enum BfReturnTypeInferState
+{
+	BfReturnTypeInferState_None,
+	BfReturnTypeInferState_Inferring,
+	BfReturnTypeInferState_Fail,
+};
+
 class BfClosureState
 {
 public:
@@ -661,6 +670,7 @@ public:
 	// When we need to look into another local method to determine captures, but we don't want to process local variable declarations or cause infinite recursion
 	bool mBlindCapturing;
 	bool mDeclaringMethodIsMutating;	
+	BfReturnTypeInferState mReturnTypeInferState;
 	BfLocalMethod* mLocalMethod;
 	BfClosureInstanceInfo* mClosureInstanceInfo;
 	BfMethodDef* mClosureMethodDef;
@@ -684,6 +694,7 @@ public:
 		mCaptureStartAccessId = -1;
 		mBlindCapturing = false;
 		mDeclaringMethodIsMutating = false;
+		mReturnTypeInferState = BfReturnTypeInferState_None;
 		mActiveDeferredLocalMethod = NULL;		
 		mReturnType = NULL;
 		mClosureType = NULL;		

+ 37 - 2
IDEHelper/Compiler/BfStmtEvaluator.cpp

@@ -4900,8 +4900,12 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
 	if (mCurMethodInstance->IsMixin())
 		retType = NULL;
 
-	if (mCurMethodState->mClosureState != NULL)	
+	bool inferReturnType = false;
+	if (mCurMethodState->mClosureState != NULL)
+	{
 		retType = mCurMethodState->mClosureState->mReturnType;
+		inferReturnType = (mCurMethodState->mClosureState->mReturnTypeInferState != BfReturnTypeInferState_None);
+	}
 
 	auto checkScope = mCurMethodState->mCurScope;
 	while (checkScope != NULL)
@@ -4931,7 +4935,7 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
 		checkLocalAssignData = checkLocalAssignData->mChainedAssignData;
 	}
 
-	if (retType == NULL)
+	if ((retType == NULL) && (!inferReturnType))
 	{
 		if (returnStmt->mExpression != NULL)
 		{
@@ -4972,7 +4976,38 @@ void BfModule::Visit(BfReturnStatement* returnStmt)
 		exprEvaluator.mReceivingValue = &mCurMethodState->mRetVal;	
 	if (mCurMethodInstance->mMethodDef->mIsReadOnly)
 		exprEvaluator.mAllowReadOnlyReference = true;
+
+	if (inferReturnType)
+		expectingReturnType = NULL;
+
 	auto retValue = CreateValueFromExpression(exprEvaluator, returnStmt->mExpression, expectingReturnType, BfEvalExprFlags_AllowRefExpr, &origType);	
+	
+	if ((retValue) && (inferReturnType))
+	{
+		if (mCurMethodState->mClosureState->mReturnType == NULL)
+			mCurMethodState->mClosureState->mReturnType = retValue.mType;
+		else
+		{
+			if ((retValue.mType == mCurMethodState->mClosureState->mReturnType) ||
+				(CanCast(retValue, mCurMethodState->mClosureState->mReturnType)))
+			{
+				// Leave as-is
+			}
+			else if (CanCast(GetFakeTypedValue(mCurMethodState->mClosureState->mReturnType), retValue.mType))
+			{
+				mCurMethodState->mClosureState->mReturnType = retValue.mType;
+			}
+			else
+			{
+				mCurMethodState->mClosureState->mReturnTypeInferState = BfReturnTypeInferState_Fail;
+			}
+		}
+	}
+	if ((retType == NULL) && (inferReturnType))
+		retType = mCurMethodState->mClosureState->mReturnType;
+	if (retType == NULL)
+		retType = GetPrimitiveType(BfTypeCode_None);
+
 	if ((!mIsComptimeModule) && (mCurMethodInstance->GetStructRetIdx() != -1))
 		alreadyWritten = exprEvaluator.mReceivingValue == NULL;
 	MarkScopeLeft(&mCurMethodState->mHeadScope);

+ 13 - 0
IDEHelper/Tests/src/Generics.bf

@@ -241,6 +241,17 @@ namespace Tests
 			return 0;
 		}
 
+		public static TResult Sum<T, TElem, TDlg, TResult>(this T it, TDlg dlg)
+		    where T: concrete, IEnumerable<TElem>
+		    where TDlg: delegate TResult(TElem)
+		    where TResult: operator TResult + TResult
+		{
+		    var result = default(TResult);
+		    for(var elem in it)
+		        result += dlg(elem);
+		    return result;
+		}
+
 		[Test]
 		public static void TestBasics()
 		{
@@ -293,6 +304,8 @@ namespace Tests
 				} == false);*/
 			Test.Assert(MethodE(floatList) == 6);
 			Test.Assert(MethodF(floatList) == 0);
+
+			Test.Assert(floatList.Sum((x) => x * 2) == 12);
 		}
 	}