Browse Source

Fixed StaticCast and DynamicCast (#1135)

- They don't need to return Ref or RefConst (saves incrementing a refcount in some cases)
- The Ref variants didn't take a const reference, so were most of the time unavailable
- StaticCast doesn't need to use RTTI to check if the types are compatible
Jorrit Rouwe 1 year ago
parent
commit
ff2217ed02

+ 8 - 12
Jolt/Core/RTTI.h

@@ -384,31 +384,27 @@ inline bool IsKindOf(const Ref<Type> &inObject, const RTTI *inRTTI)
 }
 
 /// Cast inObject to DstType, asserts on failure
-template <class DstType, class SrcType>
+template <class DstType, class SrcType, std::enable_if_t<std::is_base_of_v<DstType, SrcType> || std::is_base_of_v<SrcType, DstType>, bool> = true>
 inline const DstType *StaticCast(const SrcType *inObject)
 {
-	JPH_ASSERT(IsKindOf(inObject, JPH_RTTI(DstType)), "Invalid cast");
 	return static_cast<const DstType *>(inObject);
 }
 
-template <class DstType, class SrcType>
+template <class DstType, class SrcType, std::enable_if_t<std::is_base_of_v<DstType, SrcType> || std::is_base_of_v<SrcType, DstType>, bool> = true>
 inline DstType *StaticCast(SrcType *inObject)
 {
-	JPH_ASSERT(IsKindOf(inObject, JPH_RTTI(DstType)), "Invalid cast");
 	return static_cast<DstType *>(inObject);
 }
 
-template <class DstType, class SrcType>
-inline RefConst<DstType> StaticCast(RefConst<SrcType> &inObject)
+template <class DstType, class SrcType, std::enable_if_t<std::is_base_of_v<DstType, SrcType> || std::is_base_of_v<SrcType, DstType>, bool> = true>
+inline const DstType *StaticCast(const RefConst<SrcType> &inObject)
 {
-	JPH_ASSERT(IsKindOf(inObject, JPH_RTTI(DstType)), "Invalid cast");
 	return static_cast<const DstType *>(inObject.GetPtr());
 }
 
-template <class DstType, class SrcType>
-inline Ref<DstType> StaticCast(Ref<SrcType> &inObject)
+template <class DstType, class SrcType, std::enable_if_t<std::is_base_of_v<DstType, SrcType> || std::is_base_of_v<SrcType, DstType>, bool> = true>
+inline DstType *StaticCast(const Ref<SrcType> &inObject)
 {
-	JPH_ASSERT(IsKindOf(inObject, JPH_RTTI(DstType)), "Invalid cast");
 	return static_cast<DstType *>(inObject.GetPtr());
 }
 
@@ -426,13 +422,13 @@ inline DstType *DynamicCast(SrcType *inObject)
 }
 
 template <class DstType, class SrcType>
-inline RefConst<DstType> DynamicCast(RefConst<SrcType> &inObject)
+inline const DstType *DynamicCast(const RefConst<SrcType> &inObject)
 {
 	return inObject != nullptr? reinterpret_cast<const DstType *>(inObject->CastTo(JPH_RTTI(DstType))) : nullptr;
 }
 
 template <class DstType, class SrcType>
-inline Ref<DstType> DynamicCast(Ref<SrcType> &inObject)
+inline DstType *DynamicCast(const Ref<SrcType> &inObject)
 {
 	return inObject != nullptr? const_cast<DstType *>(reinterpret_cast<const DstType *>(inObject->CastTo(JPH_RTTI(DstType)))) : nullptr;
 }

+ 2 - 2
Jolt/Physics/Constraints/GearConstraint.cpp

@@ -106,7 +106,7 @@ bool GearConstraint::SolvePositionConstraint(float inDeltaTime, float inBaumgart
 	float gear1rot;
 	if (mGear1Constraint->GetSubType() == EConstraintSubType::Hinge)
 	{
-		gear1rot = static_cast<const HingeConstraint *>(mGear1Constraint.GetPtr())->GetCurrentAngle();
+		gear1rot = StaticCast<HingeConstraint>(mGear1Constraint)->GetCurrentAngle();
 	}
 	else
 	{
@@ -117,7 +117,7 @@ bool GearConstraint::SolvePositionConstraint(float inDeltaTime, float inBaumgart
 	float gear2rot;
 	if (mGear2Constraint->GetSubType() == EConstraintSubType::Hinge)
 	{
-		gear2rot = static_cast<const HingeConstraint *>(mGear2Constraint.GetPtr())->GetCurrentAngle();
+		gear2rot = StaticCast<HingeConstraint>(mGear2Constraint)->GetCurrentAngle();
 	}
 	else
 	{

+ 2 - 2
Jolt/Physics/Constraints/RackAndPinionConstraint.cpp

@@ -107,7 +107,7 @@ bool RackAndPinionConstraint::SolvePositionConstraint(float inDeltaTime, float i
 	float rotation;
 	if (mPinionConstraint->GetSubType() == EConstraintSubType::Hinge)
 	{
-		rotation = static_cast<const HingeConstraint *>(mPinionConstraint.GetPtr())->GetCurrentAngle();
+		rotation = StaticCast<HingeConstraint>(mPinionConstraint)->GetCurrentAngle();
 	}
 	else
 	{
@@ -118,7 +118,7 @@ bool RackAndPinionConstraint::SolvePositionConstraint(float inDeltaTime, float i
 	float translation;
 	if (mRackConstraint->GetSubType() == EConstraintSubType::Slider)
 	{
-		translation = static_cast<const SliderConstraint *>(mRackConstraint.GetPtr())->GetCurrentPosition();
+		translation = StaticCast<SliderConstraint>(mRackConstraint)->GetCurrentPosition();
 	}
 	else
 	{

+ 2 - 2
Jolt/Physics/PhysicsScene.cpp

@@ -181,7 +181,7 @@ PhysicsScene::PhysicsSceneResult PhysicsScene::sRestoreFromBinaryState(StreamIn
 			result.SetError(c_result.GetError());
 			return result;
 		}
-		cc.mSettings = static_cast<const TwoBodyConstraintSettings *>(c_result.Get().GetPtr());
+		cc.mSettings = StaticCast<TwoBodyConstraintSettings>(c_result.Get());
 		inStream.Read(cc.mBody1);
 		inStream.Read(cc.mBody2);
 	}
@@ -254,7 +254,7 @@ void PhysicsScene::FromPhysicsSystem(const PhysicsSystem *inSystem)
 
 			// Create constraint settings and add the constraint
 			Ref<ConstraintSettings> settings = c->GetConstraintSettings();
-			AddConstraint(static_cast<const TwoBodyConstraintSettings *>(settings.GetPtr()), b1->second, b2->second);
+			AddConstraint(StaticCast<TwoBodyConstraintSettings>(settings), b1->second, b2->second);
 		}
 }
 

+ 2 - 2
Jolt/Physics/Ragdoll/Ragdoll.cpp

@@ -341,7 +341,7 @@ RagdollSettings::RagdollResult RagdollSettings::sRestoreFromBinaryState(StreamIn
 				result.SetError(constraint_result.GetError());
 				return result;
 			}
-			p.mToParent = DynamicCast<TwoBodyConstraintSettings>(constraint_result.Get().GetPtr());
+			p.mToParent = DynamicCast<TwoBodyConstraintSettings>(constraint_result.Get());
 		}
 	}
 
@@ -361,7 +361,7 @@ RagdollSettings::RagdollResult RagdollSettings::sRestoreFromBinaryState(StreamIn
 			result.SetError(constraint_result.GetError());
 			return result;
 		}
-		c.mConstraint = DynamicCast<TwoBodyConstraintSettings>(constraint_result.Get().GetPtr());
+		c.mConstraint = DynamicCast<TwoBodyConstraintSettings>(constraint_result.Get());
 	}
 
 	// Create mapping tables

+ 1 - 1
Jolt/Physics/Vehicle/TrackedVehicleController.h

@@ -38,7 +38,7 @@ public:
 	explicit					WheelTV(const WheelSettingsTV &inWheel);
 
 	/// Override GetSettings and cast to the correct class
-	const WheelSettingsTV *		GetSettings() const							{ return static_cast<const WheelSettingsTV *>(mSettings.GetPtr()); }
+	const WheelSettingsTV *		GetSettings() const							{ return StaticCast<WheelSettingsTV>(mSettings); }
 
 	/// Update the angular velocity of the wheel based on the angular velocity of the track
 	void						CalculateAngularVelocity(const VehicleConstraint &inConstraint);

+ 1 - 1
Jolt/Physics/Vehicle/WheeledVehicleController.h

@@ -47,7 +47,7 @@ public:
 	explicit					WheelWV(const WheelSettingsWV &inWheel);
 
 	/// Override GetSettings and cast to the correct class
-	const WheelSettingsWV *		GetSettings() const							{ return static_cast<const WheelSettingsWV *>(mSettings.GetPtr()); }
+	const WheelSettingsWV *		GetSettings() const							{ return StaticCast<WheelSettingsWV>(mSettings); }
 
 	/// Apply a torque (N m) to the wheel for a particular delta time
 	void						ApplyTorque(float inTorque, float inDeltaTime)

+ 1 - 1
Samples/Tests/Constraints/ConstraintPriorityTest.cpp

@@ -42,7 +42,7 @@ void ConstraintPriorityTest::Initialize()
 			settings.mConstraintPriority = priority == 0? i : num_bodies - i; // Priority is reversed for one chain compared to the other
 			Ref<Constraint> c = settings.Create(*prev, segment);
 			mPhysicsSystem->AddConstraint(c);
-			mConstraints.push_back(static_cast<FixedConstraint *>(c.GetPtr()));
+			mConstraints.push_back(StaticCast<FixedConstraint>(c));
 
 			prev = &segment;
 		}

+ 1 - 1
Samples/Tests/ConvexCollision/ConvexHullShrinkTest.cpp

@@ -132,7 +132,7 @@ void ConvexHullShrinkTest::PrePhysicsUpdate(const PreUpdateParams &inParams)
 		Trace("%d: %s", mIteration - 1, result.GetError().c_str());
 		return;
 	}
-	RefConst<ConvexHullShape> shape = static_cast<const ConvexHullShape *>(result.Get().GetPtr());
+	RefConst<ConvexHullShape> shape = StaticCast<ConvexHullShape>(result.Get());
 
 	// Shape creation may have reduced the convex radius, fetch the result
 	const float convex_radius = shape->GetConvexRadius();

+ 1 - 1
Samples/Tests/Shapes/DeformedHeightFieldShapeTest.cpp

@@ -38,7 +38,7 @@ void DeformedHeightFieldShapeTest::Initialize()
 	settings.mBlockSize = cBlockSize;
 	settings.mBitsPerSample = 8;
 	settings.mMinHeightValue = -15.0f;
-	mHeightField = static_cast<HeightFieldShape *>(settings.Create().Get().GetPtr());
+	mHeightField = StaticCast<HeightFieldShape>(settings.Create().Get());
 	mHeightFieldID = mBodyInterface->CreateAndAddBody(BodyCreationSettings(mHeightField, RVec3::sZero(), Quat::sIdentity(), EMotionType::Static, Layers::NON_MOVING), EActivation::DontActivate);
 
 	// Spheres on top of the terrain

+ 1 - 1
Samples/Tests/Shapes/HeightFieldShapeTest.cpp

@@ -137,7 +137,7 @@ void HeightFieldShapeTest::Initialize()
 	HeightFieldShapeSettings settings(mTerrain.data(), mTerrainOffset, mTerrainScale, mTerrainSize, mMaterialIndices.data(), mMaterials);
 	settings.mBlockSize = 1 << sBlockSizeShift;
 	settings.mBitsPerSample = sBitsPerSample;
-	mHeightField = static_cast<const HeightFieldShape *>(settings.Create().Get().GetPtr());
+	mHeightField = StaticCast<HeightFieldShape>(settings.Create().Get());
 	Body &terrain = *mBodyInterface->CreateBody(BodyCreationSettings(mHeightField, RVec3::sZero(), Quat::sIdentity(), EMotionType::Static, Layers::NON_MOVING));
 	mBodyInterface->AddBody(terrain.GetID(), EActivation::DontActivate);
 

+ 1 - 1
Samples/Tests/Shapes/MutableCompoundShapeTest.cpp

@@ -160,7 +160,7 @@ void MutableCompoundShapeTest::RestoreState(StateRecorder &inStream)
 			stringstream data(str);
 			StreamInWrapper stream_in(data);
 			Shape::ShapeResult result = Shape::sRestoreFromBinaryState(stream_in);
-			MutableCompoundShape *shape = static_cast<MutableCompoundShape *>(result.Get().GetPtr());
+			MutableCompoundShape *shape = StaticCast<MutableCompoundShape>(result.Get());
 
 			// Restore the pointers to the sub compound
 			ShapeList sub_shapes(shape->GetNumSubShapes(), mSubCompound);

+ 1 - 1
UnitTests/Physics/CollideShapeTests.cpp

@@ -296,7 +296,7 @@ TEST_SUITE("CollideShapeTests")
 			Vec3(-132.543304f, 164.551971f, 617.646362f)
 		};
 		ConvexHullShapeSettings hull_settings(obox_points, 0.0f);
-		RefConst<ConvexShape> convex_hull = static_cast<const ConvexShape *>(hull_settings.Create().Get().GetPtr());
+		RefConst<ConvexShape> convex_hull = StaticCast<ConvexShape>(hull_settings.Create().Get());
 
 		// Create triangle support function
 		TriangleConvexSupport triangle(v0, v1, v2);

+ 4 - 4
UnitTests/Physics/HeightFieldShapeTests.cpp

@@ -32,7 +32,7 @@ TEST_SUITE("HeightFieldShapeTests")
 	static Ref<HeightFieldShape> sValidateGetPosition(const HeightFieldShapeSettings &inSettings, float inMaxError)
 	{
 		// Create shape
-		Ref<HeightFieldShape> shape = static_cast<HeightFieldShape *>(inSettings.Create().Get().GetPtr());
+		Ref<HeightFieldShape> shape = StaticCast<HeightFieldShape>(inSettings.Create().Get());
 
 		// Validate it
 		float max_diff = -1.0f;
@@ -223,7 +223,7 @@ TEST_SUITE("HeightFieldShapeTests")
 
 		// Create shape
 		ShapeRefC shape = settings.Create().Get();
-		const HeightFieldShape *height_field = static_cast<const HeightFieldShape *>(shape.GetPtr());
+		const HeightFieldShape *height_field = StaticCast<HeightFieldShape>(shape);
 
 		{
 			// Check that the GetHeights function returns the same values as the original height samples
@@ -276,7 +276,7 @@ TEST_SUITE("HeightFieldShapeTests")
 
 		// Create shape
 		Ref<Shape> shape = settings.Create().Get();
-		HeightFieldShape *height_field = static_cast<HeightFieldShape *>(shape.GetPtr());
+		HeightFieldShape *height_field = StaticCast<HeightFieldShape>(shape);
 
 		// Get the original (quantized) heights
 		Array<float> original_heights;
@@ -353,7 +353,7 @@ TEST_SUITE("HeightFieldShapeTests")
 
 		// Create shape
 		Ref<Shape> shape = settings.Create().Get();
-		HeightFieldShape *height_field = static_cast<HeightFieldShape *>(shape.GetPtr());
+		HeightFieldShape *height_field = StaticCast<HeightFieldShape>(shape);
 
 		// Check that the material is set
 		auto check_materials = [height_field, &current_state]() {

+ 1 - 1
UnitTests/Physics/PhysicsTests.cpp

@@ -1786,7 +1786,7 @@ TEST_SUITE("PhysicsTests")
 		shape_settings->AddShape(Vec3(-5, 0, 0), Quat::sIdentity(), box_shape);
 		shape_settings->AddShape(Vec3(0, 0, 5), Quat::sIdentity(), box_shape);
 		shape_settings->AddShape(Vec3(0, 0, -5), Quat::sIdentity(), box_shape);
-		RefConst<StaticCompoundShape> compound_shape = static_cast<const StaticCompoundShape *>(shape_settings->Create().Get().GetPtr());
+		RefConst<StaticCompoundShape> compound_shape = StaticCast<StaticCompoundShape>(shape_settings->Create().Get());
 		SubShapeID sub_shape_ids[] = {
 			compound_shape->GetSubShapeIDFromIndex(0, SubShapeIDCreator()).GetID(),
 			compound_shape->GetSubShapeIDFromIndex(1, SubShapeIDCreator()).GetID(),

+ 7 - 7
UnitTests/Physics/ShapeTests.cpp

@@ -362,16 +362,16 @@ TEST_SUITE("ShapeTests")
 	{
 		// Create a sphere and check radius
 		SphereShapeSettings sphere_settings(1.0f);
-		RefConst<SphereShape> sphere1 = static_cast<const SphereShape *>(sphere_settings.Create().Get().GetPtr());
+		RefConst<SphereShape> sphere1 = StaticCast<SphereShape>(sphere_settings.Create().Get());
 		CHECK(sphere1->GetRadius() == 1.0f);
 
 		// Modify radius and check that creating the shape again returns the cached result
 		sphere_settings.mRadius = 2.0f;
-		RefConst<SphereShape> sphere2 = static_cast<const SphereShape *>(sphere_settings.Create().Get().GetPtr());
+		RefConst<SphereShape> sphere2 = StaticCast<SphereShape>(sphere_settings.Create().Get());
 		CHECK(sphere2 == sphere1);
 
 		sphere_settings.ClearCachedResult();
-		RefConst<SphereShape> sphere3 = static_cast<const SphereShape *>(sphere_settings.Create().Get().GetPtr());
+		RefConst<SphereShape> sphere3 = StaticCast<SphereShape>(sphere_settings.Create().Get());
 		CHECK(sphere3->GetRadius() == 2.0f);
 	}
 
@@ -599,20 +599,20 @@ TEST_SUITE("ShapeTests")
 		CHECK(sphere->GetType() == EShapeType::Convex);
 		CHECK(sphere->GetSubType() == EShapeSubType::Sphere);
 		CHECK(sphere->GetUserData() == 0x5678123443218765);
-		CHECK(static_cast<SphereShape *>(sphere.GetPtr())->GetRadius() == cRadius);
+		CHECK(StaticCast<SphereShape>(sphere)->GetRadius() == cRadius);
 	}
 
 	// Test setting user data on shapes
 	TEST_CASE("TestIsValidSubShapeID")
 	{
 		MutableCompoundShapeSettings shape1_settings;
-		RefConst<CompoundShape> shape1 = static_cast<const CompoundShape *>(shape1_settings.Create().Get().GetPtr());
+		RefConst<CompoundShape> shape1 = StaticCast<CompoundShape>(shape1_settings.Create().Get());
 
 		MutableCompoundShapeSettings shape2_settings;
 		shape2_settings.AddShape(Vec3::sZero(), Quat::sIdentity(), new SphereShape(1.0f));
 		shape2_settings.AddShape(Vec3::sZero(), Quat::sIdentity(), new SphereShape(1.0f));
 		shape2_settings.AddShape(Vec3::sZero(), Quat::sIdentity(), new SphereShape(1.0f));
-		RefConst<CompoundShape> shape2 = static_cast<const CompoundShape *>(shape2_settings.Create().Get().GetPtr());
+		RefConst<CompoundShape> shape2 = StaticCast<CompoundShape>(shape2_settings.Create().Get());
 
 		// Get sub shape IDs of shape 2 and test if they're valid
 		SubShapeID sub_shape1 = shape2->GetSubShapeIDFromIndex(0, SubShapeIDCreator()).GetID();
@@ -731,7 +731,7 @@ TEST_SUITE("ShapeTests")
 			StreamInWrapper iwrapper(stream);
 			Shape::ShapeResult result = Shape::sRestoreFromBinaryState(iwrapper);
 			CHECK(result.IsValid());
-			RefConst<MeshShape> mesh_shape = static_cast<const MeshShape *>(result.Get().GetPtr());
+			RefConst<MeshShape> mesh_shape = StaticCast<MeshShape>(result.Get());
 
 			// Test if it contains the same amount of triangles
 			Shape::Stats stats = mesh_shape->GetStats();