Browse Source

[spirv] Handle variable definitions in if statements (#536)

Also use BinaryOperator methods to test compound assignment and
add tests for typedefs.
Lei Zhang 8 years ago
parent
commit
c5a2da0275

+ 10 - 25
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -104,24 +104,6 @@ bool isFloatOrVecMatOfFloatType(QualType type) {
           hlsl::GetHLSLMatElementType(type)->isFloatingType());
 }
 
-bool isCompoundAssignment(BinaryOperatorKind opcode) {
-  switch (opcode) {
-  case BO_AddAssign:
-  case BO_SubAssign:
-  case BO_MulAssign:
-  case BO_DivAssign:
-  case BO_RemAssign:
-  case BO_AndAssign:
-  case BO_OrAssign:
-  case BO_XorAssign:
-  case BO_ShlAssign:
-  case BO_ShrAssign:
-    return true;
-  default:
-    return false;
-  }
-}
-
 bool isSpirvMatrixOp(spv::Op opcode) {
   switch (opcode) {
   case spv::Op::OpMatrixTimesMatrix:
@@ -210,9 +192,7 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
   } else if (const auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
     doReturnStmt(retStmt);
   } else if (const auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
-    for (auto *decl : declStmt->decls()) {
-      doDecl(decl);
-    }
+    doDeclStmt(declStmt);
   } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
     doIfStmt(ifStmt);
   } else if (const auto *switchStmt = dyn_cast<SwitchStmt>(stmt)) {
@@ -698,6 +678,9 @@ void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) {
   //         +-> | merge | <-+                  +---> | merge |
   //             +-------+                            +-------+
 
+  if (const auto *declStmt = ifStmt->getConditionVariableDeclStmt())
+    doDeclStmt(declStmt);
+
   // First emit the instruction for evaluating the condition.
   const uint32_t condition = doExpr(ifStmt->getCond());
 
@@ -1531,7 +1514,7 @@ uint32_t SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
                             : mandateGenOpcode;
 
   uint32_t rhsVal, lhsPtr, lhsVal;
-  if (isCompoundAssignment(opcode)) {
+  if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
     // Evalute rhs before lhs
     rhsVal = doExpr(rhs);
     lhsVal = lhsPtr = doExpr(lhs);
@@ -2021,7 +2004,7 @@ uint32_t SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
   const spv::Op spvOp = translateOp(opcode, lhsType);
 
   uint32_t rhsVal, lhsPtr, lhsVal;
-  if (isCompoundAssignment(opcode)) {
+  if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
     // Evalute rhs before lhs
     rhsVal = doExpr(rhs);
     lhsPtr = doExpr(lhs);
@@ -2728,7 +2711,7 @@ void SPIRVEmitter::processSwitchStmtUsingSpirvOpSwitch(
   // For example: handle 'int a = b' in the following:
   // switch (int a = b) {...}
   if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
-    doStmt(condVarDeclStmt);
+    doDeclStmt(condVarDeclStmt);
 
   const uint32_t selector = doExpr(switchStmt->getCond());
 
@@ -2765,7 +2748,7 @@ void SPIRVEmitter::processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
   // For example: handle 'int a = b' in the following:
   // switch (int a = b) {...}
   if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
-    doStmt(condVarDeclStmt);
+    doDeclStmt(condVarDeclStmt);
 
   // Figure out the indexes of CaseStmts (and DefaultStmt if it exists) in
   // the flattened switch AST.
@@ -2813,6 +2796,8 @@ void SPIRVEmitter::processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
       bo->setType(astContext.getLogicalOperationType());
       curIf->setCond(bo);
       curIf->setThen(cs);
+      // No conditional variable associated with this faux if statement.
+      curIf->setConditionVariable(astContext, nullptr);
       // Each If statement is the "else" of the previous if statement.
       if (prevIfStmt)
         prevIfStmt->setElse(curIf);

+ 7 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -72,12 +72,13 @@ private:
   void doVarDecl(const VarDecl *decl);
 
   void doBreakStmt(const BreakStmt *stmt);
-  void doWhileStmt(const WhileStmt *, llvm::ArrayRef<const Attr *> attrs = {});
+  inline void doDeclStmt(const DeclStmt *stmt);
   void doForStmt(const ForStmt *, llvm::ArrayRef<const Attr *> attrs = {});
   void doIfStmt(const IfStmt *ifStmt);
   void doReturnStmt(const ReturnStmt *stmt);
   void doSwitchStmt(const SwitchStmt *stmt,
                     llvm::ArrayRef<const Attr *> attrs = {});
+  void doWhileStmt(const WhileStmt *, llvm::ArrayRef<const Attr *> attrs = {});
 
   uint32_t doBinaryOperator(const BinaryOperator *expr);
   uint32_t doCallExpr(const CallExpr *callExpr);
@@ -414,6 +415,11 @@ private:
   llvm::DenseMap<const Stmt *, uint32_t> stmtBasicBlock;
 };
 
+void SPIRVEmitter::doDeclStmt(const DeclStmt *declStmt) {
+  for (auto *decl : declStmt->decls())
+    doDecl(decl);
+}
+
 } // end namespace spirv
 } // end namespace clang
 

+ 14 - 0
tools/clang/test/CodeGenSPIRV/if-stmt.plain.hlsl

@@ -65,5 +65,19 @@ void main() {
         ;
 
 // CHECK-LABEL: %if_merge_2 = OpLabel
+
+// CHECK-NEXT: [[val4:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: OpStore %d [[val4]]
+// CHECK-NEXT: [[d:%\d+]] = OpLoad %int %d
+// CHECK-NEXT: [[cmp:%\d+]] = OpINotEqual %bool [[d]] %int_0
+// CHECK-NEXT: OpSelectionMerge %if_merge_3 None
+// CHECK-NEXT: OpBranchConditional [[cmp]] %if_true_3 %if_merge_3
+    if (int d = val) {
+// CHECK-LABEL: %if_true_3 = OpLabel
+// CHECK-NEXT: OpStore %c %true
+        c = true;
+// CHECK-NEXT: OpBranch %if_merge_3
+// CHECK-LABEL:%if_merge_3 = OpLabel
+    }
 // CHECK-NEXT: OpReturn
 }

+ 19 - 0
tools/clang/test/CodeGenSPIRV/type.typedef.hlsl

@@ -0,0 +1,19 @@
+// Run: %dxc -T vs_6_0 -E main
+
+typedef int myInt;
+typedef const uint myConstUint;
+typedef float4 v4f;
+typedef float2x3 m2v3f;
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+// CHECK: %v1 = OpVariable %_ptr_Function_int Function
+    myInt v1;
+// CHECK: %v2 = OpVariable %_ptr_Function_uint Function
+    myConstUint v2;
+// CHECK: %v3 = OpVariable %_ptr_Function_v4float Function
+    v4f v3;
+// CHECK: %v4 = OpVariable %_ptr_Function_mat2v3float Function
+    m2v3f v4;
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -38,6 +38,7 @@ TEST_F(WholeFileTest, ConstantPixelShader) {
 TEST_F(FileTest, ScalarTypes) { runFileTest("type.scalar.hlsl"); }
 TEST_F(FileTest, VectorTypes) { runFileTest("type.vector.hlsl"); }
 TEST_F(FileTest, MatrixTypes) { runFileTest("type.matrix.hlsl"); }
+TEST_F(FileTest, TypedefTypes) { runFileTest("type.typedef.hlsl"); }
 
 // For constants
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }