Browse Source

Added 'interface' and 'enum' constraints

Brian Fiete 4 years ago
parent
commit
f41365a58e

+ 4 - 4
BeefLibs/Beefy2D/src/utils/StructuredData.bf

@@ -423,7 +423,7 @@ namespace Beefy.utils
 			}
 			}
 		}
 		}
 
 
-		public void Get<T>(StringView name, ref T val) where T : Enum
+		public void Get<T>(StringView name, ref T val) where T : enum
 		{
 		{
 			Object obj = Get(name);
 			Object obj = Get(name);
 			if (obj == null)
 			if (obj == null)
@@ -547,7 +547,7 @@ namespace Beefy.utils
             return (bool)aVal;
             return (bool)aVal;
         }
         }
 
 
-        public T GetEnum<T>(String name, T defaultVal = default(T)) where T : Enum
+        public T GetEnum<T>(String name, T defaultVal = default(T)) where T : enum
         {
         {
             Object obj = Get(name);
             Object obj = Get(name);
 			if (obj == null)
 			if (obj == null)
@@ -566,7 +566,7 @@ namespace Beefy.utils
 			return defaultVal;
 			return defaultVal;
         }
         }
 
 
-		public bool GetEnum<T>(String name, ref T val) where T : Enum
+		public bool GetEnum<T>(String name, ref T val) where T : enum
 		{
 		{
 			Object obj = Get(name);
 			Object obj = Get(name);
 			if (obj == null)
 			if (obj == null)
@@ -614,7 +614,7 @@ namespace Beefy.utils
 			return;
 			return;
         }
         }
 
 
-		public T GetCurEnum<T>(T theDefault = default) where T : Enum
+		public T GetCurEnum<T>(T theDefault = default) where T : enum
 		{
 		{
 			Object obj = GetCurrent();
 			Object obj = GetCurrent();
 			
 			

+ 1 - 1
BeefLibs/corlib/src/Enum.bf

@@ -18,7 +18,7 @@ namespace System
 			((int32)iVal).ToString(strBuffer);
 			((int32)iVal).ToString(strBuffer);
 		}
 		}
 
 
-		public static Result<T> Parse<T>(StringView str, bool ignoreCase = false) where T : Enum
+		public static Result<T> Parse<T>(StringView str, bool ignoreCase = false) where T : enum
 		{
 		{
 			var typeInst = (TypeInstance)typeof(T);
 			var typeInst = (TypeInstance)typeof(T);
 			for (var field in typeInst.GetFields())
 			for (var field in typeInst.GetFields())

+ 12 - 3
IDEHelper/Compiler/BfDefBuilder.cpp

@@ -285,9 +285,10 @@ void BfDefBuilder::ParseGenericParams(BfGenericParamsDeclaration* genericParamsD
 
 
 			if (!name.empty())
 			if (!name.empty())
 			{
 			{
-				if ((name == "class") || (name == "struct") || (name == "struct*") || (name == "const") || (name == "var"))
+				if ((name == "class") || (name == "struct") || (name == "struct*") || (name == "const") || (name == "var") || (name == "interface") || (name == "enum"))
 				{
 				{
-					int prevFlags = constraintDef->mGenericParamFlags & (BfGenericParamFlag_Class | BfGenericParamFlag_Struct | BfGenericParamFlag_StructPtr);
+					int prevFlags = constraintDef->mGenericParamFlags & 
+						(BfGenericParamFlag_Class | BfGenericParamFlag_Struct | BfGenericParamFlag_StructPtr | BfGenericParamFlag_Interface | BfGenericParamFlag_Enum);
 					if (prevFlags != 0)					
 					if (prevFlags != 0)					
 					{
 					{
 						String prevFlagName;
 						String prevFlagName;
@@ -295,8 +296,12 @@ void BfDefBuilder::ParseGenericParams(BfGenericParamsDeclaration* genericParamsD
 							prevFlagName = "class";
 							prevFlagName = "class";
 						else if (prevFlags & BfGenericParamFlag_Struct)
 						else if (prevFlags & BfGenericParamFlag_Struct)
 							prevFlagName = "struct";
 							prevFlagName = "struct";
-						else //
+						else if (prevFlags & BfGenericParamFlag_StructPtr)
 							prevFlagName = "struct*";
 							prevFlagName = "struct*";
+						else if (prevFlags & BfGenericParamFlag_Enum)
+							prevFlagName = "enum";
+						else // interface
+							prevFlagName = "interface";
 
 
 						if (prevFlagName == name)
 						if (prevFlagName == name)
 							Fail(StrFormat("Cannot specify '%s' twice", prevFlagName.c_str()), constraintNode);
 							Fail(StrFormat("Cannot specify '%s' twice", prevFlagName.c_str()), constraintNode);
@@ -313,6 +318,10 @@ void BfDefBuilder::ParseGenericParams(BfGenericParamsDeclaration* genericParamsD
 						constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_StructPtr);
 						constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_StructPtr);
 					else if (name == "const")
 					else if (name == "const")
 						constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Const);
 						constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Const);
+					else if (name == "interface")
+						constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Interface);
+					else if (name == "enum")
+						constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Enum);
 					else //if (name == "var")
 					else //if (name == "var")
 						constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Var);
 						constraintDef->mGenericParamFlags = (BfGenericParamFlags)(constraintDef->mGenericParamFlags | BfGenericParamFlag_Var);
 										
 										

+ 23 - 0
IDEHelper/Compiler/BfModule.cpp

@@ -7247,6 +7247,29 @@ bool BfModule::CheckGenericConstraints(const BfGenericParamSource& genericParamS
 		return false;
 		return false;
 	}
 	}
 
 
+	if (genericParamInst->mGenericParamFlags & BfGenericParamFlag_Enum)
+	{
+		bool isEnum = checkArgType->IsEnum();
+		if ((origCheckArgType->IsGenericParam()) && (checkArgType->IsInstanceOf(mCompiler->mEnumTypeDef)))
+			isEnum = true;
+		if (((checkGenericParamFlags & (BfGenericParamFlag_Enum | BfGenericParamFlag_Var)) == 0) && (!isEnum))
+		{
+			if (!ignoreErrors)
+				*errorOut = Fail(StrFormat("The type '%s' must be an enum type in order to use it as parameter '%s' for '%s'",
+					TypeToString(origCheckArgType).c_str(), genericParamInst->GetName().c_str(), GenericParamSourceToString(genericParamSource).c_str()), checkArgTypeRef);
+			return false;
+		}
+	}
+
+	if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Interface) &&
+		((checkGenericParamFlags & (BfGenericParamFlag_Interface | BfGenericParamFlag_Var)) == 0) && (!checkArgType->IsInterface()))
+	{
+		if (!ignoreErrors)
+			*errorOut = Fail(StrFormat("The type '%s' must be an interface type in order to use it as parameter '%s' for '%s'",
+				TypeToString(origCheckArgType).c_str(), genericParamInst->GetName().c_str(), GenericParamSourceToString(genericParamSource).c_str()), checkArgTypeRef);
+		return false;
+	}
+
 	if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Const) != 0)
 	if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Const) != 0)
 	{
 	{
 		if (((checkGenericParamFlags & BfGenericParamFlag_Const) == 0) && (!checkArgType->IsConstExprValue()))
 		if (((checkGenericParamFlags & BfGenericParamFlag_Const) == 0) && (!checkArgType->IsConstExprValue()))

+ 42 - 5
IDEHelper/Compiler/BfModuleTypeUtils.cpp

@@ -363,6 +363,9 @@ bool BfModule::AreConstraintsSubset(BfGenericParamInstance* checkInner, BfGeneri
 	{
 	{
 		// If the outer had a type flag and the inner has a specific type constraint, then see if those are compatible
 		// If the outer had a type flag and the inner has a specific type constraint, then see if those are compatible
 		auto outerFlags = checkOuter->mGenericParamFlags;
 		auto outerFlags = checkOuter->mGenericParamFlags;
+		if ((outerFlags & BfGenericParamFlag_Enum) != 0)
+			outerFlags |= BfGenericParamFlag_Struct;
+
 		if (checkOuter->mTypeConstraint != NULL)
 		if (checkOuter->mTypeConstraint != NULL)
 		{
 		{
 			if (checkOuter->mTypeConstraint->IsStruct())
 			if (checkOuter->mTypeConstraint->IsStruct())
@@ -371,9 +374,17 @@ bool BfModule::AreConstraintsSubset(BfGenericParamInstance* checkInner, BfGeneri
 				outerFlags |= BfGenericParamFlag_StructPtr;
 				outerFlags |= BfGenericParamFlag_StructPtr;
 			else if (checkOuter->mTypeConstraint->IsObject())
 			else if (checkOuter->mTypeConstraint->IsObject())
 				outerFlags |= BfGenericParamFlag_Class;
 				outerFlags |= BfGenericParamFlag_Class;
+			else if (checkOuter->mTypeConstraint->IsEnum())
+				outerFlags |= BfGenericParamFlag_Enum | BfGenericParamFlag_Struct;
+			else if (checkOuter->mTypeConstraint->IsInterface())
+				outerFlags |= BfGenericParamFlag_Interface;
 		}
 		}
 
 
-		if (((checkInner->mGenericParamFlags | outerFlags) & ~BfGenericParamFlag_Var) != (outerFlags & ~BfGenericParamFlag_Var))
+		auto innerFlags = checkInner->mGenericParamFlags;
+		if ((innerFlags & BfGenericParamFlag_Enum) != 0)
+			innerFlags |= BfGenericParamFlag_Struct;
+
+		if (((innerFlags | outerFlags) & ~BfGenericParamFlag_Var) != (outerFlags & ~BfGenericParamFlag_Var))
 			return false;
 			return false;
 	}
 	}
 
 
@@ -8481,7 +8492,7 @@ BfType* BfModule::ResolveTypeRef(BfTypeReference* typeRef, BfPopulateType popula
 					{
 					{
 						auto genericParam = GetGenericParamInstance((BfGenericParamType*)resolvedType);
 						auto genericParam = GetGenericParamInstance((BfGenericParamType*)resolvedType);
 						if (((genericParam->mTypeConstraint != NULL) && (genericParam->mTypeConstraint->IsValueType())) ||
 						if (((genericParam->mTypeConstraint != NULL) && (genericParam->mTypeConstraint->IsValueType())) ||
-							((genericParam->mGenericParamFlags & (BfGenericParamFlag_Struct | BfGenericParamFlag_StructPtr)) != 0))
+							((genericParam->mGenericParamFlags & (BfGenericParamFlag_Struct | BfGenericParamFlag_StructPtr | BfGenericParamFlag_Enum)) != 0))
 						{
 						{
 							resolvedType = CreatePointerType(resolvedType);
 							resolvedType = CreatePointerType(resolvedType);
 						}
 						}
@@ -9970,7 +9981,7 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 			// Generic constrained with class or pointer type -> void*
 			// Generic constrained with class or pointer type -> void*
 			if (toType->IsVoidPtr())
 			if (toType->IsVoidPtr())
 			{
 			{
-				if ((genericParamInst->mGenericParamFlags & (BfGenericParamFlag_Class | BfGenericParamFlag_StructPtr)) ||
+				if (((genericParamInst->mGenericParamFlags & (BfGenericParamFlag_Class | BfGenericParamFlag_StructPtr | BfGenericParamFlag_Interface)) != 0) ||
 					((genericParamInst->mTypeConstraint != NULL) &&
 					((genericParamInst->mTypeConstraint != NULL) &&
 					((genericParamInst->mTypeConstraint->IsPointer()) || 
 					((genericParamInst->mTypeConstraint->IsPointer()) || 
 						(genericParamInst->mTypeConstraint->IsInstanceOf(mCompiler->mFunctionTypeDef)) || 
 						(genericParamInst->mTypeConstraint->IsInstanceOf(mCompiler->mFunctionTypeDef)) || 
@@ -9980,6 +9991,14 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 				}
 				}
 			}
 			}
 
 
+			if (toType->IsInteger())
+			{
+				if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Enum) != 0)
+				{
+					return mBfIRBuilder->GetFakeVal();
+				}
+			}
+
 			return BfIRValue();
 			return BfIRValue();
 		};
 		};
 
 
@@ -10029,7 +10048,7 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 		
 		
 		if (typedVal.mType->IsNull())
 		if (typedVal.mType->IsNull())
 		{
 		{
-			bool allowCast = (genericParamInst->mGenericParamFlags & BfGenericParamFlag_Class) || (genericParamInst->mGenericParamFlags & BfGenericParamFlag_StructPtr);
+			bool allowCast = (genericParamInst->mGenericParamFlags & (BfGenericParamFlag_Class | BfGenericParamFlag_StructPtr | BfGenericParamFlag_Interface)) != 0;
 			if ((!allowCast) && (genericParamInst->mTypeConstraint != NULL))
 			if ((!allowCast) && (genericParamInst->mTypeConstraint != NULL))
 				allowCast = genericParamInst->mTypeConstraint->IsObject() || genericParamInst->mTypeConstraint->IsPointer();
 				allowCast = genericParamInst->mTypeConstraint->IsObject() || genericParamInst->mTypeConstraint->IsPointer();
 			if (allowCast)
 			if (allowCast)
@@ -10052,7 +10071,7 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 		
 		
 		if (explicitCast)
 		if (explicitCast)
 		{
 		{
-			if ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_StructPtr) ||
+			if (((genericParamInst->mGenericParamFlags & BfGenericParamFlag_StructPtr) != 0) ||
 				((genericParamInst->mTypeConstraint != NULL) && genericParamInst->mTypeConstraint->IsInstanceOf(mCompiler->mFunctionTypeDef)))
 				((genericParamInst->mTypeConstraint != NULL) && genericParamInst->mTypeConstraint->IsInstanceOf(mCompiler->mFunctionTypeDef)))
 			{
 			{
 				auto voidPtrType = CreatePointerType(GetPrimitiveType(BfTypeCode_None));
 				auto voidPtrType = CreatePointerType(GetPrimitiveType(BfTypeCode_None));
@@ -10061,6 +10080,24 @@ BfIRValue BfModule::CastToValue(BfAstNode* srcNode, BfTypedValue typedVal, BfTyp
 					return castedVal;
 					return castedVal;
 			}
 			}
 		}
 		}
+
+		if ((typedVal.mType->IsIntegral()) && ((genericParamInst->mGenericParamFlags & BfGenericParamFlag_Enum) != 0))
+		{
+			bool allowCast = explicitCast;			
+			if ((!allowCast) && (typedVal.mType->IsIntegral()))
+			{
+				// Allow implicit cast of zero
+				auto constant = mBfIRBuilder->GetConstant(typedVal.mValue);
+				if ((constant != NULL) && (mBfIRBuilder->IsInt(constant->mTypeCode)))
+				{
+					allowCast = constant->mInt64 == 0;
+				}
+			}
+			if (allowCast)
+			{
+				return mBfIRBuilder->GetFakeVal();
+			}
+		}
 	}
 	}
 
 
 	if ((typedVal.mType->IsTypeInstance()) && (toType->IsTypeInstance()))
 	if ((typedVal.mType->IsTypeInstance()) && (toType->IsTypeInstance()))

+ 2 - 0
IDEHelper/Compiler/BfReducer.cpp

@@ -9496,6 +9496,8 @@ BfGenericConstraintsDeclaration* BfReducer::CreateGenericConstraintsDeclaration(
 				case BfToken_Var:
 				case BfToken_Var:
 				case BfToken_New:
 				case BfToken_New:
 				case BfToken_Delete:
 				case BfToken_Delete:
+				case BfToken_Enum:
+				case BfToken_Interface:
 					addToConstraint = true;
 					addToConstraint = true;
 					break;
 					break;
 				case BfToken_Operator:
 				case BfToken_Operator:

+ 14 - 12
IDEHelper/Compiler/BfSystem.h

@@ -592,18 +592,20 @@ public:
 
 
 enum BfGenericParamFlags : uint16
 enum BfGenericParamFlags : uint16
 {
 {
-	BfGenericParamFlag_None      = 0,
-	BfGenericParamFlag_Class     = 1,
-	BfGenericParamFlag_Struct    = 2,
-	BfGenericParamFlag_StructPtr = 4,
-	BfGenericParamFlag_New       = 8,
-	BfGenericParamFlag_Delete    = 0x10,
-	BfGenericParamFlag_Var       = 0x20,
-	BfGenericParamFlag_Const     = 0x40,
-	BfGenericParamFlag_Equals    = 0x80,
-	BfGenericParamFlag_Equals_Op    = 0x100,
-	BfGenericParamFlag_Equals_Type  = 0x200,
-	BfGenericParamFlag_Equals_IFace = 0x400
+	BfGenericParamFlag_None			= 0,
+	BfGenericParamFlag_Class		= 1,
+	BfGenericParamFlag_Struct		= 2,
+	BfGenericParamFlag_StructPtr	= 4,
+	BfGenericParamFlag_Enum			= 8,
+	BfGenericParamFlag_Interface	= 0x10,	
+	BfGenericParamFlag_New			= 0x20,
+	BfGenericParamFlag_Delete		= 0x40,
+	BfGenericParamFlag_Var			= 0x80,
+	BfGenericParamFlag_Const		= 0x100,
+	BfGenericParamFlag_Equals		= 0x200,
+	BfGenericParamFlag_Equals_Op    = 0x400,
+	BfGenericParamFlag_Equals_Type  = 0x800,
+	BfGenericParamFlag_Equals_IFace = 0x1000
 };
 };
 
 
 class BfConstraintDef
 class BfConstraintDef

+ 12 - 2
IDEHelper/Tests/src/Generics.bf

@@ -132,16 +132,23 @@ namespace Tests
 			return 1;
 			return 1;
 		}
 		}
 
 
-		public static int MethodA<T>(T val) where T : ValueType
+		public static int MethodA<T>(T val) where T : struct
 		{
 		{
 			return 2;
 			return 2;
 		}
 		}
 
 
-		public static int MethodA<T>(T val) where T : Enum
+		public static int MethodA<T>(T val) where T : enum
 		{
 		{
+			int val2 = (int)val;
+			T val3 = 0;
 			return 3;
 			return 3;
 		}
 		}
 
 
+		public static int MethodA<T>(T val) where T : interface
+		{
+			return 4;
+		}
+
 		public struct Entry
 		public struct Entry
 		{
 		{
 			public static int operator<=>(Entry lhs, Entry rhs)
 			public static int operator<=>(Entry lhs, Entry rhs)
@@ -200,9 +207,12 @@ namespace Tests
 			LibA.LibA0.Alloc<ClassA>();
 			LibA.LibA0.Alloc<ClassA>();
 			LibA.LibA0.Alloc<ClassB>();
 			LibA.LibA0.Alloc<ClassB>();
 
 
+			IDisposable iDisp = null;
+
 			Test.Assert(MethodA("") == 1);
 			Test.Assert(MethodA("") == 1);
 			Test.Assert(MethodA(1.2f) == 2);
 			Test.Assert(MethodA(1.2f) == 2);
 			Test.Assert(MethodA(TypeCode.Boolean) == 3);
 			Test.Assert(MethodA(TypeCode.Boolean) == 3);
+			Test.Assert(MethodA(iDisp) == 4);
 
 
 			ClassC cc = scope .();
 			ClassC cc = scope .();
 			Test.Assert(ClassC.mInstance == cc);
 			Test.Assert(ClassC.mInstance == cc);