Ver Fonte

Fixed generic assignment operators (ie +=)

Brian Fiete há 2 meses atrás
pai
commit
5b18e380a5

+ 27 - 1
IDEHelper/Compiler/BfExprEvaluator.cpp

@@ -21040,6 +21040,31 @@ BfTypedValue BfExprEvaluator::PerformAssignment_CheckOp(BfAssignmentExpression*
 				continue;
 
 			auto paramType = methodInst->GetParamType(0);
+
+			BfModuleMethodInstance moduleMethodInstance;
+
+			if (methodInst->mIsUnspecialized)
+			{
+				BfTypeVector checkMethodGenericArguments;
+				checkMethodGenericArguments.resize(methodInst->GetNumGenericArguments());
+
+				BfGenericInferContext genericInferContext;
+				genericInferContext.mModule = mModule;
+				genericInferContext.mCheckMethodGenericArguments = &checkMethodGenericArguments;
+
+				if (!genericInferContext.InferGenericArgument(methodInst, rightValue.mType, paramType, rightValue.mValue))
+					continue;
+				bool genericsInferred = true;
+				for (int i = 0; i < checkMethodGenericArguments.mSize; i++)
+					if ((checkMethodGenericArguments[i] == NULL) || (checkMethodGenericArguments[i]->IsVar()))
+						genericsInferred = false;
+				if (!genericsInferred)
+					continue;
+
+				moduleMethodInstance = mModule->GetMethodInstance(checkTypeInst, operatorDef, checkMethodGenericArguments);
+				paramType = moduleMethodInstance.mMethodInstance->GetParamType(0);
+			}
+
 			if (deferBinop)
 			{
 				if (argValues.mArguments == NULL)
@@ -21073,7 +21098,8 @@ BfTypedValue BfExprEvaluator::PerformAssignment_CheckOp(BfAssignmentExpression*
 					autoComplete->SetDefinitionLocation(operatorDef->mOperatorDeclaration->mOpTypeToken);
 			}
 
-			auto moduleMethodInstance = mModule->GetMethodInstance(checkTypeInst, operatorDef, BfTypeVector());
+			if (!moduleMethodInstance)
+				moduleMethodInstance = mModule->GetMethodInstance(checkTypeInst, operatorDef, BfTypeVector());
 
 			BfExprEvaluator exprEvaluator(mModule);
 			SizedArray<BfIRValue, 1> args;

+ 3 - 1
IDEHelper/Compiler/BfModule.cpp

@@ -8746,7 +8746,9 @@ bool BfModule::CheckGenericConstraints(const BfGenericParamSource& genericParamS
 	int checkGenericParamFlags = 0;
 	if (checkArgType->IsGenericParam())
 	{
-		BfGenericParamInstance* checkGenericParamInst = GetGenericParamInstance((BfGenericParamType*)checkArgType);
+		BfGenericParamInstance* checkGenericParamInst = GetGenericParamInstance((BfGenericParamType*)checkArgType, false, BfFailHandleKind_Soft);
+		if (checkGenericParamInst == NULL)
+			return false;
 		checkGenericParamFlags = checkGenericParamInst->mGenericParamFlags;
 		if (checkGenericParamInst->mTypeConstraint != NULL)
 			checkArgType = checkGenericParamInst->mTypeConstraint;

+ 6 - 2
IDEHelper/Compiler/BfModuleTypeUtils.cpp

@@ -13786,7 +13786,9 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 			// For these casts, it's just important we get *A* value to work with here,
 			//  as this is just use for unspecialized parsing.  We don't use the generated code
 			{
-				auto genericParamInst = GetGenericParamInstance((BfGenericParamType*)typedVal.mType);
+				auto genericParamInst = GetGenericParamInstance((BfGenericParamType*)typedVal.mType, false, BfFailHandleKind_Soft);
+				if (genericParamInst == NULL)
+					return BfIRValue();
 				retVal = _CheckGenericParamInstance(genericParamInst);
 				if (retVal)
 					return retVal;
@@ -13822,7 +13824,9 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 			}
 		}
 
-		auto genericParamInst = GetGenericParamInstance((BfGenericParamType*)toType);
+		auto genericParamInst = GetGenericParamInstance((BfGenericParamType*)toType, false, BfFailHandleKind_Soft);
+		if (genericParamInst == NULL)
+			return GetDefaultValue(toType);
 		if (genericParamInst->mGenericParamFlags & BfGenericParamFlag_Var)
 			return GetDefaultValue(toType);
 

+ 20 - 0
IDEHelper/Tests/src/Generics.bf

@@ -428,6 +428,21 @@ namespace Tests
 			}
 		}
 
+		struct A<T>
+		{
+			public int mA;
+
+			public static Self operator implicit<U>(U value)
+			{
+				return default;
+			}
+
+			public void operator +=<U>(A<U> r) mut
+			{
+				mA += r.mA;
+			}
+		}
+
 		[Test]
 		public static void TestBasics()
 		{
@@ -521,6 +536,11 @@ namespace Tests
 
 			int iVal = 123;
 			Test.Assert(IntPtrTest(&iVal) == 123);
+
+			A<int> a = .() { mA = 10 };
+			A<float> b = .() { mA = 2 };
+			a += b;
+			Test.Assert(a.mA == 12);
 		}
 	}