Преглед на файлове

Added strand based hair system running on GPU (#1869)

* This system is based on Cosserad rods.
* It supports simulation (guide) and render (follow) hairs.
* Hair vs hair collision is handled by accumulating the average velocity in a grid and using those velocities to drive hairs.
* Supports collision with the environment, although it only supports ConvexHull and CompoundShapes at the moment.
Jorrit Rouwe преди 2 седмици
родител
ревизия
4a708eea7b
променени са 62 файла, в които са добавени 5138 реда и са изтрити 3 реда
  1. 1 0
      .gitignore
  2. 8 1
      Assets/LICENSE
  3. 5 0
      Assets/download_hair.sh
  4. 143 0
      Assets/export_face.py
  5. BIN
      Assets/face.bin
  6. 1 1
      Build/cmake_vs2022_cl_32bit.bat
  7. 1 1
      Jolt/Core/StridedPtr.h
  8. 44 0
      Jolt/Jolt.cmake
  9. 1033 0
      Jolt/Physics/Hair/Hair.cpp
  10. 227 0
      Jolt/Physics/Hair/Hair.h
  11. 869 0
      Jolt/Physics/Hair/HairSettings.cpp
  12. 373 0
      Jolt/Physics/Hair/HairSettings.h
  13. 33 0
      Jolt/Physics/Hair/HairShaders.cpp
  14. 37 0
      Jolt/Physics/Hair/HairShaders.h
  15. 42 0
      Jolt/Shaders/HairApplyDeltaTransform.hlsl
  16. 14 0
      Jolt/Shaders/HairApplyDeltaTransformBindings.h
  17. 19 0
      Jolt/Shaders/HairApplyGlobalPose.h
  18. 38 0
      Jolt/Shaders/HairApplyGlobalPose.hlsl
  19. 16 0
      Jolt/Shaders/HairApplyGlobalPoseBindings.h
  20. 114 0
      Jolt/Shaders/HairCalculateCollisionPlanes.hlsl
  21. 13 0
      Jolt/Shaders/HairCalculateCollisionPlanesBindings.h
  22. 16 0
      Jolt/Shaders/HairCalculateRenderPositions.h
  23. 22 0
      Jolt/Shaders/HairCalculateRenderPositions.hlsl
  24. 16 0
      Jolt/Shaders/HairCalculateRenderPositionsBindings.h
  25. 56 0
      Jolt/Shaders/HairCommon.h
  26. 50 0
      Jolt/Shaders/HairGridAccumulate.hlsl
  27. 12 0
      Jolt/Shaders/HairGridAccumulateBindings.h
  28. 17 0
      Jolt/Shaders/HairGridClear.hlsl
  29. 9 0
      Jolt/Shaders/HairGridClearBindings.h
  30. 26 0
      Jolt/Shaders/HairGridNormalize.hlsl
  31. 9 0
      Jolt/Shaders/HairGridNormalizeBindings.h
  32. 88 0
      Jolt/Shaders/HairIntegrate.h
  33. 35 0
      Jolt/Shaders/HairIntegrate.hlsl
  34. 17 0
      Jolt/Shaders/HairIntegrateBindings.h
  35. 50 0
      Jolt/Shaders/HairSkinRoots.hlsl
  36. 23 0
      Jolt/Shaders/HairSkinRootsBindings.h
  37. 26 0
      Jolt/Shaders/HairSkinVertices.hlsl
  38. 12 0
      Jolt/Shaders/HairSkinVerticesBindings.h
  39. 120 0
      Jolt/Shaders/HairStructs.h
  40. 28 0
      Jolt/Shaders/HairTeleport.hlsl
  41. 12 0
      Jolt/Shaders/HairTeleportBindings.h
  42. 30 0
      Jolt/Shaders/HairUpdateRoots.hlsl
  43. 12 0
      Jolt/Shaders/HairUpdateRootsBindings.h
  44. 154 0
      Jolt/Shaders/HairUpdateStrands.hlsl
  45. 17 0
      Jolt/Shaders/HairUpdateStrandsBindings.h
  46. 64 0
      Jolt/Shaders/HairUpdateVelocity.h
  47. 40 0
      Jolt/Shaders/HairUpdateVelocity.hlsl
  48. 20 0
      Jolt/Shaders/HairUpdateVelocityBindings.h
  49. 44 0
      Jolt/Shaders/HairUpdateVelocityIntegrate.hlsl
  50. 21 0
      Jolt/Shaders/HairUpdateVelocityIntegrateBindings.h
  51. 135 0
      Jolt/Shaders/HairWrapper.cpp
  52. 13 0
      Jolt/Shaders/HairWrapper.h
  53. 4 0
      Jolt/Shaders/ShaderCore.h
  54. 1 0
      README.md
  55. 7 0
      Samples/Samples.cmake
  56. 14 0
      Samples/SamplesApp.cpp
  57. 112 0
      Samples/Tests/Hair/HairCollisionTest.cpp
  58. 46 0
      Samples/Tests/Hair/HairCollisionTest.h
  59. 135 0
      Samples/Tests/Hair/HairGravityPreloadTest.cpp
  60. 44 0
      Samples/Tests/Hair/HairGravityPreloadTest.h
  61. 445 0
      Samples/Tests/Hair/HairTest.cpp
  62. 105 0
      Samples/Tests/Hair/HairTest.h

+ 1 - 0
.gitignore

@@ -12,3 +12,4 @@
 /detlog.txt
 /detlog.txt
 *.spv
 *.spv
 *.metallib
 *.metallib
+/Assets/*.hair

+ 8 - 1
Assets/LICENSE

@@ -5,4 +5,11 @@ convex_hulls.bin contains point clouds used as the source for convex hulls
 heightfield1.bin contains a single 'tile' of terrain data
 heightfield1.bin contains a single 'tile' of terrain data
 Human.tof and Human/*.tof contain the Aloy ragdoll setup and a couple of sample animations mapped onto the ragdoll
 Human.tof and Human/*.tof contain the Aloy ragdoll setup and a couple of sample animations mapped onto the ragdoll
 
 
-Permission was granted by Guerrilla Games to release this under the same MIT license as the rest of the project.
+Permission was granted by Guerrilla Games to release this under the same MIT license as the rest of the project.
+
+---
+
+The following assets were taken from www.cemyuksel.com/research/hairmodels:
+
+face.bin - Created by Murat Afshar (https://www.cemyuksel.com/research/hairmodels/woman.zip), rigged in Blender 5.0 and exported to a binary file using export_face.py.
+wCurly.hair, wStraight.hair and wWavy.hair should be downloaded from https://www.cemyuksel.com/research/hairmodels/ (or by running download_hair.sh)

+ 5 - 0
Assets/download_hair.sh

@@ -0,0 +1,5 @@
+#!/bin/sh
+
+curl https://www.cemyuksel.com/research/hairmodels/wStraight.zip | gzip -d > wStraight.hair
+curl https://www.cemyuksel.com/research/hairmodels/wWavy.zip | gzip -d > wWavy.hair
+curl https://www.cemyuksel.com/research/hairmodels/wCurly.zip | gzip -d > wCurly.hair

+ 143 - 0
Assets/export_face.py

@@ -0,0 +1,143 @@
+# Usage: Run in Blender with the correct scene setup to produce a face.bin file:
+# - A collection named "Visual"
+#   - An Armature with Bones and an Animation
+#   - One of the Bones should be called Neck and will be used to parent the hair to
+#   - A mesh skinned to those Bones
+# - A collection named "Collision"
+#   - One or more meshes parent constrained to one of the Bones
+
+import bpy
+import mathutils
+import struct
+
+export_path = "C:\\Users\\jrouw\\Documents\\Code\\JoltPhysics\\Assets\\"
+head_joint = "Neck"
+scale = 0.00254
+basis = mathutils.Matrix([[0, scale, 0, 0], [0, 0, scale, 0], [scale, 0, 0, 0], [0, 0, 0, 1]])
+num_weights = 3
+
+def apply_basis(m):
+	return basis @ m @ basis.inverted()
+
+def export_scene_meshes_to_bin(filepath):
+	with open(filepath, "wb") as f:
+		bpy.context.scene.frame_set(1)
+
+		# Find the skinned mesh in the "Visual" collection
+		visual_coll = bpy.data.collections.get("Visual")
+		assert visual_coll, "No 'Visual' collection found"
+		visual_meshes = [obj for obj in visual_coll.objects if obj.type == 'MESH']
+		assert len(visual_meshes) == 1, "There must be exactly one mesh in the 'Visual' collection"
+		obj = visual_meshes[0]
+		mesh = obj.data
+
+		# Head joint index
+		armature = next((m for m in obj.modifiers if m.type == 'ARMATURE'), None)
+		assert armature
+		bones = armature.object.data.bones
+		head_joint_idx = [b.name for b in bones].index(head_joint)
+		f.write(struct.pack("<I", head_joint_idx))
+
+		# Vertices
+		vertices = [basis @ obj.matrix_world @ v.co for v in mesh.vertices]
+		f.write(struct.pack("<I", len(vertices)))
+		for v in vertices:
+			f.write(struct.pack("<3f", v[0], v[1], v[2]))
+
+		# Indices (triangles)
+		indices = [tuple(p.vertices) for p in mesh.polygons]
+		f.write(struct.pack("<I", len(indices)))
+		for t in indices:
+			f.write(struct.pack("<3I", t[0], t[1], t[2]))
+
+		# Inverse Bind Matrices
+		inv_bind_matrices = [apply_basis(b.matrix_local).inverted() for b in bones]
+		f.write(struct.pack("<I", len(inv_bind_matrices)))
+		for m in inv_bind_matrices:
+			# Write 16 floats (column major)
+			for i in range(4):
+				for j in range(4):
+					f.write(struct.pack("<f", m[j][i]))
+
+		# Skin Weights
+		weights = []
+		for v in mesh.vertices:
+			vw = []
+			total_weight = 0
+			for g in sorted(v.groups, key=lambda g: g.weight, reverse=True)[:num_weights]:
+				bone_index = [b.name for b in bones].index(obj.vertex_groups[g.group].name)
+				vw.append((bone_index, g.weight))
+				total_weight += g.weight
+			# Pad to 1
+			if len(vw) < 1:
+				vw.append((0, 1.0))
+			elif total_weight <= 0:
+				vw = [[v[0], 1.0] for v in vw]
+			else:
+				vw = [[v[0], v[1] / total_weight] for v in vw]
+			weights.append(vw)
+		f.write(struct.pack("<I", num_weights))
+		for vw in weights:
+			# Always write 'num_weights' weights per vertex (pad with zeros if needed)
+			for i in range(num_weights):
+				if i < len(vw):
+					bone_index, weight = vw[i]
+				else:
+					bone_index, weight = 0, 0.0
+				f.write(struct.pack("<If", bone_index, weight))
+
+		# Animation (per frame, per joint)
+		armature_obj = armature.object
+		if armature_obj.animation_data and armature_obj.animation_data.action:
+			action = armature_obj.animation_data.action
+			frame_start = int(action.frame_range[0])
+			frame_end = int(action.frame_range[1])
+			num_frames = frame_end - frame_start + 1
+			f.write(struct.pack("<I", num_frames))
+			for frame in range(frame_start, frame_end + 1):
+				bpy.context.scene.frame_set(frame)
+				for bone in bones:
+					pose_bone = armature.object.pose.bones[bone.name]
+					mat = apply_basis(pose_bone.matrix)
+
+					# Translation from matrix
+					t = mat.to_translation()
+
+					# Rotation quaternion from matrix
+					q = mat.to_quaternion().normalized()
+
+					# Ensure unique quaternion sign: make W positive
+					if q.w < 0.0:
+						q = -q
+
+					# Export translation (x,y,z) + quaternion real part (x,y,z)
+					f.write(struct.pack("<3f", t.x, t.y, t.z))
+					f.write(struct.pack("<3f", q.x, q.y, q.z))
+		else:
+			print("No animation data found on the armature object.")
+		bpy.context.scene.frame_set(1)
+
+		# Export collision hulls
+		collision_coll = bpy.data.collections.get("Collision")
+		assert collision_coll, "No 'Collision' collection found"
+		collision_meshes = [obj for obj in collision_coll.objects if obj.type == 'MESH']
+
+		f.write(struct.pack("<I", len(collision_meshes)))
+		for col_obj in collision_meshes:
+			# Find parent bone name and index
+			parent = col_obj.parent
+			assert parent and parent.type == 'ARMATURE', f"Collision mesh '{col_obj.name}' must be parented to a bone"
+			bone_name = col_obj.parent_bone
+			try:
+				joint_index = [b.name for b in bones].index(bone_name)
+			except ValueError:
+				joint_index = 0xffffffff
+			f.write(struct.pack("<I", joint_index))
+
+			# Write vertices
+			verts = [basis @ col_obj.matrix_world @ v.co for v in col_obj.data.vertices]
+			f.write(struct.pack("<I", len(verts)))
+			for v in verts:
+				f.write(struct.pack("<3f", v[0], v[1], v[2]))
+
+export_scene_meshes_to_bin(export_path + "face.bin")

BIN
Assets/face.bin


+ 1 - 1
Build/cmake_vs2022_cl_32bit.bat

@@ -1,3 +1,3 @@
 @echo off
 @echo off
-cmake -S . -B VS2022_CL_32BIT -G "Visual Studio 17 2022" -A Win32 -DUSE_SSE4_1=OFF -DUSE_SSE4_2=OFF -DUSE_AVX=OFF -DUSE_AVX2=OFF -DUSE_AVX512=OFF -DUSE_LZCNT=OFF -DUSE_TZCNT=OFF -DUSE_F16C=OFF -DUSE_FMADD=OFF -DJPH_USE_VK=OFF %*
+cmake -S . -B VS2022_CL_32BIT -G "Visual Studio 17 2022" -A Win32 -DUSE_SSE4_1=OFF -DUSE_SSE4_2=OFF -DUSE_AVX=OFF -DUSE_AVX2=OFF -DUSE_AVX512=OFF -DUSE_LZCNT=OFF -DUSE_TZCNT=OFF -DUSE_F16C=OFF -DUSE_FMADD=OFF -DJPH_USE_VK=OFF -DJPH_USE_DXC=OFF %*
 echo Open VS2022_CL_32BIT\JoltPhysics.sln to build the project.
 echo Open VS2022_CL_32BIT\JoltPhysics.sln to build the project.

+ 1 - 1
Jolt/Core/StridedPtr.h

@@ -10,7 +10,7 @@ JPH_NAMESPACE_BEGIN
 /// elements that the pointer points to can be part of a larger structure.
 /// elements that the pointer points to can be part of a larger structure.
 /// The stride gives the number of bytes from one element to the next.
 /// The stride gives the number of bytes from one element to the next.
 template <class T>
 template <class T>
-class JPH_EXPORT StridedPtr
+class StridedPtr
 {
 {
 public:
 public:
 	using value_type = T;
 	using value_type = T;

+ 44 - 0
Jolt/Jolt.cmake

@@ -372,6 +372,12 @@ set(JOLT_PHYSICS_SRC_FILES
 	${JOLT_PHYSICS_ROOT}/Physics/DeterminismLog.h
 	${JOLT_PHYSICS_ROOT}/Physics/DeterminismLog.h
 	${JOLT_PHYSICS_ROOT}/Physics/EActivation.h
 	${JOLT_PHYSICS_ROOT}/Physics/EActivation.h
 	${JOLT_PHYSICS_ROOT}/Physics/EPhysicsUpdateError.h
 	${JOLT_PHYSICS_ROOT}/Physics/EPhysicsUpdateError.h
+	${JOLT_PHYSICS_ROOT}/Physics/Hair/Hair.cpp
+	${JOLT_PHYSICS_ROOT}/Physics/Hair/Hair.h
+	${JOLT_PHYSICS_ROOT}/Physics/Hair/HairSettings.cpp
+	${JOLT_PHYSICS_ROOT}/Physics/Hair/HairSettings.h
+	${JOLT_PHYSICS_ROOT}/Physics/Hair/HairShaders.cpp
+	${JOLT_PHYSICS_ROOT}/Physics/Hair/HairShaders.h
 	${JOLT_PHYSICS_ROOT}/Physics/IslandBuilder.cpp
 	${JOLT_PHYSICS_ROOT}/Physics/IslandBuilder.cpp
 	${JOLT_PHYSICS_ROOT}/Physics/IslandBuilder.h
 	${JOLT_PHYSICS_ROOT}/Physics/IslandBuilder.h
 	${JOLT_PHYSICS_ROOT}/Physics/LargeIslandSplitter.cpp
 	${JOLT_PHYSICS_ROOT}/Physics/LargeIslandSplitter.cpp
@@ -436,6 +442,8 @@ set(JOLT_PHYSICS_SRC_FILES
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererRecorder.h
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererRecorder.h
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererSimple.cpp
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererSimple.cpp
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererSimple.h
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererSimple.h
+	${JOLT_PHYSICS_ROOT}/Shaders/HairWrapper.cpp
+	${JOLT_PHYSICS_ROOT}/Shaders/HairWrapper.h
 	${JOLT_PHYSICS_ROOT}/Shaders/TestComputeWrapper.cpp
 	${JOLT_PHYSICS_ROOT}/Shaders/TestComputeWrapper.cpp
 	${JOLT_PHYSICS_ROOT}/Skeleton/SkeletalAnimation.cpp
 	${JOLT_PHYSICS_ROOT}/Skeleton/SkeletalAnimation.cpp
 	${JOLT_PHYSICS_ROOT}/Skeleton/SkeletalAnimation.h
 	${JOLT_PHYSICS_ROOT}/Skeleton/SkeletalAnimation.h
@@ -478,11 +486,47 @@ endif()
 if (JPH_USE_DX12 OR JPH_USE_VK OR JPH_USE_MTL)
 if (JPH_USE_DX12 OR JPH_USE_VK OR JPH_USE_MTL)
 	# Compute shaders
 	# Compute shaders
 	set(JOLT_PHYSICS_SHADERS
 	set(JOLT_PHYSICS_SHADERS
+		${JOLT_PHYSICS_ROOT}/Shaders/HairApplyDeltaTransform.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairApplyGlobalPose.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairCalculateCollisionPlanes.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairCalculateRenderPositions.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairGridAccumulate.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairGridClear.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairGridNormalize.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairIntegrate.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairSkinRoots.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairSkinVertices.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairTeleport.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateRoots.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateStrands.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateVelocity.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateVelocityIntegrate.hlsl
 		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute.hlsl
 		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute.hlsl
 		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute2.hlsl
 		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute2.hlsl
 	)
 	)
 
 
 	set(JOLT_PHYSICS_SHADER_HEADERS
 	set(JOLT_PHYSICS_SHADER_HEADERS
+		${JOLT_PHYSICS_ROOT}/Shaders/HairApplyDeltaTransformBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairApplyGlobalPose.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairApplyGlobalPoseBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairCalculateCollisionPlanesBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairCalculateRenderPositions.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairCalculateRenderPositionsBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairCommon.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairGridAccumulateBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairGridClearBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairGridNormalizeBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairIntegrate.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairIntegrateBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairSkinRootsBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairSkinVerticesBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairStructs.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairTeleportBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateRootsBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateStrandsBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateVelocity.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateVelocityBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/HairUpdateVelocityIntegrateBindings.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderCore.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderCore.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderMat44.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderMat44.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderMath.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderMath.h

+ 1033 - 0
Jolt/Physics/Hair/Hair.cpp

@@ -0,0 +1,1033 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#include <Jolt/Physics/Hair/Hair.h>
+#include <Jolt/Physics/Hair/HairShaders.h>
+#include <Jolt/Physics/Collision/Shape/ConvexHullShape.h>
+#include <Jolt/Physics/Collision/Shape/ScaleHelpers.h>
+#include <Jolt/Physics/PhysicsSystem.h>
+#include <Jolt/Core/Profiler.h>
+#ifdef JPH_DEBUG_RENDERER
+	#include <Jolt/Renderer/DebugRenderer.h>
+#endif
+
+JPH_NAMESPACE_BEGIN
+
+Hair::Hair(const HairSettings *inSettings, RVec3Arg inPosition, QuatArg inRotation, ObjectLayer inLayer) :
+	mSettings(inSettings),
+	mPrevPosition(inPosition),
+	mPosition(inPosition),
+	mPrevRotation(inRotation),
+	mRotation(inRotation),
+	mLayer(inLayer)
+{
+}
+
+Hair::~Hair()
+{
+	// Delete debug data
+	if (mPositions != nullptr)
+		delete [] mPositions;
+	if (mRotations != nullptr)
+		delete [] mRotations;
+	if (mVelocities != nullptr)
+		delete [] mVelocities;
+	if (mRenderPositionsOverridden)
+		delete [] mRenderPositions;
+}
+
+void Hair::Init(ComputeSystem *inComputeSystem)
+{
+	// Create compute buffers
+	size_t num_vertices_padded = mSettings->GetNumVerticesPadded();
+	size_t grid_size = mSettings->mNeutralDensity.size();
+	size_t num_render_vertices = mSettings->mRenderVertices.size();
+
+	if (!mSettings->mScalpInverseBindPose.empty() && !mSettings->mScalpVertices.empty())
+	{
+		mScalpJointMatricesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, mSettings->mScalpInverseBindPose.size() * sizeof(Mat44), sizeof(Mat44)).Get();
+		mScalpVerticesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, mSettings->mScalpVertices.size(), sizeof(Float3)).Get();
+		mScalpTrianglesCB = mSettings->mScalpTrianglesCB;
+	}
+
+	if (mScalpVerticesCB != nullptr)
+	{
+		mGlobalPoseTransformsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, mSettings->mSimStrands.size(), sizeof(JPH_HairGlobalPoseTransform)).Get();
+	}
+	else
+	{
+		// No vertices provided externally and none in settings, use identity transforms
+		JPH_HairGlobalPoseTransform identity;
+		identity.mPosition = JPH_float3(0, 0, 0);
+		identity.mRotation = JPH_float4(0, 0, 0, 1);
+		Array<JPH_HairGlobalPoseTransform> identity_array(mSettings->mSimStrands.size(), identity);
+		mGlobalPoseTransformsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, mSettings->mSimStrands.size(), sizeof(JPH_HairGlobalPoseTransform), identity_array.data()).Get();
+	}
+
+	mCollisionPlanesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, num_vertices_padded, sizeof(JPH_HairCollisionPlane)).Get();
+	mMaterialsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, mSettings->mMaterials.size(), sizeof(JPH_HairMaterial)).Get();
+	mPreviousPositionsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, num_vertices_padded, sizeof(JPH_HairPosition)).Get();
+	mPositionsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, num_vertices_padded, sizeof(JPH_HairPosition)).Get();
+	mVelocitiesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, num_vertices_padded, sizeof(JPH_HairVelocity)).Get();
+	mConstantsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(JPH_HairUpdateContext)).Get();
+	mVelocityAndDensityCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, grid_size, sizeof(Float4)).Get();
+	if (!mRenderPositionsOverridden)
+		mRenderPositionsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, num_render_vertices, sizeof(Float3)).Get();
+}
+
+void Hair::InitializeContext(UpdateContext &outCtx, float inDeltaTime, const PhysicsSystem &inSystem)
+{
+	float clamped_delta_time = min(inDeltaTime, mSettings->mMaxDeltaTime);
+	outCtx.mNumIterations = (uint)std::round(clamped_delta_time * mSettings->mNumIterationsPerSecond);
+	outCtx.mDeltaTime = outCtx.mNumIterations > 0? clamped_delta_time / outCtx.mNumIterations : 0.0f;
+	outCtx.mTimeRatio = outCtx.mDeltaTime * float(HairSettings::cDefaultIterationsPerSecond);
+	outCtx.mHalfDeltaTime = 0.5f * outCtx.mDeltaTime;
+	outCtx.mInvDeltaTimeSq = outCtx.mDeltaTime > 0.0f? 1.0f / Square(outCtx.mDeltaTime) : 1.0e12f;
+	outCtx.mTwoDivDeltaTime = outCtx.mDeltaTime > 0.0f? 2.0f / outCtx.mDeltaTime : 1.0e12f;
+	outCtx.mSubStepGravity = (mRotation.Conjugated() * inSystem.GetGravity()) * outCtx.mDeltaTime;
+
+	// Calculate delta transform from previous to current position and rotation
+	outCtx.mHasTransformChanged = mPosition != mPrevPosition || mRotation != mPrevRotation;
+	RMat44 prev_com = RMat44::sRotationTranslation(mPrevRotation, mPrevPosition);
+	outCtx.mDeltaTransform = (GetWorldTransform().InversedRotationTranslation() * prev_com).ToMat44();
+	outCtx.mDeltaTransformQuat = outCtx.mDeltaTransform.GetQuaternion();
+	mPrevPosition = mPosition;
+	mPrevRotation = mRotation;
+
+	// Check if we need collision detection / grid
+	outCtx.mNeedsCollision = false;
+	outCtx.mNeedsGrid = false;
+	outCtx.mGlobalPoseOnly = true;
+	for (const HairSettings::Material &material : mSettings->mMaterials)
+	{
+		outCtx.mNeedsCollision |= material.mEnableCollision;
+		outCtx.mNeedsGrid |= material.NeedsGrid();
+		outCtx.mGlobalPoseOnly &= material.GlobalPoseOnly();
+	}
+
+	if (outCtx.mNeedsCollision)
+	{
+		struct Collector : public CollideShapeBodyCollector
+		{
+										Collector(const PhysicsSystem &inSystem, RMat44Arg inTransform, const AABox &inLocalBounds, Array<LeafShape> &ioHits) :
+											mSystem(inSystem),
+											mTransform(inTransform),
+											mInverseTransform(inTransform.InversedRotationTranslation()),
+											mLocalBounds(inLocalBounds),
+											mHits(ioHits)
+			{
+			}
+
+			virtual void				AddHit(const BodyID &inResult) override
+			{
+				BodyLockRead lock(mSystem.GetBodyLockInterface(), inResult);
+				if (lock.Succeeded())
+				{
+					const Body &body = lock.GetBody();
+					if (body.IsRigidBody()
+						&& !body.IsSensor())
+					{
+						// Calculate transform of this body relative to the hair instance
+						Mat44 com = (mInverseTransform * body.GetCenterOfMassTransform()).ToMat44();
+
+						// Collect leaf shapes
+						struct LeafShapeCollector : public TransformedShapeCollector
+						{
+												LeafShapeCollector(RMat44Arg inHeadTransform, const Body &inBody, Array<LeafShape> &ioHits) : mHeadTransform(inHeadTransform), mBody(inBody), mHits(ioHits) { }
+
+							virtual void		AddHit(const TransformedShape &inResult) override
+							{
+								mHits.emplace_back(Mat44::sRotationTranslation(inResult.mShapeRotation, Vec3(inResult.mShapePositionCOM)),
+									inResult.GetShapeScale(),
+									mHeadTransform.Multiply3x3Transposed(mBody.GetPointVelocity(mHeadTransform * inResult.mShapePositionCOM)), // Calculate velocity of shape at its center of mass position
+									mHeadTransform.Multiply3x3Transposed(mBody.GetAngularVelocity()),
+									inResult.mShape);
+							}
+
+							RMat44				mHeadTransform;
+							const Body &		mBody;
+							Array<LeafShape> &	mHits;
+						};
+						LeafShapeCollector collector(mTransform, body, mHits);
+						body.GetShape()->CollectTransformedShapes(mLocalBounds, com.GetTranslation(), com.GetQuaternion(), Vec3::sOne(), SubShapeIDCreator(), collector, { });
+					}
+				}
+			}
+
+		private:
+			const PhysicsSystem &		mSystem;
+			RMat44						mTransform;
+			RMat44						mInverseTransform;
+			AABox						mLocalBounds;
+			Array<LeafShape> &			mHits;
+		};
+
+		// Calculate world space bounding box
+		RMat44 transform = GetWorldTransform();
+		AABox world_bounds = mSettings->mSimulationBounds.Transformed(transform);
+
+		// Collect shapes that intersect with the bounding box
+		Collector collector(inSystem, transform, mSettings->mSimulationBounds, outCtx.mShapes);
+		DefaultBroadPhaseLayerFilter broadphase_layer_filter = inSystem.GetDefaultBroadPhaseLayerFilter(mLayer);
+		DefaultObjectLayerFilter object_layer_filter = inSystem.GetDefaultLayerFilter(mLayer);
+		inSystem.GetBroadPhaseQuery().CollideAABox(world_bounds, collector, broadphase_layer_filter, object_layer_filter);
+
+		// If no shapes were found, we don't need collision
+		if (outCtx.mShapes.empty())
+			outCtx.mNeedsCollision = false;
+	}
+}
+
+void Hair::Update(float inDeltaTime, Mat44Arg inJointToHair, const Mat44 *inJointMatrices, const PhysicsSystem &inSystem, const HairShaders &inShaders, ComputeSystem *inComputeSystem, ComputeQueue *inComputeQueue)
+{
+	UpdateContext ctx;
+	InitializeContext(ctx, inDeltaTime, inSystem);
+
+	if (inJointMatrices != nullptr && mScalpJointMatricesCB != nullptr)
+	{
+		JPH_PROFILE("Prepare for Skinning");
+
+		Mat44 *joints = mScalpJointMatricesCB->Map<Mat44>(ComputeBuffer::EMode::Write);
+		mSettings->PrepareForScalpSkinning(inJointToHair, inJointMatrices, joints);
+		mScalpJointMatricesCB->Unmap();
+	}
+
+	if (ctx.mNeedsCollision)
+	{
+		JPH_PROFILE("Create Collision Shapes");
+
+		// First determine buffer sizes
+		uint num_shapes = 0;
+		uint num_faces = 0;
+		uint num_vertices = 0;
+		uint num_header = 0;
+		uint num_indices = 0;
+		uint max_vertices_per_face = 0;
+		uint max_points = 0;
+		for (const LeafShape &shape : ctx.mShapes)
+			if (shape.mShape->GetSubType() == EShapeSubType::ConvexHull)
+			{
+				const ConvexHullShape *ch = static_cast<const ConvexHullShape *>(shape.mShape.GetPtr());
+				++num_shapes;
+				++num_header; // Write number of vertices
+				uint np = ch->GetNumPoints();
+				max_points = max(max_points, np);
+				num_vertices += np;
+				uint nf = ch->GetNumFaces();
+				num_faces += nf;
+				for (uint f = 0; f < nf; ++f)
+				{
+					num_header += 2; // Write indices start + end
+					uint num_vertices_in_face = ch->GetNumVerticesInFace(f);
+					num_indices += num_vertices_in_face;
+					max_vertices_per_face = max(max_vertices_per_face, num_vertices_in_face);
+				}
+			}
+		++num_header; // Terminator
+		num_indices += num_header;
+
+		// Now allocate buffers
+		if (mCollisionShapesCB == nullptr || mCollisionShapesCB->GetSize() < num_shapes)
+		{
+			mCollisionShapesCB = nullptr;
+			mCollisionShapesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, num_shapes, sizeof(JPH_HairCollisionShape)).Get();
+		}
+		if (mShapePlanesCB == nullptr || mShapePlanesCB->GetSize() < num_faces)
+		{
+			mShapePlanesCB = nullptr;
+			mShapePlanesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, max(num_faces, 1u), sizeof(Float4)).Get();
+		}
+		if (mShapeVerticesCB == nullptr || mShapeVerticesCB->GetSize() < num_vertices)
+		{
+			mShapeVerticesCB = nullptr;
+			mShapeVerticesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, max(num_vertices, 1u), sizeof(Float3)).Get();
+		}
+		if (mShapeIndicesCB == nullptr || mShapeIndicesCB->GetSize() < num_indices)
+		{
+			mShapeIndicesCB = nullptr;
+			mShapeIndicesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, num_indices, sizeof(uint32)).Get();
+		}
+
+		JPH_HairCollisionShape *collision_shapes = mCollisionShapesCB->Map<JPH_HairCollisionShape>(ComputeBuffer::EMode::Write);
+		Float4 *shape_planes = mShapePlanesCB->Map<Float4>(ComputeBuffer::EMode::Write);
+		Float3 *shape_vertices = mShapeVerticesCB->Map<Float3>(ComputeBuffer::EMode::Write);
+		uint32 *shape_indices = mShapeIndicesCB->Map<uint32>(ComputeBuffer::EMode::Write);
+		uint *face_indices = (uint *)JPH_STACK_ALLOC(max_vertices_per_face * sizeof(uint));
+		Vec3 *points = (Vec3 *)JPH_STACK_ALLOC(max_points * sizeof(Vec3));
+
+		// Convert the hulls to compute buffers
+		Float4 *sp = shape_planes;
+		Float3 *sv = shape_vertices;
+		uint32 *sh = shape_indices;
+		JPH_HairCollisionShape *cs = collision_shapes;
+		uint32 *si = shape_indices + num_header;
+		for (const LeafShape &shape : ctx.mShapes)
+			if (shape.mShape->GetSubType() == EShapeSubType::ConvexHull)
+			{
+				const ConvexHullShape *ch = static_cast<const ConvexHullShape *>(shape.mShape.GetPtr());
+
+				// Store collision shape
+				shape.mTransform.GetTranslation().StoreFloat3(&cs->mCenterOfMass);
+				shape.mLinearVelocity.StoreFloat3(&cs->mLinearVelocity);
+				shape.mAngularVelocity.StoreFloat3(&cs->mAngularVelocity);
+				++cs;
+
+				// Store points transformed to hair space
+				Mat44 shape_transform = shape.mTransform.PreScaled(shape.mScale);
+				uint first_vertex_index = uint(sv - shape_vertices);
+				for (uint p = 0, np = ch->GetNumPoints(); p < np; ++p)
+				{
+					Vec3 v = shape_transform * ch->GetPoint(p);
+					points[p] = v; // Store points in a temporary buffer so we avoid reading from GPU memory
+					v.StoreFloat3(sv);
+					++sv;
+				}
+
+				// Store number of faces
+				uint nf = ch->GetNumFaces();
+				*sh = nf;
+				++sh;
+
+				// Store the indices
+				if (ScaleHelpers::IsInsideOut(shape.mScale))
+				{
+					// Reverse winding order
+					for (uint f = 0; f < nf; ++f)
+					{
+						// Store indices
+						uint nv = ch->GetFaceVertices(f, max_vertices_per_face, face_indices);
+						uint32 indices_start = uint32(si - shape_indices);
+						*sh = indices_start;
+						++sh;
+						*sh = indices_start + nv;
+						++sh;
+						for (int v = int(nv) - 1; v >= 0; --v, ++si)
+							*si = face_indices[v] + first_vertex_index;
+
+						// Calculate plane (avoids reading from GPU memory)
+						Plane::sFromPointsCCW(points[face_indices[2]], points[face_indices[1]], points[face_indices[0]]).StoreFloat4(sp);
+						++sp;
+					}
+				}
+				else
+				{
+					// Keep winding order
+					for (uint f = 0; f < nf; ++f)
+					{
+						// Store indices
+						uint nv = ch->GetFaceVertices(f, max_vertices_per_face, face_indices);
+						uint32 indices_start = uint32(si - shape_indices);
+						*sh++ = indices_start;
+						*sh++ = indices_start + nv;
+						for (uint v = 0; v < nv; ++v)
+							*si++ = face_indices[v] + first_vertex_index;
+
+						// Calculate plane (avoids reading from GPU memory)
+						Plane::sFromPointsCCW(points[face_indices[0]], points[face_indices[1]], points[face_indices[2]]).StoreFloat4(sp);
+						++sp;
+					}
+				}
+			}
+		*sh = 0; // Terminator
+		++sh;
+		JPH_ASSERT(uint(cs - collision_shapes) == num_shapes);
+		JPH_ASSERT(uint(sp - shape_planes) == num_faces);
+		JPH_ASSERT(uint(sv - shape_vertices) == num_vertices);
+		JPH_ASSERT(uint(sh - shape_indices) == num_header);
+		JPH_ASSERT(uint(si - shape_indices) == num_indices);
+
+		// Unmap buffers
+		mCollisionShapesCB->Unmap();
+		mShapePlanesCB->Unmap();
+		mShapeVerticesCB->Unmap();
+		mShapeIndicesCB->Unmap();
+	}
+
+	{
+		JPH_PROFILE("Set materials");
+
+		JPH_HairMaterial *materials = mMaterialsCB->Map<JPH_HairMaterial>(ComputeBuffer::EMode::Write);
+		for (size_t i = 0, n = mSettings->mMaterials.size(); i < n; ++i)
+		{
+			const HairSettings::Material &m_in = mSettings->mMaterials[i];
+			JPH_HairMaterial &m_out = materials[i];
+
+			GradientSampler world_transform_influence(m_in.mWorldTransformInfluence);
+			m_out.mWorldTransformInfluence = world_transform_influence.ToFloat4();
+			GradientSampler global_pose(ctx.mGlobalPoseOnly? m_in.mGlobalPose : m_in.mGlobalPose.MakeStepDependent(ctx.mTimeRatio));
+			m_out.mGlobalPose = global_pose.ToFloat4();
+			GradientSampler global_pose_skin_to_root(m_in.mSkinGlobalPose);
+			m_out.mSkinGlobalPose = global_pose_skin_to_root.ToFloat4();
+			GradientSampler gravity_factor(m_in.mGravityFactor);
+			m_out.mGravityFactor = gravity_factor.ToFloat4();
+			GradientSampler hair_radius(m_in.mHairRadius);
+			m_out.mHairRadius = hair_radius.ToFloat4();
+			m_out.mBendComplianceMultiplier = m_in.mBendComplianceMultiplier;
+			GradientSampler grid_velocity_factor(m_in.mGridVelocityFactor.MakeStepDependent(ctx.mTimeRatio));
+			m_out.mGridVelocityFactor = grid_velocity_factor.ToFloat4();
+			m_out.mEnableCollision = ctx.mNeedsCollision && m_in.mEnableCollision? 1 : 0;
+			m_out.mEnableLRA = m_in.mEnableLRA? 1 : 0;
+			m_out.mEnableGrid = m_in.mGridVelocityFactor.mMin != 0.0f || m_in.mGridVelocityFactor.mMax != 0.0f || m_in.mGridDensityForceFactor != 0.0f;
+			m_out.mFriction = m_in.mFriction;
+			m_out.mExpLinearDampingDeltaTime = std::exp(-m_in.mLinearDamping * ctx.mDeltaTime);
+			m_out.mExpAngularDampingDeltaTime = std::exp(-m_in.mAngularDamping * ctx.mDeltaTime);
+			m_out.mBendComplianceInvDeltaTimeSq = m_in.mBendCompliance * ctx.mInvDeltaTimeSq;
+			m_out.mStretchComplianceInvDeltaTimeSq = m_in.mStretchCompliance * ctx.mInvDeltaTimeSq;
+			m_out.mGridDensityForceFactor = m_in.mGridDensityForceFactor;
+			m_out.mInertiaMultiplier = m_in.mInertiaMultiplier;
+			m_out.mMaxLinearVelocitySq = Square(m_in.mMaxLinearVelocity);
+			m_out.mMaxAngularVelocitySq = Square(m_in.mMaxAngularVelocity);
+		}
+		mMaterialsCB->Unmap();
+	}
+
+	{
+		JPH_PROFILE("Set constants");
+
+		JPH_HairUpdateContext *cdata = mConstantsCB->Map<JPH_HairUpdateContext>(ComputeBuffer::EMode::Write);
+		cdata->cNumStrands = uint32(mSettings->mSimStrands.size());
+		cdata->cNumVertices = mSettings->GetNumVerticesPadded();
+		cdata->cNumGridPoints = (uint32)mSettings->mNeutralDensity.size();
+		cdata->cNumRenderVertices = (uint)mSettings->mRenderVertices.size();
+		HairSettings::GridSampler grid_sampler(mSettings);
+		memcpy(&cdata->cGridSizeMin2, &grid_sampler.mGridSizeMin2, 3 * sizeof(float));
+		cdata->cTwoDivDeltaTime = ctx.mTwoDivDeltaTime;
+		grid_sampler.mGridSizeMin1.StoreFloat3(&cdata->cGridSizeMin1);
+		cdata->cDeltaTime = ctx.mDeltaTime;
+		grid_sampler.mOffset.StoreFloat3(&cdata->cGridOffset);
+		cdata->cHalfDeltaTime = ctx.mHalfDeltaTime;
+		grid_sampler.mScale.StoreFloat3(&cdata->cGridScale);
+		cdata->cInvDeltaTimeSq = ctx.mInvDeltaTimeSq;
+		ctx.mSubStepGravity.StoreFloat3(&cdata->cSubStepGravity);
+		cdata->cNumSkinVertices = (uint)mSettings->mScalpVertices.size();
+		memcpy(&cdata->cGridStride, &grid_sampler.mGridStride, 3 * sizeof(uint32));
+		cdata->cNumSkinWeightsPerVertex = mSettings->mScalpNumSkinWeightsPerVertex;
+		for (int i = 0; i < 4; ++i)
+			ctx.mDeltaTransform.GetColumn4(i).StoreFloat4(&cdata->cDeltaTransform[i]);
+		for (int i = 0; i < 4; ++i)
+			mScalpToHead.GetColumn4(i).StoreFloat4(&cdata->cScalpToHead[i]);
+		ctx.mDeltaTransformQuat.StoreFloat4(&cdata->cDeltaTransformQuat);
+		mConstantsCB->Unmap();
+	}
+
+	{
+		JPH_PROFILE("Set iteration constants");
+
+		// Ensure that we have the right number of constant buffers allocated
+		uint old_size = uint(mIterationConstantsCB.size());
+		if (old_size < ctx.mNumIterations)
+		{
+			mIterationConstantsCB.resize(ctx.mNumIterations);
+			for (uint i = old_size; i < ctx.mNumIterations; ++i)
+				mIterationConstantsCB[i] = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(JPH_HairIterationContext)).Get();
+		}
+
+		// Fill in the constant buffers
+		JPH_HairIterationContext iteration_data;
+		for (uint i = 0; i < ctx.mNumIterations; ++i)
+		{
+			iteration_data.cAccumulatedDeltaTime = ctx.mDeltaTime * (i + 1);
+			iteration_data.cIterationFraction = 1.0f / float(ctx.mNumIterations - i);
+
+			JPH_HairIterationContext *idata = mIterationConstantsCB[i]->Map<JPH_HairIterationContext>(ComputeBuffer::EMode::Write);
+			*idata = iteration_data;
+			mIterationConstantsCB[i]->Unmap();
+		}
+	}
+
+	{
+		JPH_PROFILE("Queue Compute");
+
+		uint dispatch_per_vertex = (mSettings->GetNumVerticesPadded() + cHairPerVertexBatch - 1) / cHairPerVertexBatch;
+		uint dispatch_per_vertex_skip_first_vertex = (mSettings->GetNumVerticesPadded() - (uint)mSettings->mSimStrands.size() + cHairPerVertexBatch - 1) / cHairPerVertexBatch; // Skip the first vertex of each strand
+		uint dispatch_per_grid_cell = uint((mSettings->mNeutralDensity.size() + cHairPerGridCellBatch - 1) / cHairPerGridCellBatch);
+		uint dispatch_per_strand = uint((mSettings->mSimStrands.size() + cHairPerStrandBatch - 1) / cHairPerStrandBatch);
+		uint dispatch_per_render_vertex = uint((mSettings->mRenderVertices.size() + cHairPerRenderVertexBatch - 1) / cHairPerRenderVertexBatch);
+
+		bool was_teleported = mTeleported;
+		mTeleported = false;
+		if (was_teleported)
+		{
+			// Initialize positions and velocities
+			inComputeQueue->SetShader(inShaders.mTeleportCS);
+			inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+			inComputeQueue->SetBuffer("gInitialPositions", mSettings->mVerticesPositionCB);
+			inComputeQueue->SetBuffer("gInitialBishops", mSettings->mVerticesBishopCB);
+			inComputeQueue->SetRWBuffer("gPositions", mPositionsCB);
+			inComputeQueue->SetRWBuffer("gVelocities", mVelocitiesCB);
+			inComputeQueue->Dispatch(dispatch_per_vertex);
+		}
+		else if (!ctx.mGlobalPoseOnly && ctx.mHasTransformChanged)
+		{
+			// Apply delta transform
+			inComputeQueue->SetShader(inShaders.mApplyDeltaTransformCS);
+			inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+			inComputeQueue->SetBuffer("gVerticesFixed", mSettings->mVerticesFixedCB);
+			inComputeQueue->SetBuffer("gStrandFractions", mSettings->mVerticesStrandFractionCB);
+			inComputeQueue->SetBuffer("gMaterials", mMaterialsCB);
+			inComputeQueue->SetBuffer("gStrandMaterialIndex", mSettings->mStrandMaterialIndexCB);
+			inComputeQueue->SetRWBuffer("gPositions", mPositionsCB);
+			inComputeQueue->SetRWBuffer("gVelocities", mVelocitiesCB);
+			inComputeQueue->Dispatch(dispatch_per_vertex_skip_first_vertex);
+		}
+
+		if (mScalpJointMatricesCB != nullptr)
+		{
+			// Skin the scalp mesh
+			inComputeQueue->SetShader(inShaders.mSkinVerticesCS);
+			inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+			inComputeQueue->SetBuffer("gScalpVertices", mSettings->mScalpVerticesCB);
+			inComputeQueue->SetBuffer("gScalpSkinWeights", mSettings->mScalpSkinWeightsCB);
+			inComputeQueue->SetBuffer("gScalpJointMatrices", mScalpJointMatricesCB);
+			inComputeQueue->SetRWBuffer("gScalpVerticesOut", mScalpVerticesCB);
+			inComputeQueue->Dispatch(uint((mSettings->mScalpVertices.size() + cHairPerVertexBatch - 1) / cHairPerVertexBatch));
+		}
+
+		if (mScalpVerticesCB != nullptr)
+		{
+			// Determine if we directly write to the position / transform buffers or if we need to interpolate
+			bool needs_interpolate = !ctx.mGlobalPoseOnly && !was_teleported;
+
+			// Create target buffers if they don't exist yet
+			if (mTargetPositionsCB == nullptr && needs_interpolate)
+			{
+				mTargetPositionsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, mSettings->mSimStrands.size(), sizeof(JPH_HairPosition)).Get();
+				mTargetGlobalPoseTransformsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, mSettings->mSimStrands.size(), sizeof(JPH_HairGlobalPoseTransform)).Get();
+			}
+
+			// Skin the strand roots to the scalp mesh
+			inComputeQueue->SetShader(inShaders.mSkinRootsCS);
+			inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+			inComputeQueue->SetBuffer("gSkinPoints", mSettings->mSkinPointsCB);
+			inComputeQueue->SetBuffer("gScalpVertices", mScalpVerticesCB);
+			inComputeQueue->SetBuffer("gScalpTriangles", mScalpTrianglesCB);
+			inComputeQueue->SetBuffer("gInitialPositions", mSettings->mVerticesPositionCB);
+			inComputeQueue->SetBuffer("gInitialBishops", mSettings->mVerticesBishopCB);
+			inComputeQueue->SetRWBuffer("gPositions", needs_interpolate? mTargetPositionsCB : mPositionsCB);
+			inComputeQueue->SetRWBuffer("gGlobalPoseTransforms", needs_interpolate? mTargetGlobalPoseTransformsCB : mGlobalPoseTransformsCB);
+			inComputeQueue->Dispatch(dispatch_per_strand);
+		}
+
+		if (ctx.mGlobalPoseOnly)
+		{
+			// Only run global pose logic
+			inComputeQueue->SetShader(inShaders.mApplyGlobalPoseCS);
+			inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+			inComputeQueue->SetBuffer("gVerticesFixed", mSettings->mVerticesFixedCB);
+			inComputeQueue->SetBuffer("gStrandFractions", mSettings->mVerticesStrandFractionCB);
+			inComputeQueue->SetBuffer("gInitialPositions", mSettings->mVerticesPositionCB);
+			inComputeQueue->SetBuffer("gInitialBishops", mSettings->mVerticesBishopCB);
+			inComputeQueue->SetBuffer("gStrandMaterialIndex", mSettings->mStrandMaterialIndexCB);
+			inComputeQueue->SetBuffer("gMaterials", mMaterialsCB);
+			inComputeQueue->SetBuffer("gGlobalPoseTransforms", mGlobalPoseTransformsCB);
+			inComputeQueue->SetRWBuffer("gPositions", mPositionsCB);
+			inComputeQueue->Dispatch(dispatch_per_vertex_skip_first_vertex);
+		}
+		else if (ctx.mNumIterations > 0)
+		{
+			if (ctx.mNeedsCollision)
+			{
+				// Calculate collision planes
+				inComputeQueue->SetShader(inShaders.mCalculateCollisionPlanesCS);
+				inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+				inComputeQueue->SetBuffer("gPositions", mPositionsCB);
+				inComputeQueue->SetBuffer("gShapePlanes", mShapePlanesCB);
+				inComputeQueue->SetBuffer("gShapeVertices", mShapeVerticesCB);
+				inComputeQueue->SetBuffer("gShapeIndices", mShapeIndicesCB);
+				inComputeQueue->SetRWBuffer("gCollisionPlanes", mCollisionPlanesCB);
+				inComputeQueue->Dispatch(dispatch_per_vertex_skip_first_vertex);
+			}
+
+			if (ctx.mNeedsGrid)
+			{
+				// Clear the grid
+				inComputeQueue->SetShader(inShaders.mGridClearCS);
+				inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+				inComputeQueue->SetRWBuffer("gVelocityAndDensity", mVelocityAndDensityCB);
+				inComputeQueue->Dispatch(dispatch_per_grid_cell);
+
+				// Accumulate vertices into the grid
+				inComputeQueue->SetShader(inShaders.mGridAccumulateCS);
+				inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+				inComputeQueue->SetBuffer("gVerticesFixed", mSettings->mVerticesFixedCB);
+				inComputeQueue->SetBuffer("gPositions", mPositionsCB);
+				inComputeQueue->SetBuffer("gVelocities", mVelocitiesCB);
+				inComputeQueue->SetRWBuffer("gVelocityAndDensity", mVelocityAndDensityCB);
+				inComputeQueue->Dispatch(dispatch_per_vertex_skip_first_vertex);
+
+				// Normalize velocities in the grid
+				inComputeQueue->SetShader(inShaders.mGridNormalizeCS);
+				inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+				inComputeQueue->SetRWBuffer("gVelocityAndDensity", mVelocityAndDensityCB);
+				inComputeQueue->Dispatch(dispatch_per_grid_cell);
+			}
+
+			// First integrate
+			inComputeQueue->SetShader(inShaders.mIntegrateCS);
+			inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+			inComputeQueue->SetBuffer("gVerticesFixed", mSettings->mVerticesFixedCB);
+			inComputeQueue->SetBuffer("gStrandFractions", mSettings->mVerticesStrandFractionCB);
+			inComputeQueue->SetBuffer("gNeutralDensity", mSettings->mNeutralDensityCB);
+			inComputeQueue->SetBuffer("gVelocityAndDensity", mVelocityAndDensityCB);
+			inComputeQueue->SetBuffer("gStrandMaterialIndex", mSettings->mStrandMaterialIndexCB);
+			inComputeQueue->SetBuffer("gMaterials", mMaterialsCB);
+			inComputeQueue->SetBuffer("gVelocities", mVelocitiesCB);
+			inComputeQueue->SetRWBuffer("gPositions", mPositionsCB);
+			inComputeQueue->SetRWBuffer("gPreviousPositions", mPreviousPositionsCB);
+			inComputeQueue->Dispatch(dispatch_per_vertex_skip_first_vertex);
+
+			for (uint it = 0; it < ctx.mNumIterations; ++it)
+			{
+				if (mTargetPositionsCB != nullptr && !was_teleported)
+				{
+					// Update skinned roots for this iteration (interpolate them towards the target positions)
+					inComputeQueue->SetShader(inShaders.mUpdateRootsCS);
+					inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+					inComputeQueue->SetConstantBuffer("gIterationContext", mIterationConstantsCB[it]);
+					inComputeQueue->SetBuffer("gTargetPositions", mTargetPositionsCB);
+					inComputeQueue->SetBuffer("gTargetGlobalPoseTransforms", mTargetGlobalPoseTransformsCB);
+					inComputeQueue->SetRWBuffer("gPositions", mPositionsCB);
+					inComputeQueue->SetRWBuffer("gGlobalPoseTransforms", mGlobalPoseTransformsCB);
+					inComputeQueue->Dispatch(dispatch_per_strand);
+				}
+
+				// Then update the constraints per strand
+				inComputeQueue->SetShader(inShaders.mUpdateStrandsCS);
+				inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+				inComputeQueue->SetBuffer("gVerticesFixed", mSettings->mVerticesFixedCB);
+				inComputeQueue->SetBuffer("gStrandFractions", mSettings->mVerticesStrandFractionCB);
+				inComputeQueue->SetBuffer("gInitialPositions", mSettings->mVerticesPositionCB);
+				inComputeQueue->SetBuffer("gOmega0s", mSettings->mVerticesOmega0CB);
+				inComputeQueue->SetBuffer("gInitialLengths", mSettings->mVerticesLengthCB);
+				inComputeQueue->SetBuffer("gStrandVertexCounts", mSettings->mStrandVertexCountsCB);
+				inComputeQueue->SetBuffer("gStrandMaterialIndex", mSettings->mStrandMaterialIndexCB);
+				inComputeQueue->SetBuffer("gMaterials", mMaterialsCB);
+				inComputeQueue->SetRWBuffer("gPositions", mPositionsCB);
+				inComputeQueue->Dispatch(dispatch_per_strand);
+
+				if (it == ctx.mNumIterations - 1)
+				{
+					// Last iteration: only update velocities
+					inComputeQueue->SetShader(inShaders.mUpdateVelocityCS);
+					inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+					inComputeQueue->SetConstantBuffer("gIterationContext", mIterationConstantsCB[it]);
+					inComputeQueue->SetBuffer("gVerticesFixed", mSettings->mVerticesFixedCB);
+					inComputeQueue->SetBuffer("gStrandFractions", mSettings->mVerticesStrandFractionCB);
+					inComputeQueue->SetBuffer("gInitialPositions", mSettings->mVerticesPositionCB);
+					inComputeQueue->SetBuffer("gInitialBishops", mSettings->mVerticesBishopCB);
+					inComputeQueue->SetBuffer("gStrandMaterialIndex", mSettings->mStrandMaterialIndexCB);
+					inComputeQueue->SetBuffer("gMaterials", mMaterialsCB);
+					inComputeQueue->SetBuffer("gPreviousPositions", mPreviousPositionsCB);
+					inComputeQueue->SetBuffer("gGlobalPoseTransforms", mGlobalPoseTransformsCB);
+					inComputeQueue->SetBuffer("gCollisionShapes", mCollisionShapesCB);
+					inComputeQueue->SetBuffer("gCollisionPlanes", mCollisionPlanesCB);
+					inComputeQueue->SetRWBuffer("gPositions", mPositionsCB);
+					inComputeQueue->SetRWBuffer("gVelocities", mVelocitiesCB);
+					inComputeQueue->Dispatch(dispatch_per_vertex_skip_first_vertex);
+				}
+				else
+				{
+					// Other iterations: update velocities then integrate again
+					inComputeQueue->SetShader(inShaders.mUpdateVelocityIntegrateCS);
+					inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+					inComputeQueue->SetConstantBuffer("gIterationContext", mIterationConstantsCB[it]);
+					inComputeQueue->SetBuffer("gVerticesFixed", mSettings->mVerticesFixedCB);
+					inComputeQueue->SetBuffer("gStrandFractions", mSettings->mVerticesStrandFractionCB);
+					inComputeQueue->SetBuffer("gInitialPositions", mSettings->mVerticesPositionCB);
+					inComputeQueue->SetBuffer("gInitialBishops", mSettings->mVerticesBishopCB);
+					inComputeQueue->SetBuffer("gNeutralDensity", mSettings->mNeutralDensityCB);
+					inComputeQueue->SetBuffer("gVelocityAndDensity", mVelocityAndDensityCB);
+					inComputeQueue->SetBuffer("gStrandMaterialIndex", mSettings->mStrandMaterialIndexCB);
+					inComputeQueue->SetBuffer("gMaterials", mMaterialsCB);
+					inComputeQueue->SetBuffer("gGlobalPoseTransforms", mGlobalPoseTransformsCB);
+					inComputeQueue->SetBuffer("gCollisionShapes", mCollisionShapesCB);
+					inComputeQueue->SetBuffer("gCollisionPlanes", mCollisionPlanesCB);
+					inComputeQueue->SetRWBuffer("gPreviousPositions", mPreviousPositionsCB);
+					inComputeQueue->SetRWBuffer("gPositions", mPositionsCB);
+					inComputeQueue->Dispatch(dispatch_per_vertex_skip_first_vertex);
+				}
+			}
+		}
+
+		// Remap simulation positions to render positions
+		inComputeQueue->SetShader(inShaders.mCalculateRenderPositionsCS);
+		inComputeQueue->SetConstantBuffer("gContext", mConstantsCB);
+		inComputeQueue->SetBuffer("gSVertexInfluences", mSettings->mSVertexInfluencesCB);
+		inComputeQueue->SetBuffer("gPositions", mPositionsCB);
+		inComputeQueue->SetRWBuffer("gRenderPositions", mRenderPositionsCB);
+		inComputeQueue->Dispatch(dispatch_per_render_vertex);
+	}
+}
+
+void Hair::ReadBackGPUState(ComputeQueue *inComputeQueue)
+{
+	if (mPositionsReadBackCB == nullptr)
+	{
+		// Create read back buffers
+		if (mScalpVerticesCB != nullptr)
+			mScalpVerticesReadBackCB = mScalpVerticesCB->CreateReadBackBuffer().Get();
+		mPositionsReadBackCB = mPositionsCB->CreateReadBackBuffer().Get();
+		mVelocitiesReadBackCB = mVelocitiesCB->CreateReadBackBuffer().Get();
+		mVelocityAndDensityReadBackCB = mVelocityAndDensityCB->CreateReadBackBuffer().Get();
+		mRenderPositionsReadBackCB = mRenderPositionsCB->CreateReadBackBuffer().Get();
+	}
+
+	{
+		JPH_PROFILE("Transfer data from GPU");
+
+		// Read back the skinned vertices
+		if (mScalpVerticesCB != nullptr)
+			inComputeQueue->ScheduleReadback(mScalpVerticesReadBackCB, mScalpVerticesCB);
+
+		// Read back the vertices
+		inComputeQueue->ScheduleReadback(mPositionsReadBackCB, mPositionsCB);
+		inComputeQueue->ScheduleReadback(mVelocitiesReadBackCB, mVelocitiesCB);
+		inComputeQueue->ScheduleReadback(mRenderPositionsReadBackCB, mRenderPositionsCB);
+
+		// Read back the velocity and density
+		inComputeQueue->ScheduleReadback(mVelocityAndDensityReadBackCB, mVelocityAndDensityCB);
+
+		// Wait for the compute queue to finish
+		inComputeQueue->ExecuteAndWait();
+	}
+
+	{
+		JPH_PROFILE("Reorder hair data");
+
+		// Reorder position and velocity data
+		const JPH_HairPosition *positions = mPositionsReadBackCB->Map<JPH_HairPosition>(ComputeBuffer::EMode::Read);
+		const JPH_HairVelocity *velocities = mVelocitiesReadBackCB->Map<JPH_HairVelocity>(ComputeBuffer::EMode::Read);
+		size_t num_vertices = mSettings->mSimVertices.size();
+		if (mPositions == nullptr)
+			mPositions = new Float3 [num_vertices];
+		if (mRotations == nullptr)
+			mRotations = new Quat [num_vertices];
+		if (mVelocities == nullptr)
+			mVelocities = new JPH_HairVelocity [num_vertices];
+		uint32 num_strands = (uint32)mSettings->mSimStrands.size();
+		for (uint32 s = 0; s < num_strands; ++s)
+		{
+			const HairSettings::SStrand &strand = mSettings->mSimStrands[s];
+			for (uint32 v = 0; v < strand.VertexCount(); ++v)
+			{
+				uint32 in_index = s + v * num_strands;
+				uint32 out_index = strand.mStartVtx + v;
+				mPositions[out_index] = Float3(positions[in_index].mPosition);
+				mRotations[out_index] = Quat(positions[in_index].mRotation);
+				mVelocities[out_index] = velocities[in_index];
+			}
+		}
+		mPositionsReadBackCB->Unmap();
+		mVelocitiesReadBackCB->Unmap();
+	}
+}
+
+void Hair::LockReadBackBuffers()
+{
+	if (mScalpVerticesReadBackCB != nullptr)
+		mScalpVertices = mScalpVerticesReadBackCB->Map<Float3>(ComputeBuffer::EMode::Read);
+	mVelocityAndDensity = mVelocityAndDensityReadBackCB->Map<Float4>(ComputeBuffer::EMode::Read);
+	if (mRenderPositionsOverridden)
+	{
+		uint num_render_vertices = (uint)mSettings->mRenderVertices.size();
+		if (mRenderPositions == nullptr)
+			mRenderPositions = new Float3 [num_render_vertices];
+		mRenderPositionsToFloat3(mRenderPositionsReadBackCB, const_cast<Float3 *>(mRenderPositions), num_render_vertices);
+	}
+	else
+		mRenderPositions = mRenderPositionsReadBackCB->Map<Float3>(ComputeBuffer::EMode::Read);
+}
+
+void Hair::UnlockReadBackBuffers()
+{
+	if (mScalpVerticesReadBackCB != nullptr)
+		mScalpVerticesReadBackCB->Unmap();
+	mVelocityAndDensityReadBackCB->Unmap();
+	if (!mRenderPositionsOverridden)
+		mRenderPositionsReadBackCB->Unmap();
+}
+
+#ifdef JPH_DEBUG_RENDERER
+
+void Hair::Draw(const DrawSettings &inSettings, DebugRenderer *inRenderer)
+{
+	LockReadBackBuffers();
+
+	const Float3 *positions = GetPositions();
+	const Float3 *render_positions = GetRenderPositions();
+	const Quat *rotations = GetRotations();
+	StridedPtr<const Float3> velocities = GetVelocities();
+	StridedPtr<const Float3> angular_velocities = GetAngularVelocities();
+	const Float4 *grid_velocity_and_density = GetGridVelocityAndDensity();
+	const Float3 *scalp_vertices = GetScalpVertices();
+
+	float arrow_size = 0.01f * mSettings->mSimulationBounds.GetSize().ReduceMin();
+	RMat44 com = GetWorldTransform();
+
+	// Draw the render strands
+	if (inSettings.mDrawRenderStrands)
+	{
+		JPH_PROFILE("Draw Render Strands");
+
+		// Calculate a map of sim vertex index to strand index
+		Array<uint> sim_vertex_to_strand;
+		sim_vertex_to_strand.resize(mSettings->mSimVertices.size(), 0);
+		for (uint i = 0, n = (uint)mSettings->mSimStrands.size(); i < n; ++i)
+		{
+			const HairSettings::SStrand &strand = mSettings->mSimStrands[i];
+			for (uint v = strand.mStartVtx; v < strand.mEndVtx; ++v)
+				sim_vertex_to_strand[v] = i;
+		}
+
+		Hash<uint32> hasher;
+		switch (inSettings.mRenderStrandColor)
+		{
+		case ERenderStrandColor::PerRenderStrand:
+			{
+				Color color = Color::sGreen;
+				for (const HairSettings::RStrand &strand : mSettings->mRenderStrands)
+				{
+					uint32 strand_idx = sim_vertex_to_strand[mSettings->mRenderVertices[strand.mStartVtx].mInfluences[0].mVertexIndex];
+					if (strand_idx >= inSettings.mSimulationStrandBegin && strand_idx < inSettings.mSimulationStrandEnd)
+					{
+						RVec3 x0 = com * Vec3(render_positions[strand.mStartVtx]);
+						for (uint32 v = strand.mStartVtx + 1; v < strand.mEndVtx; ++v)
+						{
+							RVec3 x1 = com * Vec3(render_positions[v]);
+							inRenderer->DrawLine(x0, x1, color);
+							x0 = x1;
+						}
+						color = Color(uint32(hasher(color.GetUInt32())) | 0xff000000);
+					}
+				}
+			}
+			break;
+
+		case ERenderStrandColor::PerSimulatedStrand:
+			for (const HairSettings::RStrand &strand : mSettings->mRenderStrands)
+			{
+				uint32 strand_idx = sim_vertex_to_strand[mSettings->mRenderVertices[strand.mStartVtx].mInfluences[0].mVertexIndex];
+				if (strand_idx >= inSettings.mSimulationStrandBegin && strand_idx < inSettings.mSimulationStrandEnd)
+				{
+					Color color = Color(uint32(hasher(strand_idx)) | 0xff000000);
+					RVec3 x0 = com * Vec3(render_positions[strand.mStartVtx]);
+					for (uint32 v = strand.mStartVtx + 1; v < strand.mEndVtx; ++v)
+					{
+						RVec3 x1 = com * Vec3(render_positions[v]);
+						inRenderer->DrawLine(x0, x1, color);
+						x0 = x1;
+					}
+				}
+			}
+			break;
+
+		case ERenderStrandColor::GravityFactor:
+		case ERenderStrandColor::WorldTransformInfluence:
+		case ERenderStrandColor::GridVelocityFactor:
+		case ERenderStrandColor::GlobalPose:
+		case ERenderStrandColor::SkinGlobalPose:
+			for (const HairSettings::RStrand &strand : mSettings->mRenderStrands)
+			{
+				uint32 strand_idx = sim_vertex_to_strand[mSettings->mRenderVertices[strand.mStartVtx].mInfluences[0].mVertexIndex];
+				const HairSettings::Material &material = mSettings->mMaterials[mSettings->mSimStrands[strand_idx].mMaterialIndex];
+
+				// Prepare sampler
+				GradientSampler sampler;
+				if (inSettings.mRenderStrandColor == ERenderStrandColor::GravityFactor)
+					sampler = GradientSampler(material.mGravityFactor);
+				else if (inSettings.mRenderStrandColor == ERenderStrandColor::WorldTransformInfluence)
+					sampler = GradientSampler(material.mWorldTransformInfluence);
+				else if (inSettings.mRenderStrandColor == ERenderStrandColor::GridVelocityFactor)
+					sampler = GradientSampler(material.mGridVelocityFactor);
+				else if (inSettings.mRenderStrandColor == ERenderStrandColor::GlobalPose)
+					sampler = GradientSampler(material.mGlobalPose);
+				else
+					sampler = GradientSampler(material.mSkinGlobalPose);
+
+				if (strand_idx >= inSettings.mSimulationStrandBegin && strand_idx < inSettings.mSimulationStrandEnd)
+				{
+					RVec3 x0 = com * Vec3(render_positions[strand.mStartVtx]);
+					for (uint32 v = strand.mStartVtx + 1; v < strand.mEndVtx; ++v)
+					{
+						RVec3 x1 = com * Vec3(render_positions[v]);
+						uint32 simulated_vtx = mSettings->mRenderVertices[v].mInfluences[0].mVertexIndex;
+						float factor = sampler.Sample(mSettings->mSimVertices[simulated_vtx].mStrandFraction);
+						inRenderer->DrawLine(x0, x1, Color::sGreenRedGradient(factor));
+						x0 = x1;
+					}
+				}
+			}
+			break;
+		}
+	}
+
+	// Draw the rods
+	if (inSettings.mDrawRods)
+	{
+		JPH_PROFILE("Draw Rods");
+
+		Color color = Color::sRed;
+		Hash<uint32> hasher;
+		for (uint i = 0, n = (uint)mSettings->mSimStrands.size(); i < n; ++i)
+			if (i >= inSettings.mSimulationStrandBegin && i < inSettings.mSimulationStrandEnd)
+			{
+				const HairSettings::SStrand &strand = mSettings->mSimStrands[i];
+				RVec3 x0 = com * Vec3(positions[strand.mStartVtx]);
+				for (uint32 v = strand.mStartVtx + 1; v < strand.mEndVtx; ++v)
+				{
+					RVec3 x1 = com * Vec3(positions[v]);
+					inRenderer->DrawLine(x0, x1, color);
+					x0 = x1;
+				}
+				color = Color(uint32(hasher(color.GetUInt32())) | 0xff000000);
+			}
+	}
+
+	// Draw the rods in their unloaded pose
+	if (inSettings.mDrawUnloadedRods)
+	{
+		JPH_PROFILE("Draw Unloaded Rods");
+
+		Color color = Color::sYellow;
+		Hash<uint32> hasher;
+		for (uint i = 0, n = (uint)mSettings->mSimStrands.size(); i < n; ++i)
+			if (i >= inSettings.mSimulationStrandBegin && i < inSettings.mSimulationStrandEnd)
+			{
+				const HairSettings::SStrand &strand = mSettings->mSimStrands[i];
+				RVec3 x0 = com * Vec3(positions[strand.mStartVtx]);
+				Quat rotation = mRotation * rotations[strand.mStartVtx];
+				for (uint32 v = strand.mStartVtx + 1; v < strand.mEndVtx; ++v)
+				{
+					RVec3 x1 = x0 + rotation.RotateAxisZ() * mSettings->mSimVertices[v - 1].mLength;
+					inRenderer->DrawLine(x0, x1, color);
+					rotation = (rotation * Quat(mSettings->mSimVertices[v].mOmega0)).Normalized();
+					x0 = x1;
+				}
+				color = Color(uint32(hasher(color.GetUInt32())) | 0xff000000);
+			}
+	}
+
+	// Draw vertex velocities
+	if (inSettings.mDrawVertexVelocity)
+		for (uint i = 0, n = (uint)mSettings->mSimStrands.size(); i < n; ++i)
+			if (i >= inSettings.mSimulationStrandBegin && i < inSettings.mSimulationStrandEnd)
+			{
+				const HairSettings::SStrand &strand = mSettings->mSimStrands[i];
+				for (uint32 v = strand.mStartVtx; v < strand.mEndVtx; ++v)
+				{
+					Vec3 velocity(velocities[v]);
+					if (velocity.LengthSq() > 1.0e-6f)
+					{
+						Vec3 pos = Vec3(positions[v]);
+						inRenderer->DrawArrow(com * pos, com * (pos + velocity), Color::sGreen, arrow_size);
+					}
+				}
+			}
+
+	// Draw angular velocities
+	if (inSettings.mDrawAngularVelocity)
+		for (uint i = 0, n = (uint)mSettings->mSimStrands.size(); i < n; ++i)
+			if (i >= inSettings.mSimulationStrandBegin && i < inSettings.mSimulationStrandEnd)
+			{
+				const HairSettings::SStrand &strand = mSettings->mSimStrands[i];
+				for (uint32 v = strand.mStartVtx; v < strand.mEndVtx; ++v)
+				{
+					Vec3 angular_velocity(angular_velocities[v]);
+					if (angular_velocity.LengthSq() > 1.0e-6f)
+					{
+						Vec3 pos = Vec3(positions[v]);
+						inRenderer->DrawArrow(com * pos, com * (pos + 0.1f * angular_velocity), Color::sOrange, arrow_size);
+					}
+				}
+			}
+
+	// Draw rod orientations
+	if (inSettings.mDrawOrientations)
+		for (uint i = 0, n = (uint)mSettings->mSimStrands.size(); i < n; ++i)
+			if (i >= inSettings.mSimulationStrandBegin && i < inSettings.mSimulationStrandEnd)
+			{
+				const HairSettings::SStrand &strand = mSettings->mSimStrands[i];
+				for (uint32 v = strand.mStartVtx; v < strand.mEndVtx; ++v)
+					inRenderer->DrawCoordinateSystem(com * Mat44::sRotationTranslation(rotations[v], Vec3(positions[v])), arrow_size);
+			}
+
+	// Draw grid bounds
+	if (inSettings.mDrawNeutralDensity || inSettings.mDrawGridDensity || inSettings.mDrawGridVelocity)
+		inRenderer->DrawWireBox(com, mSettings->mSimulationBounds, Color::sGrey);
+
+	// Draw neutral density
+	if (inSettings.mDrawNeutralDensity)
+	{
+		Vec3 offset = mSettings->mSimulationBounds.mMin;
+		Vec3 scale = mSettings->mSimulationBounds.GetSize() / Vec3(mSettings->mGridSize.ToFloat());
+		float marker_size = 0.5f * scale.ReduceMin();
+		for (uint32 z = 0; z < mSettings->mGridSize.GetX(); ++z)
+			for (uint32 y = 0; y < mSettings->mGridSize.GetY(); ++y)
+				for (uint32 x = 0; x < mSettings->mGridSize.GetZ(); ++x)
+				{
+					float density = mSettings->GetNeutralDensity(x, y, z);
+					JPH_ASSERT(density >= 0.0f);
+					if (density > 0.0f)
+					{
+						Vec3 pos = offset + Vec3(UVec4(x, y, z, 0).ToFloat()) * scale;
+						inRenderer->DrawMarker(com * pos, Color::sGreenRedGradient(density * mSettings->mDensityScale), marker_size);
+					}
+				}
+	}
+
+	// Draw current density
+	if (inSettings.mDrawGridDensity || inSettings.mDrawGridVelocity)
+	{
+		Vec3 offset = mSettings->mSimulationBounds.mMin;
+		Vec3 scale = mSettings->mSimulationBounds.GetSize() / Vec3(mSettings->mGridSize.ToFloat());
+		float marker_size = 0.5f * scale.ReduceMin();
+		for (uint32 z = 0; z < mSettings->mGridSize.GetX(); ++z)
+			for (uint32 y = 0; y < mSettings->mGridSize.GetY(); ++y)
+				for (uint32 x = 0; x < mSettings->mGridSize.GetZ(); ++x)
+				{
+					const Float4 &velocity_and_density = grid_velocity_and_density[x + y * mSettings->mGridSize.GetX() + z * mSettings->mGridSize.GetX() * mSettings->mGridSize.GetY()];
+					float density = velocity_and_density.w;
+					Vec3 velocity = Vec3::sLoadFloat3Unsafe((const Float3 &)velocity_and_density);
+					if (density > 0.0f)
+					{
+						RVec3 pos = com * (offset + Vec3(UVec4(x, y, z, 0).ToFloat()) * scale);
+						if (inSettings.mDrawGridDensity)
+							inRenderer->DrawMarker(pos, Color::sGreenRedGradient(density * mSettings->mDensityScale), marker_size);
+						if (inSettings.mDrawGridVelocity && velocity.LengthSq() > 1.0e-6f)
+							inRenderer->DrawArrow(pos, pos + com.Multiply3x3(velocity), Color::sYellow, arrow_size);
+					}
+				}
+	}
+
+	if (inSettings.mDrawSkinPoints)
+		for (uint i = 0, n = (uint)mSettings->mSkinPoints.size(); i < n; ++i)
+			if (i >= inSettings.mSimulationStrandBegin && i < inSettings.mSimulationStrandEnd)
+			{
+				const HairSettings::SkinPoint &sp = mSettings->mSkinPoints[i];
+				const IndexedTriangleNoMaterial &tri = mSettings->mScalpTriangles[sp.mTriangleIndex];
+				RVec3 v0 = com * Vec3(scalp_vertices[tri.mIdx[0]]);
+				RVec3 v1 = com * Vec3(scalp_vertices[tri.mIdx[1]]);
+				RVec3 v2 = com * Vec3(scalp_vertices[tri.mIdx[2]]);
+				inRenderer->DrawWireTriangle(v0, v1, v2, Color::sYellow);
+
+				RVec3 point = Real(sp.mU) * v0 + Real(sp.mV) * v1 + Real(1.0f - sp.mU - sp.mV) * v2;
+				Vec3 tangent = Vec3(v1 - v0).Normalized();
+				Vec3 normal = tangent.Cross(Vec3(v2 - v0)).Normalized();
+				Vec3 binormal = tangent.Cross(normal);
+				RMat44 basis(Vec4(normal, 0), Vec4(binormal, 0), Vec4(tangent, 0), point);
+				inRenderer->DrawCoordinateSystem(basis, 0.01f);
+			}
+
+	// Draw initial gravity
+	if (inSettings.mDrawInitialGravity)
+		inRenderer->DrawArrow(com.GetTranslation(), com * mSettings->mInitialGravity, Color::sBlue, 0.05f * mSettings->mInitialGravity.Length());
+
+	UnlockReadBackBuffers();
+}
+
+#endif // JPH_DEBUG_RENDERER
+
+JPH_NAMESPACE_END

+ 227 - 0
Jolt/Physics/Hair/Hair.h

@@ -0,0 +1,227 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Physics/Hair/HairSettings.h>
+#include <Jolt/Physics/Collision/ObjectLayer.h>
+#include <Jolt/Physics/Collision/Shape/Shape.h>
+#include <Jolt/Core/StridedPtr.h>
+#include <Jolt/Core/NonCopyable.h>
+
+JPH_NAMESPACE_BEGIN
+
+class PhysicsSystem;
+#ifdef JPH_DEBUG_RENDERER
+class DebugRenderer;
+#endif
+class HairShaders;
+
+/// Hair simulation instance
+///
+/// Note that this system is currently still in development, it is missing important features like:
+///
+/// - Level of detail
+/// - Wind forces
+/// - Advection step for the grid velocity field
+/// - Support for collision detection against shapes other than ConvexHullShape
+/// - The Gradient class is very limited and will be replaced by a texture lookup
+/// - Gravity preload factor is not fully functioning yet
+/// - It is wasteful of memory (e.g. stores everything both on CPU and GPU)
+/// - Only supports a single neutral pose to drive towards
+/// - It could use further optimizations
+class JPH_EXPORT Hair : public NonCopyable
+{
+public:
+	/// Constructor / destructor
+										Hair(const HairSettings *inSettings, RVec3Arg inPosition, QuatArg inRotation, ObjectLayer inLayer);
+										~Hair();
+
+	/// Initialize
+	void								Init(ComputeSystem *inComputeSystem);
+
+	/// Position and rotation of the hair in world space
+	void								SetPosition(RVec3Arg inPosition)				{ mPosition = inPosition; }
+	void								SetRotation(QuatArg inRotation)					{ mRotation = inRotation; }
+	RMat44								GetWorldTransform() const						{ return RMat44::sRotationTranslation(mRotation, mPosition); }
+
+	/// The hair will be initialized in its default pose with zero velocity at the new position and rotation during the next update
+	void								OnTeleported()									{ mTeleported = true; }
+
+	/// Ability to externally provide the scalp vertices buffer. This allows skipping skinning the scalp during the simulation update. You may need to override JPH_SHADER_BIND_SCALP_VERTICES in HairSkinRootsBindings.h to match the format of the provided buffer.
+	void								SetScalpVerticesCB(ComputeBuffer *inBuffer)		{ mScalpVerticesCB = inBuffer; }
+
+	/// Ability to externally provide the scalp triangle indices buffer. This allows skipping skinning the scalp in during the simulation update. You may need to override JPH_SHADER_BIND_SCALP_TRIANGLES in HairSkinRootsBindings.h to match the format of the provided buffer.
+	void								SetScalpTrianglesCB(ComputeBuffer *inBuffer)	{ mScalpTrianglesCB = inBuffer; }
+
+	/// When skipping skinning, this allow specifying a transform that transforms the scalp mesh into head space.
+	void								SetScalpToHead(Mat44Arg inMat)					{ mScalpToHead = inMat; }
+
+	/// Function that converts the render positions buffer to Float3 vertices for debugging purposes. It maps an application defined format to Float3. Third parameter is the number of vertices.
+	using RenderPositionsToFloat3 = std::function<void(ComputeBuffer *, Float3 *, uint)>;
+
+	/// Enable externally set render vertices buffer (with potentially different vertex layout). Note that this also requires replacing the HairCalculateRenderPositions shader.
+	void								OverrideRenderPositionsCB(const RenderPositionsToFloat3 &inRenderPositionsToFloat3) { JPH_ASSERT(mRenderPositionsCB == nullptr, "Must be called before Init"); mRenderPositionsOverridden = true; mRenderPositionsToFloat3 = inRenderPositionsToFloat3; }
+
+	/// Allow setting the render vertices buffer externally in case it has special requirements for the calling application. You may need to override JPH_SHADER_BIND_RENDER_POSITIONS in HairCalculateRenderPositionsBindings.h to match the format of the provided buffer.
+	void								SetRenderPositionsCB(ComputeBuffer *inBuffer)	{ JPH_ASSERT(mRenderPositionsOverridden, "Must call OverrideRenderPositionsCB first"); mRenderPositionsCB = inBuffer; }
+
+	/// Step the hair simulation forward in time
+	/// @param inDeltaTime Time step
+	/// @param inJointToHair Transform that transforms from joint space to hair local space (as defined by GetWorldTransform)
+	/// @param inJointMatrices Array of joint matrices in world space, length needs to match HairSettings::mScalpInverseBindPose.size()
+	/// @param inSystem Physics system used for collision detection
+	/// @param inShaders Preloaded hair compute shaders
+	/// @param inComputeSystem Compute system to use
+	/// @param inComputeQueue Compute queue to use
+	void								Update(float inDeltaTime, Mat44Arg inJointToHair, const Mat44 *inJointMatrices, const PhysicsSystem &inSystem, const HairShaders &inShaders, ComputeSystem *inComputeSystem, ComputeQueue *inComputeQueue);
+
+	/// Access to the resulting simulation data
+	ComputeBuffer *						GetScalpVerticesCB() const						{ return mScalpVerticesCB; }		///< Skinned scalp vertices
+	ComputeBuffer *						GetScalpTrianglesCB() const						{ return mScalpTrianglesCB; }		///< Skinned scalp triangle indices
+	ComputeBuffer *						GetPositionsCB() const							{ return mPositionsCB; }			///< Note transposed for better memory access
+	ComputeBuffer *						GetVelocitiesCB() const							{ return mVelocitiesCB; }			///< Note transposed for better memory access
+	ComputeBuffer *						GetVelocityAndDensityCB() const					{ return mVelocityAndDensityCB; }	///< Velocity grid
+	ComputeBuffer *						GetRenderPositionsCB() const					{ return mRenderPositionsCB; }		///< Render positions of the hair strands (see HairSettings::mRenderStrands to see where each strand starts and ends)
+
+	/// Read back the GPU state so that the functions below can be used. For debugging purposes only, this is slow!
+	void								ReadBackGPUState(ComputeQueue *inComputeQueue);
+
+	/// Lock/unlock the data buffers so that the functions below return valid values.
+	void								LockReadBackBuffers();
+	void								UnlockReadBackBuffers();
+
+	/// Access to the resulting simulation data (only valid when ReadBackGPUState has been called and the buffers have been locked)
+	const Float3 *						GetScalpVertices() const						{ return mScalpVertices; }
+	const Float3 *						GetPositions() const							{ return mPositions; }
+	const Quat *						GetRotations() const							{ return mRotations; }
+	StridedPtr<const Float3>			GetVelocities() const							{ return { &mVelocities->mVelocity, sizeof(JPH_HairVelocity) }; }
+	StridedPtr<const Float3>			GetAngularVelocities() const					{ return { &mVelocities->mAngularVelocity, sizeof(JPH_HairVelocity) }; }
+	const Float4 *						GetGridVelocityAndDensity() const				{ return mVelocityAndDensity; }
+	const Float3 *						GetRenderPositions() const						{ return mRenderPositions; }
+
+#ifdef JPH_DEBUG_RENDERER
+	enum class ERenderStrandColor
+	{
+		PerRenderStrand,
+		PerSimulatedStrand,
+		GravityFactor,
+		WorldTransformInfluence,
+		GridVelocityFactor,
+		GlobalPose,
+		SkinGlobalPose,
+	};
+
+	struct DrawSettings
+	{
+		/// This specifies the range of simulation strands to draw, when drawing render strands we only draw the strands that belong to these simulation strands.
+		uint							mSimulationStrandBegin = 0;
+		uint							mSimulationStrandEnd = UINT_MAX;
+
+		bool							mDrawRods = true;								///< Draws the simulated rods
+		bool							mDrawUnloadedRods = false;						///< Draw rods in their unloaded pose. This pose is obtained by removing gravity influence from the modeled pose.
+		bool							mDrawVertexVelocity = false;					///< Draws the velocity at each simulated vertex as an arrow
+		bool							mDrawAngularVelocity = false;					///< Draws the angular velocity at each simulated vertex as an arrow
+		bool							mDrawOrientations = false;						///< Draws a coordinate space for each simulated vertex
+		bool							mDrawNeutralDensity = false;					///< Draws grid density of the hair in its neutral pose
+		bool							mDrawGridDensity = false;						///< Draws the current grid density of the hair
+		bool							mDrawGridVelocity = false;						///< Draws the velocity of each grid cell as an arrow
+		bool							mDrawSkinPoints = false;						///< Draws the skinning points on the scalp
+		bool 							mDrawRenderStrands = false;						///< Draws the render strands (slow, for debugging purposes!)
+		bool 							mDrawInitialGravity = true;						///< Draws the configured initial gravity vector used to calculate the unloaded vertex positions
+		ERenderStrandColor				mRenderStrandColor = ERenderStrandColor::PerSimulatedStrand; ///< Color for each strand
+	};
+
+	/// Debug functionality to draw the hair and its simulation properties
+	void								Draw(const DrawSettings &inSettings, DebugRenderer *inRenderer);
+#endif // JPH_DEBUG_RENDERER
+
+protected:
+	using Gradient = HairSettings::Gradient;
+	using GradientSampler = HairSettings::GradientSampler;
+
+	// Information about a colliding shape. Is always a leaf shape, compound shapes are expanded.
+	struct LeafShape
+	{
+										LeafShape() = default;
+										LeafShape(Mat44Arg inTransform, Vec3Arg inScale, Vec3Arg inLinearVelocity, Vec3Arg inAngularVelocity, const Shape *inShape) : mTransform(inTransform), mScale(inScale), mLinearVelocity(inLinearVelocity), mAngularVelocity(inAngularVelocity), mShape(inShape) { }
+
+		Mat44							mTransform;
+		Vec3							mScale;
+		Vec3							mLinearVelocity;
+		Vec3							mAngularVelocity;
+		RefConst<Shape>					mShape;
+	};
+
+	// Internal context used during a simulation step
+	struct UpdateContext
+	{
+		Mat44							mDeltaTransform;								// Transforms positions from the old hair transform to the new
+		Quat							mDeltaTransformQuat;							// Rotation part of mDeltaTransform
+		uint							mNumIterations;									// Number of iterations to run the solver for
+		bool							mNeedsCollision;								// If collision detection should be performed
+		bool							mNeedsGrid;										// If the grid should be calculated
+		bool							mGlobalPoseOnly;								// If no simulation is needed and only the global pose needs to be applied
+		bool							mHasTransformChanged;							// If the world transform has changed
+		float							mDeltaTime;										// Delta time for a sub step
+		float							mHalfDeltaTime;									// 0.5 * mDeltaTime
+		float							mInvDeltaTimeSq;								// 1 / mDeltaTime^2
+		float							mTwoDivDeltaTime;								// 2 / mDeltaTime
+		float							mTimeRatio;										// Ratio between sub step delta time and default sub step delta time
+		Vec3							mSubStepGravity;								// Gravity to apply in a sub step
+		Array<LeafShape>				mShapes;										// List of colliding shapes
+	};
+
+	// Calculate the UpdateContext parameters
+	void								InitializeContext(UpdateContext &outCtx, float inDeltaTime, const PhysicsSystem &inSystem);
+
+	RefConst<HairSettings>				mSettings;										// Shared hair settings, must be kept alive during the lifetime of this hair instance
+
+	RVec3								mPrevPosition;									// Position at the start of the last time step
+	RVec3								mPosition;										// Current position in world space
+	Quat								mPrevRotation;									// Rotation at the start of the last time step
+	Quat								mRotation;										// Current rotation in world space
+	bool								mTeleported = true;								// If the hair got teleported and should be set to the default pose
+	ObjectLayer							mLayer;											// Layer for the hair to collide with
+
+	Mat44								mScalpToHead = Mat44::sIdentity();				// When skipping skinning, this allow specifying a transform that transforms the scalp mesh into head space
+
+	bool								mRenderPositionsOverridden = false;				// Indicates that the render positions buffer is provided externally
+	RenderPositionsToFloat3				mRenderPositionsToFloat3;						// Function that transforms the render positions buffer to Float3 vertices for debugging purposes
+
+	Ref<ComputeBuffer>					mScalpJointMatricesCB;
+	Ref<ComputeBuffer>					mScalpVerticesCB;
+	Ref<ComputeBuffer>					mScalpTrianglesCB;
+	Ref<ComputeBuffer>					mTargetPositionsCB;								// Target root positions determined by skinning (where we're interpolating to, eventually written to mPositionsCB)
+	Ref<ComputeBuffer>					mTargetGlobalPoseTransformsCB;					// Target global pose transforms determined by skinning (where we're interpolating to, eventually written to mGlobalPoseTransformsCB)
+	Ref<ComputeBuffer>					mGlobalPoseTransformsCB;						// Current global pose transforms used for skinning the hairs
+	Ref<ComputeBuffer>					mShapePlanesCB;
+	Ref<ComputeBuffer>					mShapeVerticesCB;
+	Ref<ComputeBuffer>					mShapeIndicesCB;
+	Ref<ComputeBuffer>					mCollisionPlanesCB;
+	Ref<ComputeBuffer>					mCollisionShapesCB;
+	Ref<ComputeBuffer>					mMaterialsCB;
+	Ref<ComputeBuffer>					mPreviousPositionsCB;
+	Ref<ComputeBuffer>					mPositionsCB;
+	Ref<ComputeBuffer>					mVelocitiesCB;
+	Ref<ComputeBuffer>					mVelocityAndDensityCB;
+	Ref<ComputeBuffer>					mConstantsCB;
+	Array<Ref<ComputeBuffer>>			mIterationConstantsCB;
+	Ref<ComputeBuffer>					mRenderPositionsCB;
+
+	// Only valid after ReadBackGPUState has been called
+	Ref<ComputeBuffer>					mScalpVerticesReadBackCB;
+	Ref<ComputeBuffer>					mPositionsReadBackCB;
+	Ref<ComputeBuffer>					mVelocitiesReadBackCB;
+	Ref<ComputeBuffer>					mVelocityAndDensityReadBackCB;
+	Ref<ComputeBuffer>					mRenderPositionsReadBackCB;
+	const Float3 *						mScalpVertices = nullptr;
+	Float3 *							mPositions = nullptr;
+	Quat *								mRotations = nullptr;
+	JPH_HairVelocity *					mVelocities = nullptr;
+	const Float4 *						mVelocityAndDensity = nullptr;
+	const Float3 *						mRenderPositions = nullptr;
+};
+
+JPH_NAMESPACE_END

+ 869 - 0
Jolt/Physics/Hair/HairSettings.cpp

@@ -0,0 +1,869 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#include <Jolt/Physics/Hair/HairSettings.h>
+#include <Jolt/ObjectStream/TypeDeclarations.h>
+#include <Jolt/Geometry/ClosestPoint.h>
+#include <Jolt/TriangleSplitter/TriangleSplitterBinning.h>
+#include <Jolt/AABBTree/AABBTreeBuilder.h>
+#include <Jolt/Core/QuickSort.h>
+
+JPH_NAMESPACE_BEGIN
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings, mSimVertices)
+	JPH_ADD_ATTRIBUTE(HairSettings, mSimStrands)
+	JPH_ADD_ATTRIBUTE(HairSettings, mRenderVertices)
+	JPH_ADD_ATTRIBUTE(HairSettings, mRenderStrands)
+	JPH_ADD_ATTRIBUTE(HairSettings, mScalpVertices)
+	JPH_ADD_ATTRIBUTE(HairSettings, mScalpTriangles)
+	JPH_ADD_ATTRIBUTE(HairSettings, mScalpInverseBindPose)
+	JPH_ADD_ATTRIBUTE(HairSettings, mScalpSkinWeights)
+	JPH_ADD_ATTRIBUTE(HairSettings, mScalpNumSkinWeightsPerVertex)
+	JPH_ADD_ATTRIBUTE(HairSettings, mNumIterationsPerSecond)
+	JPH_ADD_ATTRIBUTE(HairSettings, mMaxDeltaTime)
+	JPH_ADD_ATTRIBUTE(HairSettings, mGridSize)
+	JPH_ADD_ATTRIBUTE(HairSettings, mSimulationBoundsPadding)
+	JPH_ADD_ATTRIBUTE(HairSettings, mInitialGravity)
+	JPH_ADD_ATTRIBUTE(HairSettings, mMaterials)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::SkinWeight)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings::SkinWeight, mJointIdx)
+	JPH_ADD_ATTRIBUTE(HairSettings::SkinWeight, mWeight)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::SkinPoint)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings::SkinPoint, mTriangleIndex)
+	JPH_ADD_ATTRIBUTE(HairSettings::SkinPoint, mU)
+	JPH_ADD_ATTRIBUTE(HairSettings::SkinPoint, mV)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::SVertexInfluence)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings::SVertexInfluence, mVertexIndex)
+	JPH_ADD_ATTRIBUTE(HairSettings::SVertexInfluence, mRelativePosition)
+	JPH_ADD_ATTRIBUTE(HairSettings::SVertexInfluence, mWeight)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::RVertex)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings::RVertex, mPosition)
+	JPH_ADD_ATTRIBUTE(HairSettings::RVertex, mInfluences)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::SVertex)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings::SVertex, mPosition)
+	JPH_ADD_ATTRIBUTE(HairSettings::SVertex, mInvMass)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::RStrand)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings::RStrand, mStartVtx)
+	JPH_ADD_ATTRIBUTE(HairSettings::RStrand, mEndVtx)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::SStrand)
+{
+	JPH_ADD_BASE_CLASS(HairSettings::SStrand, HairSettings::RStrand)
+
+		JPH_ADD_ATTRIBUTE(HairSettings::SStrand, mMaterialIndex)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::Gradient)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings::Gradient, mMin)
+	JPH_ADD_ATTRIBUTE(HairSettings::Gradient, mMax)
+	JPH_ADD_ATTRIBUTE(HairSettings::Gradient, mMinFraction)
+	JPH_ADD_ATTRIBUTE(HairSettings::Gradient, mMaxFraction)
+}
+
+JPH_IMPLEMENT_SERIALIZABLE_NON_VIRTUAL(HairSettings::Material)
+{
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mEnableCollision)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mEnableLRA)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mLinearDamping)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mAngularDamping)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mMaxLinearVelocity)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mMaxAngularVelocity)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mGravityFactor)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mFriction)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mBendCompliance)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mBendComplianceMultiplier)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mStretchCompliance)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mInertiaMultiplier)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mHairRadius)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mWorldTransformInfluence)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mGridVelocityFactor)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mGridDensityForceFactor)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mGlobalPose)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mSkinGlobalPose)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mSimulationStrandsFraction)
+	JPH_ADD_ATTRIBUTE(HairSettings::Material, mGravityPreloadFactor)
+}
+
+void HairSettings::Gradient::SaveBinaryState(StreamOut &inStream) const
+{
+	inStream.Write(mMin);
+	inStream.Write(mMax);
+	inStream.Write(mMinFraction);
+	inStream.Write(mMaxFraction);
+}
+
+void HairSettings::Gradient::RestoreBinaryState(StreamIn &inStream)
+{
+	inStream.Read(mMin);
+	inStream.Read(mMax);
+	inStream.Read(mMinFraction);
+	inStream.Read(mMaxFraction);
+}
+
+void HairSettings::InitRenderAndSimulationStrands(const Array<SVertex> &inVertices, const Array<SStrand> &inStrands)
+{
+	// Copy original strands to render strands
+	mRenderVertices.resize(inVertices.size());
+	for (uint i = 0, n = uint(inVertices.size()); i < n; ++i)
+		mRenderVertices[i].mPosition = inVertices[i].mPosition;
+	mRenderStrands.resize(inStrands.size());
+	for (uint i = 0, n = uint(inStrands.size()); i < n; ++i)
+		mRenderStrands[i] = RStrand(inStrands[i].mStartVtx, inStrands[i].mEndVtx);
+
+	// Create buffer that holds indices to the strands
+	Array<uint> indices_shuffle;
+	indices_shuffle.resize(inStrands.size());
+	for (uint i = 0, n = uint(inStrands.size()); i < n; ++i)
+		indices_shuffle[i] = i;
+
+	// Order on material index
+	QuickSort(indices_shuffle.begin(), indices_shuffle.end(), [&inStrands](uint inLHS, uint inRHS) {
+		return inStrands[inLHS].mMaterialIndex < inStrands[inRHS].mMaterialIndex;
+	});
+
+	// Loop over all materials
+	Array<uint>::iterator begin_material = indices_shuffle.begin();
+	while (begin_material < indices_shuffle.end())
+	{
+		uint32 material_index = inStrands[*begin_material].mMaterialIndex;
+
+		// Find end of this material
+		Array<uint>::iterator end_material = begin_material;
+		do
+			++end_material;
+		while (end_material < indices_shuffle.end() && inStrands[*end_material].mMaterialIndex == material_index);
+
+		// Select X% random strands to simulate
+		std::mt19937 random;
+		std::shuffle(begin_material, end_material, random);
+		size_t num_simulated = max<size_t>(size_t(ceil(double(mMaterials[material_index].mSimulationStrandsFraction) * double(end_material - begin_material))), 1);
+		Array<uint>::iterator end_simulation = begin_material + num_simulated;
+		QuickSort(begin_material, end_simulation, std::less<uint>()); // Sort simulated strands back to original order
+		for (Array<uint>::const_iterator idx = begin_material; idx < end_simulation; ++idx)
+		{
+			// Add simulation strand
+			const HairSettings::SStrand &sim_strand = inStrands[*idx];
+			mSimStrands.push_back(HairSettings::SStrand((uint32)mSimVertices.size(), (uint32)mSimVertices.size() + sim_strand.VertexCount(), sim_strand.mMaterialIndex));
+
+			for (uint32 v = sim_strand.mStartVtx; v < sim_strand.mEndVtx; ++v)
+			{
+				// Link render vertex to simulation vertex
+				mRenderVertices[v].mInfluences[0].mVertexIndex = uint32(mSimVertices.size());
+
+				// Add simulation vertex
+				mSimVertices.push_back(inVertices[v]);
+			}
+		}
+
+		// Get influences for remaining strands
+		for (Array<uint>::const_iterator idx = end_simulation; idx < end_material; ++idx)
+		{
+			const HairSettings::SStrand &render_strand = inStrands[*idx];
+
+			// Find closest simulation strand
+			float closest_d_sq = FLT_MAX;
+			uint closest_strand_idx = 0;
+			for (const HairSettings::SStrand &sim_strand : mSimStrands)
+				if (sim_strand.mMaterialIndex == render_strand.mMaterialIndex)
+				{
+					// Get the first 2 vertices of the simulation strand
+					uint32 v_max = sim_strand.mEndVtx - 1;
+					uint32 v = sim_strand.mStartVtx, v_next = min(v + 1, v_max);
+					Vec3 v_pos(mSimVertices[v].mPosition), v_next_pos(mSimVertices[v_next].mPosition);
+
+					// Track total error when selecting this sim strand as parent for the render strand
+					float d_sq_total = 0.0f;
+
+					// Loop over the render strand
+					for (uint32 rv = render_strand.mStartVtx; rv < render_strand.mEndVtx; ++rv)
+					{
+						Vec3 rv_pos(mRenderVertices[rv].mPosition);
+
+						// Find closest simulated vertex (note that we assume that the strands do not loop back
+						// on themselves so that an earlier vertex in the strand could be the closest)
+						float d_sq = (rv_pos - v_pos).LengthSq();
+						float d_sq_next = (rv_pos - v_next_pos).LengthSq();
+						while (d_sq_next < d_sq)
+						{
+							// Get the next vertex of the simulation strand
+							v = v_next;
+							v_next = min(v + 1, v_max);
+							v_pos = v_next_pos;
+							v_next_pos = Vec3(mSimVertices[v_next].mPosition);
+
+							// Update distance to render vertex
+							d_sq = d_sq_next;
+							d_sq_next = (rv_pos - v_next_pos).LengthSq();
+						}
+
+						// Accumulate total error
+						d_sq_total += d_sq;
+
+						// No point in continuing the search if our result is worse already
+						if (d_sq_total > closest_d_sq)
+							break;
+					}
+
+					// If this is the smallest error, accept
+					if (d_sq_total < closest_d_sq)
+					{
+						closest_d_sq = d_sq_total;
+						closest_strand_idx = uint(&sim_strand - mSimStrands.data());
+					}
+				}
+			const HairSettings::SStrand &closest_strand = mSimStrands[closest_strand_idx];
+
+			// Link render vertices to simulation vertices
+			for (uint32 v = render_strand.mStartVtx; v < render_strand.mEndVtx; ++v)
+			{
+				HairSettings::RVertex &rv = mRenderVertices[v];
+
+				// Find closest simulated vertex
+				closest_d_sq = FLT_MAX;
+				for (uint32 cv = closest_strand.mStartVtx; cv < closest_strand.mEndVtx; ++cv)
+				{
+					float d_sq = (Vec3(mSimVertices[cv].mPosition) - Vec3(rv.mPosition)).LengthSq();
+					if (d_sq < closest_d_sq)
+					{
+						closest_d_sq = d_sq;
+						rv.mInfluences[0].mVertexIndex = cv;
+					}
+				}
+			}
+		}
+
+		// Next material
+		begin_material = end_material;
+	}
+}
+
+void HairSettings::sResample(Array<SVertex> &ioVertices, Array<SStrand> &ioStrands, uint32 inNumVerticesPerStrand)
+{
+	Array<SVertex> vertices;
+	ioVertices.swap(vertices);
+	Array<SStrand> strands;
+	ioStrands.swap(strands);
+
+	for (const SStrand &strand : strands)
+	{
+		// Determine output strand
+		SStrand out_strand;
+		out_strand.mStartVtx = (uint32)ioVertices.size();
+		out_strand.mEndVtx = out_strand.mStartVtx + inNumVerticesPerStrand;
+		out_strand.mMaterialIndex = strand.mMaterialIndex;
+		ioStrands.push_back(out_strand);
+
+		// Measure length of the strand
+		float length = strand.MeasureLength(vertices);
+
+		// Add the first vertex of the strand
+		ioVertices.push_back(vertices[strand.mStartVtx]);
+
+		// Resample the strand
+		float cur_length = 0.0f;
+		const SVertex *v0 = &vertices[strand.mStartVtx];
+		const SVertex *v1 = &vertices[strand.mStartVtx + 1];
+		float segment_length = (Vec3(v1->mPosition) - Vec3(v0->mPosition)).Length();
+		for (uint32 resampled_point = 1; resampled_point < inNumVerticesPerStrand - 1; ++resampled_point)
+		{
+			float desired_len = resampled_point * length / (inNumVerticesPerStrand - 1);
+
+			while (cur_length + segment_length < desired_len)
+			{
+				cur_length += segment_length;
+				++v0;
+				++v1;
+				JPH_ASSERT(uint32(v1 - vertices.data()) < strand.mEndVtx);
+				segment_length = (Vec3(v1->mPosition) - Vec3(v0->mPosition)).Length();
+			}
+
+			SVertex out_v = *v0;
+			float fraction = (desired_len - cur_length) / segment_length;
+			(Vec3(v0->mPosition) + (Vec3(v1->mPosition) - Vec3(v0->mPosition)) * fraction).StoreFloat3(&out_v.mPosition);
+			out_v.mInvMass = v0->mInvMass + (v1->mInvMass - v0->mInvMass) * fraction < 0.5f? 0.0f : 1.0f;
+			ioVertices.push_back(out_v);
+		}
+
+		// Add the last vertex of the strand
+		ioVertices.push_back(vertices[strand.mEndVtx - 1]);
+
+		JPH_ASSERT(uint32(ioVertices.size()) == out_strand.mEndVtx);
+	}
+}
+
+static void sHairSettingsFindClosestTriangle(Vec3Arg inPoint, const AABBTreeBuilder &inBuilder, const AABBTreeBuilder::Node *inNode, Array<Float3> &inScalpVertices, float &ioClosestDistSq, HairSettings::SkinPoint &outSkinPoint)
+{
+	if (inNode->HasChildren())
+	{
+		// Get children
+		const AABBTreeBuilder::Node *child0 = inNode->GetChild(0, inBuilder.GetNodes());
+		const AABBTreeBuilder::Node *child1 = inNode->GetChild(1, inBuilder.GetNodes());
+
+		// Order so that the first one is closest
+		float dist_sq0 = child0 != nullptr? child0->mBounds.GetSqDistanceTo(inPoint) : FLT_MAX;
+		float dist_sq1 = child1 != nullptr? child1->mBounds.GetSqDistanceTo(inPoint) : FLT_MAX;
+		if (dist_sq1 < dist_sq0)
+		{
+			std::swap(child0, child1);
+			std::swap(dist_sq0, dist_sq1);
+		}
+
+		// Visit in order of closeness
+		if (dist_sq0 < ioClosestDistSq)
+			sHairSettingsFindClosestTriangle(inPoint, inBuilder, child0, inScalpVertices, ioClosestDistSq, outSkinPoint);
+		if (dist_sq1 < ioClosestDistSq)
+			sHairSettingsFindClosestTriangle(inPoint, inBuilder, child1, inScalpVertices, ioClosestDistSq, outSkinPoint);
+	}
+	else
+	{
+		// Loop over the triangles
+		for (const IndexedTriangle *t = inBuilder.GetTriangles().data() + inNode->mTrianglesBegin, *t_end = t + inNode->mNumTriangles; t < t_end; ++t)
+		{
+			Vec3 v0 = Vec3(inScalpVertices[t->mIdx[0]]) - inPoint;
+			Vec3 v1 = Vec3(inScalpVertices[t->mIdx[1]]) - inPoint;
+			Vec3 v2 = Vec3(inScalpVertices[t->mIdx[2]]) - inPoint;
+
+			// Check if it is the closest triangle
+			uint32 set;
+			Vec3 closest_point = ClosestPoint::GetClosestPointOnTriangle(v0, v1, v2, set);
+			float dist_sq = closest_point.LengthSq();
+			if (dist_sq < ioClosestDistSq)
+			{
+				ioClosestDistSq = dist_sq;
+				outSkinPoint.mTriangleIndex = t->mMaterialIndex;
+
+				// Get barycentric coordinates of attachment point
+				float w;
+				ClosestPoint::GetBaryCentricCoordinates(v0, v1, v2, outSkinPoint.mU, outSkinPoint.mV, w);
+			}
+		}
+	}
+}
+
+void HairSettings::Init(float &outMaxDistSqHairToScalp)
+{
+	outMaxDistSqHairToScalp = 0.0f;
+
+	if (!mScalpTriangles.empty())
+	{
+		// Build a tree for all scalp triangles
+		IndexedTriangleList triangles;
+		triangles.reserve(mScalpTriangles.size());
+		for (const IndexedTriangleNoMaterial &t : mScalpTriangles)
+			triangles.push_back(IndexedTriangle(t.mIdx[0], t.mIdx[1], t.mIdx[2], uint32(&t - mScalpTriangles.data())));
+		TriangleSplitterBinning splitter(mScalpVertices, triangles);
+		AABBTreeBuilder builder(splitter, 8);
+		AABBTreeBuilderStats builder_stats;
+		const AABBTreeBuilder::Node *root = builder.Build(builder_stats);
+
+		mSkinPoints.reserve(mSimStrands.size());
+		for (const SStrand &strand : mSimStrands)
+		{
+			SkinPoint sp;
+			sp.mTriangleIndex = 0;
+			sp.mU = 0.0f;
+			sp.mV = 0.0f;
+
+			// Get root position
+			Vec3 p = Vec3(mSimVertices[strand.mStartVtx].mPosition);
+
+			// Find closest triangle on scalp
+			float closest_dist_sq = FLT_MAX;
+			sHairSettingsFindClosestTriangle(p, builder, root, mScalpVertices, closest_dist_sq, sp);
+			outMaxDistSqHairToScalp = max(outMaxDistSqHairToScalp, closest_dist_sq);
+
+			// Project root to the triangle as we will during simulation.
+			// This ensures that we calculate the Bishop frame for the root correctly.
+			const IndexedTriangleNoMaterial &t = mScalpTriangles[sp.mTriangleIndex];
+			Vec3 v0 = Vec3(mScalpVertices[t.mIdx[0]]);
+			Vec3 v1 = Vec3(mScalpVertices[t.mIdx[1]]);
+			Vec3 v2 = Vec3(mScalpVertices[t.mIdx[2]]);
+			p = sp.mU * v0 + sp.mV * v1 + (1.0f - sp.mU - sp.mV) * v2;
+			p.StoreFloat3(&mSimVertices[strand.mStartVtx].mPosition);
+
+			mSkinPoints.push_back(sp);
+		}
+	}
+
+	Array<Vec3> r; // Outside loop to avoid reallocations
+	Array<Vec3> x;
+	Array<Vec3> k; // (bend_compliance, bend_compliance, stretch_compliance)
+	Array<Vec3> g;
+	Array<Quat> bishop;
+	mMaxVerticesPerStrand = 0;
+	for (const SStrand &strand : mSimStrands)
+	{
+		// Calculate max number of vertices per strand
+		uint32 vertex_count = strand.VertexCount();
+		mMaxVerticesPerStrand = max(mMaxVerticesPerStrand, vertex_count);
+
+		// Calculate strand fraction for each vertex
+		float total_length = strand.MeasureLength(mSimVertices);
+		float cur_length = 0.0f;
+		for (uint32 i = strand.mStartVtx; i < strand.mEndVtx - 1; ++i)
+		{
+			SVertex &v = mSimVertices[i];
+			v.mStrandFraction = cur_length / total_length;
+			cur_length += (Vec3(mSimVertices[i + 1].mPosition) - Vec3(v.mPosition)).Length();
+		}
+		mSimVertices[strand.mEndVtx - 1].mStrandFraction = 1.0f;
+
+		// Particles
+		// i=0     1       2
+		// +------>+------>+
+		//    x1      x2
+		//
+		// Let r_i be the edge between particle i - 1 and i in the rest pose
+		// Let x_i be the edge between particle i - 1 and i in the deformed pose
+		//
+		// The force on particle i is:
+		// f_i = k_i * (r_i - x_i) - k_{i+1} * (r_{i+1} - x_{i+1})
+		// Where k_i = 1 / compliance_i
+		//
+		// We want to counter gravity, so:
+		// f_i = -m_i * g
+		//
+		// Rearranging gives:
+		// x_{i+1} * k_{i+1} - x_i * k_i = k_{i+1} * r_{i+1} - k_i * r_i + m_i * g
+		//
+		// Solving this with Gauss Seidel iteration:
+		// x_i = (k_i * r_i - k_{i+1} * (r_{i+1} - x_{i+1}) - m_i * g) / k_i
+
+		r.resize(vertex_count); // Rest edge
+		x.resize(vertex_count); // Deformed edge
+		k.resize(vertex_count); // Spring constant
+		g.resize(vertex_count); // Gravity
+		bishop.resize(vertex_count);
+
+		// First element unused
+		x[0] = r[0] = g[0] = k[0] = Vec3::sNaN();
+
+		const HairSettings::Material &material = mMaterials[strand.mMaterialIndex];
+		HairSettings::GradientSampler gravity_sampler(material.mGravityFactor);
+		for (uint32 i = 1; i < vertex_count; ++i)
+		{
+			const SVertex &v1 = mSimVertices[strand.mStartVtx + i - 1];
+			const SVertex &v2 = mSimVertices[strand.mStartVtx + i];
+			x[i] = r[i] = Vec3(v2.mPosition) - Vec3(v1.mPosition);
+			constexpr float cMinCompliance = 1.0e-10f;
+			float bend_compliance = 1.0f / max(cMinCompliance, material.GetBendCompliance(v2.mStrandFraction));
+			float stretch_compliance = 1.0f / max(cMinCompliance, material.mStretchCompliance);
+			k[i] = Vec3(bend_compliance, bend_compliance, stretch_compliance);
+			g[i] = v2.mInvMass > 0.0f? (material.mGravityPreloadFactor / v2.mInvMass) * mInitialGravity * gravity_sampler.Sample(v2.mStrandFraction) : Vec3::sZero();
+		}
+
+		// Solve for x
+		if (material.mGravityPreloadFactor > 0.0f)
+			for (int iteration = 0; iteration < 10; ++iteration)
+			{
+				// Don't modify the 1st vertex since it's fixed
+				// Loop backwards so that we can use the latest value of x[i + 1]
+				for (uint32 i = vertex_count - 1; i >= 1; --i)
+				{
+					// Calculate reference frame for this edge
+					Vec3 frame_x = x[i].Normalized();
+					Vec3 frame_y = frame_x.GetNormalizedPerpendicular();
+					Vec3 frame_z = frame_x.Cross(frame_y);
+					Mat44 frame(Vec4(frame_y, 0), Vec4(frame_z, 0), Vec4(frame_x, 0), Vec4(0, 0, 0, 1));
+
+					// Gauss Seidel iteration
+					// Note that we take all quantities to local space so that we can separate bend and stretch compliance and apply those as a simple vector multiplication
+					Vec3 x_local = k[i] * frame.Multiply3x3Transposed(r[i]) - frame.Multiply3x3Transposed(g[i]);
+					if (i < vertex_count - 1)
+						x_local -= k[i + 1] * frame.Multiply3x3Transposed(r[i + 1] - x[i + 1]);
+					x[i] = frame.Multiply3x3(x_local / k[i]);
+				}
+			}
+
+		// Calculate the Bishop frame for the first rod in the strand
+		{
+			SVertex &v1 = mSimVertices[strand.mStartVtx];
+			Vec3 tangent = x[1];
+			v1.mLength = tangent.Length();
+			JPH_ASSERT(v1.mLength > 0.0f, "Rods of zero length are not supported!");
+			tangent /= v1.mLength;
+			Vec3 normal = tangent.GetNormalizedPerpendicular();
+			Vec3 binormal = tangent.Cross(normal);
+			bishop[0] = Mat44(Vec4(normal, 0), Vec4(binormal, 0), Vec4(tangent, 0), Vec4(0, 0, 0, 1)).GetQuaternion().Normalized();
+			bishop[0].StoreFloat4(&v1.mBishop);
+		}
+
+		// Calculate the Bishop frames for the rest of the rods in the strand
+		for (uint32 i = 1; i < vertex_count - 1; ++i)
+		{
+			SVertex &v1 = mSimVertices[strand.mStartVtx + i];
+			const SVertex &v2 = mSimVertices[strand.mStartVtx + i + 1];
+
+			// Get the normal and tangent of the first rod's Bishop frame (that was already calculated)
+			Mat44 r1_frame = Mat44::sRotation(bishop[i - 1]);
+			Vec3 tangent1 = r1_frame.GetAxisZ();
+			Vec3 normal1 = r1_frame.GetAxisX();
+
+			// Calculate the Bishop frame for the 2nd rod
+			Vec3 tangent2 = x[i + 1];
+			v1.mLength = tangent2.Length();
+			JPH_ASSERT(v1.mLength > 0.0f, "Rods of zero length are not supported!");
+			tangent2 /= v1.mLength;
+			Vec3 t1_cross_t2 = tangent1.Cross(tangent2);
+			float sin_angle = t1_cross_t2.Length();
+			Vec3 normal2 = normal1;
+			if (sin_angle > 1.0e-6f)
+			{
+				// Rotate normal2
+				t1_cross_t2 /= sin_angle;
+				normal2 = Quat::sRotation(t1_cross_t2, ASin(sin_angle)) * normal2;
+
+				// Ensure normal2 is perpendicular to tangent2
+				normal2 -= normal2.Dot(tangent2) * tangent2;
+				normal2 = normal2.Normalized();
+			}
+			Vec3 binormal2 = tangent2.Cross(normal2);
+			bishop[i] = Mat44(Vec4(normal2, 0), Vec4(binormal2, 0), Vec4(tangent2, 0), Vec4(0, 0, 0, 1)).GetQuaternion().Normalized();
+
+			// Calculate the delta, used in simulation
+			(bishop[i - 1].Conjugated() * bishop[i]).Normalized().StoreFloat4(&v1.mOmega0);
+
+			// Calculate the Bishop frame in the modeled pose for initializing the simulation
+			Vec3 modeled_tangent2 = (Vec3(v2.mPosition) - Vec3(v1.mPosition)).Normalized();
+			Quat modeled_bishop = Quat::sFromTo(tangent2, modeled_tangent2) * bishop[i];
+			modeled_bishop.StoreFloat4(&v1.mBishop);
+		}
+
+		// Copy Bishop frame to the last vertex
+		mSimVertices[strand.mEndVtx - 1].mBishop = mSimVertices[strand.mEndVtx - 2].mBishop;
+	}
+
+	// Finalize skin points by calculating how to go from triangle frame to Bishop frame
+	for (SkinPoint &sp : mSkinPoints)
+	{
+		const IndexedTriangleNoMaterial &t = mScalpTriangles[sp.mTriangleIndex];
+		Vec3 v0 = Vec3(mScalpVertices[t.mIdx[0]]);
+		Vec3 v1 = Vec3(mScalpVertices[t.mIdx[1]]);
+		Vec3 v2 = Vec3(mScalpVertices[t.mIdx[2]]);
+
+		// Get tangent vector
+		Vec3 tangent = (v1 - v0).Normalized();
+
+		// Get normal of the triangle
+		Vec3 normal = tangent.Cross(v2 - v0).Normalized();
+
+		// Calculate basis for the triangle
+		Vec3 binormal = tangent.Cross(normal);
+		Quat triangle_basis = Mat44(Vec4(normal, 0), Vec4(binormal, 0), Vec4(tangent, 0), Vec4(0, 0, 0, 1)).GetQuaternion();
+
+		// Calculate how to rotate from the triangle basis to the Bishop frame of the root
+		Quat to_bishop = triangle_basis.Conjugated() * Quat(mSimVertices[mSimStrands[&sp - mSkinPoints.data()].mStartVtx].mBishop);
+		sp.mToBishop = to_bishop.CompressUnitQuat();
+	}
+
+	// Calculate the grid size
+	mSimulationBounds = {};
+	for (const SVertex &v : mSimVertices)
+		mSimulationBounds.Encapsulate(Vec3(v.mPosition));
+	mSimulationBounds.ExpandBy(mSimulationBoundsPadding);
+
+	// Prepare neutral density grid
+	mNeutralDensity.resize(mGridSize.GetX() * mGridSize.GetY() * mGridSize.GetZ(), 0.0f);
+	GridSampler sampler(this);
+	for (const SVertex &v : mSimVertices)
+		if (v.mInvMass > 0.0f)
+		{
+			sampler.Sample(Vec3(v.mPosition), [this, &v](uint32 inIndex, float inFraction) {
+				mNeutralDensity[inIndex] += inFraction / v.mInvMass;
+			});
+		}
+
+	// Calculate density scale for drawing the grid
+	mDensityScale = 0.0f;
+	for (float density : mNeutralDensity)
+		mDensityScale = max(mDensityScale, density);
+	if (mDensityScale > 0.0f)
+		mDensityScale = 1.0f / mDensityScale;
+
+	// Prepare render vertices
+	for (RVertex &v : mRenderVertices)
+	{
+		Vec3 render_pos(v.mPosition);
+
+		float total_weight = 0.0f;
+		for (SVertexInfluence &inf : v.mInfluences)
+			if (inf.mVertexIndex != cNoInfluence)
+			{
+				const SVertex &simulated_vertex = mSimVertices[inf.mVertexIndex];
+				Vec3 simulated_pos(simulated_vertex.mPosition);
+				Vec3 local_position = Quat(simulated_vertex.mBishop).InverseRotate(render_pos - simulated_pos);
+				local_position.StoreFloat3(&inf.mRelativePosition);
+
+				// Weigh according to inverse distance to the simulated vertex
+				inf.mWeight = 1.0f / (local_position.Length() + 1.0e-6f);
+				total_weight += inf.mWeight;
+			}
+			else
+				inf.mWeight = 0.0f;
+
+		// Normalize weights
+		if (total_weight > 0.0f)
+			for (SVertexInfluence &a : v.mInfluences)
+				if (a.mVertexIndex != cNoInfluence)
+					a.mWeight /= total_weight;
+
+		// Order so that largest weight comes first
+		QuickSort(std::begin(v.mInfluences), std::end(v.mInfluences), [](const SVertexInfluence &inLHS, const SVertexInfluence &inRHS) {
+				return inLHS.mWeight > inRHS.mWeight;
+			});
+	}
+}
+
+void HairSettings::InitCompute(ComputeSystem *inComputeSystem)
+{
+	// Optional: We can attach the roots of the hairs to the scalp
+	if (!mScalpTriangles.empty() && !mSkinPoints.empty())
+	{
+		mScalpTrianglesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, mScalpTriangles.size() * 3, sizeof(uint32), mScalpTriangles.data()).Get();
+		mSkinPointsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, mSkinPoints.size(), sizeof(SkinPoint), mSkinPoints.data()).Get();
+
+		// We can skin the scalp or the skinned vertices can be provided externally
+		if (!mScalpVertices.empty() && !mScalpInverseBindPose.empty() && !mScalpSkinWeights.empty())
+		{
+			mScalpVerticesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, mScalpVertices.size(), sizeof(Float3), mScalpVertices.data()).Get();
+			mScalpSkinWeightsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, mScalpSkinWeights.size(), sizeof(JPH_HairSkinWeight), mScalpSkinWeights.data()).Get();
+		}
+	}
+
+	// Calculate the number of vertices for every strand
+	Array<uint8> strand_vertex_counts;
+	strand_vertex_counts.resize((mSimStrands.size() + sizeof(uint32) - 1) & ~(sizeof(uint32) - 1), 0); // Make size multiple of sizeof(uint32)
+	for (size_t i = 0, n = mSimStrands.size(); i < n; ++i)
+	{
+		uint32 count = mSimStrands[i].VertexCount();
+		JPH_ASSERT(count < 256);
+		strand_vertex_counts[i] = (uint8)count;
+	}
+
+	// Calculate material index for every strand
+	Array<uint8> strand_material_indices;
+	strand_material_indices.resize((mSimStrands.size() + sizeof(uint32) - 1) & ~(sizeof(uint32) - 1), 0); // Make size multiple of sizeof(uint32)
+	for (size_t i = 0, n = mSimStrands.size(); i < n; ++i)
+	{
+		uint32 material_index = mSimStrands[i].mMaterialIndex;
+		JPH_ASSERT(material_index < 256);
+		strand_material_indices[i] = (uint8)material_index;
+	}
+
+	// Create buffers that contain information about the rest pose of the hair
+	// Rearrange vertices so that the first vertices of all strands are grouped together, then the second vertices, etc.
+	uint num_vertices = uint(mMaxVerticesPerStrand * mSimStrands.size());
+	Array<Float3> vertices_position;
+	vertices_position.resize(num_vertices);
+	Array<uint32> vertices_bishop;
+	vertices_bishop.resize(num_vertices);
+	Array<uint32> vertices_omega0;
+	vertices_omega0.resize(num_vertices);
+	Array<uint32> vertices_fixed;
+	vertices_fixed.resize((num_vertices + 31) / 32, 0);
+	Array<float> vertices_length;
+	vertices_length.resize(num_vertices);
+	Array<uint32> vertices_strand_fraction;
+	vertices_strand_fraction.resize((num_vertices + 3) / 4, 0);
+	for (size_t s = 0, ns = mSimStrands.size(); s < ns; ++s)
+	{
+		const SStrand &strand = mSimStrands[s];
+		for (uint32 v = 0, nv = strand.VertexCount(); v < nv; ++v)
+		{
+			const SVertex &in_v = mSimVertices[strand.mStartVtx + v];
+			size_t idx = v * mSimStrands.size() + s;
+
+			vertices_position[idx] = in_v.mPosition;
+			vertices_bishop[idx] = Vec4::sLoadFloat4(&in_v.mBishop).CompressUnitVector();
+			vertices_omega0[idx] = Vec4::sLoadFloat4(&in_v.mOmega0).CompressUnitVector();
+			vertices_length[idx] = in_v.mLength;
+			if (in_v.mInvMass <= 0.0f)
+				vertices_fixed[idx >> 5] |= uint32(1 << (idx & 31));
+			vertices_strand_fraction[idx >> 2] |= uint32(in_v.mStrandFraction * 255.0f) << ((idx & 3) << 3);
+		}
+	}
+
+	// Calculate a map from render vertex to strand index
+	Array<uint32> simulation_vertex_to_strand_idx;
+	simulation_vertex_to_strand_idx.resize(mSimVertices.size(), ~uint32(0));
+	for (const SStrand &strand : mSimStrands)
+		for (uint v = strand.mStartVtx; v < strand.mEndVtx; ++v)
+				simulation_vertex_to_strand_idx[v] = uint32(&strand - mSimStrands.data());
+
+	// Create buffer for simulated vertex influences
+	Array<JPH_HairSVertexInfluence> svertex_influences;
+	svertex_influences.resize(mRenderVertices.size() * cHairNumSVertexInfluences);
+	for (size_t v = 0, n = mRenderVertices.size(); v < n; ++v)
+		for (uint a = 0; a < cHairNumSVertexInfluences; ++a)
+		{
+			JPH_HairSVertexInfluence &inf = svertex_influences[v * cHairNumSVertexInfluences + a];
+			inf = static_cast<const JPH_HairSVertexInfluence &>(mRenderVertices[v].mInfluences[a]);
+
+			// Remap vertex index to reflect the transposing of the position buffer
+			if (inf.mVertexIndex != cNoInfluence)
+			{
+				uint32 strand_idx = simulation_vertex_to_strand_idx[inf.mVertexIndex];
+				uint32 start_vtx = mSimStrands[strand_idx].mStartVtx;
+				inf.mVertexIndex = strand_idx + (inf.mVertexIndex - start_vtx) * uint32(mSimStrands.size());
+			}
+			else
+			{
+				// The shader doesn't check if weight is zero, it just takes the vertex. Make sure the index points to something.
+				inf.mVertexIndex = 0;
+			}
+		}
+
+	mVerticesPositionCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, vertices_position.size(), sizeof(Float3), vertices_position.data()).Get();
+	mVerticesBishopCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, vertices_bishop.size(), sizeof(uint32), vertices_bishop.data()).Get();
+	mVerticesOmega0CB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, vertices_omega0.size(), sizeof(uint32), vertices_omega0.data()).Get();
+	mVerticesLengthCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, vertices_length.size(), sizeof(float), vertices_length.data()).Get();
+	mVerticesFixedCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, vertices_fixed.size(), sizeof(uint32), vertices_fixed.data()).Get();
+	mVerticesStrandFractionCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, vertices_strand_fraction.size(), sizeof(uint32), vertices_strand_fraction.data()).Get();
+	mStrandVertexCountsCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, strand_vertex_counts.size() / sizeof(uint32), sizeof(uint32), strand_vertex_counts.data()).Get();
+	mStrandMaterialIndexCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, strand_material_indices.size() / sizeof(uint32), sizeof(uint32), strand_material_indices.data()).Get();
+	mNeutralDensityCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, mNeutralDensity.size(), sizeof(float), mNeutralDensity.data()).Get();
+	mSVertexInfluencesCB = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, mRenderVertices.size() * cHairNumSVertexInfluences, sizeof(JPH_HairSVertexInfluence), svertex_influences.data()).Get();
+}
+
+void HairSettings::SaveBinaryState(StreamOut &inStream) const
+{
+	inStream.Write(mSimVertices);
+	inStream.Write(mSimStrands);
+	inStream.Write(mRenderVertices);
+	inStream.Write(mRenderStrands);
+	inStream.Write(mScalpVertices);
+	inStream.Write(mScalpTriangles);
+	inStream.Write(mScalpInverseBindPose);
+	inStream.Write(mScalpSkinWeights);
+	inStream.Write(mScalpNumSkinWeightsPerVertex);
+	inStream.Write(mNumIterationsPerSecond);
+	inStream.Write(mMaxDeltaTime);
+	inStream.Write(mGridSize);
+	inStream.Write(mSimulationBoundsPadding);
+	inStream.Write(mInitialGravity);
+	inStream.Write(mMaterials, [](const Material &inElement, StreamOut &inS) {
+		inS.Write(inElement.mEnableCollision);
+		inS.Write(inElement.mEnableLRA);
+		inS.Write(inElement.mLinearDamping);
+		inS.Write(inElement.mAngularDamping);
+		inS.Write(inElement.mMaxLinearVelocity);
+		inS.Write(inElement.mMaxAngularVelocity);
+		inElement.mGravityFactor.SaveBinaryState(inS);
+		inS.Write(inElement.mFriction);
+		inS.Write(inElement.mBendCompliance);
+		inS.Write(inElement.mBendComplianceMultiplier);
+		inS.Write(inElement.mStretchCompliance);
+		inS.Write(inElement.mInertiaMultiplier);
+		inElement.mHairRadius.SaveBinaryState(inS);
+		inElement.mWorldTransformInfluence.SaveBinaryState(inS);
+		inElement.mGridVelocityFactor.SaveBinaryState(inS);
+		inS.Write(inElement.mGridDensityForceFactor);
+		inElement.mGlobalPose.SaveBinaryState(inS);
+		inElement.mSkinGlobalPose.SaveBinaryState(inS);
+		inS.Write(inElement.mSimulationStrandsFraction);
+		inS.Write(inElement.mGravityPreloadFactor);
+	});
+	inStream.Write(mSkinPoints);
+	inStream.Write(mSimulationBounds);
+	inStream.Write(mNeutralDensity);
+	inStream.Write(mDensityScale);
+	inStream.Write(mMaxVerticesPerStrand);
+}
+
+void HairSettings::RestoreBinaryState(StreamIn &inStream)
+{
+	inStream.Read(mSimVertices);
+	inStream.Read(mSimStrands);
+	inStream.Read(mRenderVertices);
+	inStream.Read(mRenderStrands);
+	inStream.Read(mScalpVertices);
+	inStream.Read(mScalpTriangles);
+	inStream.Read(mScalpInverseBindPose);
+	inStream.Read(mScalpSkinWeights);
+	inStream.Read(mScalpNumSkinWeightsPerVertex);
+	inStream.Read(mNumIterationsPerSecond);
+	inStream.Read(mMaxDeltaTime);
+	inStream.Read(mGridSize);
+	inStream.Read(mSimulationBoundsPadding);
+	inStream.Read(mInitialGravity);
+	inStream.Read(mMaterials, [](StreamIn &inS, Material &outElement) {
+		inS.Read(outElement.mEnableCollision);
+		inS.Read(outElement.mEnableLRA);
+		inS.Read(outElement.mLinearDamping);
+		inS.Read(outElement.mAngularDamping);
+		inS.Read(outElement.mMaxLinearVelocity);
+		inS.Read(outElement.mMaxAngularVelocity);
+		outElement.mGravityFactor.RestoreBinaryState(inS);
+		inS.Read(outElement.mFriction);
+		inS.Read(outElement.mBendCompliance);
+		inS.Read(outElement.mBendComplianceMultiplier);
+		inS.Read(outElement.mStretchCompliance);
+		inS.Read(outElement.mInertiaMultiplier);
+		outElement.mHairRadius.RestoreBinaryState(inS);
+		outElement.mWorldTransformInfluence.RestoreBinaryState(inS);
+		outElement.mGridVelocityFactor.RestoreBinaryState(inS);
+		inS.Read(outElement.mGridDensityForceFactor);
+		outElement.mGlobalPose.RestoreBinaryState(inS);
+		outElement.mSkinGlobalPose.RestoreBinaryState(inS);
+		inS.Read(outElement.mSimulationStrandsFraction);
+		inS.Read(outElement.mGravityPreloadFactor);
+	});
+	inStream.Read(mSkinPoints);
+	inStream.Read(mSimulationBounds);
+	inStream.Read(mNeutralDensity);
+	inStream.Read(mDensityScale);
+	inStream.Read(mMaxVerticesPerStrand);
+}
+
+void HairSettings::PrepareForScalpSkinning(Mat44Arg inJointToHair, const Mat44 *inJointMatrices, Mat44 *outJointMatrices) const
+{
+	for (uint32 i = 0, n = (uint32)mScalpInverseBindPose.size(); i < n; ++i)
+		outJointMatrices[i] = inJointToHair * inJointMatrices[i] * mScalpInverseBindPose[i];
+}
+
+void HairSettings::SkinScalpVertices(Mat44Arg inJointToHair, const Mat44 *inJointMatrices, Array<Vec3> &outVertices) const
+{
+	outVertices.resize(mScalpVertices.size());
+
+	// Pre transform all joint matrices
+	Array<Mat44> joint_matrices;
+	joint_matrices.resize((uint32)mScalpInverseBindPose.size());
+	PrepareForScalpSkinning(inJointToHair, inJointMatrices, joint_matrices.data());
+
+	// Skin all vertices
+	for (uint32 i = 0; i < (uint32)mScalpVertices.size(); ++i)
+	{
+		Vec3 &v = outVertices[i];
+		v = Vec3::sZero();
+		for (const SkinWeight *w = mScalpSkinWeights.data() + i * mScalpNumSkinWeightsPerVertex, *w_end = w + mScalpNumSkinWeightsPerVertex; w < w_end; ++w)
+			if (w->mWeight > 0.0f)
+				v += w->mWeight * joint_matrices[w->mJointIdx] * Vec3(mScalpVertices[i]);
+	}
+}
+
+JPH_NAMESPACE_END

+ 373 - 0
Jolt/Physics/Hair/HairSettings.h

@@ -0,0 +1,373 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Core/Reference.h>
+#include <Jolt/Core/StreamUtils.h>
+#include <Jolt/Geometry/AABox.h>
+#include <Jolt/Geometry/IndexedTriangle.h>
+#include <Jolt/ObjectStream/SerializableObject.h>
+#include <Jolt/Compute/ComputeBuffer.h>
+#include <Jolt/Compute/ComputeSystem.h>
+#include <Jolt/Shaders/HairStructs.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// This class defines the setup of a hair groom, it can be shared between multiple hair instances
+class JPH_EXPORT HairSettings : public RefTarget<HairSettings>
+{
+	JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, HairSettings)
+
+public:
+	/// How much a vertex is influenced by a joint
+	struct JPH_EXPORT SkinWeight : public JPH_HairSkinWeight
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, SkinWeight)
+	};
+
+	/// Information about where a hair strand is attached to the scalp mesh
+	struct JPH_EXPORT SkinPoint : public JPH_HairSkinPoint
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, SkinPoint)
+	};
+
+	static constexpr uint32 cNoInfluence = ~uint32(0);
+
+	/// Describes how a render vertex is influenced by a simulated vertex
+	struct JPH_EXPORT SVertexInfluence : public JPH_HairSVertexInfluence
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, SVertexInfluence)
+
+		inline			SVertexInfluence()							{ mVertexIndex = cNoInfluence; mRelativePosition = JPH_float3(0, 0, 0); mWeight = 0.0f; }
+	};
+
+	/// A render vertex
+	struct JPH_EXPORT RVertex
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, RVertex)
+
+		Float3			mPosition { 0, 0, 0 };						///< Initial position of the vertex
+		SVertexInfluence mInfluences[cHairNumSVertexInfluences];	///< Attach to X simulated vertices (computed during Init)
+	};
+
+	/// A simulated vertex in a hair strand
+	struct JPH_EXPORT SVertex
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, SVertex)
+
+		/// Constructor
+						SVertex() = default;
+		explicit		SVertex(const Float3 &inPosition, float inInvMass = 1.0f) : mPosition(inPosition), mInvMass(inInvMass) { }
+
+		Float3			mPosition { 0, 0, 0 };						///< Initial position of the vertex in its modeled pose
+		float			mInvMass = 1.0f;							///< Inverse of the mass of the vertex
+		float			mLength = 0.0f;								///< Initial distance of this vertex to the next of the unloaded strand, computed by Init
+		float			mStrandFraction = 0.0f;						///< Fraction along the strand, 0 = start, 1 = end, computed by Init
+		Float4			mBishop { 0, 0, 0, 1.0f };					///< Bishop frame of the strand in its modeled pose, computed by Init
+		Float4			mOmega0 { 0, 0, 0, 1.0f };					///< Conjugate(Previous Bishop) * Bishop, defines the rotation difference between the previous rod and this one of the unloaded strand, computed by Init
+	};
+
+	/// A hair render strand
+	struct JPH_EXPORT RStrand
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, RStrand)
+
+		/// Constructor
+						RStrand() = default;
+						RStrand(uint32 inStartVtx, uint32 inEndVtx) : mStartVtx(inStartVtx), mEndVtx(inEndVtx) { }
+
+		uint32			VertexCount() const							{ return mEndVtx - mStartVtx; }
+
+		float			MeasureLength(const Array<SVertex> &inVertices) const
+		{
+			float length = 0.0f;
+			for (uint32 v = mStartVtx; v < mEndVtx - 1; ++v)
+				length += (Vec3(inVertices[v + 1].mPosition) - Vec3(inVertices[v].mPosition)).Length();
+			return length;
+		}
+
+		uint32			mStartVtx;
+		uint32			mEndVtx;
+	};
+
+	/// A hair simulation strand
+	struct JPH_EXPORT SStrand : public RStrand
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, SStrand)
+
+						SStrand() = default;
+						SStrand(uint32 inStartVtx, uint32 inEndVtx, uint32 inMaterialIndex) : RStrand(inStartVtx, inEndVtx), mMaterialIndex(inMaterialIndex) { }
+
+		uint32			mMaterialIndex = 0;							///< Index in mMaterials
+	};
+
+	/// Gradient along a hair strand of a value, e.g. compliance, friction, etc.
+	class JPH_EXPORT Gradient
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, Gradient)
+
+	public:
+						Gradient() = default;
+						Gradient(float inMin, float inMax, float inMinFraction = 0.0f, float inMaxFraction = 1.0f) : mMin(inMin), mMax(inMax), mMinFraction(inMinFraction), mMaxFraction(inMaxFraction) { }
+
+		/// We drive a value to its target with fixed time steps using:
+		///
+		/// x(t + fixed_dt) = target + (1 - k) * (x(t) - target)
+		///
+		/// For varying time steps we can rewrite this to:
+		///
+		/// x(t + dt) = target + (1 - k)^inTimeRatio * (x(t) - target)
+		///
+		/// Where inTimeRatio is defined as dt / fixed_dt.
+		///
+		/// This means k' = 1 - (1 - k)^inTimeRatio
+		Gradient		MakeStepDependent(float inTimeRatio) const
+		{
+			auto make_dependent = [inTimeRatio](float inValue) {
+					return 1.0f - std::pow(1.0f - inValue, inTimeRatio);
+				};
+
+			return Gradient(make_dependent(mMin), make_dependent(mMax), mMinFraction, mMaxFraction);
+		}
+
+		/// Saves the state of this object in binary form to inStream. Doesn't store the compute buffers.
+		void			SaveBinaryState(StreamOut &inStream) const;
+
+		/// Restore the state of this object from inStream.
+		void			RestoreBinaryState(StreamIn &inStream);
+
+		float			mMin = 0.0f;								///< Minimum value of the gradient
+		float			mMax = 1.0f;								///< Maximum value of the gradient
+		float			mMinFraction = 0.0f;						///< Fraction along the hair strand that corresponds to the minimum value
+		float			mMaxFraction = 1.0f;						///< Fraction along the hair strand that corresponds to the maximum value
+	};
+
+	class GradientSampler
+	{
+	public:
+						GradientSampler() = default;
+
+		explicit		GradientSampler(const Gradient &inGradient) :
+			mMultiplier((inGradient.mMax - inGradient.mMin) / (inGradient.mMaxFraction - inGradient.mMinFraction)),
+			mOffset(inGradient.mMin - inGradient.mMinFraction * mMultiplier),
+			mMin(min(inGradient.mMin, inGradient.mMax)),
+			mMax(max(inGradient.mMin, inGradient.mMax))
+		{
+		}
+
+		/// Sample the value along the strand
+		inline float	Sample(float inFraction) const
+		{
+			return min(mMax, max(mMin, mOffset + inFraction * mMultiplier));
+		}
+
+		inline float	Sample(const SStrand &inStrand, uint32 inVertex) const
+		{
+			return Sample(float(inVertex - inStrand.mStartVtx) / float(inStrand.VertexCount() - 1));
+		}
+
+		/// Convert to Float4 to pass to shader
+		inline Float4	ToFloat4() const
+		{
+			return Float4(mMultiplier, mOffset, mMin, mMax);
+		}
+
+	private:
+		float			mMultiplier;
+		float			mOffset;
+		float			mMin;
+		float			mMax;
+	};
+
+	/// The material determines the simulation parameters for a hair strand
+	struct JPH_EXPORT Material
+	{
+		JPH_DECLARE_SERIALIZABLE_NON_VIRTUAL(JPH_EXPORT, Material)
+
+		/// Returns if this material needs a density/velocity grid
+		bool			NeedsGrid() const							{ return mGridVelocityFactor.mMin != 0.0f || mGridVelocityFactor.mMax != 0.0f || mGridDensityForceFactor != 0.0f; }
+
+		/// If this material only needs running the global pose logic
+		bool			GlobalPoseOnly() const						{ return !mEnableCollision && mGlobalPose.mMin == 1.0f && mGlobalPose.mMax == 1.0f; }
+
+		/// Calculate the bend compliance at a fraction along the strand
+		float			GetBendCompliance(float inStrandFraction) const
+		{
+			float fraction = inStrandFraction * 3.0f;
+			uint idx = min(uint(fraction), 2u);
+			fraction = fraction - float(idx);
+			JPH_ASSERT(fraction >= 0.0f && fraction <= 1.0f);
+			float multiplier = mBendComplianceMultiplier[idx] * (1.0f - fraction) + mBendComplianceMultiplier[idx + 1] * fraction;
+			return multiplier * mBendCompliance;
+		}
+
+		bool			mEnableCollision = true;					///< Enable collision detection between hair strands and the environment.
+		bool			mEnableLRA = true;							///< Enable Long Range Attachments to keep hair close to the modeled pose. This prevents excessive stretching when the head moves quickly.
+		float			mLinearDamping = 2.0f;						///< Linear damping coefficient for the simulated rods.
+		float			mAngularDamping = 2.0f;						///< Angular damping coefficient for the simulated rods.
+		float			mMaxLinearVelocity = 10.0f;					///< Maximum linear velocity of a vertex.
+		float			mMaxAngularVelocity = 50.0f;				///< Maximum angular velocity of a vertex.
+		Gradient		mGravityFactor { 0.1f, 1.0f, 0.2f, 0.8f };	///< How much gravity affects the hair along its length, 0 = no gravity, 1 = full gravity. Can be used to reduce the effect of gravity.
+		float			mFriction = 0.2f;							///< Collision friction coefficient. Usually in the range [0, 1]. 0 = no friction.
+		float			mBendCompliance = 1.0e-7f;					///< Compliance for bend constraints: 1 / stiffness.
+		Float4			mBendComplianceMultiplier = { 1.0f, 100.0f, 100.0f, 1.0f }; ///< Multiplier for bend compliance at 0%, 33%, 66% and 100% of the strand length.
+		float			mStretchCompliance = 1.0e-8f;				///< Compliance for stretch constraints: 1 / stiffness.
+		float			mInertiaMultiplier = 10.0f;					///< Multiplier applied to the mass of a rod to calculate its inertia.
+		Gradient		mHairRadius = { 0.001f, 0.001f };			///< Radius of the hair strand along its length, used for collision detection.
+		Gradient		mWorldTransformInfluence { 0.0f, 1.0f };	///< How much rotating the head influences the hair, 0 = not at all, the hair will move with the head as if it had no inertia. 1 = hair stays in place as the head moves and is correctly simulated. This can be used to reduce the effect of turning the head towards the root of strands.
+		Gradient		mGridVelocityFactor { 0.05f, 0.01f };		///< Every iteration this fraction of the grid velocity will be applied to the vertex velocity. Defined at cDefaultIterationsPerSecond, if this changes, the value will internally be adjusted to result in the same behavior.
+		float			mGridDensityForceFactor = 0.0f;				///< This factor will try to push the density of the hair towards the neutral density defined in the density grid. Note that can result in artifacts so defaults to 0.
+		Gradient		mGlobalPose { 0.01f, 0, 0.0f, 0.3f };		///< Every iteration this fraction of the neutral pose will be applied to the vertex position. Defined at cDefaultIterationsPerSecond, if this changes, the value will internally be adjusted to result in the same behavior.
+		Gradient		mSkinGlobalPose { 1.0f, 0.0f, 0.0f, 0.1f }; ///< How much the global pose follows the skin of the scalp. 0 is not following, 1 is fully following.
+		float			mSimulationStrandsFraction = 0.1f;			///< Used by InitRenderAndSimulationStrands only. Indicates the fraction of strands that should be simulated.
+		float			mGravityPreloadFactor = 0.0f;				///< Note: Not fully functional yet! This controls how much of the gravity we will remove from the modeled pose when initializing. A value of 1 fully removes gravity and should result in no sagging when the simulation starts. A value of 0 doesn't remove gravity.
+	};
+
+	/// Split the supplied render strands into render and simulation strands and calculate connections between them.
+	/// When this function returns mSimVertices, mSimStrands, mRenderVertices and mRenderStrands are overwritten.
+	/// @param inVertices Vertices for the strands.
+	/// @param inStrands The strands that this instance should have.
+	void				InitRenderAndSimulationStrands(const Array<SVertex> &inVertices, const Array<SStrand> &inStrands);
+
+	/// Resample the hairs to a new fixed number of vertices per strand. Must be called prior to Init if desired.
+	static void			sResample(Array<SVertex> &ioVertices, Array<SStrand> &ioStrands, uint32 inNumVerticesPerStrand);
+
+	/// Initialize the structure, calculating simulation bounds and vertex properties
+	/// @param outMaxDistSqHairToScalp Maximum distance^2 the root vertex of a hair is from the scalp, can be used to check if the hair matched the scalp correctly
+	void				Init(float &outMaxDistSqHairToScalp);
+
+	/// Must be called after Init to setup the compute buffers
+	void				InitCompute(ComputeSystem *inComputeSystem);
+
+	/// Sample the neutral density at a grid position
+	float				GetNeutralDensity(uint32 inX, uint32 inY, uint32 inZ) const
+	{
+		JPH_ASSERT(inX < mGridSize.GetX() && inY < mGridSize.GetY() && inZ < mGridSize.GetZ());
+		return mNeutralDensity[inX + inY * mGridSize.GetX() + inZ * mGridSize.GetX() * mGridSize.GetY()];
+	}
+
+	/// Get the number of vertices in the vertex buffers padded to a multiple of mMaxVerticesPerStrand.
+	inline uint32		GetNumVerticesPadded() const
+	{
+		return uint32(mSimStrands.size()) * mMaxVerticesPerStrand;
+	}
+
+	/// @brief Calculates the pose used for skinning the scalp
+	/// @param inJointToHair Transform to bring the model space joint matrices to the hair local space
+	/// @param inJointMatrices Model space joint matrices of the joints in the face
+	/// @param outJointMatrices Joint matrices combined with the inverse bind pose
+	void				PrepareForScalpSkinning(Mat44Arg inJointToHair, const Mat44 *inJointMatrices, Mat44 *outJointMatrices) const;
+
+	/// Skin the scalp mesh to the given joint matrices and output the skinned scalp vertices
+	/// @param inJointToHair Transform to bring the model space joint matrices to the hair local space
+	/// @param inJointMatrices Model space joint matrices of the joints in the face
+	/// @param outVertices Returns skinned vertices
+	void				SkinScalpVertices(Mat44Arg inJointToHair, const Mat44 *inJointMatrices, Array<Vec3> &outVertices) const;
+
+	/// Saves the state of this object in binary form to inStream. Doesn't store the compute buffers.
+	void				SaveBinaryState(StreamOut &inStream) const;
+
+	/// Restore the state of this object from inStream.
+	void				RestoreBinaryState(StreamIn &inStream);
+
+	class GridSampler
+	{
+	public:
+		inline explicit	GridSampler(const HairSettings *inSettings) :
+			mGridSizeMin2(inSettings->mGridSize - UVec4::sReplicate(2)),
+			mGridSizeMin1((inSettings->mGridSize - UVec4::sReplicate(1)).ToFloat()),
+			mGridStride(1, inSettings->mGridSize.GetX(), inSettings->mGridSize.GetX() * inSettings->mGridSize.GetY(), 0),
+			mOffset(inSettings->mSimulationBounds.mMin),
+			mScale(Vec3(inSettings->mGridSize.ToFloat()) / inSettings->mSimulationBounds.GetSize())
+		{
+		}
+
+		/// Convert a position in hair space to a grid index and fraction
+		inline void		PositionToIndexAndFraction(Vec3Arg inPosition, UVec4 &outIndex, Vec3 &outFraction) const
+		{
+			// Get position in grid space
+			Vec3 grid_pos = Vec3::sMin(Vec3::sMax(inPosition - mOffset, Vec3::sZero()) * mScale, mGridSizeMin1);
+			outIndex = UVec4::sMin(Vec4(grid_pos).ToInt(), mGridSizeMin2);
+			outFraction = grid_pos - Vec3(outIndex.ToFloat());
+		}
+
+		template <typename F>
+		inline void		Sample(UVec4Arg inIndex, Vec3Arg inFraction, const F &inFunc) const
+		{
+			Vec3 fraction[] = { Vec3::sReplicate(1.0f) - inFraction, inFraction };
+
+			// Sample the grid
+			for (uint32 z = 0; z < 2; ++z)
+				for (uint32 y = 0; y < 2; ++y)
+					for (uint32 x = 0; x < 2; ++x)
+					{
+						uint32 index = mGridStride.Dot(inIndex + UVec4(x, y, z, 0));
+						float combined_fraction = fraction[x].GetX() * fraction[y].GetY() * fraction[z].GetZ();
+						inFunc(index, combined_fraction);
+					}
+		}
+
+		template <typename F>
+		inline void		Sample(Vec3Arg inPosition, const F &inFunc) const
+		{
+			UVec4 index;
+			Vec3 fraction;
+			PositionToIndexAndFraction(inPosition, index, fraction);
+			Sample(index, fraction, inFunc);
+		}
+
+		UVec4			mGridSizeMin2;
+		Vec3			mGridSizeMin1;
+		UVec4			mGridStride;
+		Vec3			mOffset;
+		Vec3			mScale;
+	};
+
+	static constexpr uint32 cDefaultIterationsPerSecond = 360;
+
+	Array<SVertex>		mSimVertices;								///< Simulated vertices. Used by mSimStrands.
+	Array<SStrand>		mSimStrands;								///< Defines the start and end of each simulated strand.
+
+	Array<RVertex>		mRenderVertices;							///< Rendered vertices. Used by mRenderStrands.
+	Array<RStrand>		mRenderStrands;								///< Defines the start and end of each rendered strand.
+
+	Array<Float3>		mScalpVertices;								///< Vertices of the scalp mesh, used to attach hairs. Note that the hair vertices mSimVertices must be in the same space as these vertices.
+	Array<IndexedTriangleNoMaterial> mScalpTriangles;				///< Triangles of the scalp mesh.
+	Array<Mat44>		mScalpInverseBindPose;						///< Inverse bind pose of the scalp mesh, joints are in model space
+	Array<SkinWeight>	mScalpSkinWeights;							///< Skin weights of the scalp mesh, for each vertex we have mScalpNumSkinWeightsPerVertex entries
+	uint				mScalpNumSkinWeightsPerVertex = 0;			///< Number of skin weights per vertex
+
+	uint32				mNumIterationsPerSecond = cDefaultIterationsPerSecond;
+	float				mMaxDeltaTime = 1.0f / 30.0f;				///< Maximum delta time for the simulation step (to avoid running an excessively long step, note that this will effectively slow down time)
+	UVec4				mGridSize { 32, 32, 32, 0 };				///< Number of grid cells used to simulate the hair. W unused.
+	Vec3				mSimulationBoundsPadding = Vec3::sReplicate(0.1f); ///< Padding around the simulation bounds to ensure that the grid is large enough and that we detect collisions with the hairs. This is added on all sides after calculating the bounds in the neutral pose.
+	Vec3				mInitialGravity { 0, -9.81f, 0 };			///< Initial gravity in local space of the hair, used to calculate the unloaded rest pose
+	Array<Material>		mMaterials;									///< Materials used by the hair strands
+
+	// Values computed by Init
+	Array<SkinPoint>	mSkinPoints;								///< For each simulated vertex, where it is attached to the scalp mesh
+	AABox				mSimulationBounds { Vec3::sZero(), 1.0f };	///< Bounds that the simulation is supposed to fit in
+	Array<float>		mNeutralDensity;							///< Neutral density grid used to apply forces to keep the hair in place
+	float				mDensityScale = 0.0f;						///< Highest density value in the neutral density grid, used to scale the density for rendering
+	uint32				mMaxVerticesPerStrand = 0;					///< Maximum number of vertices per strand, used for padding the compute buffers
+
+	// Compute data
+	Ref<ComputeBuffer>	mScalpVerticesCB;
+	Ref<ComputeBuffer>	mScalpTrianglesCB;
+	Ref<ComputeBuffer>	mScalpSkinWeightsCB;
+	Ref<ComputeBuffer>	mSkinPointsCB;
+	Ref<ComputeBuffer>	mVerticesFixedCB;
+	Ref<ComputeBuffer>	mVerticesPositionCB;
+	Ref<ComputeBuffer>	mVerticesBishopCB;
+	Ref<ComputeBuffer>	mVerticesOmega0CB;
+	Ref<ComputeBuffer>	mVerticesLengthCB;
+	Ref<ComputeBuffer>	mVerticesStrandFractionCB;
+	Ref<ComputeBuffer>	mStrandVertexCountsCB;
+	Ref<ComputeBuffer>	mStrandMaterialIndexCB;
+	Ref<ComputeBuffer>	mNeutralDensityCB;
+	Ref<ComputeBuffer>	mSVertexInfluencesCB;
+};
+
+JPH_NAMESPACE_END

+ 33 - 0
Jolt/Physics/Hair/HairShaders.cpp

@@ -0,0 +1,33 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#include <Jolt/Physics/Hair/HairShaders.h>
+#include <Jolt/Shaders/HairStructs.h>
+
+JPH_NAMESPACE_BEGIN
+
+void HairShaders::Init(ComputeSystem *inComputeSystem)
+{
+	auto get = [](const ComputeShaderResult &inResult) { return inResult.IsValid()? inResult.Get() : nullptr; };
+
+	mTeleportCS = get(inComputeSystem->CreateComputeShader("HairTeleport", cHairPerVertexBatch));
+	mApplyDeltaTransformCS = get(inComputeSystem->CreateComputeShader("HairApplyDeltaTransform", cHairPerVertexBatch));
+	mSkinVerticesCS = get(inComputeSystem->CreateComputeShader("HairSkinVertices", cHairPerVertexBatch));
+	mSkinRootsCS = get(inComputeSystem->CreateComputeShader("HairSkinRoots", cHairPerStrandBatch));
+	mApplyGlobalPoseCS = get(inComputeSystem->CreateComputeShader("HairApplyGlobalPose", cHairPerVertexBatch));
+	mCalculateCollisionPlanesCS = get(inComputeSystem->CreateComputeShader("HairCalculateCollisionPlanes", cHairPerVertexBatch));
+	mGridClearCS = get(inComputeSystem->CreateComputeShader("HairGridClear", cHairPerGridCellBatch));
+	mGridAccumulateCS = get(inComputeSystem->CreateComputeShader("HairGridAccumulate", cHairPerVertexBatch));
+	mGridNormalizeCS = get(inComputeSystem->CreateComputeShader("HairGridNormalize", cHairPerGridCellBatch));
+	mIntegrateCS = get(inComputeSystem->CreateComputeShader("HairIntegrate", cHairPerVertexBatch));
+	mUpdateRootsCS = get(inComputeSystem->CreateComputeShader("HairUpdateRoots", cHairPerStrandBatch));
+	mUpdateStrandsCS = get(inComputeSystem->CreateComputeShader("HairUpdateStrands", cHairPerStrandBatch));
+	mUpdateVelocityCS = get(inComputeSystem->CreateComputeShader("HairUpdateVelocity", cHairPerVertexBatch));
+	mUpdateVelocityIntegrateCS = get(inComputeSystem->CreateComputeShader("HairUpdateVelocityIntegrate", cHairPerVertexBatch));
+	mCalculateRenderPositionsCS = get(inComputeSystem->CreateComputeShader("HairCalculateRenderPositions", cHairPerRenderVertexBatch));
+}
+
+JPH_NAMESPACE_END

+ 37 - 0
Jolt/Physics/Hair/HairShaders.h

@@ -0,0 +1,37 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Core/Reference.h>
+#include <Jolt/Compute/ComputeSystem.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// This class loads the shaders used by the hair system. This can be shared among all hair instances.
+class JPH_EXPORT HairShaders : public RefTarget<HairShaders>
+{
+public:
+	/// Loads all shaders
+	/// Note that if you want to run the sim on CPU you need call HairRegisterShaders first.
+	void				Init(ComputeSystem *inComputeSystem);
+
+	Ref<ComputeShader>	mTeleportCS;
+	Ref<ComputeShader>	mApplyDeltaTransformCS;
+	Ref<ComputeShader>	mSkinVerticesCS;
+	Ref<ComputeShader>	mSkinRootsCS;
+	Ref<ComputeShader>	mApplyGlobalPoseCS;
+	Ref<ComputeShader>	mCalculateCollisionPlanesCS;
+	Ref<ComputeShader>	mGridClearCS;
+	Ref<ComputeShader>	mGridAccumulateCS;
+	Ref<ComputeShader>	mGridNormalizeCS;
+	Ref<ComputeShader>	mIntegrateCS;
+	Ref<ComputeShader>	mUpdateRootsCS;
+	Ref<ComputeShader>	mUpdateStrandsCS;
+	Ref<ComputeShader>	mUpdateVelocityCS;
+	Ref<ComputeShader>	mUpdateVelocityIntegrateCS;
+	Ref<ComputeShader>	mCalculateRenderPositionsCS;
+};
+
+JPH_NAMESPACE_END

+ 42 - 0
Jolt/Shaders/HairApplyDeltaTransform.hlsl

@@ -0,0 +1,42 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairApplyDeltaTransformBindings.h"
+#include "HairCommon.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x + cNumStrands; // Skip the root of each strand, it's fixed
+	if (vtx >= cNumVertices)
+		return;
+	if (IsVertexFixed(gVerticesFixed, vtx))
+		return;
+
+	// Load the material
+	uint strand_idx = vtx % cNumStrands;
+	JPH_HairMaterial material = gMaterials[GetStrandMaterialIndex(gStrandMaterialIndex, strand_idx)];
+
+	// Load the vertex
+	float strand_fraction = GetVertexStrandFraction(gStrandFractions, vtx);
+	JPH_HairPosition pos = gPositions[vtx];
+	JPH_HairVelocity vel = gVelocities[vtx];
+
+	// Transform the position so that it stays in the same place in world space (if influence is 1)
+	float influence = GradientSamplerSample(material.mWorldTransformInfluence, strand_fraction);
+	pos.mPosition += influence * (JPH_Mat44Mul3x4Vec3(cDeltaTransform, pos.mPosition) - pos.mPosition);
+
+	// Linear interpolate the rotation based on the influence
+	pos.mRotation = normalize(JPH_QuatMulQuat(influence * cDeltaTransformQuat + float4(0, 0, 0, 1.0f - influence), pos.mRotation));
+
+	// Transform velocities
+	vel.mVelocity = JPH_Mat44Mul3x3Vec3(cDeltaTransform, vel.mVelocity);
+	vel.mAngularVelocity = JPH_Mat44Mul3x3Vec3(cDeltaTransform, vel.mAngularVelocity);
+
+	// Write back vertex
+	gPositions[vtx] = pos;
+	gVelocities[vtx] = vel;
+}

+ 14 - 0
Jolt/Shaders/HairApplyDeltaTransformBindings.h

@@ -0,0 +1,14 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairApplyDeltaTransform)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gVerticesFixed)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandFractions)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandMaterialIndex)
+	JPH_SHADER_BIND_BUFFER(JPH_HairMaterial, gMaterials)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairVelocity, gVelocities)
+JPH_SHADER_BIND_END(JPH_HairApplyDeltaTransform)

+ 19 - 0
Jolt/Shaders/HairApplyGlobalPose.h

@@ -0,0 +1,19 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+void ApplyGlobalPose(JPH_IN_OUT(JPH_HairPosition) ioPos, float3 inRestPosition, JPH_Quat inRestOrientation, JPH_IN(JPH_HairGlobalPoseTransform) inGlobalPoseTransform, JPH_IN(JPH_HairMaterial) inMaterial, float inStrandFraction)
+{
+	// LERP between stored global pose and global pose skinned to the scalp
+	float skin_factor = GradientSamplerSample(inMaterial.mSkinGlobalPose, inStrandFraction);
+	float3 in_position = inRestPosition;
+	in_position += skin_factor * (inGlobalPoseTransform.mPosition + JPH_QuatMulVec3(inGlobalPoseTransform.mRotation, in_position) - in_position);
+	JPH_Quat in_rotation = inRestOrientation;
+	in_rotation += skin_factor * (JPH_QuatMulQuat(inGlobalPoseTransform.mRotation, in_rotation) - in_rotation);
+
+	// LERP between simulated position and skinned position
+	float pose_factor = GradientSamplerSample(inMaterial.mGlobalPose, inStrandFraction);
+	ioPos.mPosition += pose_factor * (in_position - ioPos.mPosition);
+	ioPos.mRotation += pose_factor * (in_rotation - ioPos.mRotation);
+	ioPos.mRotation = normalize(ioPos.mRotation);
+}

+ 38 - 0
Jolt/Shaders/HairApplyGlobalPose.hlsl

@@ -0,0 +1,38 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairApplyGlobalPoseBindings.h"
+#include "HairCommon.h"
+#include "HairApplyGlobalPose.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x + cNumStrands; // Skip the root of each strand, it's fixed
+	if (vtx >= cNumVertices)
+		return;
+	if (IsVertexFixed(gVerticesFixed, vtx))
+		return;
+
+	// Load the material
+	uint strand_idx = vtx % cNumStrands;
+	JPH_HairMaterial material = gMaterials[GetStrandMaterialIndex(gStrandMaterialIndex, strand_idx)];
+
+	// Load the vertex
+	float strand_fraction = GetVertexStrandFraction(gStrandFractions, vtx);
+	float3 initial_pos = gInitialPositions[vtx];
+	float4 initial_bishop = JPH_QuatDecompress(gInitialBishops[vtx]);
+	JPH_HairGlobalPoseTransform global_pose_transform = gGlobalPoseTransforms[strand_idx];
+
+	// Only apply global pose
+	JPH_HairPosition pos;
+	pos.mPosition = float3(0, 0, 0);
+	pos.mRotation = float4(0, 0, 0, 0);
+	ApplyGlobalPose(pos, initial_pos, initial_bishop, global_pose_transform, material, strand_fraction);
+
+	// Write back vertex
+	gPositions[vtx] = pos;
+}

+ 16 - 0
Jolt/Shaders/HairApplyGlobalPoseBindings.h

@@ -0,0 +1,16 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairApplyGlobalPose)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gVerticesFixed)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandFractions)
+	JPH_SHADER_BIND_BUFFER(JPH_float3, gInitialPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gInitialBishops)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandMaterialIndex)
+	JPH_SHADER_BIND_BUFFER(JPH_HairMaterial, gMaterials)
+	JPH_SHADER_BIND_BUFFER(JPH_HairGlobalPoseTransform, gGlobalPoseTransforms)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+JPH_SHADER_BIND_END(JPH_HairApplyGlobalPose)

+ 114 - 0
Jolt/Shaders/HairCalculateCollisionPlanes.hlsl

@@ -0,0 +1,114 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairCalculateCollisionPlanesBindings.h"
+#include "HairCommon.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x + cNumStrands; // Skip the root of each strand, it's fixed
+	if (vtx >= cNumVertices)
+		return;
+
+	// Load the vertex
+	float3 pos = gPositions[vtx].mPosition;
+
+	// Start with a plane that is far away (i.e. no collision)
+	JPH_HairCollisionPlane collision_plane;
+	collision_plane.mPlane = float4(1, 0, 0, 1.0e6f);
+	collision_plane.mShapeIndex = 0;
+	float largest_penetration = -1.0e6f;
+
+	// Loop over all shapes
+	uint current_idx = 0;
+	uint current_plane = 0;
+	for (uint current_shape_idx = 0;; ++current_shape_idx)
+	{
+		// Find most facing plane
+		float max_distance = -1.0e6f;
+		float3 max_plane_normal = float3(0, 0, 0);
+		uint max_plane_face_info = 0;
+
+		// Get number of faces in this shape
+		uint nf = gShapeIndices[current_idx++];
+		if (nf == 0)
+			break;
+
+		for (uint f = 0; f < nf; ++f)
+		{
+			// Get the plane
+			JPH_Plane plane = gShapePlanes[current_plane++];
+			float distance = JPH_PlaneSignedDistance(plane, pos);
+			if (distance > max_distance)
+			{
+				max_distance = distance;
+				max_plane_normal = JPH_PlaneGetNormal(plane);
+				max_plane_face_info = current_idx;
+			}
+
+			// Skip over vertex start and end
+			current_idx += 2;
+		}
+
+		// Project point onto that plane, in local space to the vertex
+		float3 closest_point = -max_distance * max_plane_normal;
+
+		// Check edges if we're outside the hull (when inside we know the closest face is also the closest point to the surface)
+		bool is_outside = max_distance > 0.0f;
+		if (is_outside)
+		{
+			// Loop over edges
+			float closest_point_dist_sq = 1.0e12f;
+			uint vi = gShapeIndices[max_plane_face_info];
+			uint vi_end = gShapeIndices[max_plane_face_info + 1];
+			float3 p1 = gShapeVertices[gShapeIndices[vi_end - 1]];
+			for (; vi < vi_end; ++vi)
+			{
+				// Get edge points
+				float3 p2 = gShapeVertices[gShapeIndices[vi]];
+
+				// Check if the position is outside the edge (if not, the face will be closer)
+				float3 p1_p2 = p2 - p1;
+				float3 p1_pos = p1 - pos;
+				float3 edge_normal = cross(p1_p2, max_plane_normal);
+				if (dot(edge_normal, p1_pos) <= 0.0f)
+				{
+					// Get closest point on edge
+					float3 closest = JPH_GetClosestPointOnLine(p1_pos, p1_p2);
+					float distance_sq = dot(closest, closest);
+					if (distance_sq < closest_point_dist_sq)
+					{
+						closest_point_dist_sq = distance_sq;
+						closest_point = closest;
+					}
+				}
+
+				// Cycle vertex
+				p1 = p2;
+			}
+		}
+
+		// Check if this is the largest penetration
+		float3 normal = -closest_point;
+		float normal_length = length(normal);
+		float penetration = normal_length;
+		if (is_outside)
+			penetration = -penetration;
+		else
+			normal = -normal;
+		if (penetration > largest_penetration)
+		{
+			// Calculate contact plane
+			normal = normal_length > 0.0f? normal / normal_length : max_plane_normal;
+			collision_plane.mPlane = JPH_PlaneFromPointAndNormal(pos + closest_point, normal);
+			collision_plane.mShapeIndex = current_shape_idx;
+			largest_penetration = penetration;
+		}
+	}
+
+	gCollisionPlanes[vtx] = collision_plane;
+}

+ 13 - 0
Jolt/Shaders/HairCalculateCollisionPlanesBindings.h

@@ -0,0 +1,13 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairCalculateCollisionPlanes)
+	JPH_SHADER_BIND_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_Plane, gShapePlanes)
+	JPH_SHADER_BIND_BUFFER(JPH_float3, gShapeVertices)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gShapeIndices)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairCollisionPlane, gCollisionPlanes)
+JPH_SHADER_BIND_END(JPH_HairCalculateCollisionPlanes)

+ 16 - 0
Jolt/Shaders/HairCalculateRenderPositions.h

@@ -0,0 +1,16 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+float3 SkinRenderVertex(uint inVertexIndex)
+{
+	// Calculating resulting render position
+	float3 out_position = float3(0, 0, 0);
+	for (uint idx = inVertexIndex * cHairNumSVertexInfluences, idx_end = idx + cHairNumSVertexInfluences; idx < idx_end; ++idx)
+	{
+		JPH_HairSVertexInfluence inf = gSVertexInfluences[idx];
+		JPH_HairPosition sim_vtx = gPositions[inf.mVertexIndex];
+		out_position += inf.mWeight * (sim_vtx.mPosition + JPH_QuatMulVec3(sim_vtx.mRotation, inf.mRelativePosition));
+	}
+	return out_position;
+}

+ 22 - 0
Jolt/Shaders/HairCalculateRenderPositions.hlsl

@@ -0,0 +1,22 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairCalculateRenderPositionsBindings.h"
+#include "HairCommon.h"
+#include "HairCalculateRenderPositions.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerRenderVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x;
+	if (vtx >= cNumRenderVertices)
+		return;
+
+	float3 out_position = SkinRenderVertex(vtx);
+
+	// Copy the vertex position to the output buffer
+	gRenderPositions[vtx] = out_position;
+}

+ 16 - 0
Jolt/Shaders/HairCalculateRenderPositionsBindings.h

@@ -0,0 +1,16 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+// Overridable output type
+#ifndef JPH_SHADER_BIND_RENDER_POSITIONS
+	#define JPH_SHADER_BIND_RENDER_POSITIONS(name)	JPH_SHADER_BIND_RW_BUFFER(JPH_float3, name)
+#endif
+
+JPH_SHADER_BIND_BEGIN(JPH_HairCalculateRenderPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_HairSVertexInfluence, gSVertexInfluences)
+	JPH_SHADER_BIND_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_RENDER_POSITIONS(gRenderPositions)
+JPH_SHADER_BIND_END(JPH_HairCalculateRenderPositions)

+ 56 - 0
Jolt/Shaders/HairCommon.h

@@ -0,0 +1,56 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "ShaderMath.h"
+#include "ShaderMat44.h"
+#include "ShaderQuat.h"
+#include "ShaderPlane.h"
+#include "ShaderVec3.h"
+
+// The density and velocity fields are stored in fixed point while accumulating, this constant converts from float to fixed point.
+JPH_SHADER_CONSTANT(int, cFloatToFixed, 1 << 10)
+JPH_SHADER_CONSTANT(float, cFixedToFloat, 1.0f / float(cFloatToFixed))
+
+bool IsVertexFixed(JPH_SHADER_BUFFER(JPH_uint) inVertexFixed, uint inVertexIndex)
+{
+	return (inVertexFixed[inVertexIndex >> 5] & (1u << (inVertexIndex & 31))) != 0;
+}
+
+float GetVertexInvMass(JPH_SHADER_BUFFER(JPH_uint) inVertexFixed, uint inVertexIndex)
+{
+	return IsVertexFixed(inVertexFixed, inVertexIndex)? 0.0f : 1.0f;
+}
+
+float GetVertexStrandFraction(JPH_SHADER_BUFFER(JPH_uint) inStrandFractions, uint inVertexIndex)
+{
+	return ((inStrandFractions[inVertexIndex >> 2] >> ((inVertexIndex & 3) << 3)) & 0xff) * (1.0f / 255.0f);
+}
+
+uint GetStrandVertexCount(JPH_SHADER_BUFFER(JPH_uint) inStrandVertexCounts, uint inStrandIndex)
+{
+	return (inStrandVertexCounts[inStrandIndex >> 2] >> ((inStrandIndex & 3) << 3)) & 0xff;
+}
+
+uint GetStrandMaterialIndex(JPH_SHADER_BUFFER(JPH_uint) inStrandMaterialIndex, uint inStrandIndex)
+{
+	return (inStrandMaterialIndex[inStrandIndex >> 2] >> ((inStrandIndex & 3) << 3)) & 0xff;
+}
+
+float GradientSamplerSample(float4 inSampler, float inStrandFraction)
+{
+	return min(inSampler.w, max(inSampler.z, inSampler.y + inStrandFraction * inSampler.x));
+}
+
+void GridPositionToIndexAndFraction(float3 inPosition, JPH_OUT(uint3) outIndex, JPH_OUT(float3) outFraction)
+{
+	// Get position in grid space
+	float3 grid_pos = min(max(inPosition - cGridOffset, float3(0, 0, 0)) * cGridScale, cGridSizeMin1);
+	outIndex = min(uint3(grid_pos), cGridSizeMin2);
+	outFraction = grid_pos - float3(outIndex);
+}
+
+uint GridIndexToBufferIndex(uint3 inIndex)
+{
+	return inIndex.x + inIndex.y * cGridStride.y + inIndex.z * cGridStride.z;
+}

+ 50 - 0
Jolt/Shaders/HairGridAccumulate.hlsl

@@ -0,0 +1,50 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairGridAccumulateBindings.h"
+#include "HairCommon.h"
+
+void AtomicAddVelocityAndDensity(uint inIndex, int4 inValue)
+{
+	JPH_AtomicAdd(gVelocityAndDensity[inIndex].x, inValue.x);
+	JPH_AtomicAdd(gVelocityAndDensity[inIndex].y, inValue.y);
+	JPH_AtomicAdd(gVelocityAndDensity[inIndex].z, inValue.z);
+	JPH_AtomicAdd(gVelocityAndDensity[inIndex].w, inValue.w);
+}
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check that we are processing a valid vertex
+	uint vtx = tid.x + cNumStrands; // Skip the root of each strand, it's fixed
+	if (vtx >= cNumVertices)
+		return;
+	if (IsVertexFixed(gVerticesFixed, vtx))
+		return;
+
+	// Convert position to grid index and fraction
+	uint3 index;
+	float3 ma;
+	GridPositionToIndexAndFraction(gPositions[vtx].mPosition, index, ma);
+	float3 a = float3(1, 1, 1) - ma;
+
+	// Get velocity
+	float4 velocity_and_density = float4(gVelocities[vtx].mVelocity, 1) * cFloatToFixed;
+
+	// Calculate contribution of density and velocity for each cell
+	uint3 stride = cGridStride;
+	uint adr_000 = GridIndexToBufferIndex(index);
+	uint adr_100 = adr_000 + 1;
+	uint adr_010 = adr_000 + stride.y;
+	uint adr_110 = adr_010 + 1;
+	AtomicAddVelocityAndDensity(adr_000,            (int4)round( a.x *  a.y *  a.z * velocity_and_density));
+	AtomicAddVelocityAndDensity(adr_100,            (int4)round(ma.x *  a.y *  a.z * velocity_and_density));
+	AtomicAddVelocityAndDensity(adr_010,            (int4)round( a.x * ma.y *  a.z * velocity_and_density));
+	AtomicAddVelocityAndDensity(adr_110,            (int4)round(ma.x * ma.y *  a.z * velocity_and_density));
+	AtomicAddVelocityAndDensity(adr_000 + stride.z, (int4)round( a.x *  a.y * ma.z * velocity_and_density));
+	AtomicAddVelocityAndDensity(adr_100 + stride.z, (int4)round(ma.x *  a.y * ma.z * velocity_and_density));
+	AtomicAddVelocityAndDensity(adr_010 + stride.z, (int4)round( a.x * ma.y * ma.z * velocity_and_density));
+	AtomicAddVelocityAndDensity(adr_110 + stride.z, (int4)round(ma.x * ma.y * ma.z * velocity_and_density));
+}

+ 12 - 0
Jolt/Shaders/HairGridAccumulateBindings.h

@@ -0,0 +1,12 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairGridAccumulate)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gVerticesFixed)
+	JPH_SHADER_BIND_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_HairVelocity, gVelocities)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_int4, gVelocityAndDensity)
+JPH_SHADER_BIND_END(JPH_HairGridAccumulate)

+ 17 - 0
Jolt/Shaders/HairGridClear.hlsl

@@ -0,0 +1,17 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairGridClearBindings.h"
+#include "HairCommon.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerGridCellBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	uint index = tid.x;
+	if (index >= cNumGridPoints)
+		return;
+
+	gVelocityAndDensity[index] = int4(0, 0, 0, 0);
+}

+ 9 - 0
Jolt/Shaders/HairGridClearBindings.h

@@ -0,0 +1,9 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairGridClear)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_int4, gVelocityAndDensity)
+JPH_SHADER_BIND_END(JPH_HairGridClear)

+ 26 - 0
Jolt/Shaders/HairGridNormalize.hlsl

@@ -0,0 +1,26 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairGridNormalizeBindings.h"
+#include "HairCommon.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerGridCellBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	uint index = tid.x;
+	if (index >= cNumGridPoints)
+		return;
+
+	// Convert from fixed point back to float and divide velocity by density to get average velocity
+	float4 v = (float4)gVelocityAndDensity[index] * cFixedToFloat;
+	float density = v.w;
+	if (density > 1.0e-12f)
+	{
+		v.x /= density;
+		v.y /= density;
+		v.z /= density;
+	}
+	gVelocityAndDensity[index] = asint(v);
+}

+ 9 - 0
Jolt/Shaders/HairGridNormalizeBindings.h

@@ -0,0 +1,9 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairGridNormalize)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_int4, gVelocityAndDensity)
+JPH_SHADER_BIND_END(JPH_HairGridNormalize)

+ 88 - 0
Jolt/Shaders/HairIntegrate.h

@@ -0,0 +1,88 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+float DeltaDensity(uint inIndex)
+{
+	return gVelocityAndDensity[inIndex].w - gNeutralDensity[inIndex];
+}
+
+void ApplyGrid(JPH_IN(JPH_HairPosition) inPos, JPH_IN_OUT(JPH_HairVelocity) ioVel, JPH_IN(JPH_HairMaterial) inMaterial, float inStrandFraction)
+{
+	if (!inMaterial.mEnableGrid)
+		return;
+
+	// Convert position to grid index and fraction
+	uint3 index;
+	float3 ma;
+	GridPositionToIndexAndFraction(inPos.mPosition, index, ma);
+	float3 a = float3(1, 1, 1) - ma;
+
+	// Get average velocity at the vertex position (trilinear sample)
+	float3 velocity;
+	uint3 stride = cGridStride;
+	uint adr_000 = GridIndexToBufferIndex(index);
+	uint adr_100 = adr_000 + 1;
+	uint adr_010 = adr_000 + stride.y;
+	uint adr_110 = adr_010 + 1;
+	velocity =  gVelocityAndDensity[adr_000].xyz            * ( a.x *  a.y *  a.z);
+	velocity += gVelocityAndDensity[adr_100].xyz            * (ma.x *  a.y *  a.z);
+	velocity += gVelocityAndDensity[adr_010].xyz            * ( a.x * ma.y *  a.z);
+	velocity += gVelocityAndDensity[adr_110].xyz            * (ma.x * ma.y *  a.z);
+	velocity += gVelocityAndDensity[adr_000 + stride.z].xyz * ( a.x *  a.y * ma.z);
+	velocity += gVelocityAndDensity[adr_100 + stride.z].xyz * (ma.x *  a.y * ma.z);
+	velocity += gVelocityAndDensity[adr_010 + stride.z].xyz * ( a.x * ma.y * ma.z);
+	velocity += gVelocityAndDensity[adr_110 + stride.z].xyz * (ma.x * ma.y * ma.z);
+
+	// Drive towards the average velocity of the cell
+	ioVel.mVelocity += GradientSamplerSample(inMaterial.mGridVelocityFactor, inStrandFraction) * (velocity - ioVel.mVelocity);
+
+	// Calculate force to go towards neutral density
+	// Based on eq 3 of Volumetric Methods for Simulation and Rendering of Hair - Lena Petrovic, Mark Henne and John Anderson
+	float dd000 = DeltaDensity(adr_000);
+	float dd100 = DeltaDensity(adr_100);
+	float dd010 = DeltaDensity(adr_010);
+	float dd110 = DeltaDensity(adr_110);
+	float dd001 = DeltaDensity(adr_000 + stride.z);
+	float dd101 = DeltaDensity(adr_100 + stride.z);
+	float dd011 = DeltaDensity(adr_010 + stride.z);
+	float dd111 = DeltaDensity(adr_110 + stride.z);
+
+	float3 force = float3(
+		   a.y *  a.z * (dd000 - dd100)
+		+ ma.y *  a.z * (dd010 - dd110)
+		+  a.y * ma.z * (dd001 - dd101)
+		+ ma.y * ma.z * (dd011 - dd111),
+
+		   a.x *  a.z * (dd000 - dd010)
+		+ ma.x *  a.z * (dd100 - dd110)
+		+  a.x * ma.z * (dd001 - dd011)
+		+ ma.x * ma.z * (dd101 - dd111),
+
+		   a.x *  a.y * (dd000 - dd001)
+		+ ma.x *  a.y * (dd100 - dd101)
+		+  a.x * ma.y * (dd010 - dd011)
+		+ ma.x * ma.y * (dd110 - dd111));
+
+	ioVel.mVelocity += inMaterial.mGridDensityForceFactor * force * cDeltaTime; // / mass, but mass is 1
+}
+
+void Integrate(JPH_IN_OUT(JPH_HairPosition) ioPos, JPH_IN(JPH_HairVelocity) inVel, JPH_IN(JPH_HairMaterial) inMaterial, float inStrandFraction)
+{
+	JPH_HairVelocity vel = inVel;
+
+	// Gravity
+	vel.mVelocity += cSubStepGravity * GradientSamplerSample(inMaterial.mGravityFactor, inStrandFraction);
+
+	// Damping
+	vel.mVelocity *= inMaterial.mExpLinearDampingDeltaTime;
+	vel.mAngularVelocity *= inMaterial.mExpAngularDampingDeltaTime;
+
+	// Integrate position
+	ioPos.mPosition += vel.mVelocity * cDeltaTime;
+
+	// Integrate rotation
+	JPH_Quat rotation = ioPos.mRotation;
+	JPH_Quat delta_rotation = cHalfDeltaTime * JPH_QuatImaginaryMulQuat(vel.mAngularVelocity, rotation);
+	ioPos.mRotation = normalize(rotation + delta_rotation);
+}

+ 35 - 0
Jolt/Shaders/HairIntegrate.hlsl

@@ -0,0 +1,35 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairIntegrateBindings.h"
+#include "HairCommon.h"
+#include "HairIntegrate.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x + cNumStrands; // Skip the root of each strand, it's fixed
+	if (vtx >= cNumVertices)
+		return;
+	if (IsVertexFixed(gVerticesFixed, vtx))
+		return;
+
+	// Load the material
+	uint strand_idx = vtx % cNumStrands;
+	JPH_HairMaterial material = gMaterials[GetStrandMaterialIndex(gStrandMaterialIndex, strand_idx)];
+
+	// Load the vertex
+	float strand_fraction = GetVertexStrandFraction(gStrandFractions, vtx);
+	JPH_HairPosition pos = gPositions[vtx];
+	JPH_HairVelocity vel = gVelocities[vtx];
+
+	// Update previous position
+	gPreviousPositions[vtx] = pos;
+
+	ApplyGrid(pos, vel, material, strand_fraction);
+	Integrate(pos, vel, material, strand_fraction);
+	gPositions[vtx] = pos;
+}

+ 17 - 0
Jolt/Shaders/HairIntegrateBindings.h

@@ -0,0 +1,17 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairIntegrate)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gVerticesFixed)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandFractions)
+	JPH_SHADER_BIND_BUFFER(JPH_float, gNeutralDensity)
+	JPH_SHADER_BIND_BUFFER(JPH_float4, gVelocityAndDensity)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandMaterialIndex)
+	JPH_SHADER_BIND_BUFFER(JPH_HairMaterial, gMaterials)
+	JPH_SHADER_BIND_BUFFER(JPH_HairVelocity, gVelocities)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPreviousPositions)
+JPH_SHADER_BIND_END(JPH_HairIntegrate)

+ 50 - 0
Jolt/Shaders/HairSkinRoots.hlsl

@@ -0,0 +1,50 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairSkinRootsBindings.h"
+#include "HairCommon.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerStrandBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid strand
+	uint strand_idx = tid.x;
+	if (strand_idx >= cNumStrands)
+		return;
+
+	JPH_HairSkinPoint sp = gSkinPoints[strand_idx];
+
+	// Get the vertices of the attached triangle
+	uint tri_idx = sp.mTriangleIndex * 3;
+	float3 v0 = JPH_Mat44Mul3x3Vec3(cScalpToHead, gScalpVertices[gScalpTriangles[tri_idx + 0]]);
+	float3 v1 = JPH_Mat44Mul3x3Vec3(cScalpToHead, gScalpVertices[gScalpTriangles[tri_idx + 1]]);
+	float3 v2 = JPH_Mat44Mul3x3Vec3(cScalpToHead, gScalpVertices[gScalpTriangles[tri_idx + 2]]);
+
+	JPH_HairPosition root;
+
+	// Set the position of the root
+	root.mPosition = sp.mU * v0 + sp.mV * v1 + (1.0f - sp.mU - sp.mV) * v2 + cScalpToHead[3].xyz;
+
+	// Get tangent vector
+	float3 tangent = normalize(v1 - v0);
+
+	// Get normal of the triangle
+	float3 normal = normalize(cross(tangent, v2 - v0));
+
+	// Calculate basis for the triangle
+	float3 binormal = cross(tangent, normal);
+	JPH_Quat triangle_basis = JPH_QuatFromMat33(normal, binormal, tangent);
+
+	// Calculate the new Bishop frame of the root
+	root.mRotation = JPH_QuatMulQuat(triangle_basis, JPH_QuatDecompress(sp.mToBishop));
+
+	gPositions[strand_idx] = root;
+
+	// Calculate the transform that transforms the stored global pose to the space of the skinned root of the strand
+	JPH_HairGlobalPoseTransform transform;
+	transform.mRotation = JPH_QuatMulQuat(root.mRotation, JPH_QuatConjugate(JPH_QuatDecompress(gInitialBishops[strand_idx])));
+	transform.mPosition = root.mPosition - JPH_QuatMulVec3(transform.mRotation, gInitialPositions[strand_idx]);
+	gGlobalPoseTransforms[strand_idx] = transform;
+}

+ 23 - 0
Jolt/Shaders/HairSkinRootsBindings.h

@@ -0,0 +1,23 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+// Overridable input types
+#ifndef JPH_SHADER_BIND_SCALP_VERTICES
+	#define JPH_SHADER_BIND_SCALP_VERTICES(name)	JPH_SHADER_BIND_BUFFER(JPH_float3, name)
+#endif
+#ifndef JPH_SHADER_BIND_SCALP_TRIANGLES
+	#define JPH_SHADER_BIND_SCALP_TRIANGLES(name)	JPH_SHADER_BIND_BUFFER(JPH_uint, name)
+#endif
+
+JPH_SHADER_BIND_BEGIN(JPH_HairSkinRoots)
+	JPH_SHADER_BIND_BUFFER(JPH_HairSkinPoint, gSkinPoints)
+	JPH_SHADER_BIND_SCALP_VERTICES(gScalpVertices)
+	JPH_SHADER_BIND_SCALP_TRIANGLES(gScalpTriangles)
+	JPH_SHADER_BIND_BUFFER(JPH_float3, gInitialPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gInitialBishops)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairGlobalPoseTransform, gGlobalPoseTransforms)
+JPH_SHADER_BIND_END(JPH_HairSkinRoots)

+ 26 - 0
Jolt/Shaders/HairSkinVertices.hlsl

@@ -0,0 +1,26 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairSkinVerticesBindings.h"
+#include "HairCommon.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x;
+		if (vtx >= cNumSkinVertices)
+			return;
+
+	// Skin the vertex
+	float3 v = float3(0, 0, 0);
+	for (uint w = vtx * cNumSkinWeightsPerVertex, w_end = w + cNumSkinWeightsPerVertex; w < w_end; ++w)
+	{
+		JPH_HairSkinWeight sw = gScalpSkinWeights[w];
+		if (sw.mWeight > 0.0f)
+			v += sw.mWeight * JPH_Mat44Mul3x4Vec3(gScalpJointMatrices[sw.mJointIdx], gScalpVertices[vtx]);
+	}
+	gScalpVerticesOut[vtx] = v;
+}

+ 12 - 0
Jolt/Shaders/HairSkinVerticesBindings.h

@@ -0,0 +1,12 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairSkinVertices)
+	JPH_SHADER_BIND_BUFFER(JPH_float3, gScalpVertices)
+	JPH_SHADER_BIND_BUFFER(JPH_HairSkinWeight, gScalpSkinWeights)
+	JPH_SHADER_BIND_BUFFER(JPH_Mat44, gScalpJointMatrices)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_float3, gScalpVerticesOut)
+JPH_SHADER_BIND_END(JPH_HairSkinVertices)

+ 120 - 0
Jolt/Shaders/HairStructs.h

@@ -0,0 +1,120 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "ShaderCore.h"
+
+// Prevent including this file multiple times unless we're generating bindings
+#if !defined(HAIR_STRUCTS_H) || defined(JPH_SHADER_GENERATE_WRAPPER)
+#ifndef JPH_SHADER_GENERATE_WRAPPER
+#define HAIR_STRUCTS_H
+#endif
+
+JPH_SUPPRESS_WARNING_PUSH
+JPH_SUPPRESS_WARNINGS
+
+JPH_SHADER_CONSTANT(int, cHairPerVertexBatch, 64)
+JPH_SHADER_CONSTANT(int, cHairPerGridCellBatch, 32)
+JPH_SHADER_CONSTANT(int, cHairPerStrandBatch, 32)
+JPH_SHADER_CONSTANT(int, cHairPerRenderVertexBatch, 128)
+
+JPH_SHADER_CONSTANT(int, cHairNumSVertexInfluences, 3)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairSkinWeight)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			JointIdx)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			Weight)
+JPH_SHADER_STRUCT_END(JPH_HairSkinWeight)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairSkinPoint)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			TriangleIndex)		///< Index of triangle in mScalpVertices to which this skin point is attached
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			U)					///< Barycentric u coordinate of skin point
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			V)					///< Barycentric v coordinate of skin point
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			ToBishop)			///< Compressed quaternion to rotate the frame defined by the triangle normal and the first edge to the Bishop frame of the first vertex of the strand
+JPH_SHADER_STRUCT_END(JPH_HairSkinPoint)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairGlobalPoseTransform)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		Position)
+	JPH_SHADER_STRUCT_MEMBER(JPH_Quat,			Rotation)
+JPH_SHADER_STRUCT_END(JPH_HairGlobalPoseTransform)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairSVertexInfluence)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			VertexIndex)		///< Index in mSimVertices that indicates to which simulated vertex this vertex is attached.
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		RelativePosition)	///< Position in local space from the simulated vertex to the render vertex
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			Weight)				///< Influence weight, 0 = not attached, 1 = fully attached
+JPH_SHADER_STRUCT_END(JPH_HairSVertexInfluence)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairPosition)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		Position)
+	JPH_SHADER_STRUCT_MEMBER(JPH_Quat,			Rotation)
+JPH_SHADER_STRUCT_END(JPH_HairPosition)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairVelocity)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		Velocity)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		AngularVelocity)
+JPH_SHADER_STRUCT_END(JPH_HairVelocity)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairMaterial)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float4,		WorldTransformInfluence)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float4,		GlobalPose)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float4,		SkinGlobalPose)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float4,		GravityFactor)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float4,		HairRadius)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float4,		BendComplianceMultiplier)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float4,		GridVelocityFactor)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			EnableCollision)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			EnableLRA)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			EnableGrid)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			Friction)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			ExpLinearDampingDeltaTime)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			ExpAngularDampingDeltaTime)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			BendComplianceInvDeltaTimeSq)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			StretchComplianceInvDeltaTimeSq)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			GridDensityForceFactor)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			InertiaMultiplier)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			MaxLinearVelocitySq)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float,			MaxAngularVelocitySq)
+JPH_SHADER_STRUCT_END(JPH_HairMaterial)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairCollisionPlane)
+	JPH_SHADER_STRUCT_MEMBER(JPH_Plane,			Plane)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			ShapeIndex)
+JPH_SHADER_STRUCT_END(JPH_HairCollisionPlane)
+
+JPH_SHADER_STRUCT_BEGIN(JPH_HairCollisionShape)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		CenterOfMass)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		LinearVelocity)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		AngularVelocity)
+JPH_SHADER_STRUCT_END(JPH_HairCollisionShape)
+
+// Note: The order was chosen to match the struct between C++ and HLSL.
+JPH_SHADER_CONSTANTS_BEGIN(JPH_HairUpdateContext, gContext)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumStrands)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumVertices)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumGridPoints)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumRenderVertices)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint3,		GridSizeMin2)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float,		TwoDivDeltaTime)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		GridSizeMin1)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float,		DeltaTime)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		GridOffset)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float,		HalfDeltaTime)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		GridScale)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float,		InvDeltaTimeSq)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		SubStepGravity)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumSkinVertices)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint3,		GridStride)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumSkinWeightsPerVertex)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_Mat44,		DeltaTransform)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_Mat44,		ScalpToHead)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_Quat,		DeltaTransformQuat)
+JPH_SHADER_CONSTANTS_END(JPH_HairUpdateContext)
+
+// Note: The order was chosen to match the struct between C++ and HLSL.
+JPH_SHADER_CONSTANTS_BEGIN(JPH_HairIterationContext, gIterationContext)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float,		AccumulatedDeltaTime)		///< = Iteration * DeltaTime
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float,		IterationFraction)			///< = 1 / (NumIterations - Iteration) or the fraction to apply to get from current to target for this iteration step
+JPH_SHADER_CONSTANTS_END(JPH_HairIterationContext)
+
+JPH_SUPPRESS_WARNING_POP
+
+#endif // !defined(HAIR_STRUCTS_H) || defined(JPH_SHADER_GENERATE_WRAPPER)

+ 28 - 0
Jolt/Shaders/HairTeleport.hlsl

@@ -0,0 +1,28 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairTeleportBindings.h"
+#include "HairCommon.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x;
+	if (vtx >= cNumVertices)
+		return;
+
+	// Initialize position based on the initial vertex data
+	JPH_HairPosition pos;
+	pos.mPosition = gInitialPositions[vtx];
+	pos.mRotation = JPH_QuatDecompress(gInitialBishops[vtx]);
+	gPositions[vtx] = pos;
+
+	// Initialize velocity to zero
+	JPH_HairVelocity vel;
+	vel.mVelocity = float3(0, 0, 0);
+	vel.mAngularVelocity = float3(0, 0, 0);
+	gVelocities[vtx] = vel;
+}

+ 12 - 0
Jolt/Shaders/HairTeleportBindings.h

@@ -0,0 +1,12 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairTeleport)
+	JPH_SHADER_BIND_BUFFER(JPH_float3, gInitialPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gInitialBishops)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairVelocity, gVelocities)
+JPH_SHADER_BIND_END(JPH_HairTeleport)

+ 30 - 0
Jolt/Shaders/HairUpdateRoots.hlsl

@@ -0,0 +1,30 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairUpdateRootsBindings.h"
+#include "HairCommon.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerStrandBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid strand
+	uint strand_idx = tid.x;
+	if (strand_idx >= cNumStrands)
+		return;
+
+	float inv_fraction = 1.0f - cIterationFraction;
+
+	JPH_HairPosition pos = gPositions[strand_idx];
+	JPH_HairPosition target_pos = gTargetPositions[strand_idx];
+	pos.mPosition = pos.mPosition * inv_fraction + target_pos.mPosition * cIterationFraction;
+	pos.mRotation = normalize(pos.mRotation * inv_fraction + target_pos.mRotation * cIterationFraction);
+	gPositions[strand_idx] = pos;
+
+	JPH_HairGlobalPoseTransform transf = gGlobalPoseTransforms[strand_idx];
+	JPH_HairGlobalPoseTransform target_transf = gTargetGlobalPoseTransforms[strand_idx];
+	transf.mPosition = transf.mPosition * inv_fraction + target_transf.mPosition * cIterationFraction;
+	transf.mRotation = normalize(transf.mRotation * inv_fraction + target_transf.mRotation * cIterationFraction);
+	gGlobalPoseTransforms[strand_idx] = transf;
+}

+ 12 - 0
Jolt/Shaders/HairUpdateRootsBindings.h

@@ -0,0 +1,12 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairUpdateRoots)
+	JPH_SHADER_BIND_BUFFER(JPH_HairPosition, gTargetPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_HairGlobalPoseTransform, gTargetGlobalPoseTransforms)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairGlobalPoseTransform, gGlobalPoseTransforms)
+JPH_SHADER_BIND_END(JPH_HairUpdateRoots)

+ 154 - 0
Jolt/Shaders/HairUpdateStrands.hlsl

@@ -0,0 +1,154 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairUpdateStrandsBindings.h"
+#include "HairCommon.h"
+
+void ApplyLRA(float3 inX0, float inMaxDist, JPH_IN_OUT(JPH_HairPosition) ioV0)
+{
+	float3 delta = ioV0.mPosition - inX0;
+	float delta_len_sq = dot(delta, delta);
+	if (delta_len_sq > JPH_Square(inMaxDist))
+		ioV0.mPosition = inX0 + delta * inMaxDist / sqrt(delta_len_sq);
+}
+
+void ApplyStretchShear(JPH_IN_OUT(JPH_HairPosition) ioV0, JPH_IN_OUT(JPH_HairPosition) ioV1, float inLength01, float inInvMass0, float inInvMass1, JPH_IN(JPH_HairMaterial) inMaterial)
+{
+	// Inertia of a thin rod of length L ~ m * L^2, we take the maximum mass of the two vertices
+	float rod_inv_mass = min(inInvMass0, inInvMass1) / inMaterial.mInertiaMultiplier; // / L^2, which we'll apply later
+
+	// Equation 37 from "Position and Orientation Based Cosserat Rods" - Kugelstadt and Schoemer - SIGGRAPH 2016
+	float denom = inInvMass0 + inInvMass1 + 4.0f * rod_inv_mass /* / L^2 * L^2 cancels */ + inMaterial.mStretchComplianceInvDeltaTimeSq;
+	if (denom >= 1.0e-12f)
+	{
+		float3 x0 = ioV0.mPosition;
+		float3 x1 = ioV1.mPosition;
+		JPH_Quat rotation1 = ioV0.mRotation;
+		float3 d3 = JPH_QuatRotateAxisZ(rotation1);
+		float3 delta = (x1 - x0 - d3 * inLength01) / denom;
+		ioV0.mPosition = x0 + inInvMass0 * delta;
+		ioV1.mPosition = x1 - inInvMass1 * delta;
+		// q * e3_bar = q * (0, 0, -1, 0) = [-qy, qx, -qw, qz]
+		JPH_Quat q_e3_bar = JPH_Quat(-rotation1.y, rotation1.x, -rotation1.w, rotation1.z);
+		rotation1 += (2.0f * rod_inv_mass / inLength01 /* / L^2 * L => / L */) * JPH_QuatImaginaryMulQuat(delta, q_e3_bar);
+		ioV0.mRotation = normalize(rotation1);
+	}
+}
+
+void ApplyBendTwist(JPH_IN_OUT(JPH_HairPosition) ioV0, JPH_IN_OUT(JPH_HairPosition) ioV1, JPH_Quat inOmega0, float inLength01, float inLength12, float inStrandFraction1, float inInvMass0, float inInvMass1, float inInvMass2, JPH_IN(JPH_HairMaterial) inMaterial)
+{
+	// Inertia of a thin rod of length L ~ m * L^2, we take the maximum mass of the two vertices
+	float rod_inv_mass = min(inInvMass0, inInvMass1) / (inMaterial.mInertiaMultiplier * JPH_Square(inLength01));
+	float rod2_inv_mass = min(inInvMass1, inInvMass2) / (inMaterial.mInertiaMultiplier * JPH_Square(inLength12));
+
+	// Calculate multiplier for the bend compliance based on strand fraction
+	float fraction = inStrandFraction1 * 3.0f;
+	uint idx = uint(fraction);
+	fraction = fraction - float(idx);
+	float multiplier = inMaterial.mBendComplianceMultiplier[idx] * (1.0f - fraction) + inMaterial.mBendComplianceMultiplier[idx + 1] * fraction;
+
+	// Equation 40 from "Position and Orientation Based Cosserat Rods" - Kugelstadt and Schoemer - SIGGRAPH 2016
+	float denom = rod_inv_mass + rod2_inv_mass + inMaterial.mBendComplianceInvDeltaTimeSq * multiplier;
+	if (denom >= 1.0e-12f)
+	{
+		JPH_Quat rotation1 = ioV0.mRotation;
+		JPH_Quat rotation2 = ioV1.mRotation;
+		JPH_Quat omega = JPH_QuatMulQuat(JPH_QuatConjugate(rotation1), rotation2);
+		JPH_Quat omega_min_omega0 = omega - inOmega0;
+		JPH_Quat omega_plus_omega0 = omega + inOmega0;
+		// Take the shortest of the two rotations
+		JPH_Quat delta_omega = dot(omega_plus_omega0, omega_plus_omega0) < dot(omega_min_omega0, omega_min_omega0) ? omega_plus_omega0 : omega_min_omega0;
+		delta_omega /= denom;
+		delta_omega.w = 0; // Scalar part needs to be zero because the real part of the Darboux vector doesn't vanish, see text between eq. 39 and 40.
+		JPH_Quat delta_rod2 = rod2_inv_mass * JPH_QuatMulQuat(rotation1, delta_omega);
+		rotation1 += rod_inv_mass * JPH_QuatMulQuat(rotation2, delta_omega);
+		rotation2 -= delta_rod2;
+		ioV0.mRotation = normalize(rotation1);
+		ioV1.mRotation = normalize(rotation2);
+	}
+}
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerStrandBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid strand
+	uint strand_idx = tid.x;
+	if (strand_idx >= cNumStrands)
+		return;
+	uint strand_vtx_count = GetStrandVertexCount(gStrandVertexCounts, strand_idx);
+
+	// Load the material
+	JPH_HairMaterial material = gMaterials[GetStrandMaterialIndex(gStrandMaterialIndex, strand_idx)];
+
+	// Load the first two vertices
+	uint vtx_idx_to_load = strand_idx;
+	float inv_mass0 = GetVertexInvMass(gVerticesFixed, vtx_idx_to_load);
+	float length0 = gInitialLengths[vtx_idx_to_load];
+	JPH_HairPosition v0 = gPositions[vtx_idx_to_load];
+
+	// LRA: Vertex everything is attached to
+	float3 x0 = gInitialPositions[vtx_idx_to_load];
+
+	// LRA: Tracks the distance from the first vertex
+	float max_dist = length0;
+
+	vtx_idx_to_load += cNumStrands;
+	float inv_mass1 = GetVertexInvMass(gVerticesFixed, vtx_idx_to_load);
+	float strand_fraction1 = GetVertexStrandFraction(gStrandFractions, vtx_idx_to_load);
+	float length1 = gInitialLengths[vtx_idx_to_load];
+	JPH_HairPosition v1 = gPositions[vtx_idx_to_load];
+
+	// Process 2nd vertex
+	if (material.mEnableLRA && inv_mass1 > 0.0f)
+		ApplyLRA(x0, max_dist, v1);
+	max_dist += length1;
+
+	uint vtx_idx_to_retire = strand_idx;
+	for (uint vtx = 2; vtx < strand_vtx_count; ++vtx)
+	{
+		// Get the initial rotation difference from the middle vertex
+		JPH_Quat omega0 = JPH_QuatDecompress(gOmega0s[vtx_idx_to_load]);
+
+		// Load the next vertex
+		vtx_idx_to_load += cNumStrands;
+		float inv_mass2 = GetVertexInvMass(gVerticesFixed, vtx_idx_to_load);
+		float strand_fraction2 = GetVertexStrandFraction(gStrandFractions, vtx_idx_to_load);
+		float length2 = gInitialLengths[vtx_idx_to_load];
+		JPH_HairPosition v2 = gPositions[vtx_idx_to_load];
+
+		// Process newly added vertex
+		if (material.mEnableLRA && inv_mass2 > 0.0f)
+			ApplyLRA(x0, max_dist, v2);
+		max_dist += length2;
+
+		// Stitched mode as per Strand-based Hair System - Pedersen - SIGGRAPH 2022
+		ApplyStretchShear(v1, v2, length1, inv_mass1, inv_mass2, material);
+		ApplyStretchShear(v0, v1, length0, inv_mass0, inv_mass1, material);
+		ApplyBendTwist(v0, v1, omega0, length0, length1, strand_fraction1, inv_mass0, inv_mass1, inv_mass2, material);
+
+		// Retire vertex
+		gPositions[vtx_idx_to_retire] = v0;
+		vtx_idx_to_retire += cNumStrands;
+
+		// Shift the vertices
+		inv_mass0 = inv_mass1;
+		inv_mass1 = inv_mass2;
+		strand_fraction1 = strand_fraction2;
+		length0 = length1;
+		length1 = length2;
+		v0 = v1;
+		v1 = v2;
+	}
+
+	// Retire 2nd to last vertex
+	gPositions[vtx_idx_to_retire] = v0;
+	vtx_idx_to_retire += cNumStrands;
+
+	// Cannot calculate rotation for last vertex, take the 2nd last
+	v1.mRotation = v0.mRotation;
+
+	// Retire last vertex
+	gPositions[vtx_idx_to_retire] = v1;
+}

+ 17 - 0
Jolt/Shaders/HairUpdateStrandsBindings.h

@@ -0,0 +1,17 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairUpdateStrands)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gVerticesFixed)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandFractions)
+	JPH_SHADER_BIND_BUFFER(JPH_float3, gInitialPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gOmega0s)
+	JPH_SHADER_BIND_BUFFER(JPH_float, gInitialLengths)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandVertexCounts)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandMaterialIndex)
+	JPH_SHADER_BIND_BUFFER(JPH_HairMaterial, gMaterials)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+JPH_SHADER_BIND_END(JPH_HairUpdateStrands)

+ 64 - 0
Jolt/Shaders/HairUpdateVelocity.h

@@ -0,0 +1,64 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairApplyGlobalPose.h"
+
+void ApplyCollisionAndUpdateVelocity(uint inVtx, JPH_IN_OUT(JPH_HairPosition) ioPos, JPH_IN(JPH_HairPosition) inPreviousPos, JPH_IN(JPH_HairMaterial) inMaterial, float inStrandFraction, JPH_OUT(JPH_HairVelocity) outVel)
+{
+	// Update velocities
+	outVel.mVelocity = (ioPos.mPosition - inPreviousPos.mPosition) / cDeltaTime;
+	outVel.mAngularVelocity = cTwoDivDeltaTime * JPH_QuatMulQuat(ioPos.mRotation, JPH_QuatConjugate(inPreviousPos.mRotation)).xyz;
+
+	if (inMaterial.mEnableCollision)
+	{
+		// Calculate closest point on the collision plane
+		JPH_HairCollisionPlane plane = gCollisionPlanes[inVtx];
+		float distance_to_plane = JPH_PlaneSignedDistance(plane.mPlane, ioPos.mPosition);
+		float3 contact_normal = JPH_PlaneGetNormal(plane.mPlane);
+		float3 point_on_plane = ioPos.mPosition - distance_to_plane * contact_normal;
+
+		// Calculate how much the plane moved in this time step
+		JPH_HairCollisionShape shape = gCollisionShapes[plane.mShapeIndex];
+		float3 plane_velocity = shape.mLinearVelocity + cross(shape.mAngularVelocity, point_on_plane - shape.mCenterOfMass);
+		float plane_movement = dot(plane_velocity, contact_normal) * cAccumulatedDeltaTime;
+
+		float projected_distance = -distance_to_plane + plane_movement + GradientSamplerSample(inMaterial.mHairRadius, inStrandFraction);
+		if (projected_distance > 0.0f)
+		{
+			// Resolve penetration
+			ioPos.mPosition += contact_normal * projected_distance;
+
+			// Only update velocity when moving towards each other
+			float3 v_relative = outVel.mVelocity - plane_velocity;
+			float v_relative_dot_normal = dot(contact_normal, v_relative);
+			if (v_relative_dot_normal < 0.0f)
+			{
+				// Calculate normal and tangential velocity (equation 30)
+				float3 v_normal = contact_normal * v_relative_dot_normal;
+				float3 v_tangential = v_relative - v_normal;
+				float v_tangential_length = length(v_tangential);
+
+				// Apply friction as described in Detailed Rigid Body Simulation with Extended Position Based Dynamics - Matthias Muller et al. (modified equation 31)
+				if (v_tangential_length > 0.0f)
+					outVel.mVelocity -= v_tangential * min(inMaterial.mFriction * projected_distance / (v_tangential_length * cDeltaTime), 1.0f);
+
+				// Apply restitution of zero (equation 35)
+				outVel.mVelocity -= v_normal;
+			}
+		}
+	}
+}
+
+void LimitVelocity(JPH_IN_OUT(JPH_HairVelocity) ioVel, JPH_IN(JPH_HairMaterial) inMaterial)
+{
+	// Limit linear velocity
+	float linear_velocity_sq = dot(ioVel.mVelocity, ioVel.mVelocity);
+	if (linear_velocity_sq > inMaterial.mMaxLinearVelocitySq)
+		ioVel.mVelocity *= sqrt(inMaterial.mMaxLinearVelocitySq / linear_velocity_sq);
+
+	// Limit angular velocity
+	float angular_velocity_sq = dot(ioVel.mAngularVelocity, ioVel.mAngularVelocity);
+	if (angular_velocity_sq > inMaterial.mMaxAngularVelocitySq)
+		ioVel.mAngularVelocity *= sqrt(inMaterial.mMaxAngularVelocitySq / angular_velocity_sq);
+}

+ 40 - 0
Jolt/Shaders/HairUpdateVelocity.hlsl

@@ -0,0 +1,40 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairUpdateVelocityBindings.h"
+#include "HairCommon.h"
+#include "HairUpdateVelocity.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x + cNumStrands; // Skip the root of each strand, it's fixed
+	if (vtx >= cNumVertices)
+		return;
+	if (IsVertexFixed(gVerticesFixed, vtx))
+		return;
+
+	// Load the material
+	uint strand_idx = vtx % cNumStrands;
+	JPH_HairMaterial material = gMaterials[GetStrandMaterialIndex(gStrandMaterialIndex, strand_idx)];
+
+	// Load the vertex
+	float strand_fraction = GetVertexStrandFraction(gStrandFractions, vtx);
+	float3 initial_pos = gInitialPositions[vtx];
+	float4 initial_bishop = JPH_QuatDecompress(gInitialBishops[vtx]);
+	JPH_HairGlobalPoseTransform global_pose_transform = gGlobalPoseTransforms[strand_idx];
+	JPH_HairPosition pos = gPositions[vtx];
+	JPH_HairPosition prev_pos = gPreviousPositions[vtx];
+
+	JPH_HairVelocity vel;
+	ApplyGlobalPose(pos, initial_pos, initial_bishop, global_pose_transform, material, strand_fraction);
+	ApplyCollisionAndUpdateVelocity(vtx, pos, prev_pos, material, strand_fraction, vel);
+	LimitVelocity(vel, material);
+
+	// Write back vertex
+	gPositions[vtx] = pos;
+	gVelocities[vtx] = vel;
+}

+ 20 - 0
Jolt/Shaders/HairUpdateVelocityBindings.h

@@ -0,0 +1,20 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairUpdateVelocity)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gVerticesFixed)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandFractions)
+	JPH_SHADER_BIND_BUFFER(JPH_float3, gInitialPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gInitialBishops)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandMaterialIndex)
+	JPH_SHADER_BIND_BUFFER(JPH_HairMaterial, gMaterials)
+	JPH_SHADER_BIND_BUFFER(JPH_HairPosition, gPreviousPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_HairGlobalPoseTransform, gGlobalPoseTransforms)
+	JPH_SHADER_BIND_BUFFER(JPH_HairCollisionShape, gCollisionShapes)
+	JPH_SHADER_BIND_BUFFER(JPH_HairCollisionPlane, gCollisionPlanes)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairVelocity, gVelocities)
+JPH_SHADER_BIND_END(JPH_HairUpdateVelocity)

+ 44 - 0
Jolt/Shaders/HairUpdateVelocityIntegrate.hlsl

@@ -0,0 +1,44 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairUpdateVelocityIntegrateBindings.h"
+#include "HairCommon.h"
+#include "HairIntegrate.h"
+#include "HairUpdateVelocity.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cHairPerVertexBatch, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Check if this is a valid vertex
+	uint vtx = tid.x + cNumStrands; // Skip the root of each strand, it's fixed
+	if (vtx >= cNumVertices)
+		return;
+	if (IsVertexFixed(gVerticesFixed, vtx))
+		return;
+
+	// Load the material
+	uint strand_idx = vtx % cNumStrands;
+	JPH_HairMaterial material = gMaterials[GetStrandMaterialIndex(gStrandMaterialIndex, strand_idx)];
+
+	// Load the vertex
+	float strand_fraction = GetVertexStrandFraction(gStrandFractions, vtx);
+	float3 initial_pos = gInitialPositions[vtx];
+	float4 initial_bishop = JPH_QuatDecompress(gInitialBishops[vtx]);
+	JPH_HairGlobalPoseTransform global_pose_transform = gGlobalPoseTransforms[strand_idx];
+	JPH_HairPosition pos = gPositions[vtx];
+	JPH_HairPosition prev_pos = gPreviousPositions[vtx];
+
+	// HairUpdateVelocity shader
+	JPH_HairVelocity vel; // Keeps velocity as a local variable
+	ApplyGlobalPose(pos, initial_pos, initial_bishop, global_pose_transform, material, strand_fraction);
+	ApplyCollisionAndUpdateVelocity(vtx, pos, prev_pos, material, strand_fraction, vel);
+	LimitVelocity(vel, material);
+
+	// HairIntegrate shader
+	gPreviousPositions[vtx] = pos;
+	ApplyGrid(pos, vel, material, strand_fraction);
+	Integrate(pos, vel, material, strand_fraction);
+	gPositions[vtx] = pos;
+}

+ 21 - 0
Jolt/Shaders/HairUpdateVelocityIntegrateBindings.h

@@ -0,0 +1,21 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "HairStructs.h"
+
+JPH_SHADER_BIND_BEGIN(JPH_HairUpdateVelocityIntegrate)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gVerticesFixed)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandFractions)
+	JPH_SHADER_BIND_BUFFER(JPH_float3, gInitialPositions)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gInitialBishops)
+	JPH_SHADER_BIND_BUFFER(JPH_float, gNeutralDensity)
+	JPH_SHADER_BIND_BUFFER(JPH_float4, gVelocityAndDensity)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gStrandMaterialIndex)
+	JPH_SHADER_BIND_BUFFER(JPH_HairMaterial, gMaterials)
+	JPH_SHADER_BIND_BUFFER(JPH_HairGlobalPoseTransform, gGlobalPoseTransforms)
+	JPH_SHADER_BIND_BUFFER(JPH_HairCollisionShape, gCollisionShapes)
+	JPH_SHADER_BIND_BUFFER(JPH_HairCollisionPlane, gCollisionPlanes)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPreviousPositions)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_HairPosition, gPositions)
+JPH_SHADER_BIND_END(JPH_HairUpdateVelocityIntegrate)

+ 135 - 0
Jolt/Shaders/HairWrapper.cpp

@@ -0,0 +1,135 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#include <Jolt/Shaders/HairWrapper.h>
+
+#define JPH_SHADER_NAME HairTeleport
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairTeleport.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairTeleportBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairApplyDeltaTransform
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairApplyDeltaTransform.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairApplyDeltaTransformBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairSkinVertices
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairSkinVertices.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairSkinVerticesBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairSkinRoots
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairSkinRoots.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairSkinRootsBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairApplyGlobalPose
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairApplyGlobalPose.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairApplyGlobalPoseBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairCalculateCollisionPlanes
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairCalculateCollisionPlanes.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairCalculateCollisionPlanesBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairGridClear
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairGridClear.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairGridClearBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairGridAccumulate
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairGridAccumulate.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairGridAccumulateBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairGridNormalize
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairGridNormalize.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairGridNormalizeBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairIntegrate
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairIntegrate.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairIntegrateBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairUpdateRoots
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairUpdateRoots.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairUpdateRootsBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairUpdateStrands
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairUpdateStrands.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairUpdateStrandsBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairUpdateVelocity
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairUpdateVelocity.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairUpdateVelocityBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairUpdateVelocityIntegrate
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairUpdateVelocityIntegrate.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairUpdateVelocityIntegrateBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME HairCalculateRenderPositions
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "HairCalculateRenderPositions.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "HairCalculateRenderPositionsBindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+JPH_NAMESPACE_BEGIN
+
+void JPH_EXPORT HairRegisterShaders(ComputeSystemCPU *inComputeSystem)
+{
+	JPH_REGISTER_SHADER(inComputeSystem, HairTeleport);
+	JPH_REGISTER_SHADER(inComputeSystem, HairApplyDeltaTransform);
+	JPH_REGISTER_SHADER(inComputeSystem, HairSkinVertices);
+	JPH_REGISTER_SHADER(inComputeSystem, HairSkinRoots);
+	JPH_REGISTER_SHADER(inComputeSystem, HairApplyGlobalPose);
+	JPH_REGISTER_SHADER(inComputeSystem, HairCalculateCollisionPlanes);
+	JPH_REGISTER_SHADER(inComputeSystem, HairGridClear);
+	JPH_REGISTER_SHADER(inComputeSystem, HairGridAccumulate);
+	JPH_REGISTER_SHADER(inComputeSystem, HairGridNormalize);
+	JPH_REGISTER_SHADER(inComputeSystem, HairIntegrate);
+	JPH_REGISTER_SHADER(inComputeSystem, HairUpdateRoots);
+	JPH_REGISTER_SHADER(inComputeSystem, HairUpdateStrands);
+	JPH_REGISTER_SHADER(inComputeSystem, HairUpdateVelocity);
+	JPH_REGISTER_SHADER(inComputeSystem, HairUpdateVelocityIntegrate);
+	JPH_REGISTER_SHADER(inComputeSystem, HairCalculateRenderPositions);
+}
+
+JPH_NAMESPACE_END

+ 13 - 0
Jolt/Shaders/HairWrapper.h

@@ -0,0 +1,13 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+JPH_NAMESPACE_BEGIN
+
+class ComputeSystemCPU;
+
+void JPH_EXPORT HairRegisterShaders(ComputeSystemCPU *inComputeSystem);
+
+JPH_NAMESPACE_END

+ 4 - 0
Jolt/Shaders/ShaderCore.h

@@ -34,6 +34,10 @@
 
 
 	JPH_SUPPRESS_WARNING_POP
 	JPH_SUPPRESS_WARNING_POP
 #else
 #else
+	#define JPH_SUPPRESS_WARNING_PUSH
+	#define JPH_SUPPRESS_WARNING_POP
+	#define JPH_SUPPRESS_WARNINGS
+
 	typedef float JPH_float;
 	typedef float JPH_float;
 	typedef float3 JPH_float3;
 	typedef float3 JPH_float3;
 	typedef float4 JPH_float4;
 	typedef float4 JPH_float4;

+ 1 - 0
README.md

@@ -86,6 +86,7 @@ Why create yet another physics engine? Firstly, it has been a personal learning
 	* Internal pressure.
 	* Internal pressure.
 	* Collision with simulated rigid bodies.
 	* Collision with simulated rigid bodies.
 	* Collision tests against soft bodies.
 	* Collision tests against soft bodies.
+* A GPU based hair simulation.
 * Water buoyancy calculations.
 * Water buoyancy calculations.
 * An optional double precision mode that allows large worlds.
 * An optional double precision mode that allows large worlds.
 
 

+ 7 - 0
Samples/Samples.cmake

@@ -155,6 +155,12 @@ set(SAMPLES_SRC_FILES
 	${SAMPLES_ROOT}/Tests/General/WallTest.h
 	${SAMPLES_ROOT}/Tests/General/WallTest.h
 	${SAMPLES_ROOT}/Tests/General/ActivateDuringUpdateTest.cpp
 	${SAMPLES_ROOT}/Tests/General/ActivateDuringUpdateTest.cpp
 	${SAMPLES_ROOT}/Tests/General/ActivateDuringUpdateTest.h
 	${SAMPLES_ROOT}/Tests/General/ActivateDuringUpdateTest.h
+	${SAMPLES_ROOT}/Tests/Hair/HairCollisionTest.cpp
+	${SAMPLES_ROOT}/Tests/Hair/HairCollisionTest.h
+	${SAMPLES_ROOT}/Tests/Hair/HairGravityPreloadTest.cpp
+	${SAMPLES_ROOT}/Tests/Hair/HairGravityPreloadTest.h
+	${SAMPLES_ROOT}/Tests/Hair/HairTest.cpp
+	${SAMPLES_ROOT}/Tests/Hair/HairTest.h
 	${SAMPLES_ROOT}/Tests/Rig/CreateRigTest.cpp
 	${SAMPLES_ROOT}/Tests/Rig/CreateRigTest.cpp
 	${SAMPLES_ROOT}/Tests/Rig/CreateRigTest.h
 	${SAMPLES_ROOT}/Tests/Rig/CreateRigTest.h
 	${SAMPLES_ROOT}/Tests/SoftBody/SoftBodyBendConstraintTest.cpp
 	${SAMPLES_ROOT}/Tests/SoftBody/SoftBodyBendConstraintTest.cpp
@@ -317,6 +323,7 @@ endif()
 # Assets used by the samples
 # Assets used by the samples
 set(SAMPLES_ASSETS
 set(SAMPLES_ASSETS
 	${PHYSICS_REPO_ROOT}/Assets/convex_hulls.bin
 	${PHYSICS_REPO_ROOT}/Assets/convex_hulls.bin
+	${PHYSICS_REPO_ROOT}/Assets/face.bin
 	${PHYSICS_REPO_ROOT}/Assets/heightfield1.bin
 	${PHYSICS_REPO_ROOT}/Assets/heightfield1.bin
 	${PHYSICS_REPO_ROOT}/Assets/Human/dead_pose1.tof
 	${PHYSICS_REPO_ROOT}/Assets/Human/dead_pose1.tof
 	${PHYSICS_REPO_ROOT}/Assets/Human/dead_pose2.tof
 	${PHYSICS_REPO_ROOT}/Assets/Human/dead_pose2.tof

+ 14 - 0
Samples/SamplesApp.cpp

@@ -46,6 +46,7 @@
 #include <Jolt/Physics/Constraints/DistanceConstraint.h>
 #include <Jolt/Physics/Constraints/DistanceConstraint.h>
 #include <Jolt/Physics/Constraints/PulleyConstraint.h>
 #include <Jolt/Physics/Constraints/PulleyConstraint.h>
 #include <Jolt/Physics/Character/CharacterVirtual.h>
 #include <Jolt/Physics/Character/CharacterVirtual.h>
+#include <Jolt/Shaders/HairWrapper.h>
 #include <Utils/Log.h>
 #include <Utils/Log.h>
 #include <Utils/ShapeCreator.h>
 #include <Utils/ShapeCreator.h>
 #include <Utils/CustomMemoryHook.h>
 #include <Utils/CustomMemoryHook.h>
@@ -391,6 +392,17 @@ static TestNameAndRTTI sSoftBodyTests[] =
 	{ "Soft Body vs Sensor",				JPH_RTTI(SoftBodySensorTest) }
 	{ "Soft Body vs Sensor",				JPH_RTTI(SoftBodySensorTest) }
 };
 };
 
 
+JPH_DECLARE_RTTI_FOR_FACTORY(JPH_NO_EXPORT, HairTest)
+JPH_DECLARE_RTTI_FOR_FACTORY(JPH_NO_EXPORT, HairCollisionTest)
+JPH_DECLARE_RTTI_FOR_FACTORY(JPH_NO_EXPORT, HairGravityPreloadTest)
+
+static TestNameAndRTTI sHairTests[] =
+{
+	{ "Hair",								JPH_RTTI(HairTest) },
+	{ "Hair Collision",						JPH_RTTI(HairCollisionTest) },
+	{ "Hair Gravity Preload",				JPH_RTTI(HairGravityPreloadTest) },
+};
+
 JPH_DECLARE_RTTI_FOR_FACTORY(JPH_NO_EXPORT, BroadPhaseCastRayTest)
 JPH_DECLARE_RTTI_FOR_FACTORY(JPH_NO_EXPORT, BroadPhaseCastRayTest)
 JPH_DECLARE_RTTI_FOR_FACTORY(JPH_NO_EXPORT, BroadPhaseInsertionTest)
 JPH_DECLARE_RTTI_FOR_FACTORY(JPH_NO_EXPORT, BroadPhaseInsertionTest)
 
 
@@ -437,6 +449,7 @@ static TestCategory sAllCategories[] =
 	{ "Water", sWaterTests, size(sWaterTests) },
 	{ "Water", sWaterTests, size(sWaterTests) },
 	{ "Vehicle", sVehicleTests, size(sVehicleTests) },
 	{ "Vehicle", sVehicleTests, size(sVehicleTests) },
 	{ "Soft Body", sSoftBodyTests, size(sSoftBodyTests) },
 	{ "Soft Body", sSoftBodyTests, size(sSoftBodyTests) },
+	{ "Hair", sHairTests, size(sHairTests) },
 	{ "Broad Phase", sBroadPhaseTests, size(sBroadPhaseTests) },
 	{ "Broad Phase", sBroadPhaseTests, size(sBroadPhaseTests) },
 	{ "Convex Collision", sConvexCollisionTests, size(sConvexCollisionTests) },
 	{ "Convex Collision", sConvexCollisionTests, size(sConvexCollisionTests) },
 	{ "Tools", sTools, size(sTools) }
 	{ "Tools", sTools, size(sTools) }
@@ -494,6 +507,7 @@ SamplesApp::SamplesApp(const String &inCommandLine) :
 
 
 	// Create compute system CPU
 	// Create compute system CPU
 	mComputeSystemCPU = StaticCast<ComputeSystemCPU>(CreateComputeSystemCPU().Get());
 	mComputeSystemCPU = StaticCast<ComputeSystemCPU>(CreateComputeSystemCPU().Get());
+	HairRegisterShaders(mComputeSystemCPU);
 	mComputeQueueCPU = mComputeSystemCPU->CreateComputeQueue().Get();
 	mComputeQueueCPU = mComputeSystemCPU->CreateComputeQueue().Get();
 
 
 	{
 	{

+ 112 - 0
Samples/Tests/Hair/HairCollisionTest.cpp

@@ -0,0 +1,112 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <TestFramework.h>
+
+#include <Tests/Hair/HairCollisionTest.h>
+#include <Jolt/Physics/Body/BodyCreationSettings.h>
+#include <Jolt/Physics/Collision/Shape/ConvexHullShape.h>
+#include <Jolt/Physics/Collision/Shape/StaticCompoundShape.h>
+#include <Application/DebugUI.h>
+#include <Layers.h>
+
+JPH_IMPLEMENT_RTTI_VIRTUAL(HairCollisionTest)
+{
+	JPH_ADD_BASE_CLASS(HairCollisionTest, Test)
+}
+
+void HairCollisionTest::Initialize()
+{
+	// Load shaders
+	mHairShaders.Init(mComputeSystem);
+
+	// Create a single strand
+	mHairSettings = new HairSettings;
+	HairSettings::Material m;
+	m.mHairRadius = HairSettings::Gradient(0, 0); // Override radius to 0 so we can see it touch the moving body
+	mHairSettings->mMaterials.push_back(m);
+	mHairSettings->mSimulationBoundsPadding = Vec3::sReplicate(1.0f);
+	Array<HairSettings::SVertex> hair_vertices = { HairSettings::SVertex(Float3(0, 2, 0), 0), HairSettings::SVertex(Float3(0, 0, 0), 1) };
+	Array<HairSettings::SStrand> hair_strands = { HairSettings::SStrand(0, 2, 0) };
+	mHairSettings->InitRenderAndSimulationStrands(hair_vertices, hair_strands);
+	float max_dist_sq = 0.0f;
+	mHairSettings->Init(max_dist_sq);
+	mHairSettings->InitCompute(mComputeSystem);
+
+	mHair = new Hair(mHairSettings, RVec3::sZero(), Quat::sRotation(Vec3::sAxisY(), 0.5f * JPH_PI), Layers::MOVING); // Ensure hair is rotated
+	mHair->Init(mComputeSystem);
+	mHair->Update(0.0f, Mat44::sIdentity(), nullptr, *mPhysicsSystem, mHairShaders, mComputeSystem, mComputeQueue);
+	mHair->ReadBackGPUState(mComputeQueue);
+
+	// Create moving body that moves through the strand
+	ConvexHullShapeSettings shape1;
+	shape1.SetEmbedded();
+	constexpr float cWidth = 0.01f, cHeight = 0.5f, cLength1 = 0.6f;
+	shape1.mPoints = {
+		Vec3( cWidth,  cHeight,  cLength1),
+		Vec3(-cWidth,  cHeight,  cLength1),
+		Vec3( cWidth, -cHeight,  cLength1),
+		Vec3(-cWidth, -cHeight,  cLength1),
+		Vec3( cWidth,  cHeight, -cLength1),
+		Vec3(-cWidth,  cHeight, -cLength1),
+		Vec3( cWidth, -cHeight, -cLength1),
+		Vec3(-cWidth, -cHeight, -cLength1)
+	};
+	ConvexHullShapeSettings shape2;
+	shape2.SetEmbedded();
+	constexpr float cLength2 = 0.5f;
+	shape2.mPoints = {
+		Vec3( cWidth,  cHeight,  cLength2),
+		Vec3(-cWidth,  cHeight,  cLength2),
+		Vec3( cWidth, -cHeight,  cLength2),
+		Vec3(-cWidth, -cHeight,  cLength2),
+		Vec3( cWidth,  cHeight, -cLength2),
+		Vec3(-cWidth,  cHeight, -cLength2),
+		Vec3( cWidth, -cHeight, -cLength2),
+		Vec3(-cWidth, -cHeight, -cLength2)
+	};
+	StaticCompoundShapeSettings compound; // Use a compound to test center of mass differences between body and shape
+	compound.SetEmbedded();
+	compound.AddShape(Vec3(0, 0, -cLength2), Quat::sIdentity(), &shape1);
+	compound.AddShape(Vec3(0, 0, cLength1), Quat::sIdentity(), &shape2);
+	BodyCreationSettings moving_body(&compound, RVec3(-1, 0, 0), Quat::sIdentity(), EMotionType::Kinematic, Layers::MOVING);
+	mMovingBodyID = mBodyInterface->CreateAndAddBody(moving_body, EActivation::Activate);
+}
+
+void HairCollisionTest::PrePhysicsUpdate(const PreUpdateParams &inParams)
+{
+#ifdef JPH_DEBUG_RENDERER
+	Hair::DrawSettings settings;
+	settings.mDrawRods = true;
+	settings.mDrawOrientations = true;
+	mHair->Draw(settings, mDebugRenderer);
+#endif // JPH_DEBUG_RENDERER
+
+	// Set moving body velocity
+	++mFrame;
+	if (sRotating)
+		mBodyInterface->SetLinearAndAngularVelocity(mMovingBodyID, Vec3::sZero(), Vec3(0, 1, 0));
+	else
+		mBodyInterface->SetLinearAndAngularVelocity(mMovingBodyID, mFrame % 240 < 120? Vec3(1, 0, 0) : Vec3(-1, 0, 0), Vec3::sZero());
+
+	// Update the hair
+	mHair->Update(inParams.mDeltaTime, Mat44::sIdentity(), nullptr, *mPhysicsSystem, mHairShaders, mComputeSystem, mComputeQueue);
+	mComputeQueue->ExecuteAndWait();
+	mHair->ReadBackGPUState(mComputeQueue);
+}
+
+void HairCollisionTest::SaveState(StateRecorder &inStream) const
+{
+	inStream.Write(mFrame);
+}
+
+void HairCollisionTest::RestoreState(StateRecorder &inStream)
+{
+	inStream.Read(mFrame);
+}
+
+void HairCollisionTest::CreateSettingsMenu(DebugUI *inUI, UIElement *inSubMenu)
+{
+	inUI->CreateCheckBox(inSubMenu, "Rotating", sRotating, [](UICheckBox::EState inState) { sRotating = inState == UICheckBox::STATE_CHECKED; });
+}

+ 46 - 0
Samples/Tests/Hair/HairCollisionTest.h

@@ -0,0 +1,46 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Tests/Test.h>
+#include <Jolt/Physics/Hair/Hair.h>
+#include <Jolt/Physics/Hair/HairShaders.h>
+
+class HairCollisionTest : public Test
+{
+public:
+	JPH_DECLARE_RTTI_VIRTUAL(JPH_NO_EXPORT, HairCollisionTest)
+
+	// Destructor
+	virtual					~HairCollisionTest() override											{ delete mHair; }
+
+	// Description of the test
+	virtual const char *	GetDescription() const override
+	{
+		return "Hair collision demo.";
+	}
+
+	// See: Test
+	virtual void			Initialize() override;
+	virtual void			PrePhysicsUpdate(const PreUpdateParams &inParams) override;
+	virtual void			SaveState(StateRecorder &inStream) const override;
+	virtual void			RestoreState(StateRecorder &inStream) override;
+
+	// Number used to scale the terrain and camera movement to the scene
+	virtual float			GetWorldScale() const override											{ return 0.01f; }
+
+	// Optional settings menu
+	virtual bool			HasSettingsMenu() const override										{ return true; }
+	virtual void			CreateSettingsMenu(DebugUI *inUI, UIElement *inSubMenu) override;
+
+private:
+	inline static bool		sRotating = false;
+
+	Ref<HairSettings>		mHairSettings = nullptr;
+	HairShaders				mHairShaders;
+	Hair *					mHair = nullptr;
+	uint32					mFrame = 0;
+	BodyID					mMovingBodyID;
+};

+ 135 - 0
Samples/Tests/Hair/HairGravityPreloadTest.cpp

@@ -0,0 +1,135 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <TestFramework.h>
+
+#include <Tests/Hair/HairGravityPreloadTest.h>
+#include <Layers.h>
+#include <Application/DebugUI.h>
+
+JPH_IMPLEMENT_RTTI_VIRTUAL(HairGravityPreloadTest)
+{
+	JPH_ADD_BASE_CLASS(HairGravityPreloadTest, Test)
+}
+
+const char *HairGravityPreloadTest::sScenes[] =
+{
+	"Zig Zag",
+	"Helix",
+	"Horizontal Bar",
+};
+
+const char *HairGravityPreloadTest::sSceneName = "Zig Zag";
+
+void HairGravityPreloadTest::Initialize()
+{
+	// Load shaders
+	mHairShaders.Init(mComputeSystem);
+
+	Array<HairSettings::SVertex> hair_vertices;
+	Array<HairSettings::SStrand> hair_strands;
+
+	if (strcmp(sSceneName, "Zig Zag") == 0)
+	{
+		// Create a hanging zig zag
+		constexpr float cHoriz = 0.05f;
+		constexpr int cNumVertices = 128;
+		constexpr float cHeight = 0.5f;
+		for (int j = 0; j < 2; ++j)
+			for (int i = 0; i < cNumVertices; ++i)
+			{
+				float fraction = float(i) / (cNumVertices - 1);
+
+				HairSettings::SVertex v;
+				v.mPosition = Float3((j == 0? -0.1f : 0.1f) + (i & 1? cHoriz : -cHoriz), (1.0f - fraction) * cHeight, 0);
+				v.mInvMass = i == 0? 0.0f : 1.0f;
+				hair_vertices.push_back(v);
+			}
+		hair_strands = { HairSettings::SStrand(0, cNumVertices, 0), HairSettings::SStrand(cNumVertices, 2 * cNumVertices, 1) };
+	}
+	else if (strcmp(sSceneName, "Helix") == 0)
+	{
+		// Create a hanging helix
+		constexpr float cRadius = 0.05f;
+		constexpr int cNumVertices = 128;
+		constexpr float cHeight = 0.5f;
+		constexpr float cNumCycles = 10;
+		for (int j = 0; j < 2; ++j)
+			for (int i = 0; i < cNumVertices; ++i)
+			{
+				float fraction = float(i) / (cNumVertices - 1);
+
+				HairSettings::SVertex v;
+				float alpha = cNumCycles * 2.0f * JPH_PI * fraction;
+				v.mPosition = Float3((j == 0? -0.1f : 0.1f) + cRadius * Sin(alpha), (1.0f - fraction) * cHeight, cRadius * Cos(alpha));
+				v.mInvMass = i == 0? 0.0f : 1.0f;
+				hair_vertices.push_back(v);
+			}
+		hair_strands = { HairSettings::SStrand(0, cNumVertices, 0), HairSettings::SStrand(cNumVertices, 2 * cNumVertices, 1) };
+	}
+	else if (strcmp(sSceneName, "Horizontal Bar") == 0)
+	{
+		// Create horizontal bar
+		constexpr int cNumVertices = 10;
+		for (int j = 0; j < 2; ++j)
+			for (int i = 0; i < cNumVertices; ++i)
+			{
+				HairSettings::SVertex v;
+				v.mPosition = Float3(j == 0? -0.1f : 0.1f, 0, 1.0f * float(i));
+				v.mInvMass = i == 0? 0.0f : 1.0f;
+				hair_vertices.push_back(v);
+			}
+
+		hair_strands = { HairSettings::SStrand(0, cNumVertices, 0), HairSettings::SStrand(cNumVertices, 2 * cNumVertices, 0) };
+	}
+
+	mHairSettings = new HairSettings;
+	HairSettings::Material m;
+	m.mGlobalPose = HairSettings::Gradient(0, 0);
+	m.mEnableLRA = false; // We're testing gravity preloading, so disable LRA to avoid hitting the stretch limits
+	m.mBendCompliance = 1e-8f;
+	m.mStretchCompliance = 1e-10f;
+	m.mBendComplianceMultiplier = { 1, 100, 100, 1 }; // Non uniform
+	m.mGridVelocityFactor = { 0.0f, 0.0f }; // Don't let the grid affect the simulation
+	m.mGravityPreloadFactor = 0.0f;
+	m.mGravityFactor = { 1.0f, 0.5f, 0.2f, 0.8f }; // Non uniform
+	m.mSimulationStrandsFraction = 1.0f;
+	mHairSettings->mMaterials.push_back(m);
+	m.mGravityPreloadFactor = 1.0f;
+	mHairSettings->mMaterials.push_back(m);
+	mHairSettings->mSimulationBoundsPadding = Vec3::sReplicate(1.0f);
+	mHairSettings->InitRenderAndSimulationStrands(hair_vertices, hair_strands);
+	float max_dist_sq = 0.0f;
+	mHairSettings->Init(max_dist_sq);
+	mHairSettings->InitCompute(mComputeSystem);
+	mHair = new Hair(mHairSettings, RVec3::sZero(), Quat::sIdentity(), Layers::MOVING); // Ensure hair is rotated
+	mHair->Init(mComputeSystem);
+	mHair->Update(0.0f, Mat44::sIdentity(), nullptr, *mPhysicsSystem, mHairShaders, mComputeSystem, mComputeQueue);
+	mHair->ReadBackGPUState(mComputeQueue);
+}
+
+void HairGravityPreloadTest::PrePhysicsUpdate(const PreUpdateParams &inParams)
+{
+#ifdef JPH_DEBUG_RENDERER
+	Hair::DrawSettings settings;
+	settings.mDrawRods = true;
+	settings.mDrawUnloadedRods = true;
+	mHair->Draw(settings, mDebugRenderer);
+#endif // JPH_DEBUG_RENDERER
+
+	// Update the hair
+	mHair->Update(inParams.mDeltaTime, Mat44::sIdentity(), nullptr, *mPhysicsSystem, mHairShaders, mComputeSystem, mComputeQueue);
+	mComputeQueue->ExecuteAndWait();
+	mHair->ReadBackGPUState(mComputeQueue);
+}
+
+void HairGravityPreloadTest::CreateSettingsMenu(DebugUI *inUI, UIElement *inSubMenu)
+{
+	inUI->CreateTextButton(inSubMenu, "Select Scene", [this, inUI]() {
+		UIElement *scene_name = inUI->CreateMenu();
+		for (uint i = 0; i < size(sScenes); ++i)
+			inUI->CreateTextButton(scene_name, sScenes[i], [this, i]() { sSceneName = sScenes[i]; RestartTest(); });
+		inUI->ShowMenu(scene_name);
+	});
+}

+ 44 - 0
Samples/Tests/Hair/HairGravityPreloadTest.h

@@ -0,0 +1,44 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Tests/Test.h>
+#include <Jolt/Physics/Hair/Hair.h>
+#include <Jolt/Physics/Hair/HairShaders.h>
+
+class HairGravityPreloadTest : public Test
+{
+public:
+	JPH_DECLARE_RTTI_VIRTUAL(JPH_NO_EXPORT, HairGravityPreloadTest)
+
+	// Destructor
+	virtual					~HairGravityPreloadTest() override										{ delete mHair; }
+
+	// Description of the test
+	virtual const char *	GetDescription() const override
+	{
+		return	"Hair gravity preloading demo. This prevents the hair from sagging at the start of the simulation.\n"
+				"Note: Not fully functional!";
+	}
+
+	// See: Test
+	virtual void			Initialize() override;
+	virtual void			PrePhysicsUpdate(const PreUpdateParams &inParams) override;
+
+	// Number used to scale the terrain and camera movement to the scene
+	virtual float			GetWorldScale() const override											{ return 0.01f; }
+
+	// Optional settings menu
+	virtual bool			HasSettingsMenu() const override										{ return true; }
+	virtual void			CreateSettingsMenu(DebugUI *inUI, UIElement *inSubMenu) override;
+
+private:
+	static const char *		sScenes[];
+	static const char *		sSceneName;
+
+	Ref<HairSettings>		mHairSettings = nullptr;
+	HairShaders				mHairShaders;
+	Hair *					mHair = nullptr;
+};

+ 445 - 0
Samples/Tests/Hair/HairTest.cpp

@@ -0,0 +1,445 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <TestFramework.h>
+
+#include <Tests/Hair/HairTest.h>
+#include <Jolt/Physics/Body/BodyCreationSettings.h>
+#include <Jolt/Physics/Collision/Shape/ConvexHullShape.h>
+#include <Jolt/Physics/Collision/RayCast.h>
+#include <Jolt/Physics/Collision/CastResult.h>
+#include <Jolt/Core/StreamWrapper.h>
+#include <Utils/ReadData.h>
+#include <Utils/Log.h>
+#include <Utils/AssetStream.h>
+#include <Application/DebugUI.h>
+#include <Layers.h>
+#include <Renderer/DebugRendererImp.h>
+
+JPH_SUPPRESS_WARNINGS_STD_BEGIN
+#include <filesystem>
+JPH_SUPPRESS_WARNINGS_STD_END
+
+JPH_IMPLEMENT_RTTI_VIRTUAL(HairTest)
+{
+	JPH_ADD_BASE_CLASS(HairTest, Test)
+}
+
+auto tenth_of_inch_to_m = [](Mat44Arg inInvNeckTransform, Vec3Arg inVertex) { return inInvNeckTransform * ((2.54f / 1000.0f) * inVertex.Swizzle<SWIZZLE_Y, SWIZZLE_Z, SWIZZLE_X>()); }; // Original model seems to be in 10ths of inches
+
+const HairTest::Groom HairTest::sGrooms[] =
+{
+	{ "Straight", tenth_of_inch_to_m, false },
+	{ "Curly", tenth_of_inch_to_m, false },
+	{ "Wavy", tenth_of_inch_to_m, false },
+};
+
+const HairTest::Groom *HairTest::sSelectedGroom = &sGrooms[0];
+
+void HairTest::Initialize()
+{
+	// Check groom file exists
+	String groom_file = "w" + String(sSelectedGroom->mName) + ".hair";
+	String full_path = AssetStream::sGetAssetsBasePath() + groom_file;
+	if (!std::filesystem::exists(full_path))
+		FatalError("File %s not found.\n\n"
+			"wCurly.hair, wStraight.hair and wWavy.hair should be downloaded from https://www.cemyuksel.com/research/hairmodels/ (or by running Assets/download_hair.sh)", full_path.c_str());
+
+	// Read face mesh and animation
+	AssetStream asset_stream("face.bin", std::ios::in | std::ios::binary);
+	StreamInWrapper stream(asset_stream.Get());
+
+	// Neck joint index
+	stream.Read(mHeadJointIdx);
+
+	// Vertices
+	uint32 num_vertices;
+	stream.Read(num_vertices);
+	Array<Float3> vertices;
+	vertices.resize(num_vertices);
+	stream.ReadBytes(vertices.data(), sizeof(Float3) * num_vertices);
+
+	// Indices
+	uint32 num_indices;
+	stream.Read(num_indices);
+	Array<IndexedTriangleNoMaterial> indices;
+	indices.resize(num_indices);
+	stream.ReadBytes(indices.data(), sizeof(IndexedTriangleNoMaterial) * num_indices);
+
+	// Inverse Bind Matrices
+	uint32 num_joints;
+	stream.Read(num_joints);
+	Array<Mat44> inv_bind_pose;
+	inv_bind_pose.resize(num_joints);
+	stream.ReadBytes(inv_bind_pose.data(), sizeof(Mat44) * num_joints);
+
+	// Skin Weights
+	uint num_skin_weights_per_vertex;
+	Array<HairSettings::SkinWeight> skin_weights;
+	stream.Read(num_skin_weights_per_vertex);
+	skin_weights.resize(num_skin_weights_per_vertex * num_vertices);
+	stream.ReadBytes(skin_weights.data(), sizeof(HairSettings::SkinWeight) * num_skin_weights_per_vertex * num_vertices);
+
+	// Animation
+	uint32 num_frames;
+	stream.Read(num_frames);
+	mFaceAnimation.resize(num_frames);
+	for (uint32 frame = 0; frame < num_frames; ++frame)
+	{
+		mFaceAnimation[frame].resize(num_joints);
+		for (uint32 joint = 0; joint < num_joints; ++joint)
+		{
+			Float3 translation, rotation;
+			stream.Read(translation);
+			stream.Read(rotation);
+			Quat rotation_quat(rotation.x, rotation.y, rotation.z, sqrt(1.0f - Vec3(rotation).LengthSq()));
+			mFaceAnimation[frame][joint] = Mat44::sRotationTranslation(rotation_quat, Vec3(translation));
+		}
+	}
+
+	// Read collision hulls
+	uint32 num_hulls;
+	stream.Read(num_hulls);
+	for (uint32 i = 0; i < num_hulls; ++i)
+	{
+		// Attached to joint
+		uint32 joint_index;
+		stream.Read(joint_index);
+
+		// Read number of vertices
+		uint32 num_hull_vertices;
+		stream.Read(num_hull_vertices);
+
+		// Read vertices
+		ConvexHullShapeSettings shape_settings;
+		shape_settings.SetEmbedded();
+		shape_settings.mPoints.resize(num_hull_vertices);
+		for (uint32 j = 0; j < num_hull_vertices; ++j)
+			stream.Read(shape_settings.mPoints[j]);
+
+		Mat44 transform = joint_index != 0xffffffff? mFaceAnimation[0][joint_index] : Mat44::sIdentity();
+		Mat44 inv_transform = transform.Inversed();
+		for (Vec3 &v : shape_settings.mPoints)
+			v = inv_transform * v;
+
+		// Create the body
+		BodyCreationSettings body(&shape_settings, RVec3(transform.GetTranslation()), transform.GetQuaternion(), EMotionType::Kinematic, Layers::MOVING);
+		BodyID body_id = mBodyInterface->CreateAndAddBody(body, EActivation::DontActivate);
+
+		mAttachedBodies.push_back({ joint_index, body_id });
+	}
+
+	// Make mesh relative to neck bind pose
+	Mat44 inv_bind_neck = inv_bind_pose[mHeadJointIdx];
+	Mat44 bind_neck = inv_bind_neck.Inversed();
+	for (Float3 &v : vertices)
+		(inv_bind_neck * Vec3(v)).StoreFloat3(&v);
+	for (Mat44 &m : inv_bind_pose)
+		m = m * bind_neck;
+
+	// Read hair file
+	Array<uint8> data = ReadData(groom_file.c_str());
+	if (data[0] != 'H' || data[1] != 'A' || data[2] != 'I' || data[3] != 'R')
+		FatalError("Invalid hair file");
+
+	uint32 features = *reinterpret_cast<const uint32 *>(&data[12]);
+	if ((features & 0b10) != 0b10)
+		FatalError("We require points to be defined");
+
+	uint32 num_strands = *reinterpret_cast<const uint32 *>(&data[4]);
+	uint32 num_points = *reinterpret_cast<const uint32 *>(&data[8]);
+
+	const uint16 *num_segments = nullptr;
+	int num_segments_delta = 0;
+	const Float3 *points = nullptr;
+	if (features & 0b01)
+	{
+		// Num segments differs per strand
+		num_segments = reinterpret_cast<const uint16 *>(&data[128]);
+		num_segments_delta = 1;
+		points = reinterpret_cast<const Float3 *>(&data[128 + num_strands * sizeof(uint16)]);
+	}
+	else
+	{
+		// Num segments is constant
+		num_segments = reinterpret_cast<const uint16 *>(&data[16]);
+		num_segments_delta = 0;
+		points = reinterpret_cast<const Float3 *>(&data[128]);
+	}
+
+	// Init strands
+	if (sLimitMaxStrands)
+		num_strands = std::min(num_strands, sMaxStrands);
+	Array<HairSettings::SVertex> hair_vertices;
+	hair_vertices.resize(num_points);
+	Array<HairSettings::SStrand> hair_strands;
+	hair_strands.reserve(num_strands);
+	const Mat44 &neck_transform = mFaceAnimation[0][mHeadJointIdx];
+	Mat44 inv_neck_transform = neck_transform.Inversed();
+	for (uint32 strand = 0; strand < num_strands; ++strand)
+	{
+		// Transform relative to neck
+		Array<Vec3> out_points;
+		for (uint16 point = 0; point < *num_segments + 1; ++point)
+			out_points.push_back(sSelectedGroom->mVertexTransform(inv_neck_transform, Vec3(points[point])));
+
+		// Attach the first vertex to the skull collision
+		if (sSelectedGroom->mAttachToHull)
+		{
+			const float cMaxDist = 10.0f;
+			Vec3 direction = cMaxDist * (out_points[0] - out_points[1]).NormalizedOr(-Vec3::sAxisY());
+			Vec3 origin = out_points[0] - 0.5f * direction;
+			RRayCast ray(RVec3(neck_transform * origin), neck_transform.Multiply3x3(direction));
+			RayCastResult hit;
+			if (mPhysicsSystem->GetNarrowPhaseQuery().CastRay(ray, hit))
+			{
+				Vec3 delta = origin + hit.mFraction * direction - out_points[0];
+				for (Vec3 &v : out_points)
+					v += delta;
+			}
+		}
+
+		// Add the strand to the hair settings
+		uint32 first_point = uint32(hair_vertices.size());
+		for (uint32 point = 0; point < uint32(out_points.size()); ++point)
+		{
+			HairSettings::SVertex v;
+			out_points[point].StoreFloat3(&v.mPosition);
+			v.mInvMass = point == 0? 0.0f : 1.0f;
+			hair_vertices.push_back(v);
+		}
+		hair_strands.push_back(HairSettings::SStrand(first_point, uint32(hair_vertices.size()), 0));
+
+		points += *num_segments + 1;
+		num_segments += num_segments_delta;
+	}
+
+	// Resample if requested
+	if (sOverrideVerticesPerStrand > 1)
+		HairSettings::sResample(hair_vertices, hair_strands, sOverrideVerticesPerStrand);
+
+	// Load shaders
+	mHairShaders.Init(mComputeSystem);
+
+	// Init hair settings
+	mHairSettings = new HairSettings;
+	mHairSettings->mScalpVertices = std::move(vertices);
+	mHairSettings->mScalpTriangles = std::move(indices);
+	mHairSettings->mScalpInverseBindPose = std::move(inv_bind_pose);
+	mHairSettings->mScalpSkinWeights = std::move(skin_weights);
+	mHairSettings->mScalpNumSkinWeightsPerVertex = num_skin_weights_per_vertex;
+	mHairSettings->mNumIterationsPerSecond = sNumSolverIterationsPerSecond;
+	HairSettings::Material m;
+	m.mEnableCollision = sEnableCollision;
+	m.mEnableLRA = sEnableLRA;
+	m.mLinearDamping = sLinearDamping;
+	m.mAngularDamping = sAngularDamping;
+	m.mFriction = sFriction;
+	m.mMaxLinearVelocity = sMaxLinearVelocity;
+	m.mMaxAngularVelocity = sMaxAngularVelocity;
+	m.mGravityFactor = sGravityFactor;
+	m.mGravityPreloadFactor = sGravityPreloadFactor;
+	m.mBendCompliance = std::pow(10.0f, sBendComplianceExponent);
+	m.mStretchCompliance = std::pow(10.0f, sStretchComplianceExponent);
+	m.mInertiaMultiplier = sInertiaMultiplier;
+	m.mHairRadius = sHairRadius;
+	m.mWorldTransformInfluence = sWorldTransformInfluence;
+	m.mGridVelocityFactor = sGridVelocityFactor;
+	m.mGridDensityForceFactor = sGridDensityForceFactor;
+	m.mGlobalPose = sGlobalPose;
+	m.mSkinGlobalPose = sSkinGlobalPose;
+	m.mMaxLinearVelocity = 10.0f;
+	m.mSimulationStrandsFraction = 0.01f * sSimulationStrandsPercentage;
+	mHairSettings->mMaterials.push_back(m);
+	mHairSettings->mSimulationBoundsPadding = Vec3::sReplicate(0.1f);
+	mHairSettings->mInitialGravity = inv_bind_neck.Multiply3x3(mPhysicsSystem->GetGravity());
+	mHairSettings->InitRenderAndSimulationStrands(hair_vertices, hair_strands);
+	float max_dist_sq = 0.0f;
+	mHairSettings->Init(max_dist_sq);
+	JPH_ASSERT(max_dist_sq < 1.0e-4f);
+
+	// Write and read back to test SaveBinaryState
+	stringstream stream_data;
+	{
+		StreamOutWrapper stream_out(stream_data);
+		mHairSettings->SaveBinaryState(stream_out);
+	}
+	mHairSettings = new HairSettings;
+	{
+		StreamInWrapper stream_in(stream_data);
+		mHairSettings->RestoreBinaryState(stream_in);
+	}
+	mHairSettings->InitCompute(mComputeSystem);
+
+	mHair = new Hair(mHairSettings, RVec3(neck_transform.GetTranslation()), neck_transform.GetQuaternion(), Layers::MOVING);
+	mHair->Init(mComputeSystem);
+	mHair->Update(0.0f, inv_neck_transform, mFaceAnimation[0].data(), *mPhysicsSystem, mHairShaders, mComputeSystem, mComputeQueue);
+	mHair->ReadBackGPUState(mComputeQueue);
+
+#ifdef JPH_DEBUG_RENDERER
+	// Update drawing range
+	sDrawSimulationStrandCount = (uint)mHairSettings->mSimStrands.size();
+#endif // JPH_DEBUG_RENDERER
+}
+
+void HairTest::PrePhysicsUpdate(const PreUpdateParams &inParams)
+{
+	BodyInterface &bi = mPhysicsSystem->GetBodyInterfaceNoLock();
+
+#ifdef JPH_DEBUG_RENDERER
+	Hair::DrawSettings settings;
+	settings.mSimulationStrandBegin = sDrawSimulationStrandBegin;
+	settings.mSimulationStrandEnd = sDrawSimulationStrandBegin + sDrawSimulationStrandCount;
+	settings.mDrawRods = sDrawRods;
+	settings.mDrawUnloadedRods = sDrawUnloadedRods;
+	settings.mDrawRenderStrands = sDrawRenderStrands;
+	settings.mRenderStrandColor = sRenderStrandColor;
+	settings.mDrawVertexVelocity = sDrawVertexVelocity;
+	settings.mDrawAngularVelocity = sDrawAngularVelocity;
+	settings.mDrawOrientations = sDrawOrientations;
+	settings.mDrawGridVelocity = sDrawGridVelocity;
+	settings.mDrawGridDensity = sDrawGridDensity;
+	settings.mDrawSkinPoints = sDrawSkinPoints;
+	settings.mDrawNeutralDensity = sDrawNeutralDensity;
+	settings.mDrawInitialGravity = sDrawInitialGravity;
+	mHair->Draw(settings, mDebugRenderer);
+#else
+	// Draw the rods
+	mHair->LockReadBackBuffers();
+	const Float3 *positions = mHair->GetRenderPositions();
+	RMat44 com = mHair->GetWorldTransform();
+	if (sDrawRenderStrands)
+	{
+		JPH_PROFILE("Draw Render Strands");
+
+		Color color = Color::sWhite;
+		Hash<uint32> hasher;
+		for (const HairSettings::RStrand &render_strand : mHairSettings->mRenderStrands)
+		{
+			RVec3 x0 = com * Vec3(positions[render_strand.mStartVtx]);
+			for (uint32 v = render_strand.mStartVtx + 1; v < render_strand.mEndVtx; ++v)
+			{
+				RVec3 x1 = com * Vec3(positions[v]);
+				mDebugRenderer->DrawLine(x0, x1, color);
+				x0 = x1;
+			}
+			color = Color(uint32(hasher(color.GetUInt32())) | 0xff000000);
+		}
+	}
+	mHair->UnlockReadBackBuffers();
+#endif // JPH_DEBUG_RENDERER
+
+	// Get skinned vertices
+	RMat44 neck_transform = mHair->GetWorldTransform();
+
+	if (sDrawHeadMesh)
+	{
+		JPH_PROFILE("Draw Head Mesh");
+
+		const Float3 *scalp_vertices = mHair->GetScalpVertices();
+		Ref<DebugRenderer::Geometry> geometry = new DebugRenderer::Geometry(mDebugRenderer->CreateTriangleBatch(scalp_vertices, (uint)mHairSettings->mScalpVertices.size(), mHairSettings->mScalpTriangles.data(), (uint)mHairSettings->mScalpTriangles.size()), mHairSettings->mSimulationBounds);
+		mDebugRenderer->DrawGeometry(neck_transform, Color::sGrey, geometry, DebugRenderer::ECullMode::CullBackFace, DebugRenderer::ECastShadow::On, DebugRenderer::EDrawMode::Solid);
+	}
+
+	// Select the next animation frame
+	++mFrame;
+	mFrame = mFrame % (uint32)mFaceAnimation.size();
+	Array<Mat44> &joints = mFaceAnimation[mFrame];
+
+	// Position the collision hulls
+	for (const AttachedBody &ab : mAttachedBodies)
+	{
+		Mat44 body_transform = ab.mJointIdx != 0xffffffff? joints[ab.mJointIdx] : Mat44::sIdentity();
+		bi.MoveKinematic(ab.mBodyID, RVec3(body_transform.GetTranslation()), body_transform.GetQuaternion(), inParams.mDeltaTime);
+	}
+
+	// Set the new rotation of the hair
+	RVec3 position = RVec3(joints[mHeadJointIdx].GetTranslation());
+	Quat rotation = joints[mHeadJointIdx].GetQuaternion();
+	mHair->SetPosition(position);
+	mHair->SetRotation(rotation);
+
+	// Update the hair
+	mHair->Update(inParams.mDeltaTime, joints[mHeadJointIdx].Inversed(), joints.data(), *mPhysicsSystem, mHairShaders, mComputeSystem, mComputeQueue);
+	{
+		JPH_PROFILE("Hair Compute");
+		mComputeQueue->ExecuteAndWait();
+	}
+	mHair->ReadBackGPUState(mComputeQueue);
+}
+
+void HairTest::SaveState(StateRecorder &inStream) const
+{
+	inStream.Write(mFrame);
+}
+
+void HairTest::RestoreState(StateRecorder &inStream)
+{
+	inStream.Read(mFrame);
+}
+
+void HairTest::GradientSetting(DebugUI *inUI, UIElement *inSubMenu, const String &inName, float inMax, float inStep, HairSettings::Gradient &inStaticStorage, HairSettings::Gradient &inDynamicStorage)
+{
+	inUI->CreateTextButton(inSubMenu, inName, [inUI, inName, inMax, inStep, &inStaticStorage, &inDynamicStorage]() {
+		UIElement *gradient_setting = inUI->CreateMenu();
+		inUI->CreateSlider(gradient_setting, inName + " Min", inStaticStorage.mMin, 0.0f, inMax, inStep, [&inStaticStorage, &inDynamicStorage](float inValue) { inStaticStorage.mMin = inDynamicStorage.mMin = inValue; });
+		inUI->CreateSlider(gradient_setting, inName + " Max", inStaticStorage.mMax, 0.0f, inMax, inStep, [&inStaticStorage, &inDynamicStorage](float inValue) { inStaticStorage.mMax = inDynamicStorage.mMax = inValue; });
+		inUI->CreateSlider(gradient_setting, inName + " Min Fraction", inStaticStorage.mMinFraction, 0.0f, 1.0f, 0.01f, [&inStaticStorage, &inDynamicStorage](float inValue) { inStaticStorage.mMinFraction = inDynamicStorage.mMinFraction = min(inStaticStorage.mMaxFraction - 0.001f, inValue); });
+		inUI->CreateSlider(gradient_setting, inName + " Max Fraction", inStaticStorage.mMaxFraction, 0.0f, 1.0f, 0.01f, [&inStaticStorage, &inDynamicStorage](float inValue) { inStaticStorage.mMaxFraction = inDynamicStorage.mMaxFraction = max(inStaticStorage.mMinFraction + 0.001f, inValue); });
+		inUI->ShowMenu(gradient_setting);
+	});
+}
+
+void HairTest::CreateSettingsMenu(DebugUI *inUI, UIElement *inSubMenu)
+{
+	inUI->CreateTextButton(inSubMenu, "Select Groom", [this, inUI]() {
+		UIElement *groom_name = inUI->CreateMenu();
+		for (uint i = 0; i < size(sGrooms); ++i)
+			inUI->CreateTextButton(groom_name, sGrooms[i].mName, [this, i]() { sSelectedGroom = &sGrooms[i]; RestartTest(); });
+		inUI->ShowMenu(groom_name);
+	});
+	inUI->CreateCheckBox(inSubMenu, "Limit Max Strands", sLimitMaxStrands, [](UICheckBox::EState inState) { sLimitMaxStrands = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateSlider(inSubMenu, "Max Strands", float(sMaxStrands), 1.0f, 10000.0f, 1.0f, [](float inValue) { sMaxStrands = uint(inValue); });
+	inUI->CreateSlider(inSubMenu, "Simulation Strands Percentage", float(sSimulationStrandsPercentage), 1.0f, 100.0f, 1.0f, [](float inValue) { sSimulationStrandsPercentage = inValue; });
+	inUI->CreateSlider(inSubMenu, "Override Vertices Per Strand", float(sOverrideVerticesPerStrand), 1.0f, 64.0f, 1.0f, [](float inValue) { sOverrideVerticesPerStrand = uint(inValue); });
+	inUI->CreateSlider(inSubMenu, "Num Solver Iterations Per Second", float(sNumSolverIterationsPerSecond), 1.0f, 960.0f, 1.0f, [settings = mHairSettings](float inValue) { sNumSolverIterationsPerSecond = uint(inValue); settings->mNumIterationsPerSecond = sNumSolverIterationsPerSecond; });
+	GradientSetting(inUI, inSubMenu, "Hair Radius", 0.01f, 0.001f, sHairRadius, mHairSettings->mMaterials[0].mHairRadius);
+	inUI->CreateCheckBox(inSubMenu, "Enable Collision", sEnableCollision, [settings = mHairSettings](UICheckBox::EState inState) { sEnableCollision = inState == UICheckBox::STATE_CHECKED; settings->mMaterials[0].mEnableCollision = sEnableCollision; });
+	inUI->CreateCheckBox(inSubMenu, "Enable LRA", sEnableLRA, [settings = mHairSettings](UICheckBox::EState inState) { sEnableLRA = inState == UICheckBox::STATE_CHECKED; settings->mMaterials[0].mEnableLRA = sEnableLRA; });
+	inUI->CreateSlider(inSubMenu, "Bend Compliance (10^x)", sBendComplianceExponent, -10.0f, 0.0f, 0.01f, [settings = mHairSettings](float inValue) { sBendComplianceExponent = inValue; settings->mMaterials[0].mBendCompliance = std::pow(10.0f, inValue); });
+	inUI->CreateSlider(inSubMenu, "Stretch Compliance (10^x)", sStretchComplianceExponent, -10.0f, 0.0f, 0.01f, [settings = mHairSettings](float inValue) { sStretchComplianceExponent = inValue; settings->mMaterials[0].mStretchCompliance = std::pow(10.0f, inValue); });
+	inUI->CreateSlider(inSubMenu, "Inertia Multiplier", sInertiaMultiplier, 1.0f, 100.0f, 0.1f, [settings = mHairSettings](float inValue) { sInertiaMultiplier = inValue; settings->mMaterials[0].mInertiaMultiplier = inValue; });
+	inUI->CreateSlider(inSubMenu, "Linear Damping", sLinearDamping, 0.0f, 5.0f, 0.01f, [settings = mHairSettings](float inValue) { sLinearDamping = inValue; settings->mMaterials[0].mLinearDamping = inValue; });
+	inUI->CreateSlider(inSubMenu, "Angular Damping", sAngularDamping, 0.0f, 5.0f, 0.01f, [settings = mHairSettings](float inValue) { sAngularDamping = inValue; settings->mMaterials[0].mAngularDamping = inValue; });
+	inUI->CreateSlider(inSubMenu, "Friction", sFriction, 0.0f, 1.0f, 0.01f, [settings = mHairSettings](float inValue) { sFriction = inValue; settings->mMaterials[0].mFriction = inValue; });
+	inUI->CreateSlider(inSubMenu, "Max Linear Velocity", sMaxLinearVelocity, 0.01f, 10.0f, 0.01f, [settings = mHairSettings](float inValue) { sMaxLinearVelocity = inValue; settings->mMaterials[0].mMaxLinearVelocity = inValue; });
+	inUI->CreateSlider(inSubMenu, "Max Angular Velocity", sMaxAngularVelocity, 0.01f, 50.0f, 0.01f, [settings = mHairSettings](float inValue) { sMaxAngularVelocity = inValue; settings->mMaterials[0].mMaxAngularVelocity = inValue; });
+	GradientSetting(inUI, inSubMenu, "World Transform Influence", 1.0f, 0.01f, sWorldTransformInfluence, mHairSettings->mMaterials[0].mWorldTransformInfluence);
+	GradientSetting(inUI, inSubMenu, "Gravity Factor", 1.0f, 0.01f, sGravityFactor, mHairSettings->mMaterials[0].mGravityFactor);
+	inUI->CreateSlider(inSubMenu, "Gravity Preload Factor", sGravityPreloadFactor, 0.0f, 1.0f, 0.01f, [settings = mHairSettings](float inValue) { sGravityPreloadFactor = inValue; });
+	GradientSetting(inUI, inSubMenu, "Grid Velocity Factor", 1.0f, 0.01f, sGridVelocityFactor, mHairSettings->mMaterials[0].mGridVelocityFactor);
+	inUI->CreateSlider(inSubMenu, "Grid Density Force Factor", sGridDensityForceFactor, 0.0f, 10.0f, 0.1f, [settings = mHairSettings](float inValue) { sGridDensityForceFactor = inValue; settings->mMaterials[0].mGridDensityForceFactor = sGridDensityForceFactor; });
+	GradientSetting(inUI, inSubMenu, "Global Pose", 1.0f, 0.001f, sGlobalPose, mHairSettings->mMaterials[0].mGlobalPose);
+	GradientSetting(inUI, inSubMenu, "Skin Global Pose", 1.0f, 0.001f, sSkinGlobalPose, mHairSettings->mMaterials[0].mSkinGlobalPose);
+#ifdef JPH_DEBUG_RENDERER
+	if (mHairSettings->mSimStrands.size() > 1)
+	{
+		inUI->CreateSlider(inSubMenu, "Draw Simulation Strand Begin", (float)sDrawSimulationStrandBegin, 0.0f, float(mHairSettings->mSimStrands.size() - 1), 1.0f, [](float inValue) { sDrawSimulationStrandBegin = (uint)inValue; });
+		inUI->CreateSlider(inSubMenu, "Draw Simulation Strand Count", (float)sDrawSimulationStrandCount, 1.0f, (float)mHairSettings->mSimStrands.size(), 1.0f, [](float inValue) { sDrawSimulationStrandCount = (uint)inValue; });
+	}
+	inUI->CreateCheckBox(inSubMenu, "Draw Rods", sDrawRods, [](UICheckBox::EState inState) { sDrawRods = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Unloaded Rods", sDrawUnloadedRods, [](UICheckBox::EState inState) { sDrawUnloadedRods = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Vertex Velocity", sDrawVertexVelocity, [](UICheckBox::EState inState) { sDrawVertexVelocity = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Angular Velocity", sDrawAngularVelocity, [](UICheckBox::EState inState) { sDrawAngularVelocity = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Rod Orientations", sDrawOrientations, [](UICheckBox::EState inState) { sDrawOrientations = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Neutral Density", sDrawNeutralDensity, [](UICheckBox::EState inState) { sDrawNeutralDensity = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Grid Density", sDrawGridDensity, [](UICheckBox::EState inState) { sDrawGridDensity = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Grid Velocity", sDrawGridVelocity, [](UICheckBox::EState inState) { sDrawGridVelocity = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Skin Points", sDrawSkinPoints, [](UICheckBox::EState inState) { sDrawSkinPoints = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Render Strands", sDrawRenderStrands, [](UICheckBox::EState inState) { sDrawRenderStrands = inState == UICheckBox::STATE_CHECKED; });
+	inUI->CreateComboBox(inSubMenu, "Render Strands Color", { "PerRenderStrand", "PerSimulatedStrand", "GravityFactor", "WorldInfluence", "GridVelocityFactor", "GlobalPose", "SkinGlobalPose" }, (int)sRenderStrandColor, [](int inItem) { sRenderStrandColor = (Hair::ERenderStrandColor)inItem; });
+	inUI->CreateCheckBox(inSubMenu, "Draw Initial Gravity", sDrawInitialGravity, [](UICheckBox::EState inState) { sDrawInitialGravity = inState == UICheckBox::STATE_CHECKED; });
+#endif // JPH_DEBUG_RENDERER
+	inUI->CreateCheckBox(inSubMenu, "Draw Head Mesh", sDrawHeadMesh, [](UICheckBox::EState inState) { sDrawHeadMesh = inState == UICheckBox::STATE_CHECKED; });
+}

+ 105 - 0
Samples/Tests/Hair/HairTest.h

@@ -0,0 +1,105 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Tests/Test.h>
+#include <Jolt/Physics/Hair/Hair.h>
+#include <Jolt/Physics/Hair/HairShaders.h>
+
+class HairTest : public Test
+{
+public:
+	JPH_DECLARE_RTTI_VIRTUAL(JPH_NO_EXPORT, HairTest)
+
+	// Destructor
+	virtual					~HairTest() override										{ delete mHair; }
+
+	// Description of the test
+	virtual const char *	GetDescription() const override
+	{
+		return "Hair demo.";
+	}
+
+	// See: Test
+	virtual void			Initialize() override;
+	virtual void			PrePhysicsUpdate(const PreUpdateParams &inParams) override;
+	virtual void			SaveState(StateRecorder &inStream) const override;
+	virtual void			RestoreState(StateRecorder &inStream) override;
+
+	// Number used to scale the terrain and camera movement to the scene
+	virtual float			GetWorldScale() const override								{ return 0.01f; }
+
+	// Optional settings menu
+	virtual bool			HasSettingsMenu() const override							{ return true; }
+	virtual void			CreateSettingsMenu(DebugUI *inUI, UIElement *inSubMenu) override;
+
+private:
+	using Gradient = HairSettings::Gradient;
+
+	void					GradientSetting(DebugUI *inUI, UIElement *inSubMenu, const String &inName, float inMax, float inStep, Gradient &inStaticStorage, Gradient &inDynamicStorage);
+
+	struct Groom
+	{
+		const char *		mName;
+		std::function<Vec3(Mat44Arg, Vec3Arg)> mVertexTransform;
+		bool				mAttachToHull;
+	};
+
+	static const Groom		sGrooms[];
+	static const Groom *	sSelectedGroom;
+	inline static bool		sLimitMaxStrands = true;
+	inline static uint		sMaxStrands = JPH_IF_DEBUG(500u) JPH_IF_NOT_DEBUG(25000u);
+	inline static float		sSimulationStrandsPercentage = 10.0f;
+	inline static uint		sOverrideVerticesPerStrand = 32;
+	inline static uint		sNumSolverIterationsPerSecond = HairSettings::cDefaultIterationsPerSecond;
+	inline static bool		sEnableCollision = true;
+	inline static bool		sEnableLRA = true;
+	inline static float		sLinearDamping = 2.0f;
+	inline static float		sAngularDamping = 2.0f;
+	inline static float		sFriction = 0.2f;
+	inline static float		sMaxLinearVelocity = 10.0f;
+	inline static float		sMaxAngularVelocity = 50.0f;
+	inline static float		sBendComplianceExponent = -7.0f; // 1.0e-7f
+	inline static float		sStretchComplianceExponent = -8.0f; // 1.0e-8f
+	inline static float		sInertiaMultiplier = 10.0f;
+	inline static Gradient	sHairRadius { 0.001f, 0.001f };
+	inline static Gradient	sWorldTransformInfluence { 0.0f, 1.0f };
+	inline static Gradient	sGravityFactor { 0.1f, 1.0f, 0.2f, 0.8f };
+	inline static float		sGravityPreloadFactor = 1.0f;
+	inline static Gradient	sGridVelocityFactor { 0.05f, 0.01f };
+	inline static Gradient	sGlobalPose { 0.01f, 0, 0.0f, 0.3f };
+	inline static Gradient	sSkinGlobalPose { 1.0f, 0.0f, 0.0f, 0.1f };
+	inline static float		sGridDensityForceFactor = 0.0f;
+#ifdef JPH_DEBUG_RENDERER
+	inline static uint		sDrawSimulationStrandBegin = 0;
+	inline static uint		sDrawSimulationStrandCount = UINT_MAX;
+	inline static bool		sDrawRods = false;
+	inline static bool		sDrawUnloadedRods = false;
+	inline static bool		sDrawVertexVelocity = false;
+	inline static bool		sDrawAngularVelocity = false;
+	inline static bool		sDrawOrientations = false;
+	inline static bool		sDrawNeutralDensity = false;
+	inline static bool		sDrawGridDensity = false;
+	inline static bool		sDrawGridVelocity = false;
+	inline static bool		sDrawSkinPoints = false;
+	inline static Hair::ERenderStrandColor sRenderStrandColor = Hair::ERenderStrandColor::PerSimulatedStrand;
+	inline static bool		sDrawInitialGravity = false;
+#endif // JPH_DEBUG_RENDERER
+	inline static bool		sDrawRenderStrands = true;
+	inline static bool		sDrawHeadMesh = true;
+
+	uint32					mHeadJointIdx = 0;
+	Array<Array<Mat44>>		mFaceAnimation;
+	struct AttachedBody
+	{
+		uint32				mJointIdx;
+		BodyID				mBodyID;
+	};
+	Array<AttachedBody>		mAttachedBodies;
+	Ref<HairSettings>		mHairSettings = nullptr;
+	HairShaders				mHairShaders;
+	Hair *					mHair = nullptr;
+	uint					mFrame = 0;
+};