Przeglądaj źródła

Improved constraint checks where generic param type constraint passes

Brian Fiete 2 miesięcy temu
rodzic
commit
b7725d0ed0

+ 52 - 1
IDEHelper/Compiler/BfExprEvaluator.cpp

@@ -23367,6 +23367,15 @@ void BfExprEvaluator::PerformUnaryOperation_OnResult(BfExpression* unaryOpExpr,
 			mResult = opResult;
 			return;
 		}
+
+		auto typeConstraint = mModule->GetGenericParamInstanceTypeConstraint(mResult.mType);
+		if ((typeConstraint != NULL) && (!typeConstraint->IsGenericParam()))
+		{
+			// Handle cases such as 'where T : float'
+			mResult.mType = typeConstraint;
+			PerformUnaryOperation_OnResult(unaryOpExpr, unaryOp, opToken, opFlags);
+			return;
+		}
 	}
 
 	switch (unaryOp)
@@ -24947,7 +24956,7 @@ void BfExprEvaluator::PerformBinaryOperation(BfAstNode* leftExpression, BfAstNod
 			BfBinaryOp findBinaryOp = binaryOp;
 
 			bool isComparison = (binaryOp >= BfBinaryOp_Equality) && (binaryOp <= BfBinaryOp_LessThanOrEqual);
-
+			
 			for (int pass = 0; pass < 2; pass++)
 			{
 				BfBinaryOp oppositeBinaryOp = BfGetOppositeBinaryOp(findBinaryOp);
@@ -25316,6 +25325,48 @@ void BfExprEvaluator::PerformBinaryOperation(BfAstNode* leftExpression, BfAstNod
 					findBinaryOp = flippedBinaryOp;
 			}
 
+			auto _FixOpCheckGenericParam = [&](BfTypedValue& typedVal)
+				{
+					if ((typedVal.mType != NULL) && (typedVal.mType->IsGenericParam()))
+					{
+						auto genericParamInstance = mModule->GetGenericParamInstance((BfGenericParamType*)typedVal.mType);
+						if (genericParamInstance->mTypeConstraint != NULL)
+						{
+							typedVal.mType = genericParamInstance->mTypeConstraint;
+							return true;
+						}
+					}
+					return false;
+				};
+
+			auto leftTypeConstraint = mModule->GetGenericParamInstanceTypeConstraint(leftValue.mType);
+			auto rightTypeConstraint = mModule->GetGenericParamInstanceTypeConstraint(rightValue.mType);
+			if ((leftTypeConstraint != NULL) || (rightTypeConstraint != NULL))
+			{
+				// Handle cases such as 'where T : float'
+				bool needNewCheck = false;
+
+				BfTypedValue newLeftValue = leftValue;
+				if ((leftTypeConstraint != NULL) && (!leftTypeConstraint->IsGenericParam()))
+				{
+					newLeftValue.mType = leftTypeConstraint;
+					needNewCheck = true;
+				}
+
+				BfTypedValue newRightValue = rightValue;
+				if ((rightTypeConstraint != NULL) && (!rightTypeConstraint->IsGenericParam()))
+				{
+					newRightValue.mType = rightTypeConstraint;
+					needNewCheck = true;
+				}
+
+				if (needNewCheck)
+				{
+					PerformBinaryOperation(leftExpression, rightExpression, binaryOp, opToken, flags, newLeftValue, newRightValue);
+					return;
+				}
+			}
+
 			bool resultHandled = false;
 			if (((origLeftType != NULL) && (origLeftType->IsIntUnknown())) ||
 				((origRightType != NULL) && (origRightType->IsIntUnknown())))

+ 1 - 0
IDEHelper/Compiler/BfModule.h

@@ -1972,6 +1972,7 @@ public:
 	bool IsUnboundGeneric(BfType* type);
 	BfGenericParamInstance* GetGenericTypeParamInstance(int paramIdx, BfFailHandleKind failHandleKind = BfFailHandleKind_Normal);
 	BfGenericParamInstance* GetGenericParamInstance(BfGenericParamType* type, bool checkMixinBind = false, BfFailHandleKind failHandleKind = BfFailHandleKind_Normal);
+	BfType* GetGenericParamInstanceTypeConstraint(BfType* type, bool checkMixinBind = false, BfFailHandleKind failHandleKind = BfFailHandleKind_Normal);
 	void GetActiveTypeGenericParamInstances(SizedArray<BfGenericParamInstance*, 4>& genericParamInstance);
 	BfGenericParamInstance* GetMergedGenericParamData(BfType* type, BfGenericParamFlags& outFlags, BfType*& outTypeConstraint);
 	BfTypeInstance* GetBaseType(BfTypeInstance* typeInst);

+ 10 - 0
IDEHelper/Compiler/BfModuleTypeUtils.cpp

@@ -9861,6 +9861,16 @@ BfGenericParamInstance* BfModule::GetGenericParamInstance(BfGenericParamType* ty
 	return GetGenericTypeParamInstance(type->mGenericParamIdx, failHandleKind);
 }
 
+BfType* BfModule::GetGenericParamInstanceTypeConstraint(BfType* type, bool checkMixinBind, BfFailHandleKind failHandleKind)
+{
+	if (!type->IsGenericParam())
+		return NULL;
+	auto genericParamInstance = GetGenericParamInstance((BfGenericParamType*)type, checkMixinBind, failHandleKind);
+	if (genericParamInstance != NULL)
+		return genericParamInstance->mTypeConstraint;
+	return NULL;
+}
+
 bool BfModule::ResolveTypeResult_Validate(BfAstNode* typeRef, BfType* resolvedTypeRef)
 {
 	if ((typeRef == NULL) || (resolvedTypeRef == NULL))

+ 13 - 0
IDEHelper/Tests/src/Constraints.bf

@@ -7,6 +7,19 @@ namespace Tests
 {
 	class Constraints
 	{
+		struct Vector2<T>
+		{
+			public T mX;
+			public T mY;
+		}
+
+		extension Vector2<T> where T : float
+		{
+			public T LengthSquared => mX * mX + mY * mY;
+		    public T Length => Math.Sqrt(LengthSquared);
+			public T NegX = -mX;
+		}
+
 		class Dicto : Dictionary<int, float>
 		{