Browse Source

Fixed 'ref' params for reflected method invocations

Brian Fiete 5 years ago
parent
commit
c826bac949

+ 47 - 5
BeefLibs/corlib/src/Reflection/MethodInfo.bf

@@ -122,7 +122,24 @@ namespace System.Reflection
 				bool isPtrToPtr = false;
 				bool isValid = true;
 
-				if (paramType.IsValueType)
+				bool added =  false;
+
+				if (var refParamType = paramType as RefType)
+				{
+					if (argType.IsPointer)
+					{
+						Type elemType = argType.UnderlyingType;
+						if (elemType != refParamType.UnderlyingType)
+							isValid = false;
+
+						ffiParamList.Add(&FFIType.Pointer);
+						ffiArgList.Add(dataPtr);
+						added = true;
+					}
+					else
+						isValid = false;
+				}
+				else if (paramType.IsValueType)
 				{
 					if (argType.IsPointer)
 					{
@@ -162,7 +179,11 @@ namespace System.Reflection
 						return .Err(.InvalidArgument((.)argIdx));
 				}
 
-				if (paramType.IsStruct)
+				if (added)
+				{
+					// Already handled
+				}
+				else if (paramType.IsStruct)
 				{
 					TypeInstance paramTypeInst = (TypeInstance)paramType;
 
@@ -376,7 +397,24 @@ namespace System.Reflection
 				void* dataPtr = (uint8*)Internal.UnsafeCastToPtr(arg) + argType.[Friend]mMemberDataOffset;
 				bool isValid = true;
 
-				if (paramType.IsValueType)
+				bool added = false;
+
+				if (var refParamType = paramType as RefType)
+				{
+					if (argType.IsBoxedStructPtr || argType.IsBoxedPrimitivePtr)
+					{
+						var elemType = argType.BoxedPtrType;
+						if (elemType != refParamType.UnderlyingType)
+							isValid = false;
+
+						ffiParamList.Add(&FFIType.Pointer);
+						ffiArgList.Add(dataPtr);
+						added = true;
+					}
+					else
+						isValid = false;
+				}
+				else if (paramType.IsValueType)
 				{
 					bool handled = true;
 
@@ -387,7 +425,7 @@ namespace System.Reflection
 					if ((paramType.IsPrimitive) && (underlyingType.IsTypedPrimitive)) // Boxed primitive?
 						underlyingType = underlyingType.UnderlyingType;
 
-					if ((argType.IsBoxedStructPtr) || (argIdx == -1))
+					if (argType.IsBoxedStructPtr || argType.IsBoxedPrimitivePtr)
 					{
 						dataPtr = *(void**)dataPtr;
 						handled = true;
@@ -421,7 +459,11 @@ namespace System.Reflection
 						return .Err(.InvalidArgument((.)argIdx));
 				}
 
-				if (paramType.IsStruct)
+				if (added)
+				{
+					// Already handled
+				}
+				else if (paramType.IsStruct)
 				{
 					TypeInstance paramTypeInst = (TypeInstance)paramType;
 

+ 49 - 0
BeefLibs/corlib/src/Type.bf

@@ -254,6 +254,47 @@ namespace System
 		    }
 		}
 
+		public bool IsBoxedPrimitivePtr
+		{
+			get
+			{
+				if (!mTypeFlags.HasFlag(.Boxed))
+					return false;
+
+				let underyingType = UnderlyingType;
+				if (var genericTypeInstance = underyingType as SpecializedGenericType)
+				{
+					if (genericTypeInstance.UnspecializedType == typeof(Pointer<>))
+						return true;
+				}
+
+				return false;
+			}
+		}
+
+		public Type BoxedPtrType
+		{
+			get
+			{
+				if (!mTypeFlags.HasFlag(.Boxed))
+					return null;
+
+				if (mTypeFlags.HasFlag(.Pointer))
+				{
+					return UnderlyingType;
+				}
+
+				let underyingType = UnderlyingType;
+				if (var genericTypeInstance = underyingType as SpecializedGenericType)
+				{
+					if (genericTypeInstance.UnspecializedType == typeof(Pointer<>))
+						return genericTypeInstance.GetGenericArg(0);
+				}
+
+				return null;
+			}
+		}
+
 		public bool IsEnum
 		{
 		    get
@@ -877,6 +918,14 @@ namespace System.Reflection
         TypeId mUnspecializedType;
         TypeId* mResolvedTypeRefs;
 
+		public Type UnspecializedType
+		{
+			get
+			{
+				return Type.[Friend]GetType(mUnspecializedType);
+			}
+		}
+
 		public override int32 GenericParamCount
 		{
 			get

+ 103 - 3
IDEHelper/Tests/src/Reflection.bf

@@ -54,8 +54,11 @@ namespace Tests
 		class ClassA
 		{
 			[AlwaysInclude, AttrC(71, 72)]
-			static float StaticMethodA(int32 a, int32 b, float c)
+			static float StaticMethodA(int32 a, int32 b, float c, ref int32 d, ref StructA sa)
 			{
+				d += a + b;
+				sa.mA += a;
+				sa.mB += b;
 				return a + b + c;
 			}
 
@@ -76,6 +79,24 @@ namespace Tests
 			}
 		}
 
+		[Reflect, AlwaysInclude(IncludeAllMethods=true)]
+		struct StructA
+		{
+			public int mA;
+			public int mB;
+
+			int GetA(int a)
+			{
+				return a + mA * 100;
+			}
+
+			int GetB(int a) mut
+			{
+				mB += a;
+				return a + mA * 100;
+			}
+		}
+
 		class ClassA2 : ClassA
 		{
 			public override int GetA(int32 a)
@@ -136,12 +157,21 @@ namespace Tests
 				switch (methodIdx)
 				{
 				case 0:
+					StructA sa = .() { mA = 1, mB = 2 };
+
 					Test.Assert(methodInfo.Name == "StaticMethodA");
-					var result = methodInfo.Invoke(null, 100, (int32)20, 3.0f).Get();
+					int32 a = 0;
+					var result = methodInfo.Invoke(null, 100, (int32)20, 3.0f, &a, &sa).Get();
+					Test.Assert(a == 120);
+					Test.Assert(sa.mA == 101);
+					Test.Assert(sa.mB == 22);
 					Test.Assert(result.Get<float>() == 123);
 					result.Dispose();
 
-					result = methodInfo.Invoke(.(), .Create(100), .Create((int32)20), .Create(3.0f)).Get();
+					result = methodInfo.Invoke(.(), .Create(100), .Create((int32)20), .Create(3.0f), .Create(&a), .Create(&sa)).Get();
+					Test.Assert(a == 240);
+					Test.Assert(sa.mA == 201);
+					Test.Assert(sa.mB == 42);
 					Test.Assert(result.Get<float>() == 123);
 					result.Dispose();
 
@@ -225,5 +255,75 @@ namespace Tests
 			Test.Assert(attrC.mA == 1);
 			Test.Assert(attrC.mB == 2);
 		}
+
+		[Test]
+		static void TestD()
+		{
+			StructA sa = .() { mA = 12, mB = 23 };
+			var typeInfo = typeof(StructA);
+
+			int methodIdx = 0;
+			for (let methodInfo in typeInfo.GetMethods())
+			{
+				switch (methodIdx)
+				{
+				case 0:
+					Test.Assert(methodInfo.Name == "GetA");
+
+					var result = methodInfo.Invoke(sa, 34).Get();
+					Test.Assert(result.Get<int32>() == 1234);
+					result.Dispose();
+
+					result = methodInfo.Invoke(&sa, 34).Get();
+					Test.Assert(result.Get<int32>() == 1234);
+					result.Dispose();
+
+					Variant saV = .Create(sa);
+					defer saV.Dispose();
+					result = methodInfo.Invoke(saV, .Create(34));
+					Test.Assert(result.Get<int32>() == 1234);
+					result.Dispose();
+
+					result = methodInfo.Invoke(.Create(&sa), .Create(34));
+					Test.Assert(result.Get<int32>() == 1234);
+					result.Dispose();
+				case 1:
+					Test.Assert(methodInfo.Name == "GetB");
+
+					var result = methodInfo.Invoke(sa, 34).Get();
+					Test.Assert(result.Get<int32>() == 1234);
+					Test.Assert(sa.mB == 23);
+					result.Dispose();
+
+					result = methodInfo.Invoke(&sa, 34).Get();
+					Test.Assert(result.Get<int32>() == 1234);
+					Test.Assert(sa.mB == 57);
+					result.Dispose();
+
+					Variant saV = .Create(sa);
+					defer saV.Dispose();
+					result = methodInfo.Invoke(saV, .Create(34));
+					Test.Assert(result.Get<int32>() == 1234);
+					Test.Assert(sa.mB == 57);
+					result.Dispose();
+
+					result = methodInfo.Invoke(.Create(&sa), .Create(34));
+					Test.Assert(result.Get<int32>() == 1234);
+					Test.Assert(sa.mB == 91);
+					result.Dispose();
+
+				case 2:
+					Test.Assert(methodInfo.Name == "__BfCtor");
+				case 3:
+					Test.Assert(methodInfo.Name == "__Equals");
+				case 4:
+					Test.Assert(methodInfo.Name == "__StrictEquals");
+				default:
+					Test.FatalError(); // Shouldn't have any more
+				}
+
+				methodIdx++;
+			}
+		}
 	}
 }