浏览代码

Added ability to dynamically cast delegates with compatible signatures

Brian Fiete 5 月之前
父节点
当前提交
37f72cd3b6

+ 6 - 0
BeefLibs/corlib/src/Object.bf

@@ -130,6 +130,12 @@ namespace System
 		{
 		    return null;
 		}
+
+		[NoShow]
+		public virtual Object DynamicCastToSignature(int32 sig)
+		{
+		    return null;
+		}
 #endif
 
         int IHashable.GetHashCode()

+ 6 - 0
IDEHelper/Compiler/BfCompiler.cpp

@@ -5622,6 +5622,12 @@ void BfCompiler::MarkStringPool(BfModule* module)
 		stringPoolEntry.mLastUsedRevision = mRevision;
 	}
 
+	for (int stringId : module->mSignatureIdRefs)
+	{
+		BfStringPoolEntry& stringPoolEntry = module->mContext->mStringObjectIdMap[stringId];
+		stringPoolEntry.mLastUsedRevision = mRevision;
+	}
+
 	/*if (module->mOptModule != NULL)
 		MarkStringPool(module->mOptModule);*/
 	auto altModule = module->mNextAltModule;

+ 21 - 0
IDEHelper/Compiler/BfDefBuilder.cpp

@@ -1398,6 +1398,27 @@ void BfDefBuilder::AddDynamicCastMethods(BfTypeDef* typeDef)
 		methodDef->mReturnTypeRef = typeDef->mSystem->mDirectObjectTypeRef;
 		methodDef->mIsNoReflect = true;
 	}
+
+	if ((typeDef->mIsDelegate) && (!typeDef->mIsClosure))
+	{
+		auto methodDef = new BfMethodDef();
+		methodDef->mIdx = (int)typeDef->mMethods.size();
+		typeDef->mMethods.push_back(methodDef);
+		methodDef->mDeclaringType = typeDef;
+		methodDef->mName = BF_METHODNAME_DYNAMICCAST_SIGNATURE;
+		methodDef->mProtection = BfProtection_Protected;
+		methodDef->mIsStatic = false;
+		methodDef->mMethodType = BfMethodType_Normal;
+		methodDef->mIsVirtual = true;
+		methodDef->mIsOverride = true;
+
+		auto paramDef = new BfParameterDef();
+		paramDef->mName = "sig";
+		paramDef->mTypeRef = typeDef->mSystem->mDirectInt32TypeRef;
+		methodDef->mParams.push_back(paramDef);
+		methodDef->mReturnTypeRef = typeDef->mSystem->mDirectObjectTypeRef;
+		methodDef->mIsNoReflect = true;
+	}
 }
 
 void BfDefBuilder::AddParam(BfMethodDef* methodDef, BfTypeReference* typeRef, const StringImpl& paramName)

+ 144 - 60
IDEHelper/Compiler/BfModule.cpp

@@ -5053,91 +5053,106 @@ void BfModule::CreateDynamicCastMethod()
 	}
 
 	bool isInterfacePass = mCurMethodInstance->mMethodDef->mName == BF_METHODNAME_DYNAMICCAST_INTERFACE;
-
+	bool isSignaturePass = mCurMethodInstance->mMethodDef->mName == BF_METHODNAME_DYNAMICCAST_SIGNATURE;
+	
 	auto func = mCurMethodState->mIRFunction;
 	auto thisValue = mBfIRBuilder->GetArgument(0);
 	auto typeIdValue = mBfIRBuilder->GetArgument(1);
 
-	auto intPtrType = GetPrimitiveType(BfTypeCode_IntPtr);
-	auto int32Type = GetPrimitiveType(BfTypeCode_Int32);
-	typeIdValue = CastToValue(NULL, BfTypedValue(typeIdValue, intPtrType), int32Type, (BfCastFlags)(BfCastFlags_Explicit | BfCastFlags_SilentFail));
-
 	auto thisObject = mBfIRBuilder->CreateBitCast(thisValue, mBfIRBuilder->MapType(objType));
 
 	auto trueBB = mBfIRBuilder->CreateBlock("check.true");
 	//auto falseBB = mBfIRBuilder->CreateBlock("check.false");
 	auto exitBB = mBfIRBuilder->CreateBlock("exit");
 
-	SizedArray<int, 8> typeMatches;
-	SizedArray<BfTypeInstance*, 8> exChecks;
-	FindSubTypes(mCurTypeInstance, &typeMatches, &exChecks, isInterfacePass);
+	Array<BfIRValue> incomingFalses;
+	BfIRBlock curBlock;
 
-	if ((mCurTypeInstance->IsGenericTypeInstance()) && (!mCurTypeInstance->IsUnspecializedType()))
+	if (isSignaturePass)
 	{
-		// Add 'unbound' type id to cast list so things like "List<int> is List<>" work
-		auto genericTypeInst = mCurTypeInstance->mTypeDef;
-		BfTypeVector genericArgs;
-		for (int i = 0; i < (int) genericTypeInst->mGenericParamDefs.size(); i++)
-			genericArgs.push_back(GetGenericParamType(BfGenericParamKind_Type, i));
-		auto unboundType = ResolveTypeDef(mCurTypeInstance->mTypeDef->GetDefinition(), genericArgs, BfPopulateType_Declaration);
-		typeMatches.push_back(unboundType->mTypeId);
-	}
+		//auto falseBB = mBfIRBuilder->CreateBlock("check.false");
+		curBlock = mBfIRBuilder->GetInsertBlock();
 
-	if (mCurTypeInstance->IsBoxed())
+		auto signatureId = GetDelegateSignatureId(mCurTypeInstance);		
+		auto eqResult = mBfIRBuilder->CreateCmpEQ(typeIdValue, GetConstValue32(signatureId));
+		mBfIRBuilder->CreateCondBr(eqResult, trueBB, exitBB);				
+	}
+	else
 	{
-		BfBoxedType* boxedType = (BfBoxedType*)mCurTypeInstance;
-		BfTypeInstance* innerType = boxedType->mElementType->ToTypeInstance();
+		auto intPtrType = GetPrimitiveType(BfTypeCode_IntPtr);
+		auto int32Type = GetPrimitiveType(BfTypeCode_Int32);
+		typeIdValue = CastToValue(NULL, BfTypedValue(typeIdValue, intPtrType), int32Type, (BfCastFlags)(BfCastFlags_Explicit | BfCastFlags_SilentFail));
 
-		FindSubTypes(innerType, &typeMatches, &exChecks, isInterfacePass);
+		SizedArray<int, 8> typeMatches;
+		SizedArray<BfTypeInstance*, 8> exChecks;
+		FindSubTypes(mCurTypeInstance, &typeMatches, &exChecks, isInterfacePass);
 
-		if (innerType->IsTypedPrimitive())
+		if ((mCurTypeInstance->IsGenericTypeInstance()) && (!mCurTypeInstance->IsUnspecializedType()))
 		{
-			auto underlyingType = innerType->GetUnderlyingType();
-			typeMatches.push_back(underlyingType->mTypeId);
+			// Add 'unbound' type id to cast list so things like "List<int> is List<>" work
+			auto genericTypeInst = mCurTypeInstance->mTypeDef;
+			BfTypeVector genericArgs;
+			for (int i = 0; i < (int)genericTypeInst->mGenericParamDefs.size(); i++)
+				genericArgs.push_back(GetGenericParamType(BfGenericParamKind_Type, i));
+			auto unboundType = ResolveTypeDef(mCurTypeInstance->mTypeDef->GetDefinition(), genericArgs, BfPopulateType_Declaration);
+			typeMatches.push_back(unboundType->mTypeId);
 		}
 
-		auto innerTypeInst = innerType->ToTypeInstance();
-		if ((innerTypeInst->IsInstanceOf(mCompiler->mSizedArrayTypeDef)) ||
-			(innerTypeInst->IsInstanceOf(mCompiler->mPointerTTypeDef)) ||
-			(innerTypeInst->IsInstanceOf(mCompiler->mMethodRefTypeDef)))
+		if (mCurTypeInstance->IsBoxed())
 		{
-			PopulateType(innerTypeInst);
-			//TODO: What case was this supposed to handle?
-			//typeMatches.push_back(innerTypeInst->mFieldInstances[0].mResolvedType->mTypeId);
-		}
-	}
+			BfBoxedType* boxedType = (BfBoxedType*)mCurTypeInstance;
+			BfTypeInstance* innerType = boxedType->mElementType->ToTypeInstance();
 
-	auto curBlock = mBfIRBuilder->GetInsertBlock();
+			FindSubTypes(innerType, &typeMatches, &exChecks, isInterfacePass);
 
-	BfIRValue vDataPtr;
-	if (!exChecks.empty())
-	{
-		BfType* intPtrType = GetPrimitiveType(BfTypeCode_IntPtr);
-		auto ptrPtrType = mBfIRBuilder->GetPointerTo(mBfIRBuilder->GetPointerTo(mBfIRBuilder->MapType(intPtrType)));
-		auto vDataPtrPtr = mBfIRBuilder->CreateBitCast(thisValue, ptrPtrType);
-		vDataPtr = FixClassVData(mBfIRBuilder->CreateLoad(vDataPtrPtr/*, "vtable"*/));
-	}
+			if (innerType->IsTypedPrimitive())
+			{
+				auto underlyingType = innerType->GetUnderlyingType();
+				typeMatches.push_back(underlyingType->mTypeId);
+			}
+
+			auto innerTypeInst = innerType->ToTypeInstance();
+			if ((innerTypeInst->IsInstanceOf(mCompiler->mSizedArrayTypeDef)) ||
+				(innerTypeInst->IsInstanceOf(mCompiler->mPointerTTypeDef)) ||
+				(innerTypeInst->IsInstanceOf(mCompiler->mMethodRefTypeDef)))
+			{
+				PopulateType(innerTypeInst);
+				//TODO: What case was this supposed to handle?
+				//typeMatches.push_back(innerTypeInst->mFieldInstances[0].mResolvedType->mTypeId);
+			}
+		}
 
-	auto switchStatement = mBfIRBuilder->CreateSwitch(typeIdValue, exitBB, (int)typeMatches.size() + (int)exChecks.size());
-	for (auto typeMatch : typeMatches)
-		mBfIRBuilder->AddSwitchCase(switchStatement, GetConstValue32(typeMatch), trueBB);
+		curBlock = mBfIRBuilder->GetInsertBlock();
 
-	Array<BfIRValue> incomingFalses;
-	for (auto ifaceTypeInst : exChecks)
-	{
-		BfIRBlock nextBB = mBfIRBuilder->CreateBlock("exCheck", true);
-		mBfIRBuilder->AddSwitchCase(switchStatement, GetConstValue32(ifaceTypeInst->mTypeId), nextBB);
-		mBfIRBuilder->SetInsertPoint(nextBB);
+		BfIRValue vDataPtr;
+		if (!exChecks.empty())
+		{
+			BfType* intPtrType = GetPrimitiveType(BfTypeCode_IntPtr);
+			auto ptrPtrType = mBfIRBuilder->GetPointerTo(mBfIRBuilder->GetPointerTo(mBfIRBuilder->MapType(intPtrType)));
+			auto vDataPtrPtr = mBfIRBuilder->CreateBitCast(thisValue, ptrPtrType);
+			vDataPtr = FixClassVData(mBfIRBuilder->CreateLoad(vDataPtrPtr/*, "vtable"*/));
+		}
+
+		auto switchStatement = mBfIRBuilder->CreateSwitch(typeIdValue, exitBB, (int)typeMatches.size() + (int)exChecks.size());
+		for (auto typeMatch : typeMatches)
+			mBfIRBuilder->AddSwitchCase(switchStatement, GetConstValue32(typeMatch), trueBB);
+		
+		for (auto ifaceTypeInst : exChecks)
+		{
+			BfIRBlock nextBB = mBfIRBuilder->CreateBlock("exCheck", true);
+			mBfIRBuilder->AddSwitchCase(switchStatement, GetConstValue32(ifaceTypeInst->mTypeId), nextBB);
+			mBfIRBuilder->SetInsertPoint(nextBB);
 
-		BfIRValue slotOfs = GetInterfaceSlotNum(ifaceTypeInst);
+			BfIRValue slotOfs = GetInterfaceSlotNum(ifaceTypeInst);
 
-		auto ifacePtrPtr = mBfIRBuilder->CreateInBoundsGEP(vDataPtr, slotOfs/*, "iface"*/);
-		auto ifacePtr = mBfIRBuilder->CreateLoad(ifacePtrPtr);
+			auto ifacePtrPtr = mBfIRBuilder->CreateInBoundsGEP(vDataPtr, slotOfs/*, "iface"*/);
+			auto ifacePtr = mBfIRBuilder->CreateLoad(ifacePtrPtr);
 
-		auto cmpResult = mBfIRBuilder->CreateCmpNE(ifacePtr, mBfIRBuilder->CreateConst(BfTypeCode_IntPtr, 0));
-		mBfIRBuilder->CreateCondBr(cmpResult, trueBB, exitBB);
+			auto cmpResult = mBfIRBuilder->CreateCmpNE(ifacePtr, mBfIRBuilder->CreateConst(BfTypeCode_IntPtr, 0));
+			mBfIRBuilder->CreateCondBr(cmpResult, trueBB, exitBB);
 
-		incomingFalses.push_back(nextBB);
+			incomingFalses.push_back(nextBB);
+		}
 	}
 
 	mBfIRBuilder->AddBlock(trueBB);
@@ -10947,8 +10962,25 @@ void BfModule::EmitDynamicCastCheck(const BfTypedValue& targetValue, BfType* tar
 
 	auto typeTypeInstance = ResolveTypeDef(mCompiler->mReflectTypeInstanceTypeDef)->ToTypeInstance();
 
-	if (mCompiler->mOptions.mAllowHotSwapping)
+	if (targetType->IsDelegate())
 	{
+		// Delegate signature check
+		int signatureId = GetDelegateSignatureId(targetType->ToTypeInstance());		
+		BfExprEvaluator exprEvaluator(this);
+
+		AddBasicBlock(checkBB);
+		auto objectParam = mBfIRBuilder->CreateBitCast(targetValue.mValue, mBfIRBuilder->MapType(mContext->mBfObjectType));
+		auto moduleMethodInstance = GetMethodByName(mContext->mBfObjectType, "DynamicCastToSignature");
+		SizedArray<BfIRValue, 4> irArgs;
+		irArgs.push_back(objectParam);
+		irArgs.push_back(GetConstValue32(signatureId));
+		auto callResult = exprEvaluator.CreateCall(NULL, moduleMethodInstance.mMethodInstance, moduleMethodInstance.mFunc, false, irArgs);
+		auto cmpResult = mBfIRBuilder->CreateCmpNE(callResult.mValue, GetDefaultValue(callResult.mType));
+		irb->CreateCondBr(cmpResult, trueBlock, falseBlock);
+	}
+	else if (mCompiler->mOptions.mAllowHotSwapping)
+	{
+		// "Slow" check
 		BfExprEvaluator exprEvaluator(this);
 
 		AddBasicBlock(checkBB);
@@ -10967,7 +10999,7 @@ void BfModule::EmitDynamicCastCheck(const BfTypedValue& targetValue, BfType* tar
 		BfIRValue vDataPtr = irb->CreateBitCast(targetValue.mValue, irb->MapType(intPtrType));
 		vDataPtr = irb->CreateLoad(vDataPtr);
 		if ((mCompiler->mOptions.mObjectHasDebugFlags) && (!mIsComptimeModule))
-			vDataPtr = irb->CreateAnd(vDataPtr, irb->CreateConst(BfTypeCode_IntPtr, (uint64)~0xFFULL));
+			vDataPtr = irb->CreateAnd(vDataPtr, irb->CreateConst(BfTypeCode_IntPtr, (uint64)~0xFFULL));		
 
 		if (targetType->IsInterface())
 		{
@@ -17291,9 +17323,61 @@ BfType* BfModule::GetDelegateReturnType(BfType* delegateType)
 
 BfMethodInstance* BfModule::GetDelegateInvokeMethod(BfTypeInstance* typeInstance)
 {
+	if (typeInstance->IsClosure())
+		typeInstance = typeInstance->mBaseType;
 	return GetRawMethodInstanceAtIdx(typeInstance, 0, "Invoke");
 }
 
+String BfModule::GetDelegateSignatureString(BfTypeInstance* typeInstance)
+{
+	auto invokeMethod = GetDelegateInvokeMethod(typeInstance);
+	if (invokeMethod == NULL)
+		return "";
+
+	String sigString = "";
+	sigString = TypeToString(invokeMethod->mReturnType);
+	sigString += "(";
+	for (int paramIdx = 0; paramIdx < invokeMethod->GetParamCount(); paramIdx++)
+	{
+		if (paramIdx > 0)
+			sigString += ", ";
+
+		auto paramKind = invokeMethod->GetParamKind(paramIdx);
+
+		if (paramKind == BfParamKind_Params)
+		{
+			sigString += "params ";
+		}
+
+		auto paramType = invokeMethod->GetParamType(paramIdx);
+		sigString += TypeToString(paramType);
+		
+		if (paramKind == BfParamKind_ExplicitThis)
+			sigString += " this";
+	}
+	sigString += ")";
+	return sigString;
+}
+
+int BfModule::GetSignatureId(const StringImpl& str)
+{
+	int strId = mContext->GetStringLiteralId(str);
+	mSignatureIdRefs.Add(strId);
+	return strId;
+}
+
+int BfModule::GetDelegateSignatureId(BfTypeInstance* typeInstance)
+{
+	BF_ASSERT(typeInstance->IsDelegate());
+	if (typeInstance->mTypeInfoEx == NULL)
+	{
+		typeInstance->mTypeInfoEx = new BfTypeInfoEx();
+		auto signature = GetDelegateSignatureString(typeInstance);
+		typeInstance->mTypeInfoEx->mMinValue = GetSignatureId(signature);
+	}
+	return (int)typeInstance->mTypeInfoEx->mMinValue;
+}
+
 void BfModule::CreateDelegateInvokeMethod()
 {
 	// Clear out debug loc - otherwise we'll single step onto the delegate type declaration
@@ -22544,7 +22628,7 @@ void BfModule::ProcessMethod(BfMethodInstance* methodInstance, bool isInlineDup,
 				mCurMethodState->mLeftBlockUncond = true;
 			}
 		}
-		else if ((methodDef->mName == BF_METHODNAME_DYNAMICCAST) || (methodDef->mName == BF_METHODNAME_DYNAMICCAST_INTERFACE))
+		else if ((methodDef->mName == BF_METHODNAME_DYNAMICCAST) || (methodDef->mName == BF_METHODNAME_DYNAMICCAST_INTERFACE) || (methodDef->mName == BF_METHODNAME_DYNAMICCAST_SIGNATURE))
 		{
 			if (mCurTypeInstance->IsObject())
 			{

+ 4 - 0
IDEHelper/Compiler/BfModule.h

@@ -1552,6 +1552,7 @@ public:
 	Dictionary<int, BfIRValue> mStringCharPtrPool;
 	Array<int> mStringPoolRefs;
 	HashSet<int> mUnreifiedStringPoolRefs;
+	HashSet<int> mSignatureIdRefs;
 
 	Array<BfIRBuilder*> mPrevIRBuilders; // Before extensions
 	BfIRBuilder* mBfIRBuilder;
@@ -2019,6 +2020,9 @@ public:
 	void CreateDelegateInvokeMethod();
 	BfType* GetDelegateReturnType(BfType* delegateType);
 	BfMethodInstance* GetDelegateInvokeMethod(BfTypeInstance* typeInstance);
+	String GetDelegateSignatureString(BfTypeInstance* typeInstance);	
+	int GetSignatureId(const StringImpl& str);
+	int GetDelegateSignatureId(BfTypeInstance* typeInstance);
 	String GetLocalMethodName(const StringImpl& baseName, BfAstNode* anchorNode, BfMethodState* declMethodState, BfMixinState* declMixinState);
 	BfMethodDef* GetLocalMethodDef(BfLocalMethod* localMethod);
 	BfModuleMethodInstance GetLocalMethodInstance(BfLocalMethod* localMethod, const BfTypeVector& methodGenericArguments, BfMethodInstance* methodInstance = NULL, bool force = false);

+ 1 - 0
IDEHelper/Compiler/BfSystem.h

@@ -851,6 +851,7 @@ enum BfCallingConvention : uint8
 #define BF_METHODNAME_FIND_TLS_MEMBERS "GCFindTLSMembers"
 #define BF_METHODNAME_DYNAMICCAST "DynamicCastToTypeId"
 #define BF_METHODNAME_DYNAMICCAST_INTERFACE "DynamicCastToInterface"
+#define BF_METHODNAME_DYNAMICCAST_SIGNATURE "DynamicCastToSignature"
 #define BF_METHODNAME_CALCAPPEND "this$calcAppend"
 #define BF_METHODNAME_ENUM_HASFLAG "HasFlag"
 #define BF_METHODNAME_ENUM_GETUNDERLYING "get__Underlying"

+ 15 - 3
IDEHelper/Compiler/CeMachine.cpp

@@ -8110,16 +8110,28 @@ bool CeContext::Execute(CeFunction* startFunction, uint8* startStackPtr, uint8*
 			{
 				CE_CHECKADDR(valueAddr, sizeof(int32));
 
-				auto ifaceType = GetBfType(ifaceId);
+				auto wantType = GetBfType(ifaceId);
 				int32 objTypeId = *(int32*)(memStart + valueAddr);
 				auto valueType = GetBfType(objTypeId);
-				if ((ifaceType == NULL) || (valueType == NULL))
+				if ((wantType == NULL) || (valueType == NULL))
 				{
 					_Fail("Invalid type in CeOp_DynamicCastCheck");
 					return false;
 				}
 
-				if (ceModule->TypeIsSubTypeOf(valueType->ToTypeInstance(), ifaceType->ToTypeInstance(), false))
+				bool matches = false;
+				if (ceModule->TypeIsSubTypeOf(valueType->ToTypeInstance(), wantType->ToTypeInstance(), false))
+				{
+					matches = true;					
+				}
+				else if ((valueType->IsDelegate()) && (wantType->IsDelegate()))
+				{
+					int valueSignatureId = ceModule->GetDelegateSignatureId(valueType->ToTypeInstance());
+					int checkSignatureId = ceModule->GetDelegateSignatureId(wantType->ToTypeInstance());
+					matches = valueSignatureId == checkSignatureId;
+				}
+				
+				if (matches)
 					CeSetAddrVal(&result, valueAddr, ptrSize);
 				else
 					CeSetAddrVal(&result, 0, ptrSize);

+ 8 - 1
IDEHelper/Tests/src/Delegates.bf

@@ -232,9 +232,16 @@ namespace Tests
 
 		public static void TestCasting()
 		{
-			delegate int(int, int) dlg0 = null;
+			delegate int(int, int) dlg0 = scope (a, b) => 1;
 			delegate int(int a, int b) dlg1 = dlg0;
 			delegate int(int a2, int b2) dlg2 = (.)dlg1;
+			delegate int(float a, float b) dlg3 = scope (a, b) => 2;
+
+			Object obj = dlg0;
+			dlg1 = (.)obj;
+			dlg2 = (.)obj;
+			Test.Assert(obj is delegate int(int a, int b));
+			Test.Assert(!(obj is delegate int(float a, float b)));
 
 			function int(int, int) func0 = null;
 			function int(int a, int b) func1 = func0;