Browse Source

Fix handling of boolean expressions in Wave Intrinsics (#2502)

Vishal Sharma 6 years ago
parent
commit
bd57c4413e

+ 5 - 5
tools/clang/lib/Sema/gen_intrin_main_tables_15.h

@@ -487,7 +487,7 @@ static const HLSL_INTRINSIC_ARGUMENT g_Intrinsics_Args67[] =
 static const HLSL_INTRINSIC_ARGUMENT g_Intrinsics_Args68[] =
 {
     {"WaveActiveAllEqual", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 0, LICOMPTYPE_BOOL, IA_R, IA_C},
-    {"value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, IA_R, IA_C},
+    {"value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_ANY, IA_R, IA_C},
 };
 
 static const HLSL_INTRINSIC_ARGUMENT g_Intrinsics_Args69[] =
@@ -639,15 +639,15 @@ static const HLSL_INTRINSIC_ARGUMENT g_Intrinsics_Args92[] =
 
 static const HLSL_INTRINSIC_ARGUMENT g_Intrinsics_Args93[] =
 {
-    {"WaveReadLaneAt", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, IA_R, IA_C},
-    {"value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, IA_R, IA_C},
+    {"WaveReadLaneAt", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_ANY, IA_R, IA_C},
+    {"value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_ANY, IA_R, IA_C},
     {"lane", AR_QUAL_IN, 2, LITEMPLATE_SCALAR, 2, LICOMPTYPE_UINT, 1, 1},
 };
 
 static const HLSL_INTRINSIC_ARGUMENT g_Intrinsics_Args94[] =
 {
-    {"WaveReadLaneFirst", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, IA_R, IA_C},
-    {"value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, IA_R, IA_C},
+    {"WaveReadLaneFirst", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_ANY, IA_R, IA_C},
+    {"value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_ANY, IA_R, IA_C},
 };
 
 static const HLSL_INTRINSIC_ARGUMENT g_Intrinsics_Args95[] =

+ 59 - 0
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/reduction/WaveReductionOps.hlsl

@@ -0,0 +1,59 @@
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_0 /DTYPE=float /DRET_TYPE=float %s | FileCheck %s -check-prefix=RDLNAT_FLT
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_2 -enable-16bit-types /DTYPE=half /DRET_TYPE=half %s | FileCheck %s -check-prefix=RDLNAT_HALF
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_0 /DTYPE=min16float /DRET_TYPE=min16float %s | FileCheck %s -check-prefix=RDLNAT_MIN16FLT
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_0 /DTYPE=int /DRET_TYPE=int %s | FileCheck %s -check-prefix=RDLNAT_INT
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_0 /DTYPE=uint /DRET_TYPE=uint %s | FileCheck %s -check-prefix=RDLNAT_UINT
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_0 /DTYPE=double /DRET_TYPE=float %s | FileCheck %s -check-prefix=RDLNAT_DBL
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_0 /DTYPE=bool /DRET_TYPE=bool %s | FileCheck %s -check-prefix=RDLNAT_BOOL
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_0 /DTYPE=bool3 /DRET_TYPE=bool3 %s | FileCheck %s -check-prefix=RDLNAT_BOOL_VEC3
+// RUN: %dxc -E test_wavereadlaneat -T vs_6_0 /DTYPE=int2x3 /DRET_TYPE=int %s | FileCheck %s -check-prefix=RDLNAT_MAT
+// RUN: %dxc -E test_wavereadlanefirst -T vs_6_0 /DTYPE=float /DRET_TYPE=float %s | FileCheck %s -check-prefix=RDLNFRST_FLT
+// RUN: %dxc -E test_wavereadlanefirst -T vs_6_2 -enable-16bit-types /DTYPE=half /DRET_TYPE=half %s | FileCheck %s -check-prefix=RDLNFRST_HALF
+// RUN: %dxc -E test_wavereadlanefirst -T vs_6_0 /DTYPE=min16float /DRET_TYPE=min16float %s | FileCheck %s -check-prefix=RDLNFRST_MIN16FLT
+// RUN: %dxc -E test_wavereadlanefirst -T vs_6_0 /DTYPE=int /DRET_TYPE=int %s | FileCheck %s -check-prefix=RDLNFRST_INT
+// RUN: %dxc -E test_wavereadlanefirst -T vs_6_0 /DTYPE=uint /DRET_TYPE=uint %s | FileCheck %s -check-prefix=RDLNFRST_UINT
+// RUN: %dxc -E test_wavereadlanefirst -T vs_6_0 /DTYPE=bool /DRET_TYPE=bool %s | FileCheck %s -check-prefix=RDLNFRST_BOOL
+// RUN: %dxc -E test_wavereadlanefirst -T vs_6_0 /DTYPE=bool3 /DRET_TYPE=bool3 %s | FileCheck %s -check-prefix=RDLNFRST_BOOL_VEC3
+// RUN: %dxc -E test_wavereadlanefirst -T vs_6_0 /DTYPE=int2x3 /DRET_TYPE=int %s | FileCheck %s -check-prefix=RDLNFRST_MAT
+
+// This file should contain tests to cover all supported overloads of WaveIntrinsics used for reduction operations
+// TODO: Currently only covers WaveReadFirstLane and WaveReadLaneAt. Add coverage for others.
+// TODO: Add related coverage once these bugs are fixed: bug# 2501
+
+cbuffer CB
+{
+  TYPE expr;
+}
+
+// RDLNAT_FLT: call float @dx.op.waveReadLaneAt.f32(i32 117, float
+// RDLNAT_HALF: call half @dx.op.waveReadLaneAt.f16(i32 117, half
+// RDLNAT_MIN16FLT: call half @dx.op.waveReadLaneAt.f16(i32 117, half
+// RDLNAT_INT: call i32 @dx.op.waveReadLaneAt.i32(i32 117, i32
+// RDLNAT_UINT: call i32 @dx.op.waveReadLaneAt.i32(i32 117, i32
+// RDLNAT_DBL: call double @dx.op.waveReadLaneAt.f64(i32 117, double
+// RDLNAT_BOOL: call i1 @dx.op.waveReadLaneAt.i1(i32 117, i1
+// RDLNAT_BOOL_VEC3: call i1 @dx.op.waveReadLaneAt.i1(i32 117, i1
+// RDLNAT_BOOL_VEC3: call i1 @dx.op.waveReadLaneAt.i1(i32 117, i1
+// RDLNAT_BOOL_VEC3: call i1 @dx.op.waveReadLaneAt.i1(i32 117, i1
+// RDLNAT_MAT: call i32 @dx.op.waveReadLaneAt.i32(i32 117, i32
+
+RET_TYPE test_wavereadlaneat(uint id: IN0) : OUT
+{
+  return WaveReadLaneAt(expr, id);
+}
+
+
+// RDLNFRST_FLT: call float @dx.op.waveReadLaneFirst.f32(i32 118, float
+// RDLNFRST_HALF: call half @dx.op.waveReadLaneFirst.f16(i32 118, half
+// RDLNFRST_MIN16FLT: call half @dx.op.waveReadLaneFirst.f16(i32 118, half
+// RDLNFRST_INT: call i32 @dx.op.waveReadLaneFirst.i32(i32 118, i32
+// RDLNFRST_UINT: call i32 @dx.op.waveReadLaneFirst.i32(i32 118, i32
+// RDLNFRST_BOOL: call i1 @dx.op.waveReadLaneFirst.i1(i32 118, i1
+// RDLNFRST_BOOL_VEC3: call i1 @dx.op.waveReadLaneFirst.i1(i32 118, i1
+// RDLNFRST_BOOL_VEC3: call i1 @dx.op.waveReadLaneFirst.i1(i32 118, i1
+// RDLNFRST_BOOL_VEC3: call i1 @dx.op.waveReadLaneFirst.i1(i32 118, i1
+// RDLNFRST_MAT: call i32 @dx.op.waveReadLaneFirst.i32(i32 118, i32
+RET_TYPE test_wavereadlanefirst() : OUT
+{
+  return WaveReadLaneFirst(expr);
+}

+ 23 - 19
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/wave.hlsl

@@ -1,25 +1,26 @@
 // RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
 
 // CHECK: Wave level operations
-// CHECK: waveIsFirstLane
-// CHECK: waveGetLaneIndex
-// CHECK: waveGetLaneCount
-// CHECK: waveAnyTrue
-// CHECK: waveAllTrue
-// CHECK: waveActiveAllEqual
-// CHECK: waveActiveBallot
-// CHECK: waveReadLaneAt
-// CHECK: waveReadLaneFirst
-// CHECK: waveActiveOp
-// CHECK: waveActiveOp
-// CHECK: waveActiveBit
-// CHECK: waveActiveBit
-// CHECK: waveActiveBit
-// CHECK: waveActiveOp
-// CHECK: waveActiveOp
-// CHECK: quadReadLaneAt
-// CHECK: quadOp
-// CHECK: quadOp
+// CHECK: call i1 @dx.op.waveIsFirstLane(i32 110
+// CHECK: call i32 @dx.op.waveGetLaneIndex(i32 111
+// CHECK: call i32 @dx.op.waveGetLaneCount(i32 112
+// CHECK: call i1 @dx.op.waveAnyTrue(i32 113, i1
+// CHECK: call i1 @dx.op.waveAllTrue(i32 114, i1
+// CHECK: call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32
+// CHECK: call i1 @dx.op.waveActiveAllEqual.i1(i32 115, i1
+// CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1
+// CHECK: call float @dx.op.waveReadLaneAt.f32(i32 117, float
+// CHECK: call float @dx.op.waveReadLaneFirst.f32(i32 118, float
+// CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float
+// CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float
+// CHECK: call i32 @dx.op.waveActiveBit.i32(i32 120, i32
+// CHECK: call i32 @dx.op.waveActiveBit.i32(i32 120, i32
+// CHECK: call i32 @dx.op.waveActiveBit.i32(i32 120, i32
+// CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32
+// CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32
+// CHECK: call float @dx.op.quadReadLaneAt.f32(i32 122, float
+// CHECK: call i32 @dx.op.quadOp.i32(i32 123, i32
+// CHECK: call i32 @dx.op.quadOp.i32(i32 123, i32
 
 float4 main() : SV_TARGET {
   float f = 1;
@@ -39,6 +40,9 @@ float4 main() : SV_TARGET {
   if (WaveActiveAllEqual(WaveGetLaneIndex())) {
     f += 1;
   }
+  if(WaveActiveAllEqual((f<100))){
+    f += 1;
+  }
   uint4 val = WaveActiveBallot(true);
   if (val.x == 1) {
     f += 1;

+ 3 - 3
utils/hct/gen_intrin_main.txt

@@ -251,10 +251,10 @@ uint   [[rn]] WaveGetLaneIndex();
 uint   [[rn]] WaveGetLaneCount();
 bool   [[]] WaveActiveAnyTrue(in bool cond);
 bool   [[]] WaveActiveAllTrue(in bool cond);
-$match<1, 0> bool<> [[]] WaveActiveAllEqual(in numeric<> value);
+$match<1, 0> bool<> [[]] WaveActiveAllEqual(in any<> value);
 uint<4> [[]] WaveActiveBallot(in bool cond);
-$type1 [[]] WaveReadLaneAt(in numeric<> value, in uint lane);
-$type1 [[]] WaveReadLaneFirst(in numeric<> value);
+$type1 [[]] WaveReadLaneAt(in any<> value, in uint lane);
+$type1 [[]] WaveReadLaneFirst(in any<> value);
 uint   [[]] WaveActiveCountBits(in bool value);
 $type1 [[unsigned_op=WaveActiveUSum]] WaveActiveSum(in numeric<> value);
 $type1 [[unsigned_op=WaveActiveUProduct]] WaveActiveProduct(in numeric<> value);