Prechádzať zdrojové kódy

[spirv] Translate for statement (#470)

* Only the basic for statement are tested. More complicated cases
  yet to be implemented and tested.
* Also add support for less than and prefix increment operator.
Lei Zhang 8 rokov pred
rodič
commit
59b4535990

+ 8 - 3
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -98,10 +98,15 @@ public:
 
   // \brief Creates an unconditional branch to the given target label.
   void createBranch(uint32_t targetLabel);
-  // \brief Creates a conditional branch. The OpSelectionMerge instruction
-  // required for the conditional branch will also be created.
+
+  // \brief Creates a conditional branch. An OpSelectionMerge instruction
+  // will be created if mergeLabel is not 0 and continueLabel is 0.
+  // An OpLoopMerge instruction will also be created if both continueLabel
+  // and mergeLabel are not 0. For other cases, mergeLabel and continueLabel
+  // will be ignored.
   void createConditionalBranch(uint32_t condition, uint32_t trueLabel,
-                               uint32_t falseLabel, uint32_t mergeLabel);
+                               uint32_t falseLabel, uint32_t mergeLabel = 0,
+                               uint32_t continueLabel = 0);
 
   /// \brief Creates a return instruction.
   void createReturn();

+ 141 - 2
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -241,8 +241,10 @@ public:
       }
     } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
       doIfStmt(ifStmt);
+    } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
+      doForStmt(forStmt);
     } else if (const auto *nullStmt = dyn_cast<NullStmt>(stmt)) {
-      // We don't need to do anything for NullStmt
+      // For the null statement ";". We don't need to do anything.
     } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
       // All cases for expressions used as statements
       doExpr(expr);
@@ -372,12 +374,92 @@ public:
     theBuilder.setInsertPoint(mergeBB);
   }
 
+  void doForStmt(const ForStmt *forStmt) {
+    // for loops are composed of:
+    //   for (<init>; <check>; <continue>) <body>
+    //
+    // To translate a for loop, we'll need to emit all <init> statements
+    // in the current basic block, and then have separate basic blocks for
+    // <check>, <continue>, and <body>. Besides, since SPIR-V requires
+    // structured control flow, we need two more basic blocks, <header>
+    // and <merge>. <header> is the block before control flow diverges,
+    // while <merge> is the block where control flow subsequently converges.
+    // The <check> block can take the responsibility of the <header> block.
+    // The final CFG should normally be like the following. Exceptions will
+    // occur with non-local exits like loop breaks or early returns.
+    //             +--------+
+    //             |  init  |
+    //             +--------+
+    //                 |
+    //                 v
+    //            +----------+
+    //            |  header  | <---------------+
+    //            | (check)  |                 |
+    //            +----------+                 |
+    //                 |                       |
+    //         +-------+-------+               |
+    //         | false         | true          |
+    //         |               v               |
+    //         |            +------+     +----------+
+    //         |            | body | --> | continue |
+    //         v            +------+     +----------+
+    //     +-------+
+    //     | merge |
+    //     +-------+
+    //
+    // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
+
+    // Create basic blocks
+    const uint32_t checkBB = theBuilder.createBasicBlock("for.check");
+    const uint32_t bodyBB = theBuilder.createBasicBlock("for.body");
+    const uint32_t continueBB = theBuilder.createBasicBlock("for.continue");
+    const uint32_t mergeBB = theBuilder.createBasicBlock("for.merge");
+
+    // Process the <init> block
+    if (const Stmt *initStmt = forStmt->getInit()) {
+      doStmt(initStmt);
+    }
+    theBuilder.createBranch(checkBB);
+
+    // Process the <check> block
+    theBuilder.setInsertPoint(checkBB);
+    uint32_t condition;
+    if (const Expr *check = forStmt->getCond()) {
+      condition = doExpr(check);
+    } else {
+      condition = theBuilder.getConstantBool(true);
+    }
+    theBuilder.createConditionalBranch(condition, bodyBB,
+                                       /*false branch*/ mergeBB,
+                                       /*merge*/ mergeBB, continueBB);
+
+    // Process the <body> block
+    theBuilder.setInsertPoint(bodyBB);
+    if (const Stmt *body = forStmt->getBody()) {
+      doStmt(body);
+    }
+    theBuilder.createBranch(continueBB);
+
+    // Process the <continue> block
+    theBuilder.setInsertPoint(continueBB);
+    if (const Expr *cont = forStmt->getInc()) {
+      doExpr(cont);
+    }
+    theBuilder.createBranch(checkBB); // <continue> should jump back to header
+
+    // Set insertion point to the <merge> block for subsequent statements
+    theBuilder.setInsertPoint(mergeBB);
+  }
+
   uint32_t doExpr(const Expr *expr) {
     if (auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
       // Returns the <result-id> of the referenced Decl.
       const NamedDecl *referredDecl = delRefExpr->getFoundDecl();
       assert(referredDecl && "found non-NamedDecl referenced");
       return declIdMapper.getDeclResultId(referredDecl);
+    } else if (auto *parenExpr = dyn_cast<ParenExpr>(expr)) {
+      // Just need to return what's inside the parentheses.
+      return doExpr(parenExpr->getSubExpr());
     } else if (auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
       const uint32_t base = doExpr(memberExpr->getBase());
       auto *memberDecl = memberExpr->getMemberDecl();
@@ -428,6 +510,8 @@ public:
       return translateAPFloat(floatLiteral->getValue(), expr->getType());
     } else if (auto *binOp = dyn_cast<BinaryOperator>(expr)) {
       return doBinaryOperator(binOp);
+    } else if (auto *unaryOp = dyn_cast<UnaryOperator>(expr)) {
+      return doUnaryOperator(unaryOp);
     }
 
     emitError("Expr '%0' is not supported yet.") << expr->getStmtClassName();
@@ -459,7 +543,8 @@ public:
     case BO_Sub:
     case BO_Mul:
     case BO_Div:
-    case BO_Rem: {
+    case BO_Rem:
+    case BO_LT: {
       const spv::Op spvOp = translateOp(opcode, elemType);
       return theBuilder.createBinaryOp(spvOp, typeId, lhs, rhs);
     }
@@ -475,6 +560,33 @@ public:
     return 0;
   }
 
+  uint32_t doUnaryOperator(const UnaryOperator *expr) {
+    const auto opcode = expr->getOpcode();
+    const auto *subExpr = expr->getSubExpr();
+    const auto subType = subExpr->getType();
+    const auto subValue = doExpr(subExpr);
+    const auto subTypeId = typeTranslator.translateType(subType);
+
+    switch (opcode) {
+    case UO_PreInc: {
+      const spv::Op spvOp = translateOp(BO_Add, subType);
+      const uint32_t one = getValueOne(subType);
+      const uint32_t originValue = theBuilder.createLoad(subTypeId, subValue);
+      const uint32_t incValue =
+          theBuilder.createBinaryOp(spvOp, subTypeId, originValue, one);
+      theBuilder.createStore(subValue, incValue);
+      // Prefix increment operator returns a lvalue.
+      return subValue;
+    }
+    default:
+      break;
+    }
+
+    emitError("unary operator '%0' unimplemented yet") << opcode;
+    expr->dump();
+    return 0;
+  }
+
   uint32_t doImplicitCastExpr(const ImplicitCastExpr *expr) {
     const Expr *subExpr = expr->getSubExpr();
     const QualType toType = expr->getType();
@@ -509,6 +621,8 @@ public:
     }
   }
 
+  /// Translates the given frontend binary operator into its SPIR-V equivalent
+  /// taking consideration of the operand type.
   spv::Op translateOp(BinaryOperator::Opcode op, QualType type) {
     // TODO: the following is not considering vector types yet.
     const bool isSintType = type->isSignedIntegerType();
@@ -562,6 +676,7 @@ case BO_##kind : {                                                             \
       //
       // Note there is no OpURem in SPIR-V.
       BIN_OP_CASE_SINT_UINT_FLOAT(Rem, SRem, UMod, FRem);
+      BIN_OP_CASE_SINT_UINT_FLOAT(LT, SLessThan, ULessThan, FOrdLessThan);
     default:
       break;
     }
@@ -573,6 +688,26 @@ case BO_##kind : {                                                             \
     return spv::Op::OpNop;
   }
 
+  /// Returns the <result-id> for constant value 1 of the given type.
+  uint32_t getValueOne(QualType type) {
+    if (type->isSignedIntegerType()) {
+      return theBuilder.getConstantInt32(1);
+    }
+
+    if (type->isUnsignedIntegerType()) {
+      return theBuilder.getConstantUint32(1);
+    }
+
+    if (type->isFloatingType()) {
+      return theBuilder.getConstantFloat32(1.0);
+    }
+
+    emitError("getting value 1 for type '%0' unimplemented") << type;
+    return 0;
+  }
+
+  /// Translates the given frontend APValue into its SPIR-V equivalent for the
+  /// given targetType.
   uint32_t translateAPValue(const APValue &value, const QualType targetType) {
     if (targetType->isBooleanType()) {
       const bool boolValue = value.getInt().getBoolValue();
@@ -593,6 +728,8 @@ case BO_##kind : {                                                             \
     return 0;
   }
 
+  /// Translates the given frontend APInt into its SPIR-V equivalent for the
+  /// given targetType.
   uint32_t translateAPInt(const llvm::APInt &intValue, QualType targetType) {
     const auto bitwidth = astContext.getIntWidth(targetType);
 
@@ -619,6 +756,8 @@ case BO_##kind : {                                                             \
     return 0;
   }
 
+  /// Translates the given frontend APFloat into its SPIR-V equivalent for the
+  /// given targetType.
   uint32_t translateAPFloat(const llvm::APFloat &floatValue,
                             QualType targetType) {
     const auto &semantics = astContext.getFloatTypeSemantics(targetType);

+ 19 - 4
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -153,6 +153,7 @@ uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
 
 void ModuleBuilder::createBranch(uint32_t targetLabel) {
   assert(insertPoint && "null insert point");
+
   instBuilder.opBranch(targetLabel).x();
   insertPoint->appendInstruction(std::move(constructSite));
 }
@@ -160,11 +161,25 @@ void ModuleBuilder::createBranch(uint32_t targetLabel) {
 void ModuleBuilder::createConditionalBranch(uint32_t condition,
                                             uint32_t trueLabel,
                                             uint32_t falseLabel,
-                                            uint32_t mergeLabel) {
+                                            uint32_t mergeLabel,
+                                            uint32_t continueLabel) {
   assert(insertPoint && "null insert point");
-  instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
-      .x();
-  insertPoint->appendInstruction(std::move(constructSite));
+
+  if (mergeLabel) {
+    if (continueLabel) {
+      instBuilder
+          .opLoopMerge(mergeLabel, continueLabel,
+                       spv::LoopControlMask::MaskNone)
+          .x();
+      insertPoint->appendInstruction(std::move(constructSite));
+    } else {
+      instBuilder
+          .opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
+          .x();
+      insertPoint->appendInstruction(std::move(constructSite));
+    }
+  }
+
   instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
   insertPoint->appendInstruction(std::move(constructSite));
 }

+ 23 - 0
tools/clang/test/CodeGenSPIRV/binary-op.comparison.scalar.hlsl

@@ -0,0 +1,23 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+    bool r;
+    int a, b;
+    uint i, j;
+    float o, p;
+
+// CHECK:      [[a:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[b:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: %{{\d+}} = OpSLessThan %bool [[a]] [[b]]
+    r = a < b;
+
+// CHECK:      [[i:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[j:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: %{{\d+}} = OpULessThan %bool [[i]] [[j]]
+    r = i < j;
+
+// CHECK:      [[o:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[p:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: %{{\d+}} = OpFOrdLessThan %bool [[o]] [[p]]
+    r = o < p;
+}

+ 61 - 0
tools/clang/test/CodeGenSPIRV/for-stmt.plain.hlsl

@@ -0,0 +1,61 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+    int val = 0;
+
+// CHECK: OpBranch %for_check
+// CHECK-LABEL: %for_check = OpLabel
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[lt0:%\d+]] = OpSLessThan %bool [[i0]] %int_10
+// CHECK-NEXT: OpLoopMerge %for_merge %for_continue None
+// CHECK-NEXT: OpBranchConditional [[lt0]] %for_body %for_merge
+    for (int i = 0; i < 10; ++i) {
+// CHECK-LABEL: %for_body = OpLabel
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: OpStore %val [[i1]]
+// CHECK-NEXT: OpBranch %for_continue
+        val = i;
+// CHECK-LABEL: %for_continue = OpLabel
+// CHECK-NEXT: [[i2:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[add0:%\d+]] = OpIAdd %int [[i2]] %int_1
+// CHECK-NEXT: OpStore %i [[add0]]
+// CHECK-NEXT: OpBranch %for_check
+    }
+
+// CHECK-LABEL: %for_merge = OpLabel
+// CHECK-NEXT: OpBranch %for_check_0
+// CHECK-LABEL: %for_check_0 = OpLabel
+// CHECK-NEXT: OpLoopMerge %for_merge_0 %for_continue_0 None
+// CHECK-NEXT: OpBranchConditional %true %for_body_0 %for_merge_0
+    // Infinite loop
+    for ( ; ; ) {
+// CHECK-LABEL: %for_body_0 = OpLabel
+// CHECK-NEXT: OpStore %val %int_0
+// CHECK-NEXT: OpBranch %for_continue_0
+        val = 0;
+// CHECK-LABEL: %for_continue_0 = OpLabel
+// CHECK-NEXT: OpBranch %for_check_0
+    }
+// CHECK-LABEL: %for_merge_0 = OpLabel
+// CHECK: OpBranch %for_check_1
+
+    // Null body
+// CHECK-LABEL: %for_check_1 = OpLabel
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[lt1:%\d+]] = OpSLessThan %bool [[j0]] %int_10
+// CHECK-NEXT: OpLoopMerge %for_merge_1 %for_continue_1 None
+// CHECK-NEXT: OpBranchConditional [[lt1]] %for_body_1 %for_merge_1
+    for (int j = 0; j < 10; ++j)
+// CHECK-LABEL: %for_body_1 = OpLabel
+// CHECK-NEXT: OpBranch %for_continue_1
+        ;
+// CHECK-LABEL: %for_continue_1 = OpLabel
+// CHECK-NEXT: [[j1:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[add1:%\d+]] = OpIAdd %int [[j1]] %int_1
+// CHECK-NEXT: OpStore %j [[add1]]
+// CHECK-NEXT: OpBranch %for_check_1
+
+// CHECK-LABEL: %for_merge_1 = OpLabel
+// CHECK-NEXT: OpReturn
+}

+ 64 - 0
tools/clang/test/CodeGenSPIRV/unary-op.prefix-inc.hlsl

@@ -0,0 +1,64 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+    int a, b;
+// CHECK:      [[a0:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[a1:%\d+]] = OpIAdd %int [[a0]] %int_1
+// CHECK-NEXT: OpStore %a [[a1]]
+// CHECK-NEXT: [[a2:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpStore %b [[a2]]
+    b = ++a;
+// CHECK-NEXT: [[b0:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[a3:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[a4:%\d+]] = OpIAdd %int [[a3]] %int_1
+// CHECK-NEXT: OpStore %a [[a4]]
+// CHECK-NEXT: OpStore %a [[b0]]
+    ++a = b;
+
+// Spot check a complicated usage case. No need to duplicate it for all types.
+
+// CHECK-NEXT: [[b1:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[b2:%\d+]] = OpIAdd %int [[b1]] %int_1
+// CHECK-NEXT: OpStore %b [[b2]]
+// CHECK-NEXT: [[b3:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[b4:%\d+]] = OpIAdd %int [[b3]] %int_1
+// CHECK-NEXT: OpStore %b [[b4]]
+// CHECK-NEXT: [[b5:%\d+]] = OpLoad %int %b
+
+// CHECK-NEXT: [[a5:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[a6:%\d+]] = OpIAdd %int [[a5]] %int_1
+// CHECK-NEXT: OpStore %a [[a6]]
+// CHECK-NEXT: [[a7:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[a8:%\d+]] = OpIAdd %int [[a7]] %int_1
+// CHECK-NEXT: OpStore %a [[a8]]
+// CHECK-NEXT: OpStore %a [[b5]]
+    ++(++a) = ++(++b);
+
+    uint i, j;
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[i1:%\d+]] = OpIAdd %uint [[i0]] %uint_1
+// CHECK-NEXT: OpStore %i [[i1]]
+// CHECK-NEXT: [[i2:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: OpStore %j [[i2]]
+    j = ++i;
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: [[i3:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[i4:%\d+]] = OpIAdd %uint [[i3]] %uint_1
+// CHECK-NEXT: OpStore %i [[i4]]
+// CHECK-NEXT: OpStore %i [[j0]]
+    ++i = j;
+
+    float o, p;
+// CHECK-NEXT: [[o0:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[01:%\d+]] = OpFAdd %float [[o0]] %float_1
+// CHECK-NEXT: OpStore %o [[01]]
+// CHECK-NEXT: [[o2:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: OpStore %p [[o2]]
+    p = ++o;
+// CHECK-NEXT: [[p0:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: [[o3:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[o4:%\d+]] = OpFAdd %float [[o3]] %float_1
+// CHECK-NEXT: OpStore %o [[o4]]
+// CHECK-NEXT: OpStore %o [[p0]]
+    ++o = p;
+}

+ 10 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -38,14 +38,24 @@ TEST_F(FileTest, ScalarTypes) { runFileTest("type.scalar.hlsl"); }
 
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
 
+TEST_F(FileTest, UnaryOpPrefixIncrement) {
+  runFileTest("unary-op.prefix-inc.hlsl");
+}
+
 TEST_F(FileTest, BinaryOpAssign) { runFileTest("binary-op.assign.hlsl"); }
 
 TEST_F(FileTest, BinaryOpScalarArithmetic) {
   runFileTest("binary-op.arithmetic.scalar.hlsl");
 }
 
+TEST_F(FileTest, BinaryOpScalarComparison) {
+  runFileTest("binary-op.comparison.scalar.hlsl");
+}
+
 TEST_F(FileTest, IfStmtPlainAssign) { runFileTest("if-stmt.plain.hlsl"); }
 
 TEST_F(FileTest, IfStmtNestedIfStmt) { runFileTest("if-stmt.nested.hlsl"); }
 
+TEST_F(FileTest, ForStmtPlainAssign) { runFileTest("for-stmt.plain.hlsl"); }
+
 } // namespace