Przeglądaj źródła

Allow Wave intrinsics in DXR shader stages. (#2742)

Tex Riddell 5 lat temu
rodzic
commit
cd48b34db0

+ 2 - 2
lib/DXIL/DxilOperations.cpp

@@ -656,7 +656,7 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
   // WaveReadLaneFirst=118, WaveActiveOp=119, WaveActiveBit=120,
   // WavePrefixOp=121, WaveAllBitCount=135, WavePrefixBitCount=136
   if ((110 <= op && op <= 121) || (135 <= op && op <= 136)) {
-    mask = SFLAG(Library) | SFLAG(Compute) | SFLAG(Amplification) | SFLAG(Mesh) | SFLAG(Pixel) | SFLAG(Vertex) | SFLAG(Hull) | SFLAG(Domain) | SFLAG(Geometry);
+    mask = SFLAG(Library) | SFLAG(Compute) | SFLAG(Amplification) | SFLAG(Mesh) | SFLAG(Pixel) | SFLAG(Vertex) | SFLAG(Hull) | SFLAG(Domain) | SFLAG(Geometry) | SFLAG(RayGeneration) | SFLAG(Intersection) | SFLAG(AnyHit) | SFLAG(ClosestHit) | SFLAG(Miss) | SFLAG(Callable);
     return;
   }
   // Instructions: Sample=60, SampleBias=61, SampleCmp=64, CalculateLOD=81,
@@ -792,7 +792,7 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
   // WaveMultiPrefixBitCount=167
   if ((165 <= op && op <= 167)) {
     major = 6;  minor = 5;
-    mask = SFLAG(Library) | SFLAG(Compute) | SFLAG(Amplification) | SFLAG(Mesh) | SFLAG(Pixel) | SFLAG(Vertex) | SFLAG(Hull) | SFLAG(Domain) | SFLAG(Geometry);
+    mask = SFLAG(Library) | SFLAG(Compute) | SFLAG(Amplification) | SFLAG(Mesh) | SFLAG(Pixel) | SFLAG(Vertex) | SFLAG(Hull) | SFLAG(Domain) | SFLAG(Geometry) | SFLAG(RayGeneration) | SFLAG(Intersection) | SFLAG(AnyHit) | SFLAG(ClosestHit) | SFLAG(Miss) | SFLAG(Callable);
     return;
   }
   // Instructions: GeometryIndex=213

+ 2 - 2
lib/HLSL/DxilValidation.cpp

@@ -887,7 +887,7 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   // WaveReadLaneFirst=118, WaveActiveOp=119, WaveActiveBit=120,
   // WavePrefixOp=121, WaveAllBitCount=135, WavePrefixBitCount=136
   if ((110 <= op && op <= 121) || (135 <= op && op <= 136))
-    return (SK == DXIL::ShaderKind::Library || SK == DXIL::ShaderKind::Compute || SK == DXIL::ShaderKind::Amplification || SK == DXIL::ShaderKind::Mesh || SK == DXIL::ShaderKind::Pixel || SK == DXIL::ShaderKind::Vertex || SK == DXIL::ShaderKind::Hull || SK == DXIL::ShaderKind::Domain || SK == DXIL::ShaderKind::Geometry);
+    return (SK == DXIL::ShaderKind::Library || SK == DXIL::ShaderKind::Compute || SK == DXIL::ShaderKind::Amplification || SK == DXIL::ShaderKind::Mesh || SK == DXIL::ShaderKind::Pixel || SK == DXIL::ShaderKind::Vertex || SK == DXIL::ShaderKind::Hull || SK == DXIL::ShaderKind::Domain || SK == DXIL::ShaderKind::Geometry || SK == DXIL::ShaderKind::RayGeneration || SK == DXIL::ShaderKind::Intersection || SK == DXIL::ShaderKind::AnyHit || SK == DXIL::ShaderKind::ClosestHit || SK == DXIL::ShaderKind::Miss || SK == DXIL::ShaderKind::Callable);
   // Instructions: Sample=60, SampleBias=61, SampleCmp=64, CalculateLOD=81,
   // DerivCoarseX=83, DerivCoarseY=84, DerivFineX=85, DerivFineY=86
   if ((60 <= op && op <= 61) || op == 64 || op == 81 || (83 <= op && op <= 86))
@@ -981,7 +981,7 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   // WaveMultiPrefixBitCount=167
   if ((165 <= op && op <= 167))
     return (major > 6 || (major == 6 && minor >= 5))
-        && (SK == DXIL::ShaderKind::Library || SK == DXIL::ShaderKind::Compute || SK == DXIL::ShaderKind::Amplification || SK == DXIL::ShaderKind::Mesh || SK == DXIL::ShaderKind::Pixel || SK == DXIL::ShaderKind::Vertex || SK == DXIL::ShaderKind::Hull || SK == DXIL::ShaderKind::Domain || SK == DXIL::ShaderKind::Geometry);
+        && (SK == DXIL::ShaderKind::Library || SK == DXIL::ShaderKind::Compute || SK == DXIL::ShaderKind::Amplification || SK == DXIL::ShaderKind::Mesh || SK == DXIL::ShaderKind::Pixel || SK == DXIL::ShaderKind::Vertex || SK == DXIL::ShaderKind::Hull || SK == DXIL::ShaderKind::Domain || SK == DXIL::ShaderKind::Geometry || SK == DXIL::ShaderKind::RayGeneration || SK == DXIL::ShaderKind::Intersection || SK == DXIL::ShaderKind::AnyHit || SK == DXIL::ShaderKind::ClosestHit || SK == DXIL::ShaderKind::Miss || SK == DXIL::ShaderKind::Callable);
   // Instructions: GeometryIndex=213
   if (op == 213)
     return (major > 6 || (major == 6 && minor >= 5))

+ 23 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/raytracing_anyhit_wave.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -T lib_6_3 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK: call i1 @dx.op.waveActiveAllEqual.i32(i32 115,
+
+struct MyPayload {
+  float4 color;
+  uint2 pos;
+};
+
+struct MyAttributes {
+  float2 bary;
+  uint id;
+};
+
+[shader("anyhit")]
+void anyhit1( inout MyPayload payload : SV_RayPayload,
+              in MyAttributes attr : SV_IntersectionAttributes )
+{
+  if (WaveActiveAllEqual(attr.id)) {
+    AcceptHitAndEndSearch();
+  }
+  payload.color += float4(0.125, 0.25, 0.5, 1.0);
+}

+ 25 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/raytracing_callable_wave.hlsl

@@ -0,0 +1,25 @@
+// RUN: %dxc -T lib_6_3 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK:   %[[_7_:[0-9]+]] = call %dx.types.ResRet.f32 @dx.op.sampleLevel.f32(i32 62,
+// CHECK:   %[[_8_:[0-9]+]] = extractvalue %dx.types.ResRet.f32 %[[_7_]], 0
+// CHECK:   %[[_9_:[0-9]+]] = extractvalue %dx.types.ResRet.f32 %[[_7_]], 1
+// CHECK:   %[[_10_:[0-9]+]] = extractvalue %dx.types.ResRet.f32 %[[_7_]], 2
+// CHECK:   %[[_11_:[0-9]+]] = extractvalue %dx.types.ResRet.f32 %[[_7_]], 3
+// CHECK:   call float @dx.op.wavePrefixOp.f32(i32 121, float %[[_8_]], i8 1, i8 0)
+// CHECK:   call float @dx.op.wavePrefixOp.f32(i32 121, float %[[_9_]], i8 1, i8 0)
+// CHECK:   call float @dx.op.wavePrefixOp.f32(i32 121, float %[[_10_]], i8 1, i8 0)
+// CHECK:   call float @dx.op.wavePrefixOp.f32(i32 121, float %[[_11_]], i8 1, i8 0)
+
+struct MyParam {
+  float2 coord;
+  float4 output;
+};
+
+Texture2D T : register(t1);
+SamplerState S : register(s1);
+
+[shader("callable")]
+void callable1(inout MyParam param)
+{
+  param.output = WavePrefixProduct(T.SampleLevel(S, param.coord, 0));
+}

+ 21 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/raytracing_closesthit_wave.hlsl

@@ -0,0 +1,21 @@
+// RUN: %dxc -T lib_6_3 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %{{.*}}, i8 2, i8 0)
+// CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %{{.*}}, i8 2, i8 0)
+
+struct MyPayload {
+  float4 color;
+  uint2 pos;
+};
+
+struct MyParam {
+  float2 coord;
+  float4 output;
+};
+
+[shader("closesthit")]
+void closesthit1( inout MyPayload payload : SV_RayPayload,
+                  in BuiltInTriangleIntersectionAttributes attr : SV_IntersectionAttributes )
+{
+  payload.color.xy += attr.barycentrics + WaveActiveMin(attr.barycentrics);
+}

+ 18 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/raytracing_intersection_wave.hlsl

@@ -0,0 +1,18 @@
+// RUN: %dxc -T lib_6_3 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK:   %[[RayTCurrent:RayTCurrent|[0-9]+]] = call float @dx.op.rayTCurrent.f32(i32 154)
+// CHECK:   %[[WaveActiveOp:WaveActiveOp|[0-9]+]] = call float @dx.op.waveActiveOp.f32(i32 119, float %[[RayTCurrent]], i8 2, i8 0)
+// CHECK:   call i1 @dx.op.reportHit.struct.MyAttributes(i32 158, float %[[WaveActiveOp]], i32 0,
+
+struct MyAttributes {
+  float2 bary;
+  uint id;
+};
+
+[shader("intersection")]
+void intersection1()
+{
+  float hitT = RayTCurrent();
+  MyAttributes attr = (MyAttributes)0;
+  bool bReported = ReportHit(WaveActiveMin(hitT), 0, attr);
+}

+ 17 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/raytracing_miss_wave.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -T lib_6_3 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK:   call float @dx.op.wavePrefixOp.f32(i32 121, float %{{.*}}, i8 1, i8 0)
+// CHECK:   call float @dx.op.wavePrefixOp.f32(i32 121, float %{{.*}}, i8 1, i8 0)
+// CHECK:   call float @dx.op.wavePrefixOp.f32(i32 121, float %{{.*}}, i8 1, i8 0)
+// CHECK:   call float @dx.op.wavePrefixOp.f32(i32 121, float %{{.*}}, i8 1, i8 0)
+
+struct MyPayload {
+  float4 color;
+  uint2 pos;
+};
+
+[shader("miss")]
+void miss1(inout MyPayload payload : SV_RayPayload)
+{
+  payload.color = WavePrefixProduct(payload.color);
+}

+ 30 - 0
tools/clang/test/HLSLFileCheck/shader_targets/raytracing/raytracing_raygeneration_wave.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -T lib_6_3 -auto-binding-space 11 %s | FileCheck %s
+
+// CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 8, i8 0, i8 0)
+
+struct MyPayload {
+  float4 color;
+  uint2 pos;
+};
+
+RaytracingAccelerationStructure RTAS : register(t5);
+
+RWByteAddressBuffer Log;
+
+[shader("raygeneration")]
+void raygen1()
+{
+  MyPayload p = (MyPayload)0;
+  p.pos = DispatchRaysIndex().xy;
+
+  uint offset = 0;
+  Log.InterlockedAdd(0, 8, offset);
+  offset += WaveActiveSum(8);
+  Log.Store(offset, p.pos.x);
+  Log.Store(offset + 4, p.pos.y);
+
+  float3 origin = {0, 0, 0};
+  float3 dir = normalize(float3(p.pos / (float)DispatchRaysDimensions(), 1));
+  RayDesc ray = { origin, 0.125, dir, 128.0};
+  TraceRay(RTAS, RAY_FLAG_NONE, 0, 0, 1, 0, ray, p);
+}

+ 4 - 1
utils/hct/hctdb.py

@@ -318,7 +318,10 @@ class db_dxil(object):
             if i.name.startswith("Wave"):
                 i.category = "Wave"
                 i.is_wave = True
-                i.shader_stages = ("library", "compute", "amplification", "mesh", "pixel", "vertex", "hull", "domain", "geometry")
+                i.shader_stages = (
+                    "library", "compute", "amplification", "mesh",
+                    "pixel", "vertex", "hull", "domain", "geometry",
+                    "raygeneration", "intersection", "anyhit", "closesthit", "miss", "callable")
             elif i.name.startswith("Quad"):
                 i.category = "Quad Wave Ops"
                 i.is_wave = True