Browse Source

Ability to save and restore the simulation in parts (#1282)

* Added StateRecorder::SetIsLastPart which allows you to specify that you are restoring the state from multiple streams, only the last part should have this set to true.
* Added StateRecorderFilter::ShouldRestoreContact function which allows you to skip restoring certain constraints. This can be used when restoring a partial snapshot onto a full snapshot by selectively ignoring contacts from one of the two snapshots.

Fixes #1254
Jorrit Rouwe 10 months ago
parent
commit
5a3935ee69

+ 107 - 72
Jolt/Physics/Constraints/ContactConstraintManager.cpp

@@ -470,7 +470,7 @@ void ContactConstraintManager::ManifoldCache::SaveState(StateRecorder &inStream,
 		inStream.Write(m_kv->GetKey());
 }
 
-bool ContactConstraintManager::ManifoldCache::RestoreState(const ManifoldCache &inReadCache, StateRecorder &inStream)
+bool ContactConstraintManager::ManifoldCache::RestoreState(const ManifoldCache &inReadCache, StateRecorder &inStream, const StateRecorderFilter *inFilter)
 {
 	JPH_ASSERT(!mIsFinalized);
 
@@ -499,72 +499,95 @@ bool ContactConstraintManager::ManifoldCache::RestoreState(const ManifoldCache &
 			body_pair_key = all_bp[i]->GetKey();
 		inStream.Read(body_pair_key);
 
-		// Create new entry for this body pair
-		uint64 body_pair_hash = body_pair_key.GetHash();
-		BPKeyValue *bp_kv = Create(contact_allocator, body_pair_key, body_pair_hash);
-		if (bp_kv == nullptr)
+		// Check if we want to restore this contact
+		if (inFilter == nullptr || inFilter->ShouldRestoreContact(body_pair_key.mBodyA, body_pair_key.mBodyB))
 		{
-			// Out of cache space
-			success = false;
-			break;
-		}
-		CachedBodyPair &bp = bp_kv->GetValue();
-
-		// Read body pair
-		if (inStream.IsValidating() && i < all_bp.size())
-			memcpy(&bp, &all_bp[i]->GetValue(), sizeof(CachedBodyPair));
-		bp.RestoreState(inStream);
-
-		// When validating, get all existing manifolds
-		Array<const MKeyValue *> all_m;
-		if (inStream.IsValidating())
-			inReadCache.GetAllManifoldsSorted(all_bp[i]->GetValue(), all_m);
-
-		// Read amount of manifolds
-		uint32 num_manifolds;
-		if (inStream.IsValidating())
-			num_manifolds = uint32(all_m.size());
-		inStream.Read(num_manifolds);
-
-		uint32 handle = ManifoldMap::cInvalidHandle;
-		for (uint32 j = 0; j < num_manifolds; ++j)
-		{
-			// Read key
-			SubShapeIDPair sub_shape_key;
-			if (inStream.IsValidating() && j < all_m.size())
-				sub_shape_key = all_m[j]->GetKey();
-			inStream.Read(sub_shape_key);
-			uint64 sub_shape_key_hash = sub_shape_key.GetHash();
-
-			// Read amount of contact points
-			uint16 num_contact_points;
-			if (inStream.IsValidating() && j < all_m.size())
-				num_contact_points = all_m[j]->GetValue().mNumContactPoints;
-			inStream.Read(num_contact_points);
-
-			// Read manifold
-			MKeyValue *m_kv = Create(contact_allocator, sub_shape_key, sub_shape_key_hash, num_contact_points);
-			if (m_kv == nullptr)
+			// Create new entry for this body pair
+			uint64 body_pair_hash = body_pair_key.GetHash();
+			BPKeyValue *bp_kv = Create(contact_allocator, body_pair_key, body_pair_hash);
+			if (bp_kv == nullptr)
 			{
 				// Out of cache space
 				success = false;
 				break;
 			}
-			CachedManifold &cm = m_kv->GetValue();
-			if (inStream.IsValidating() && j < all_m.size())
+			CachedBodyPair &bp = bp_kv->GetValue();
+
+			// Read body pair
+			if (inStream.IsValidating() && i < all_bp.size())
+				memcpy(&bp, &all_bp[i]->GetValue(), sizeof(CachedBodyPair));
+			bp.RestoreState(inStream);
+
+			// When validating, get all existing manifolds
+			Array<const MKeyValue *> all_m;
+			if (inStream.IsValidating())
+				inReadCache.GetAllManifoldsSorted(all_bp[i]->GetValue(), all_m);
+
+			// Read amount of manifolds
+			uint32 num_manifolds = 0;
+			if (inStream.IsValidating())
+				num_manifolds = uint32(all_m.size());
+			inStream.Read(num_manifolds);
+
+			uint32 handle = ManifoldMap::cInvalidHandle;
+			for (uint32 j = 0; j < num_manifolds; ++j)
 			{
-				memcpy(&cm, &all_m[j]->GetValue(), CachedManifold::sGetRequiredTotalSize(num_contact_points));
-				cm.mNumContactPoints = uint16(num_contact_points); // Restore num contact points
-			}
-			cm.RestoreState(inStream);
-			cm.mNextWithSameBodyPair = handle;
-			handle = ToHandle(m_kv);
+				// Read key
+				SubShapeIDPair sub_shape_key;
+				if (inStream.IsValidating() && j < all_m.size())
+					sub_shape_key = all_m[j]->GetKey();
+				inStream.Read(sub_shape_key);
+				uint64 sub_shape_key_hash = sub_shape_key.GetHash();
+
+				// Read amount of contact points
+				uint16 num_contact_points = 0;
+				if (inStream.IsValidating() && j < all_m.size())
+					num_contact_points = all_m[j]->GetValue().mNumContactPoints;
+				inStream.Read(num_contact_points);
+
+				// Read manifold
+				MKeyValue *m_kv = Create(contact_allocator, sub_shape_key, sub_shape_key_hash, num_contact_points);
+				if (m_kv == nullptr)
+				{
+					// Out of cache space
+					success = false;
+					break;
+				}
+				CachedManifold &cm = m_kv->GetValue();
+				if (inStream.IsValidating() && j < all_m.size())
+				{
+					memcpy(&cm, &all_m[j]->GetValue(), CachedManifold::sGetRequiredTotalSize(num_contact_points));
+					cm.mNumContactPoints = uint16(num_contact_points); // Restore num contact points
+				}
+				cm.RestoreState(inStream);
+				cm.mNextWithSameBodyPair = handle;
+				handle = ToHandle(m_kv);
 
-			// Read contact points
-			for (uint32 k = 0; k < num_contact_points; ++k)
-				cm.mContactPoints[k].RestoreState(inStream);
+				// Read contact points
+				for (uint32 k = 0; k < num_contact_points; ++k)
+					cm.mContactPoints[k].RestoreState(inStream);
+			}
+			bp.mFirstCachedManifold = handle;
+		}
+		else
+		{
+			// Skip the contact
+			CachedBodyPair bp;
+			bp.RestoreState(inStream);
+			uint32 num_manifolds = 0;
+			inStream.Read(num_manifolds);
+			for (uint32 j = 0; j < num_manifolds; ++j)
+			{
+				SubShapeIDPair sub_shape_key;
+				inStream.Read(sub_shape_key);
+				uint16 num_contact_points;
+				inStream.Read(num_contact_points);
+				CachedManifold cm;
+				cm.RestoreState(inStream);
+				for (uint32 k = 0; k < num_contact_points; ++k)
+					cm.mContactPoints[0].RestoreState(inStream);
+			}
 		}
-		bp.mFirstCachedManifold = handle;
 	}
 
 	// When validating, get all existing CCD manifolds
@@ -585,22 +608,28 @@ bool ContactConstraintManager::ManifoldCache::RestoreState(const ManifoldCache &
 		if (inStream.IsValidating() && j < all_m.size())
 			sub_shape_key = all_m[j]->GetKey();
 		inStream.Read(sub_shape_key);
-		uint64 sub_shape_key_hash = sub_shape_key.GetHash();
 
-		// Create CCD manifold
-		MKeyValue *m_kv = Create(contact_allocator, sub_shape_key, sub_shape_key_hash, 0);
-		if (m_kv == nullptr)
+		// Check if we want to restore this contact
+		if (inFilter == nullptr || inFilter->ShouldRestoreContact(sub_shape_key.GetBody1ID(), sub_shape_key.GetBody2ID()))
 		{
-			// Out of cache space
-			success = false;
-			break;
+			// Create CCD manifold
+			uint64 sub_shape_key_hash = sub_shape_key.GetHash();
+			MKeyValue *m_kv = Create(contact_allocator, sub_shape_key, sub_shape_key_hash, 0);
+			if (m_kv == nullptr)
+			{
+				// Out of cache space
+				success = false;
+				break;
+			}
+			CachedManifold &cm = m_kv->GetValue();
+			cm.mFlags |= (uint16)CachedManifold::EFlags::CCDContact;
 		}
-		CachedManifold &cm = m_kv->GetValue();
-		cm.mFlags |= (uint16)CachedManifold::EFlags::CCDContact;
 	}
 
 #ifdef JPH_ENABLE_ASSERTS
-	mIsFinalized = true;
+	// We don't finalize until the last part is restored
+	if (inStream.IsLastPart())
+		mIsFinalized = true;
 #endif
 
 	return success;
@@ -1707,11 +1736,17 @@ void ContactConstraintManager::SaveState(StateRecorder &inStream, const StateRec
 	mCache[mCacheWriteIdx ^ 1].SaveState(inStream, inFilter);
 }
 
-bool ContactConstraintManager::RestoreState(StateRecorder &inStream)
+bool ContactConstraintManager::RestoreState(StateRecorder &inStream, const StateRecorderFilter *inFilter)
 {
-	bool success = mCache[mCacheWriteIdx].RestoreState(mCache[mCacheWriteIdx ^ 1], inStream);
-	mCacheWriteIdx ^= 1;
-	mCache[mCacheWriteIdx].Clear();
+	bool success = mCache[mCacheWriteIdx].RestoreState(mCache[mCacheWriteIdx ^ 1], inStream, inFilter);
+
+	// If this is the last part, the cache is finalized
+	if (inStream.IsLastPart())
+	{
+		mCacheWriteIdx ^= 1;
+		mCache[mCacheWriteIdx].Clear();
+	}
+
 	return success;
 }
 

+ 2 - 2
Jolt/Physics/Constraints/ContactConstraintManager.h

@@ -251,7 +251,7 @@ public:
 	void						SaveState(StateRecorder &inStream, const StateRecorderFilter *inFilter) const;
 
 	/// Restoring state for replay. Returns false when failed.
-	bool						RestoreState(StateRecorder &inStream);
+	bool						RestoreState(StateRecorder &inStream, const StateRecorderFilter *inFilter);
 
 private:
 	/// Local space contact point, used for caching impulses
@@ -393,7 +393,7 @@ private:
 
 		/// Saving / restoring state for replay
 		void					SaveState(StateRecorder &inStream, const StateRecorderFilter *inFilter) const;
-		bool					RestoreState(const ManifoldCache &inReadCache, StateRecorder &inStream);
+		bool					RestoreState(const ManifoldCache &inReadCache, StateRecorder &inStream, const StateRecorderFilter *inFilter);
 
 	private:
 		/// Block size used when allocating new blocks in the contact cache

+ 11 - 8
Jolt/Physics/PhysicsSystem.cpp

@@ -2657,7 +2657,7 @@ void PhysicsSystem::SaveState(StateRecorder &inStream, EStateRecorderState inSta
 		mConstraintManager.SaveState(inStream, inFilter);
 }
 
-bool PhysicsSystem::RestoreState(StateRecorder &inStream)
+bool PhysicsSystem::RestoreState(StateRecorder &inStream, const StateRecorderFilter *inFilter)
 {
 	JPH_PROFILE_FUNCTION();
 
@@ -2676,17 +2676,20 @@ bool PhysicsSystem::RestoreState(StateRecorder &inStream)
 			return false;
 
 		// Update bounding boxes for all bodies in the broadphase
-		Array<BodyID> bodies;
-		for (const Body *b : mBodyManager.GetBodies())
-			if (BodyManager::sIsValidBodyPointer(b) && b->IsInBroadPhase())
-				bodies.push_back(b->GetID());
-		if (!bodies.empty())
-			mBroadPhase->NotifyBodiesAABBChanged(&bodies[0], (int)bodies.size());
+		if (inStream.IsLastPart())
+		{
+			Array<BodyID> bodies;
+			for (const Body *b : mBodyManager.GetBodies())
+				if (BodyManager::sIsValidBodyPointer(b) && b->IsInBroadPhase())
+					bodies.push_back(b->GetID());
+			if (!bodies.empty())
+				mBroadPhase->NotifyBodiesAABBChanged(&bodies[0], (int)bodies.size());
+		}
 	}
 
 	if (uint8(state) & uint8(EStateRecorderState::Contacts))
 	{
-		if (!mContactManager.RestoreState(inStream))
+		if (!mContactManager.RestoreState(inStream, inFilter))
 			return false;
 	}
 

+ 1 - 1
Jolt/Physics/PhysicsSystem.h

@@ -123,7 +123,7 @@ public:
 	void						SaveState(StateRecorder &inStream, EStateRecorderState inState = EStateRecorderState::All, const StateRecorderFilter *inFilter = nullptr) const;
 
 	/// Restoring state for replay. Returns false if failed.
-	bool						RestoreState(StateRecorder &inStream);
+	bool						RestoreState(StateRecorder &inStream, const StateRecorderFilter *inFilter = nullptr);
 
 	/// Saving state of a single body.
 	void						SaveBodyState(const Body &inBody, StateRecorder &inStream) const;

+ 65 - 0
Jolt/Physics/StateRecorder.h

@@ -24,6 +24,51 @@ enum class EStateRecorderState : uint8
 	All					= Global | Bodies | Contacts | Constraints					///< Save all state
 };
 
+/// Bitwise OR operator for EStateRecorderState
+constexpr EStateRecorderState operator | (EStateRecorderState inLHS, EStateRecorderState inRHS)
+{
+	return EStateRecorderState(uint8(inLHS) | uint8(inRHS));
+}
+
+/// Bitwise AND operator for EStateRecorderState
+constexpr EStateRecorderState operator & (EStateRecorderState inLHS, EStateRecorderState inRHS)
+{
+	return EStateRecorderState(uint8(inLHS) & uint8(inRHS));
+}
+
+/// Bitwise XOR operator for EStateRecorderState
+constexpr EStateRecorderState operator ^ (EStateRecorderState inLHS, EStateRecorderState inRHS)
+{
+	return EStateRecorderState(uint8(inLHS) ^ uint8(inRHS));
+}
+
+/// Bitwise NOT operator for EStateRecorderState
+constexpr EStateRecorderState operator ~ (EStateRecorderState inAllowedDOFs)
+{
+	return EStateRecorderState(~uint8(inAllowedDOFs));
+}
+
+/// Bitwise OR assignment operator for EStateRecorderState
+constexpr EStateRecorderState & operator |= (EStateRecorderState &ioLHS, EStateRecorderState inRHS)
+{
+	ioLHS = ioLHS | inRHS;
+	return ioLHS;
+}
+
+/// Bitwise AND assignment operator for EStateRecorderState
+constexpr EStateRecorderState & operator &= (EStateRecorderState &ioLHS, EStateRecorderState inRHS)
+{
+	ioLHS = ioLHS & inRHS;
+	return ioLHS;
+}
+
+/// Bitwise XOR assignment operator for EStateRecorderState
+constexpr EStateRecorderState & operator ^= (EStateRecorderState &ioLHS, EStateRecorderState inRHS)
+{
+	ioLHS = ioLHS ^ inRHS;
+	return ioLHS;
+}
+
 /// User callbacks that allow determining which parts of the simulation should be saved by a StateRecorder
 class JPH_EXPORT StateRecorderFilter
 {
@@ -31,6 +76,9 @@ public:
 	/// Destructor
 	virtual				~StateRecorderFilter() = default;
 
+	///@name Functions called during SaveState
+	///@{
+
 	/// If the state of a specific body should be saved
 	virtual bool		ShouldSaveBody([[maybe_unused]] const Body &inBody) const					{ return true; }
 
@@ -39,6 +87,15 @@ public:
 
 	/// If the state of a specific contact should be saved
 	virtual bool		ShouldSaveContact([[maybe_unused]] const BodyID &inBody1, [[maybe_unused]] const BodyID &inBody2) const { return true; }
+
+	///@}
+	///@name Functions called during RestoreState
+	///@{
+
+	/// If the state of a specific contact should be restored
+	virtual bool		ShouldRestoreContact([[maybe_unused]] const BodyID &inBody1, [[maybe_unused]] const BodyID &inBody2) const { return true; }
+
+	///@}
 };
 
 /// Class that records the state of a physics system. Can be used to check if the simulation is deterministic by putting the recorder in validation mode.
@@ -59,8 +116,16 @@ public:
 	void				SetValidating(bool inValidating)							{ mIsValidating = inValidating; }
 	bool				IsValidating() const										{ return mIsValidating; }
 
+	/// This allows splitting the state in multiple parts. While restoring, only the last part should have this flag set to true.
+	/// Note that you should ensure that the different parts contain information for disjoint sets of bodies, constraints and contacts.
+	/// E.g. if you restore the same contact twice, you get undefined behavior. In order to create disjoint sets you can use the StateRecorderFilter.
+	/// Note that validation is not compatible with restoring a simulation state in multiple parts.
+	void				SetIsLastPart(bool inIsLastPart)							{ mIsLastPart = inIsLastPart; }
+	bool				IsLastPart() const											{ return mIsLastPart; }
+
 private:
 	bool				mIsValidating = false;
+	bool				mIsLastPart = true;
 };
 
 JPH_NAMESPACE_END

+ 3 - 0
Jolt/Physics/StateRecorderImpl.h

@@ -40,6 +40,9 @@ public:
 	/// Convert the binary data to a string
 	string				GetData() const												{ return mStream.str(); }
 
+	/// Get size of the binary data in bytes
+	size_t				GetDataSize()												{ return size_t(mStream.tellp()); }
+
 private:
 	std::stringstream	mStream;
 };

+ 19 - 11
UnitTests/Layers.h

@@ -17,10 +17,11 @@ namespace Layers
 	static constexpr ObjectLayer UNUSED5 = 4;
 	static constexpr ObjectLayer NON_MOVING = 5;
 	static constexpr ObjectLayer MOVING = 6;
-	static constexpr ObjectLayer HQ_DEBRIS = 7; // High quality debris collides with MOVING and NON_MOVING but not with any debris
-	static constexpr ObjectLayer LQ_DEBRIS = 8; // Low quality debris only collides with NON_MOVING
-	static constexpr ObjectLayer SENSOR = 9; // Sensors only collide with MOVING objects
-	static constexpr ObjectLayer NUM_LAYERS = 10;
+	static constexpr ObjectLayer MOVING2 = 7; // Another moving layer that acts as MOVING but doesn't collide with MOVING
+	static constexpr ObjectLayer HQ_DEBRIS = 8; // High quality debris collides with MOVING and NON_MOVING but not with any debris
+	static constexpr ObjectLayer LQ_DEBRIS = 9; // Low quality debris only collides with NON_MOVING
+	static constexpr ObjectLayer SENSOR = 10; // Sensors only collide with MOVING objects
+	static constexpr ObjectLayer NUM_LAYERS = 11;
 };
 
 /// Class that determines if two object layers can collide
@@ -38,15 +39,17 @@ public:
 		case Layers::UNUSED5:
 			return false;
 		case Layers::NON_MOVING:
-			return inObject2 == Layers::MOVING || inObject2 == Layers::HQ_DEBRIS || inObject2 == Layers::LQ_DEBRIS;
+			return inObject2 == Layers::MOVING || inObject2 == Layers::MOVING2 || inObject2 == Layers::HQ_DEBRIS || inObject2 == Layers::LQ_DEBRIS;
 		case Layers::MOVING:
 			return inObject2 == Layers::NON_MOVING || inObject2 == Layers::MOVING || inObject2 == Layers::HQ_DEBRIS || inObject2 == Layers::SENSOR;
+		case Layers::MOVING2:
+			return inObject2 == Layers::NON_MOVING || inObject2 == Layers::MOVING2 || inObject2 == Layers::HQ_DEBRIS || inObject2 == Layers::SENSOR;
 		case Layers::HQ_DEBRIS:
-			return inObject2 == Layers::NON_MOVING || inObject2 == Layers::MOVING;
+			return inObject2 == Layers::NON_MOVING || inObject2 == Layers::MOVING || inObject2 == Layers::MOVING2;
 		case Layers::LQ_DEBRIS:
 			return inObject2 == Layers::NON_MOVING;
 		case Layers::SENSOR:
-			return inObject2 == Layers::MOVING;
+			return inObject2 == Layers::MOVING || inObject2 == Layers::MOVING2;
 		default:
 			JPH_ASSERT(false);
 			return false;
@@ -59,10 +62,11 @@ namespace BroadPhaseLayers
 {
 	static constexpr BroadPhaseLayer NON_MOVING(0);
 	static constexpr BroadPhaseLayer MOVING(1);
-	static constexpr BroadPhaseLayer LQ_DEBRIS(2);
-	static constexpr BroadPhaseLayer UNUSED(3);
-	static constexpr BroadPhaseLayer SENSOR(4);
-	static constexpr uint NUM_LAYERS(5);
+	static constexpr BroadPhaseLayer MOVING2(2);
+	static constexpr BroadPhaseLayer LQ_DEBRIS(3);
+	static constexpr BroadPhaseLayer UNUSED(4);
+	static constexpr BroadPhaseLayer SENSOR(5);
+	static constexpr uint NUM_LAYERS(6);
 };
 
 /// BroadPhaseLayerInterface implementation
@@ -79,6 +83,7 @@ public:
 		mObjectToBroadPhase[Layers::UNUSED5] = BroadPhaseLayers::UNUSED;
 		mObjectToBroadPhase[Layers::NON_MOVING] = BroadPhaseLayers::NON_MOVING;
 		mObjectToBroadPhase[Layers::MOVING] = BroadPhaseLayers::MOVING;
+		mObjectToBroadPhase[Layers::MOVING2] = BroadPhaseLayers::MOVING2;
 		mObjectToBroadPhase[Layers::HQ_DEBRIS] = BroadPhaseLayers::MOVING; // HQ_DEBRIS is also in the MOVING layer as an example on how to map multiple layers onto the same broadphase layer
 		mObjectToBroadPhase[Layers::LQ_DEBRIS] = BroadPhaseLayers::LQ_DEBRIS;
 		mObjectToBroadPhase[Layers::SENSOR] = BroadPhaseLayers::SENSOR;
@@ -102,6 +107,7 @@ public:
 		{
 		case (BroadPhaseLayer::Type)BroadPhaseLayers::NON_MOVING:	return "NON_MOVING";
 		case (BroadPhaseLayer::Type)BroadPhaseLayers::MOVING:		return "MOVING";
+		case (BroadPhaseLayer::Type)BroadPhaseLayers::MOVING2:		return "MOVING2";
 		case (BroadPhaseLayer::Type)BroadPhaseLayers::LQ_DEBRIS:	return "LQ_DEBRIS";
 		case (BroadPhaseLayer::Type)BroadPhaseLayers::UNUSED:		return "UNUSED";
 		case (BroadPhaseLayer::Type)BroadPhaseLayers::SENSOR:		return "SENSOR";
@@ -127,6 +133,8 @@ public:
 		case Layers::MOVING:
 		case Layers::HQ_DEBRIS:
 			return inLayer2 == BroadPhaseLayers::NON_MOVING || inLayer2 == BroadPhaseLayers::MOVING || inLayer2 == BroadPhaseLayers::SENSOR;
+		case Layers::MOVING2:
+			return inLayer2 == BroadPhaseLayers::NON_MOVING || inLayer2 == BroadPhaseLayers::MOVING2 || inLayer2 == BroadPhaseLayers::SENSOR;
 		case Layers::LQ_DEBRIS:
 			return inLayer2 == BroadPhaseLayers::NON_MOVING;
 		case Layers::SENSOR:

+ 103 - 1
UnitTests/Physics/PhysicsTests.cpp

@@ -1612,7 +1612,7 @@ TEST_SUITE("PhysicsTests")
 			if (mode == 1)
 			{
 				// Don't save the global state
-				state_to_save = EStateRecorderState(uint(EStateRecorderState::All) ^ uint(EStateRecorderState::Global));
+				state_to_save = EStateRecorderState::All ^ EStateRecorderState::Global;
 
 				// Don't save some bodies
 				filter.mIgnoreBodies.push_back(ground.GetID());
@@ -1767,6 +1767,108 @@ TEST_SUITE("PhysicsTests")
 		}
 	}
 
+	TEST_CASE("TestMultiPartRestoreState")
+	{
+		class MyFilter : public StateRecorderFilter
+		{
+		public:
+										MyFilter(const Array<BodyID> &inStoredBodies) : mStoredBodies(inStoredBodies) { }
+
+			bool						ShouldSaveBody(const BodyID &inBodyID) const
+			{
+				return std::find(mStoredBodies.cbegin(), mStoredBodies.cend(), inBodyID) != mStoredBodies.cend();
+			}
+
+			virtual bool				ShouldSaveBody(const Body &inBody) const override
+			{
+				if (ShouldSaveBody(inBody.GetID()))
+				{
+					++mNumBodies;
+					return true;
+				}
+				return false;
+			}
+
+			virtual bool				ShouldSaveContact(const BodyID &inBody1, const BodyID &inBody2) const override
+			{
+				if (ShouldSaveBody(inBody1) || ShouldSaveBody(inBody2))
+				{
+					++mNumContacts;
+					return true;
+				}
+				return false;
+			}
+
+			const Array<BodyID> &		mStoredBodies;
+			mutable int					mNumBodies = 0;
+			mutable int					mNumContacts = 0;
+		};
+
+		PhysicsTestContext c;
+		c.CreateFloor();
+
+		// Create 1st set of moving bodies
+		constexpr int cNumMoving1 = 10;
+		Array<BodyID> moving1;
+		for (int i = 0; i < cNumMoving1; ++i)
+			moving1.push_back(c.CreateSphere(RVec3(0, 2.0f + 2.0f * i, 0.01f * i), 1.0f, EMotionType::Dynamic, EMotionQuality::Discrete, Layers::MOVING, EActivation::Activate).GetID());
+
+		// Create 2nd set of moving bodies, note that although the bodies overlap with the 1st set, they don't collide because of their layer.
+		// We need to create disjoint sets for restoring in parts to work.
+		constexpr int cNumMoving2 = 12;
+		Array<BodyID> moving2;
+		for (int i = 0; i < cNumMoving2; ++i)
+			moving2.push_back(c.CreateSphere(RVec3(1.0f, 2.0f + 2.0f * i, 0.01f * i), 1.0f, EMotionType::Dynamic, EMotionQuality::Discrete, Layers::MOVING2, EActivation::Activate).GetID());
+
+		// Simulate for a short while to get some contacts
+		c.Simulate(2.0f);
+
+		// Save full snapshot
+		StateRecorderImpl initial_state;
+		c.GetSystem()->SaveState(initial_state);
+
+		// Save everything relating to 1st set of bodies
+		MyFilter filter1(moving1);
+		StateRecorderImpl state1;
+		c.GetSystem()->SaveState(state1, EStateRecorderState::All, &filter1);
+		CHECK(filter1.mNumBodies == cNumMoving1);
+		CHECK(filter1.mNumContacts > cNumMoving1 / 2); // Many bodies should be in contact now, if not we're not testing contact restoring
+		CHECK(state1.GetDataSize() < initial_state.GetDataSize()); // Should be smaller than the full state
+
+		// Save everything relating to 2nd set of bodies
+		MyFilter filter2(moving2);
+		StateRecorderImpl state2;
+		c.GetSystem()->SaveState(state2, EStateRecorderState::Bodies | EStateRecorderState::Contacts, &filter2);
+		CHECK(filter2.mNumBodies == cNumMoving2);
+		CHECK(filter2.mNumContacts > cNumMoving2 / 2);
+		CHECK(state2.GetDataSize() < initial_state.GetDataSize());
+
+		// Simulate for 2 seconds
+		c.Simulate(2.0f);
+
+		// Save result
+		StateRecorderImpl final_state;
+		c.GetSystem()->SaveState(final_state);
+
+		// Restore the initial state in parts
+		state1.SetIsLastPart(false);
+		c.GetSystem()->RestoreState(state1);
+		c.GetSystem()->RestoreState(state2);
+
+		// Verify we're back to the first state
+		StateRecorderImpl verify1;
+		c.GetSystem()->SaveState(verify1);
+		CHECK(initial_state.IsEqual(verify1));
+
+		// Simulate for 2 seconds again
+		c.Simulate(2.0f);
+
+		// Check we end up in the final state again
+		StateRecorderImpl verify2;
+		c.GetSystem()->SaveState(verify2);
+		CHECK(final_state.IsEqual(verify2));
+	}
+
 	// This tests that when switching UseManifoldReduction on/off we get the correct contact callbacks
 	TEST_CASE("TestSwitchUseManifoldReduction")
 	{