فهرست منبع

[SPIR-V] Add short-circuiting ternary operator for HLSL 2021 (#4672)

Daniele Vettorel 3 سال پیش
والد
کامیت
16ad9c07db

+ 58 - 3
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1054,7 +1054,12 @@ SpirvInstruction *SpirvEmitter::doExpr(const Expr *expr,
   } else if (const auto *subscriptExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
     result = doArraySubscriptExpr(subscriptExpr, range);
   } else if (const auto *condExpr = dyn_cast<ConditionalOperator>(expr)) {
-    result = doConditionalOperator(condExpr);
+    // Beginning with HLSL 2021, the ternary operator is short-circuited.
+    if (getCompilerInstance().getLangOpts().EnableShortCircuit) {
+      result = doShortCircuitedConditionalOperator(condExpr);
+    } else {
+      result = doConditionalOperator(condExpr);
+    }
   } else if (const auto *defaultArgExpr = dyn_cast<CXXDefaultArgExpr>(expr)) {
     if (defaultArgExpr->getParam()->hasUninstantiatedDefaultArg()) {
       auto defaultArg = defaultArgExpr->getParam()->getUninstantiatedDefaultArg();
@@ -3323,6 +3328,58 @@ SpirvEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
   return processAssignment(lhs, result, true, lhsPtr, expr->getSourceRange());
 }
 
+SpirvInstruction *SpirvEmitter::doShortCircuitedConditionalOperator(
+    const ConditionalOperator *expr) {
+  const auto type = expr->getType();
+  const SourceLocation loc = expr->getExprLoc();
+  const SourceRange range = expr->getSourceRange();
+  const Expr *cond = expr->getCond();
+  const Expr *falseExpr = expr->getFalseExpr();
+  const Expr *trueExpr = expr->getTrueExpr();
+
+  // Short-circuited operators can only be used with scalar conditions. This
+  // is checked earlier.
+  assert(cond->getType()->isScalarType());
+
+  auto *tempVar = spvBuilder.addFnVar(type, loc, "temp.var.ternary");
+  auto *thenBB = spvBuilder.createBasicBlock("ternary.lhs");
+  auto *elseBB = spvBuilder.createBasicBlock("ternary.rhs");
+  auto *mergeBB = spvBuilder.createBasicBlock("ternary.merge");
+
+  // Create the branch instruction. This will end the current basic block.
+  SpirvInstruction *condition = loadIfGLValue(cond);
+  condition = castToBool(condition, cond->getType(), astContext.BoolTy,
+                         cond->getLocEnd());
+  spvBuilder.createConditionalBranch(condition, thenBB, elseBB, loc, mergeBB);
+  spvBuilder.addSuccessor(thenBB);
+  spvBuilder.addSuccessor(elseBB);
+  spvBuilder.setMergeTarget(mergeBB);
+
+  // Handle the true case.
+  spvBuilder.setInsertPoint(thenBB);
+  SpirvInstruction *trueVal = loadIfGLValue(trueExpr);
+  trueVal = castToType(trueVal, trueExpr->getType(), type,
+                       trueExpr->getExprLoc(), range);
+  spvBuilder.createStore(tempVar, trueVal, trueExpr->getLocStart(), range);
+  spvBuilder.createBranch(mergeBB, trueExpr->getLocEnd());
+  spvBuilder.addSuccessor(mergeBB);
+
+  // Handle the false case.
+  spvBuilder.setInsertPoint(elseBB);
+  SpirvInstruction *falseVal = loadIfGLValue(falseExpr);
+  falseVal = castToType(falseVal, falseExpr->getType(), type,
+                        falseExpr->getExprLoc(), range);
+  spvBuilder.createStore(tempVar, falseVal, falseExpr->getLocStart(), range);
+  spvBuilder.createBranch(mergeBB, falseExpr->getLocEnd());
+  spvBuilder.addSuccessor(mergeBB);
+
+  // From now on, emit instructions into the merge block.
+  spvBuilder.setInsertPoint(mergeBB);
+  SpirvInstruction *result = spvBuilder.createLoad(type, tempVar, loc, range);
+  result->setRValue();
+  return result;
+}
+
 SpirvInstruction *
 SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) {
   const auto type = expr->getType();
@@ -3332,8 +3389,6 @@ SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) {
   const Expr *falseExpr = expr->getFalseExpr();
   const Expr *trueExpr = expr->getTrueExpr();
 
-  // According to HLSL doc, all sides of the ?: expression are always evaluated.
-
   // Corner-case: In HLSL, the condition of the ternary operator can be a
   // matrix of booleans which results in selecting between components of two
   // matrices. However, a matrix of booleans is not a valid type in SPIR-V.

+ 2 - 0
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -122,6 +122,8 @@ private:
                                SourceRange rangeOverride = {});
   SpirvInstruction *doCompoundAssignOperator(const CompoundAssignOperator *);
   SpirvInstruction *doConditionalOperator(const ConditionalOperator *expr);
+  SpirvInstruction *
+  doShortCircuitedConditionalOperator(const ConditionalOperator *expr);
   SpirvInstruction *doCXXMemberCallExpr(const CXXMemberCallExpr *expr);
   SpirvInstruction *doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr,
                                           SourceRange rangeOverride = {});

+ 24 - 0
tools/clang/test/CodeGenSPIRV/ternary-op.short-circuited-cond-op.hlsl

@@ -0,0 +1,24 @@
+// RUN: %dxc -T ps_6_0 -E main -HV 2021
+
+void main() {
+  // CHECK-LABEL: %bb_entry = OpLabel
+
+  bool b;
+  int m, n, o;
+  // CHECK:      %temp_var_ternary = OpVariable %_ptr_Function_int Function
+  // CHECK-NEXT: [[b:%\d+]] = OpLoad %bool %b
+  // CHECK-NEXT: OpSelectionMerge %ternary_merge None
+  // CHECK-NEXT: OpBranchConditional [[b]] %ternary_lhs %ternary_rhs
+  // CHECK-NEXT: %ternary_lhs = OpLabel
+  // CHECK-NEXT: [[m:%\d+]] = OpLoad %int %m
+  // CHECK-NEXT: OpStore %temp_var_ternary [[m]]
+  // CHECK-NEXT: OpBranch %ternary_merge
+  // CHECK-NEXT: %ternary_rhs = OpLabel
+  // CHECK-NEXT: [[n:%\d+]] = OpLoad %int %n
+  // CHECK-NEXT: OpStore %temp_var_ternary [[n]]
+  // CHECK-NEXT: OpBranch %ternary_merge
+  // CHECK-NEXT: %ternary_merge = OpLabel
+  // CHECK-NEXT: [[r:%\d+]] = OpLoad %int %temp_var_ternary
+  // CHECK-NEXT: OpStore %o [[r]]
+  o = b ? m : n;
+}

+ 5 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -395,6 +395,11 @@ TEST_F(FileTest, TernaryOpConditionalOp) {
   runFileTest("ternary-op.cond-op.hlsl");
 }
 
+// For short-circuited ternary operators (HLSL 2021)
+TEST_F(FileTest, TernaryOpShortCircuitedConditionalOp) {
+  runFileTest("ternary-op.short-circuited-cond-op.hlsl");
+}
+
 // For vector accessing/swizzling operators
 TEST_F(FileTest, OpVectorSwizzle) { runFileTest("op.vector.swizzle.hlsl"); }
 TEST_F(FileTest, OpVectorSwizzle1) {