浏览代码

[spirv] Support boolean matrix as ternary condition. (#3096)

Ehsan 5 年之前
父节点
当前提交
2667e538d6
共有 2 个文件被更改,包括 104 次插入24 次删除
  1. 45 3
      tools/clang/lib/SPIRV/SpirvEmitter.cpp
  2. 59 21
      tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

+ 45 - 3
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -2660,14 +2660,56 @@ SpirvInstruction *
 SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) {
 SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) {
   const auto type = expr->getType();
   const auto type = expr->getType();
   const SourceLocation loc = expr->getExprLoc();
   const SourceLocation loc = expr->getExprLoc();
+  const Expr *cond = expr->getCond();
+  const Expr *falseExpr = expr->getFalseExpr();
+  const Expr *trueExpr = expr->getTrueExpr();
 
 
   // According to HLSL doc, all sides of the ?: expression are always evaluated.
   // According to HLSL doc, all sides of the ?: expression are always evaluated.
 
 
+  // Corner-case: In HLSL, the condition of the ternary operator can be a
+  // matrix of booleans which results in selecting between components of two
+  // matrices. However, a matrix of booleans is not a valid type in SPIR-V.
+  // If the AST has inserted a splat of a scalar/vector to a matrix, we can just
+  // use that scalar/vector as an if-clause condition.
+  if (auto *cast = dyn_cast<ImplicitCastExpr>(cond))
+    if (cast->getCastKind() == CK_HLSLMatrixSplat)
+      cond = cast->getSubExpr();
+
   // If we are selecting between two SampleState objects, none of the three
   // If we are selecting between two SampleState objects, none of the three
   // operands has a LValueToRValue implicit cast.
   // operands has a LValueToRValue implicit cast.
-  auto *condition = loadIfGLValue(expr->getCond());
-  auto *trueBranch = loadIfGLValue(expr->getTrueExpr());
-  auto *falseBranch = loadIfGLValue(expr->getFalseExpr());
+  auto *condition = loadIfGLValue(cond);
+  auto *trueBranch = loadIfGLValue(trueExpr);
+  auto *falseBranch = loadIfGLValue(falseExpr);
+
+  // Corner-case: In HLSL, the condition of the ternary operator can be a
+  // matrix of booleans which results in selecting between components of two
+  // matrices. However, a matrix of booleans is not a valid type in SPIR-V.
+  // Therefore, we need to perform OpSelect for each row of the matrix.
+  {
+    QualType condElemType = {}, elemType = {};
+    uint32_t rowCount = 0, colCount = 0;
+    if (isMxNMatrix(type, &elemType, &rowCount, &colCount) &&
+        isMxNMatrix(cond->getType(), &condElemType) &&
+        condElemType->isBooleanType()) {
+      const auto rowType = astContext.getExtVectorType(elemType, colCount);
+      const auto condRowType =
+          astContext.getExtVectorType(condElemType, colCount);
+      llvm::SmallVector<SpirvInstruction *, 4> rows;
+      for (uint32_t i = 0; i < rowCount; ++i) {
+        auto *condRow =
+            spvBuilder.createCompositeExtract(condRowType, condition, {i}, loc);
+        auto *trueRow =
+            spvBuilder.createCompositeExtract(rowType, trueBranch, {i}, loc);
+        auto *falseRow =
+            spvBuilder.createCompositeExtract(rowType, falseBranch, {i}, loc);
+        rows.push_back(
+            spvBuilder.createSelect(rowType, condRow, trueRow, falseRow, loc));
+      }
+      auto *result = spvBuilder.createCompositeConstruct(type, rows, loc);
+      result->setRValue();
+      return result;
+    }
+  }
 
 
   // For cases where the return type is a scalar or a vector, we can use
   // For cases where the return type is a scalar or a vector, we can use
   // OpSelect to choose between the two. OpSelect's return type must be either
   // OpSelect to choose between the two. OpSelect's return type must be either

+ 59 - 21
tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

@@ -14,6 +14,8 @@ void main() {
   // CHECK-LABEL: %bb_entry = OpLabel
   // CHECK-LABEL: %bb_entry = OpLabel
 
 
   // CHECK: %temp_var_ternary = OpVariable %_ptr_Function_mat2v3float Function
   // CHECK: %temp_var_ternary = OpVariable %_ptr_Function_mat2v3float Function
+  // CHECK: %temp_var_ternary_0 = OpVariable %_ptr_Function_mat2v3float Function
+  // CHECK: %temp_var_ternary_1 = OpVariable %_ptr_Function_type_sampler Function
 
 
   bool b0;
   bool b0;
   int m, n, o;
   int m, n, o;
@@ -46,9 +48,11 @@ void main() {
   // Try condition with various type.
   // Try condition with various type.
   // Note: the SPIR-V OpSelect selection argument must be the same size as the return type.
   // Note: the SPIR-V OpSelect selection argument must be the same size as the return type.
   int3 u, v, w;
   int3 u, v, w;
+  float2x3 umat, vmat, wmat;
   bool cond;
   bool cond;
   bool3 cond3;
   bool3 cond3;
   float floatCond;
   float floatCond;
+  int intCond;
   int3 int3Cond;
   int3 int3Cond;
 
 
   // CHECK:      [[cond3:%\d+]] = OpLoad %v3bool %cond3
   // CHECK:      [[cond3:%\d+]] = OpLoad %v3bool %cond3
@@ -79,6 +83,23 @@ void main() {
   // CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
   // CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
   w = int3Cond ? u : v;
   w = int3Cond ? u : v;
 
 
+  // CHECK:       [[intCond:%\d+]] = OpLoad %int %intCond
+  // CHECK-NEXT: [[boolCond:%\d+]] = OpINotEqual %bool [[intCond]] %int_0
+  // CHECK-NEXT:     [[umat:%\d+]] = OpLoad %mat2v3float %umat
+  // CHECK-NEXT:     [[vmat:%\d+]] = OpLoad %mat2v3float %vmat
+  // CHECK-NEXT:                     OpSelectionMerge %if_merge None
+  // CHECK-NEXT:                     OpBranchConditional [[boolCond]] %if_true %if_false
+  // CHECK-NEXT:          %if_true = OpLabel
+  // CHECK-NEXT:                     OpStore %temp_var_ternary [[umat]]
+  // CHECK-NEXT:                     OpBranch %if_merge
+  // CHECK-NEXT:         %if_false = OpLabel
+  // CHECK-NEXT:                     OpStore %temp_var_ternary [[vmat]]
+  // CHECK-NEXT:                     OpBranch %if_merge
+  // CHECK-NEXT:         %if_merge = OpLabel
+  // CHECK-NEXT:  [[tempVar:%\d+]] = OpLoad %mat2v3float %temp_var_ternary
+  // CHECK-NEXT:                     OpStore %wmat [[tempVar]]
+  wmat = intCond ? umat : vmat;
+
   // Make sure literal types are handled correctly in ternary ops
   // Make sure literal types are handled correctly in ternary ops
 
 
   // CHECK: [[b_float:%\d+]] = OpSelect %float {{%\d+}} %float_1_5 %float_2_5
   // CHECK: [[b_float:%\d+]] = OpSelect %float {{%\d+}} %float_1_5 %float_2_5
@@ -100,16 +121,16 @@ void main() {
   // CHECK:     [[cond:%\d+]] = OpLoad %bool %cond
   // CHECK:     [[cond:%\d+]] = OpLoad %bool %cond
   // CHECK-NEXT:   [[e:%\d+]] = OpLoad %mat2v3float %e
   // CHECK-NEXT:   [[e:%\d+]] = OpLoad %mat2v3float %e
   // CHECK-NEXT:   [[f:%\d+]] = OpLoad %mat2v3float %f
   // CHECK-NEXT:   [[f:%\d+]] = OpLoad %mat2v3float %f
-  // CHECK-NEXT:                OpSelectionMerge %if_merge None
-  // CHECK-NEXT:                OpBranchConditional [[cond]] %if_true %if_false
-  // CHECK-NEXT:     %if_true = OpLabel
-  // CHECK-NEXT:                OpStore %temp_var_ternary [[e]]
-  // CHECK-NEXT:                OpBranch %if_merge
-  // CHECK-NEXT:    %if_false = OpLabel
-  // CHECK-NEXT:                OpStore %temp_var_ternary [[f]]
-  // CHECK-NEXT:                OpBranch %if_merge
-  // CHECK-NEXT:    %if_merge = OpLabel
-  // CHECK-NEXT:[[temp:%\d+]] = OpLoad %mat2v3float %temp_var_ternary
+  // CHECK-NEXT:                OpSelectionMerge %if_merge_0 None
+  // CHECK-NEXT:                OpBranchConditional [[cond]] %if_true_0 %if_false_0
+  // CHECK-NEXT:   %if_true_0 = OpLabel
+  // CHECK-NEXT:                OpStore %temp_var_ternary_0 [[e]]
+  // CHECK-NEXT:                OpBranch %if_merge_0
+  // CHECK-NEXT:  %if_false_0 = OpLabel
+  // CHECK-NEXT:                OpStore %temp_var_ternary_0 [[f]]
+  // CHECK-NEXT:                OpBranch %if_merge_0
+  // CHECK-NEXT:  %if_merge_0 = OpLabel
+  // CHECK-NEXT:[[temp:%\d+]] = OpLoad %mat2v3float %temp_var_ternary_0
   // CHECK-NEXT:                OpStore %g [[temp]]
   // CHECK-NEXT:                OpStore %g [[temp]]
   float2x3 g = cond ? e : f;
   float2x3 g = cond ? e : f;
 
 
@@ -139,26 +160,43 @@ void main() {
   // CHECK:      [[cond:%\d+]] = OpLoad %bool %cond
   // CHECK:      [[cond:%\d+]] = OpLoad %bool %cond
   // CHECK-NEXT: [[gSS1:%\d+]] = OpLoad %type_sampler %gSS1
   // CHECK-NEXT: [[gSS1:%\d+]] = OpLoad %type_sampler %gSS1
   // CHECK-NEXT: [[gSS2:%\d+]] = OpLoad %type_sampler %gSS2
   // CHECK-NEXT: [[gSS2:%\d+]] = OpLoad %type_sampler %gSS2
-  // CHECK-NEXT:                 OpSelectionMerge %if_merge_0 None
-  // CHECK-NEXT:                 OpBranchConditional [[cond]] %if_true_0 %if_false_0
-  // CHECK-NEXT:    %if_true_0 = OpLabel
-  // CHECK-NEXT:                 OpStore %temp_var_ternary_0 [[gSS1]]
-  // CHECK-NEXT:                 OpBranch %if_merge_0
-  // CHECK-NEXT:   %if_false_0 = OpLabel
-  // CHECK-NEXT:                 OpStore %temp_var_ternary_0 [[gSS2]]
-  // CHECK-NEXT:                 OpBranch %if_merge_0
-  // CHECK-NEXT:   %if_merge_0 = OpLabel
-  // CHECK-NEXT:   [[ss:%\d+]] = OpLoad %type_sampler %temp_var_ternary_0
+  // CHECK-NEXT:                 OpSelectionMerge %if_merge_1 None
+  // CHECK-NEXT:                 OpBranchConditional [[cond]] %if_true_1 %if_false_1
+  // CHECK-NEXT:    %if_true_1 = OpLabel
+  // CHECK-NEXT:                 OpStore %temp_var_ternary_1 [[gSS1]]
+  // CHECK-NEXT:                 OpBranch %if_merge_1
+  // CHECK-NEXT:   %if_false_1 = OpLabel
+  // CHECK-NEXT:                 OpStore %temp_var_ternary_1 [[gSS2]]
+  // CHECK-NEXT:                 OpBranch %if_merge_1
+  // CHECK-NEXT:   %if_merge_1 = OpLabel
+  // CHECK-NEXT:   [[ss:%\d+]] = OpLoad %type_sampler %temp_var_ternary_1
   // CHECK-NEXT:      {{%\d+}} = OpSampledImage %type_sampled_image {{%\d+}} [[ss]]
   // CHECK-NEXT:      {{%\d+}} = OpSampledImage %type_sampled_image {{%\d+}} [[ss]]
   float4 l = gTex.Sample(cond ? gSS1 : gSS2, float2(1., 2.));
   float4 l = gTex.Sample(cond ? gSS1 : gSS2, float2(1., 2.));
 
 
   zoo();
   zoo();
+
+// CHECK:       [[cond2x3:%\d+]] = OpLoad %_arr_v3bool_uint_2 %cond2x3
+// CHECK-NEXT:  [[true2x3:%\d+]] = OpLoad %mat2v3float %true2x3
+// CHECK-NEXT: [[false2x3:%\d+]] = OpLoad %mat2v3float %false2x3
+// CHECK-NEXT:       [[c0:%\d+]] = OpCompositeExtract %v3bool [[cond2x3]] 0
+// CHECK-NEXT:       [[t0:%\d+]] = OpCompositeExtract %v3float [[true2x3]] 0
+// CHECK-NEXT:       [[f0:%\d+]] = OpCompositeExtract %v3float [[false2x3]] 0
+// CHECK-NEXT:       [[r0:%\d+]] = OpSelect %v3float [[c0]] [[t0]] [[f0]]
+// CHECK-NEXT:       [[c1:%\d+]] = OpCompositeExtract %v3bool [[cond2x3]] 1
+// CHECK-NEXT:       [[t1:%\d+]] = OpCompositeExtract %v3float [[true2x3]] 1
+// CHECK-NEXT:       [[f1:%\d+]] = OpCompositeExtract %v3float [[false2x3]] 1
+// CHECK-NEXT:       [[r1:%\d+]] = OpSelect %v3float [[c1]] [[t1]] [[f1]]
+// CHECK-NEXT:   [[result:%\d+]] = OpCompositeConstruct %mat2v3float [[r0]] [[r1]]
+// CHECK-NEXT:                     OpStore %result2x3 [[result]]
+  bool2x3 cond2x3;
+  float2x3 true2x3, false2x3;
+  float2x3 result2x3 = cond2x3 ? true2x3 : false2x3;
 }
 }
 
 
 //
 //
 // The literal integer type should be deduced from the function return type.
 // The literal integer type should be deduced from the function return type.
 //
 //
-// CHECK: OpSelect %uint %174 %uint_1 %uint_2
+// CHECK: OpSelect %uint {{%\d+}} %uint_1 %uint_2
 uint zoo() {
 uint zoo() {
   bool cond;
   bool cond;
   return cond ? 1 : 2;
   return cond ? 1 : 2;