Browse Source

[SPIR-V] Add support for logical operator intrinsics (#4674)

* [SPIR-V] Add support for logical operator intrinsics

Implemented `and`, `or`, and `select`. I used the existing tests for the
old behavior of `&&`, `||`, and `?:` in order to verify that they behave
as expected (with the caveat of #4673, but that has the same behavior
for DXIL as for SPIR-V).

Fixes #4148

* Clean up

* Remove unused variable
Cassandra Beckley 3 năm trước cách đây
mục cha
commit
009e7f1070

+ 25 - 13
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1058,7 +1058,10 @@ SpirvInstruction *SpirvEmitter::doExpr(const Expr *expr,
     if (getCompilerInstance().getLangOpts().EnableShortCircuit) {
       result = doShortCircuitedConditionalOperator(condExpr);
     } else {
-      result = doConditionalOperator(condExpr);
+      const Expr *cond = condExpr->getCond();
+      const Expr *falseExpr = condExpr->getFalseExpr();
+      const Expr *trueExpr = condExpr->getTrueExpr();
+      result = doConditional(condExpr, cond, falseExpr, trueExpr);
     }
   } else if (const auto *defaultArgExpr = dyn_cast<CXXDefaultArgExpr>(expr)) {
     if (defaultArgExpr->getParam()->hasUninstantiatedDefaultArg()) {
@@ -3380,14 +3383,14 @@ SpirvInstruction *SpirvEmitter::doShortCircuitedConditionalOperator(
   return result;
 }
 
-SpirvInstruction *
-SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) {
+SpirvInstruction
+*SpirvEmitter::doConditional(const Expr *expr,
+                             const Expr *cond,
+                             const Expr *falseExpr,
+                             const Expr *trueExpr) {
   const auto type = expr->getType();
   const SourceLocation loc = expr->getExprLoc();
   const SourceRange range = expr->getSourceRange();
-  const Expr *cond = expr->getCond();
-  const Expr *falseExpr = expr->getFalseExpr();
-  const Expr *trueExpr = expr->getTrueExpr();
 
   // Corner-case: In HLSL, the condition of the ternary operator can be a
   // matrix of booleans which results in selecting between components of two
@@ -3444,12 +3447,12 @@ SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) {
     // selection must be a vector of booleans (one per output component).
     uint32_t count = 0;
     if (isVectorType(expr->getType(), nullptr, &count) &&
-        !isVectorType(expr->getCond()->getType())) {
+        !isVectorType(cond->getType())) {
       const llvm::SmallVector<SpirvInstruction *, 4> components(size_t(count),
                                                                 condition);
       condition = spvBuilder.createCompositeConstruct(
           astContext.getExtVectorType(astContext.BoolTy, count), components,
-          expr->getCond()->getLocEnd());
+          cond->getLocEnd());
     }
 
     auto *value = spvBuilder.createSelect(type, condition, trueBranch,
@@ -3475,21 +3478,21 @@ SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) {
 
   // Create the branch instruction. This will end the current basic block.
   spvBuilder.createConditionalBranch(condition, thenBB, elseBB,
-                                     expr->getCond()->getLocEnd(), mergeBB);
+                                     cond->getLocEnd(), mergeBB);
   spvBuilder.addSuccessor(thenBB);
   spvBuilder.addSuccessor(elseBB);
   spvBuilder.setMergeTarget(mergeBB);
   // Handle the then branch
   spvBuilder.setInsertPoint(thenBB);
   spvBuilder.createStore(tempVar, trueBranch,
-                         expr->getTrueExpr()->getLocStart(), range);
-  spvBuilder.createBranch(mergeBB, expr->getTrueExpr()->getLocEnd());
+                         trueExpr->getLocStart(), range);
+  spvBuilder.createBranch(mergeBB, trueExpr->getLocEnd());
   spvBuilder.addSuccessor(mergeBB);
   // Handle the else branch
   spvBuilder.setInsertPoint(elseBB);
   spvBuilder.createStore(tempVar, falseBranch,
-                         expr->getFalseExpr()->getLocStart(), range);
-  spvBuilder.createBranch(mergeBB, expr->getFalseExpr()->getLocEnd());
+                         falseExpr->getLocStart(), range);
+  spvBuilder.createBranch(mergeBB, falseExpr->getLocEnd());
   spvBuilder.addSuccessor(mergeBB);
   // From now on, emit instructions into the merge block.
   spvBuilder.setInsertPoint(mergeBB);
@@ -8249,6 +8252,13 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   case hlsl::IntrinsicOp::IOP_SetMeshOutputCounts: {
     processMeshOutputCounts(callExpr);
     break;
+  }
+  case hlsl::IntrinsicOp::IOP_select: {
+    const Expr *cond = callExpr->getArg(0);
+    const Expr *trueExpr = callExpr->getArg(1);
+    const Expr *falseExpr = callExpr->getArg(2);
+    retVal = doConditional(callExpr, cond, falseExpr, trueExpr);
+    break;
   }
     INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
     INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
@@ -8262,6 +8272,8 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
     INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
     INTRINSIC_SPIRV_OP_CASE(reversebits, BitReverse, false);
+    INTRINSIC_SPIRV_OP_CASE(and, LogicalAnd, false);
+    INTRINSIC_SPIRV_OP_CASE(or, LogicalOr, false);
     INTRINSIC_OP_CASE(round, RoundEven, true);
     INTRINSIC_OP_CASE(uabs, SAbs, true);
     INTRINSIC_OP_CASE_INT_FLOAT(abs, SAbs, FAbs, true);

+ 5 - 1
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -122,6 +122,10 @@ private:
                                SourceRange rangeOverride = {});
   SpirvInstruction *doCompoundAssignOperator(const CompoundAssignOperator *);
   SpirvInstruction *doConditionalOperator(const ConditionalOperator *expr);
+  SpirvInstruction *doConditional(const Expr *expr,
+                                  const Expr *cond,
+                                  const Expr *falseExpr,
+                                  const Expr *trueExpr);
   SpirvInstruction *
   doShortCircuitedConditionalOperator(const ConditionalOperator *expr);
   SpirvInstruction *doCXXMemberCallExpr(const CXXMemberCallExpr *expr);
@@ -624,7 +628,7 @@ private:
       SourceLocation loc);
   /// Process spirv intrinsic instruction
   SpirvInstruction *processSpvIntrinsicCallExpr(const CallExpr *expr);
-  
+
   /// Process spirv intrinsic type definition
   SpirvInstruction *processSpvIntrinsicTypeDef(const CallExpr *expr);
 

+ 35 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.and.hlsl

@@ -0,0 +1,35 @@
+// RUN: %dxc -T ps_6_0 -E main -HV 2021
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    bool a, b, c;
+    // Plain assign (scalar)
+// CHECK:      [[a0:%\d+]] = OpLoad %bool %a
+// CHECK-NEXT: [[b0:%\d+]] = OpLoad %bool %b
+// CHECK-NEXT: [[and0:%\d+]] = OpLogicalAnd %bool [[a0]] [[b0]]
+// CHECK-NEXT: OpStore %c [[and0]]
+    c = and(a, b);
+
+    bool1 i, j, k;
+    bool3 o, p, q;
+    // Plain assign (vector)
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %bool %i
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %bool %j
+// CHECK-NEXT: [[and1:%\d+]] = OpLogicalAnd %bool [[i0]] [[j0]]
+// CHECK-NEXT: OpStore %k [[and1]]
+// CHECK-NEXT: [[o0:%\d+]] = OpLoad %v3bool %o
+// CHECK-NEXT: [[p0:%\d+]] = OpLoad %v3bool %p
+// CHECK-NEXT: [[and2:%\d+]] = OpLogicalAnd %v3bool [[o0]] [[p0]]
+// CHECK-NEXT: OpStore %q [[and2]]
+    k = and(i, j);
+    q = and(o, p);
+
+// The result of '&&' could be 'const bool'. In such cases, make sure
+// the result type is correct.
+// CHECK:        [[a1:%\d+]] = OpLoad %bool %a
+// CHECK-NEXT:   [[b1:%\d+]] = OpLoad %bool %b
+// CHECK-NEXT: [[and3:%\d+]] = OpLogicalAnd %bool [[a1]] [[b1]]
+// CHECK-NEXT:      {{%\d+}} = OpCompositeConstruct %v2bool [[and3]] %true
+    bool2 t = bool2(and(a, b), true);
+}

+ 27 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.or.hlsl

@@ -0,0 +1,27 @@
+// RUN: %dxc -T ps_6_0 -E main -HV 2021
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    bool a, b, c;
+    // Plain assign (scalar)
+// CHECK:      [[a0:%\d+]] = OpLoad %bool %a
+// CHECK-NEXT: [[b0:%\d+]] = OpLoad %bool %b
+// CHECK-NEXT: [[or0:%\d+]] = OpLogicalOr %bool [[a0]] [[b0]]
+// CHECK-NEXT: OpStore %c [[or0]]
+    c = or(a, b);
+
+    bool1 i, j, k;
+    bool3 o, p, q;
+    // Plain assign (vector)
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %bool %i
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %bool %j
+// CHECK-NEXT: [[or1:%\d+]] = OpLogicalOr %bool [[i0]] [[j0]]
+// CHECK-NEXT: OpStore %k [[or1]]
+// CHECK-NEXT: [[o0:%\d+]] = OpLoad %v3bool %o
+// CHECK-NEXT: [[p0:%\d+]] = OpLoad %v3bool %p
+// CHECK-NEXT: [[or2:%\d+]] = OpLogicalOr %v3bool [[o0]] [[p0]]
+// CHECK-NEXT: OpStore %q [[or2]]
+    k = or(i, j);
+    q = or(o, p);
+}

+ 177 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.select.hlsl

@@ -0,0 +1,177 @@
+// RUN: %dxc -T ps_6_0 -E main -HV 2021
+
+// CHECK: [[v3i0:%\d+]] = OpConstantComposite %v3int %int_0 %int_0 %int_0
+
+uint foo() { return 1; }
+float bar() { return 3.0; }
+uint zoo();
+
+void main() {
+  // CHECK-LABEL: %bb_entry = OpLabel
+
+  // CHECK: %temp_var_ternary = OpVariable %_ptr_Function_mat2v3float Function
+  // CHECK: %temp_var_ternary_0 = OpVariable %_ptr_Function_mat2v3float Function
+
+  bool b0;
+  int m, n, o;
+  // Plain assign (scalar)
+  // CHECK:      [[b0:%\d+]] = OpLoad %bool %b0
+  // CHECK-NEXT: [[m0:%\d+]] = OpLoad %int %m
+  // CHECK-NEXT: [[n0:%\d+]] = OpLoad %int %n
+  // CHECK-NEXT: [[s0:%\d+]] = OpSelect %int [[b0]] [[m0]] [[n0]]
+  // CHECK-NEXT: OpStore %o [[s0]]
+  o = select(b0, m, n);
+
+  bool1 b1;
+  bool3 b3;
+  uint1 p, q, r;
+  float3 x, y, z;
+  // Plain assign (vector)
+  // CHECK-NEXT: [[b1:%\d+]] = OpLoad %bool %b1
+  // CHECK-NEXT: [[p0:%\d+]] = OpLoad %uint %p
+  // CHECK-NEXT: [[q0:%\d+]] = OpLoad %uint %q
+  // CHECK-NEXT: [[s1:%\d+]] = OpSelect %uint [[b1]] [[p0]] [[q0]]
+  // CHECK-NEXT: OpStore %r [[s1]]
+  r = select(b1, p, q);
+  // CHECK-NEXT: [[b3:%\d+]] = OpLoad %v3bool %b3
+  // CHECK-NEXT: [[x0:%\d+]] = OpLoad %v3float %x
+  // CHECK-NEXT: [[y0:%\d+]] = OpLoad %v3float %y
+  // CHECK-NEXT: [[s2:%\d+]] = OpSelect %v3float [[b3]] [[x0]] [[y0]]
+  // CHECK-NEXT: OpStore %z [[s2]]
+  z = select(b3, x, y);
+
+  // Try condition with various type.
+  // Note: the SPIR-V OpSelect selection argument must be the same size as the return type.
+  int3 u, v, w;
+  float2x3 umat, vmat, wmat;
+  bool cond;
+  bool3 cond3;
+  float floatCond;
+  int intCond;
+  int3 int3Cond;
+
+  // CHECK:      [[cond3:%\d+]] = OpLoad %v3bool %cond3
+  // CHECK-NEXT:     [[u:%\d+]] = OpLoad %v3int %u
+  // CHECK-NEXT:     [[v:%\d+]] = OpLoad %v3int %v
+  // CHECK-NEXT:       {{%\d+}} = OpSelect %v3int [[cond3]] [[u]] [[v]]
+  w = select(cond3, u, v);
+
+  // CHECK:       [[cond:%\d+]] = OpLoad %bool %cond
+  // CHECK-NEXT: [[splat:%\d+]] = OpCompositeConstruct %v3bool [[cond]] [[cond]] [[cond]]
+  // CHECK-NEXT:     [[u:%\d+]] = OpLoad %v3int %u
+  // CHECK-NEXT:     [[v:%\d+]] = OpLoad %v3int %v
+  // CHECK-NEXT:       {{%\d+}} = OpSelect %v3int [[splat]] [[u]] [[v]]
+  w = select(cond, u, v);
+
+  // CHECK:      [[floatCond:%\d+]] = OpLoad %float %floatCond
+  // CHECK-NEXT:  [[boolCond:%\d+]] = OpFOrdNotEqual %bool [[floatCond]] %float_0
+  // CHECK-NEXT: [[bool3Cond:%\d+]] = OpCompositeConstruct %v3bool [[boolCond]] [[boolCond]] [[boolCond]]
+  // CHECK-NEXT:         [[u:%\d+]] = OpLoad %v3int %u
+  // CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
+  // CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
+  w = select(floatCond, u, v);
+
+  // CHECK:       [[int3Cond:%\d+]] = OpLoad %v3int %int3Cond
+  // CHECK-NEXT: [[bool3Cond:%\d+]] = OpINotEqual %v3bool [[int3Cond]] [[v3i0]]
+  // CHECK-NEXT:         [[u:%\d+]] = OpLoad %v3int %u
+  // CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
+  // CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
+  w = select(int3Cond, u, v);
+
+  // CHECK:       [[intCond:%\d+]] = OpLoad %int %intCond
+  // CHECK-NEXT: [[boolCond:%\d+]] = OpINotEqual %bool [[intCond]] %int_0
+  // CHECK-NEXT:     [[umat:%\d+]] = OpLoad %mat2v3float %umat
+  // CHECK-NEXT:     [[vmat:%\d+]] = OpLoad %mat2v3float %vmat
+  // CHECK-NEXT:                     OpSelectionMerge %if_merge None
+  // CHECK-NEXT:                     OpBranchConditional [[boolCond]] %if_true %if_false
+  // CHECK-NEXT:          %if_true = OpLabel
+  // CHECK-NEXT:                     OpStore %temp_var_ternary [[umat]]
+  // CHECK-NEXT:                     OpBranch %if_merge
+  // CHECK-NEXT:         %if_false = OpLabel
+  // CHECK-NEXT:                     OpStore %temp_var_ternary [[vmat]]
+  // CHECK-NEXT:                     OpBranch %if_merge
+  // CHECK-NEXT:         %if_merge = OpLabel
+  // CHECK-NEXT:  [[tempVar:%\d+]] = OpLoad %mat2v3float %temp_var_ternary
+  // CHECK-NEXT:                     OpStore %wmat [[tempVar]]
+  wmat = select(intCond, umat, vmat);
+
+  // Make sure literal types are handled correctly in ternary ops
+
+  // CHECK: [[b_float:%\d+]] = OpSelect %float {{%\d+}} %float_1_5 %float_2_5
+  // CHECK-NEXT:    {{%\d+}} = OpConvertFToS %int [[b_float]]
+  int b = select(cond, 1.5, 2.5);
+
+  // CHECK:      [[a_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
+  // CHECK-NEXT:       {{%\d+}} = OpConvertSToF %float [[a_int]]
+  float a = select(cond, 1, 0);
+
+  // CHECK:      [[c_long:%\d+]] = OpSelect %long {{%\d+}} %long_3000000000 %long_4000000000
+  double c = select(cond, 3000000000, 4000000000);
+
+  // CHECK:      [[d_int:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_0
+  uint d = select(cond, 1, 0);
+
+  float2x3 e;
+  float2x3 f;
+  // CHECK:     [[cond:%\d+]] = OpLoad %bool %cond
+  // CHECK-NEXT:   [[e:%\d+]] = OpLoad %mat2v3float %e
+  // CHECK-NEXT:   [[f:%\d+]] = OpLoad %mat2v3float %f
+  // CHECK-NEXT:                OpSelectionMerge %if_merge_0 None
+  // CHECK-NEXT:                OpBranchConditional [[cond]] %if_true_0 %if_false_0
+  // CHECK-NEXT:   %if_true_0 = OpLabel
+  // CHECK-NEXT:                OpStore %temp_var_ternary_0 [[e]]
+  // CHECK-NEXT:                OpBranch %if_merge_0
+  // CHECK-NEXT:  %if_false_0 = OpLabel
+  // CHECK-NEXT:                OpStore %temp_var_ternary_0 [[f]]
+  // CHECK-NEXT:                OpBranch %if_merge_0
+  // CHECK-NEXT:  %if_merge_0 = OpLabel
+  // CHECK-NEXT:[[temp:%\d+]] = OpLoad %mat2v3float %temp_var_ternary_0
+  // CHECK-NEXT:                OpStore %g [[temp]]
+  float2x3 g = select(cond, e, f);
+
+  // CHECK:      [[inner:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_2
+  // CHECK-NEXT:       {{%\d+}} = OpSelect %uint {{%\d+}} %uint_9 [[inner]]
+  uint h = select(cond, 9, select(cond, 1, 2));
+
+  //CHECK:      [[i_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
+  //CHECK-NEXT:       {{%\d+}} = OpINotEqual %bool [[i_int]] %int_0
+  bool i = select(cond, 1, 0);
+
+  // CHECK:     [[foo:%\d+]] = OpFunctionCall %uint %foo
+  // CHECKNEXT:     {{%\d+}} = OpSelect %uint {{%\d+}} %uint_3 [[foo]]
+  uint j = select(cond, 3, foo());
+
+  // CHECK:          [[bar:%\d+]] = OpFunctionCall %float %bar
+  // CHECK-NEXT: [[k_float:%\d+]] = OpSelect %float {{%\d+}} %float_4 [[bar]]
+  // CHECK-NEXT:         {{%\d+}} = OpConvertFToU %uint [[k_float]]
+  uint k = select(cond, 4, bar());
+
+  zoo();
+
+  // CHECK:       [[cond2x3:%\d+]] = OpLoad %_arr_v3bool_uint_2 %cond2x3
+  // CHECK-NEXT:  [[true2x3:%\d+]] = OpLoad %mat2v3float %true2x3
+  // CHECK-NEXT: [[false2x3:%\d+]] = OpLoad %mat2v3float %false2x3
+  // CHECK-NEXT:       [[c0:%\d+]] = OpCompositeExtract %v3bool [[cond2x3]] 0
+  // CHECK-NEXT:       [[t0:%\d+]] = OpCompositeExtract %v3float [[true2x3]] 0
+  // CHECK-NEXT:       [[f0:%\d+]] = OpCompositeExtract %v3float [[false2x3]] 0
+  // CHECK-NEXT:       [[r0:%\d+]] = OpSelect %v3float [[c0]] [[t0]] [[f0]]
+  // CHECK-NEXT:       [[c1:%\d+]] = OpCompositeExtract %v3bool [[cond2x3]] 1
+  // CHECK-NEXT:       [[t1:%\d+]] = OpCompositeExtract %v3float [[true2x3]] 1
+  // CHECK-NEXT:       [[f1:%\d+]] = OpCompositeExtract %v3float [[false2x3]] 1
+  // CHECK-NEXT:       [[r1:%\d+]] = OpSelect %v3float [[c1]] [[t1]] [[f1]]
+  // CHECK-NEXT:   [[result:%\d+]] = OpCompositeConstruct %mat2v3float [[r0]] [[r1]]
+  // CHECK-NEXT:                     OpStore %result2x3 [[result]]
+  bool2x3 cond2x3;
+  float2x3 true2x3, false2x3;
+  float2x3 result2x3 = select(cond2x3, true2x3, false2x3);
+}
+
+//
+// The literal integer type should be deduced from the function return type.
+//
+// CHECK: OpSelect %uint {{%\d+}} %uint_1 %uint_2
+uint zoo() {
+  bool cond;
+  return select(cond, 1, 2);
+}
+

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1231,6 +1231,7 @@ TEST_F(FileTest, IntrinsicsAllMemoryBarrier) {
 TEST_F(FileTest, IntrinsicsAllMemoryBarrierWithGroupSync) {
   runFileTest("intrinsics.allmemorybarrierwithgroupsync.hlsl");
 }
+TEST_F(FileTest, IntrinsicsAnd) { runFileTest("intrinsics.and.hlsl"); }
 TEST_F(FileTest, IntrinsicsDeviceMemoryBarrierWithGroupSync) {
   runFileTest("intrinsics.devicememorybarrierwithgroupsync.hlsl");
 }
@@ -1291,6 +1292,7 @@ TEST_F(FileTest, IntrinsicsMsad4) { runFileTest("intrinsics.msad4.hlsl"); }
 TEST_F(FileTest, IntrinsicsNormalize) {
   runFileTest("intrinsics.normalize.hlsl");
 }
+TEST_F(FileTest, IntrinsicsOr) { runFileTest("intrinsics.or.hlsl"); }
 TEST_F(FileTest, IntrinsicsPow) { runFileTest("intrinsics.pow.hlsl"); }
 TEST_F(FileTest, IntrinsicsRsqrt) { runFileTest("intrinsics.rsqrt.hlsl"); }
 TEST_F(FileTest, IntrinsicsFloatSign) {
@@ -1311,6 +1313,7 @@ TEST_F(FileTest, IntrinsicsSmoothStep) {
 }
 TEST_F(FileTest, IntrinsicsStep) { runFileTest("intrinsics.step.hlsl"); }
 TEST_F(FileTest, IntrinsicsSqrt) { runFileTest("intrinsics.sqrt.hlsl"); }
+TEST_F(FileTest, IntrinsicsSelect) { runFileTest("intrinsics.select.hlsl"); }
 TEST_F(FileTest, IntrinsicsTranspose) {
   runFileTest("intrinsics.transpose.hlsl");
 }