Pārlūkot izejas kodu

[spirv] Handle all types for ternary op (#989)

HLSL ternary operator allows scalar, vector, and matrix arguments.
The SPIR-V CodeGen used to use OpSelect which only works on scalars and
vectors.
Ehsan 7 gadi atpakaļ
vecāks
revīzija
4fd4c66cb9

+ 49 - 16
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2135,6 +2135,8 @@ SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
 
 SpirvEvalInfo
 SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
+  const auto type = expr->getType();
+
   // Enhancement for special case when the ConditionalOperator return type is a
   // literal type. For example:
   //
@@ -2153,34 +2155,65 @@ SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
   TypeTranslator::LiteralTypeHint hint(typeTranslator);
   if (canBeRepresentedIn32Bits(expr->getTrueExpr()) &&
       canBeRepresentedIn32Bits(expr->getFalseExpr())) {
-    if (expr->getType()->isSpecificBuiltinType(BuiltinType::LitInt))
+    if (type->isSpecificBuiltinType(BuiltinType::LitInt))
       hint.setHint(astContext.IntTy);
-    else if (expr->getType()->isSpecificBuiltinType(BuiltinType::LitFloat))
+    else if (type->isSpecificBuiltinType(BuiltinType::LitFloat))
       hint.setHint(astContext.FloatTy);
   }
 
   // According to HLSL doc, all sides of the ?: expression are always
   // evaluated.
-  const uint32_t type = typeTranslator.translateType(expr->getType());
+  const uint32_t typeId = typeTranslator.translateType(type);
   uint32_t condition = doExpr(expr->getCond());
   const uint32_t trueBranch = doExpr(expr->getTrueExpr());
   const uint32_t falseBranch = doExpr(expr->getFalseExpr());
 
-  // The SPIR-V OpSelect instruction must have a selection argument that is the
-  // same size as the return type. If the return type is a vector, the selection
-  // must be a vector of booleans (one per output component).
-  uint32_t count = 0;
-  if (TypeTranslator::isVectorType(expr->getType(), nullptr, &count) &&
-      !TypeTranslator::isVectorType(expr->getCond()->getType())) {
-    const uint32_t condVecType =
-        theBuilder.getVecType(theBuilder.getBoolType(), count);
-    const llvm::SmallVector<uint32_t, 4> components(size_t(count), condition);
-    condition = theBuilder.createCompositeConstruct(condVecType, components);
+  // For cases where the return type is a scalar or a vector, we can use
+  // OpSelect to choose between the two. OpSelect's return type must be either
+  // scalar or vector.
+  if (TypeTranslator::isScalarType(type) ||
+      TypeTranslator::isVectorType(type)) {
+    // The SPIR-V OpSelect instruction must have a selection argument that is
+    // the same size as the return type. If the return type is a vector, the
+    // selection must be a vector of booleans (one per output component).
+    uint32_t count = 0;
+    if (TypeTranslator::isVectorType(expr->getType(), nullptr, &count) &&
+        !TypeTranslator::isVectorType(expr->getCond()->getType())) {
+      const uint32_t condVecType =
+          theBuilder.getVecType(theBuilder.getBoolType(), count);
+      const llvm::SmallVector<uint32_t, 4> components(size_t(count), condition);
+      condition = theBuilder.createCompositeConstruct(condVecType, components);
+    }
+
+    auto valueId =
+        theBuilder.createSelect(typeId, condition, trueBranch, falseBranch);
+    return SpirvEvalInfo(valueId).setRValue();
   }
 
-  auto valueId =
-      theBuilder.createSelect(type, condition, trueBranch, falseBranch);
-  return SpirvEvalInfo(valueId).setRValue();
+  // If we can't use OpSelect, we need to create if-else control flow.
+  const uint32_t tempVar = theBuilder.addFnVar(typeId, "temp.var.ternary");
+  const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
+  const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
+  const uint32_t elseBB = theBuilder.createBasicBlock("if.false");
+
+  // Create the branch instruction. This will end the current basic block.
+  theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB);
+  theBuilder.addSuccessor(thenBB);
+  theBuilder.addSuccessor(elseBB);
+  theBuilder.setMergeTarget(mergeBB);
+  // Handle the then branch
+  theBuilder.setInsertPoint(thenBB);
+  theBuilder.createStore(tempVar, trueBranch);
+  theBuilder.createBranch(mergeBB);
+  theBuilder.addSuccessor(mergeBB);
+  // Handle the else branch
+  theBuilder.setInsertPoint(elseBB);
+  theBuilder.createStore(tempVar, falseBranch);
+  theBuilder.createBranch(mergeBB);
+  theBuilder.addSuccessor(mergeBB);
+  // From now on, emit instructions into the merge block.
+  theBuilder.setInsertPoint(mergeBB);
+  return SpirvEvalInfo(theBuilder.createLoad(typeId, tempVar)).setRValue();
 }
 
 uint32_t SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(

+ 20 - 0
tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

@@ -5,6 +5,8 @@
 void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 
+// CHECK: %temp_var_ternary = OpVariable %_ptr_Function_mat2v3float Function
+
     bool b0;
     int m, n, o;
     // Plain assign (scalar)
@@ -88,4 +90,22 @@ void main() {
 // CHECK:      [[d_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
 // CHECK-NEXT:       {{%\d+}} = OpBitcast %uint [[d_int]]
     uint d = cond ? 1 : 0;
+
+    float2x3 e;
+    float2x3 f;
+// CHECK:     [[cond:%\d+]] = OpLoad %bool %cond
+// CHECK-NEXT:   [[e:%\d+]] = OpLoad %mat2v3float %e
+// CHECK-NEXT:   [[f:%\d+]] = OpLoad %mat2v3float %f
+// CHECK-NEXT:                OpSelectionMerge %if_merge None
+// CHECK-NEXT:                OpBranchConditional [[cond]] %if_true %if_false
+// CHECK-NEXT:     %if_true = OpLabel
+// CHECK-NEXT:                OpStore %temp_var_ternary [[e]]
+// CHECK-NEXT:                OpBranch %if_merge
+// CHECK-NEXT:    %if_false = OpLabel
+// CHECK-NEXT:                OpStore %temp_var_ternary [[f]]
+// CHECK-NEXT:                OpBranch %if_merge
+// CHECK-NEXT:    %if_merge = OpLabel
+// CHECK-NEXT:[[temp:%\d+]] = OpLoad %mat2v3float %temp_var_ternary
+// CHECK-NEXT:                OpStore %g [[temp]]
+    float2x3 g = cond ? e : f;
 }