Browse Source

[spirv] Handle early return and discard statements (#546)

Ehsan 8 years ago
parent
commit
546597d994

+ 3 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -155,6 +155,9 @@ public:
                     uint32_t defaultLabel,
                     llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target);
 
+  /// \brief Creates a fragment-shader discard via by emitting OpKill.
+  void createKill();
+
   /// \brief Creates an unconditional branch to the given target label.
   /// If mergeBB and continueBB are non-zero, it creates an OpLoopMerge
   /// instruction followed by an unconditional branch to the given target label.

+ 7 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -243,6 +243,13 @@ void ModuleBuilder::createSwitch(
   insertPoint->appendInstruction(std::move(constructSite));
 }
 
+void ModuleBuilder::createKill() {
+  assert(insertPoint && "null insert point");
+  assert(!isCurrentBasicBlockTerminated());
+  instBuilder.opKill().x();
+  insertPoint->appendInstruction(std::move(constructSite));
+}
+
 void ModuleBuilder::createBranch(uint32_t targetLabel, uint32_t mergeBB,
                                  uint32_t continueBB,
                                  spv::LoopControlMask loopControl) {

+ 34 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -262,6 +262,8 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
     doBreakStmt(breakStmt);
   } else if (const auto *theDoStmt = dyn_cast<DoStmt>(stmt)) {
     doDoStmt(theDoStmt, attrs);
+  } else if (const auto *discardStmt = dyn_cast<DiscardStmt>(stmt)) {
+    doDiscardStmt(discardStmt);
   } else if (const auto *continueStmt = dyn_cast<ContinueStmt>(stmt)) {
     doContinueStmt(continueStmt);
   } else if (const auto *whileStmt = dyn_cast<WhileStmt>(stmt)) {
@@ -383,6 +385,11 @@ uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
 }
 
 void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
+  // We are about to start translation for a new function. Clear the break stack
+  // and the continue stack.
+  breakStack = std::stack<uint32_t>();
+  continueStack = std::stack<uint32_t>();
+
   curFunction = decl;
 
   const llvm::StringRef funcName = decl->getName();
@@ -538,6 +545,16 @@ spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) {
   return spv::LoopControlMask::MaskNone;
 }
 
+void SPIRVEmitter::doDiscardStmt(const DiscardStmt *discardStmt) {
+  assert(!theBuilder.isCurrentBasicBlockTerminated());
+  theBuilder.createKill();
+  if (!isLastStmtBeforeControlFlowBranching(astContext, discardStmt)) {
+    const uint32_t unreachableBB =
+        theBuilder.createBasicBlock("unreachable", /*isReachable*/ false);
+    theBuilder.setInsertPoint(unreachableBB);
+  }
+}
+
 void SPIRVEmitter::doDoStmt(const DoStmt *theDoStmt,
                             llvm::ArrayRef<const Attr *> attrs) {
   // do-while loops are composed of:
@@ -934,6 +951,17 @@ void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) {
 }
 
 void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
+  processReturnStmt(stmt);
+
+  // Handle early returns
+  if (!isLastStmtBeforeControlFlowBranching(astContext, stmt)) {
+    const uint32_t unreachableBB =
+        theBuilder.createBasicBlock("unreachable", /*isReachable*/ false);
+    theBuilder.setInsertPoint(unreachableBB);
+  }
+}
+
+void SPIRVEmitter::processReturnStmt(const ReturnStmt *stmt) {
   // For normal functions, just return in the normal way.
   if (curFunction->getName() != entryFunctionName) {
     theBuilder.createReturnValue(doExpr(stmt->getRetValue()));
@@ -958,6 +986,12 @@ void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
     return;
   }
 
+  // RetValue is nullptr when "return;" is used for a void function.
+  if (!stmt->getRetValue()) {
+    theBuilder.createReturn();
+    return;
+  }
+
   QualType retType = stmt->getRetValue()->getType();
 
   if (const auto *structType = retType->getAsStructureType()) {

+ 7 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -74,6 +74,7 @@ private:
   void doVarDecl(const VarDecl *decl);
 
   void doBreakStmt(const BreakStmt *stmt);
+  void doDiscardStmt(const DiscardStmt *stmt);
   inline void doDeclStmt(const DeclStmt *stmt);
   void doForStmt(const ForStmt *, llvm::ArrayRef<const Attr *> attrs = {});
   void doIfStmt(const IfStmt *ifStmt);
@@ -96,6 +97,12 @@ private:
   uint32_t doMemberExpr(const MemberExpr *expr);
   uint32_t doUnaryOperator(const UnaryOperator *expr);
 
+private:
+  /// Translates the return statement into its SPIR-V equivalent. Also generates
+  /// necessary instructions for the entry function ensuring that the signature
+  /// matches the SPIR-V requirements.
+  void processReturnStmt(const ReturnStmt *stmt);
+
 private:
   /// Translates the given frontend binary operator into its SPIR-V equivalent
   /// taking consideration of the operand type.

+ 31 - 0
tools/clang/test/CodeGenSPIRV/cf.discard.hlsl

@@ -0,0 +1,31 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// According to the HLS spec, discard can only be called from a pixel shader.
+// This translates to OpKill in SPIR-V. OpKill must be the last instruction in a block.
+
+void main() {
+  int a, b;
+  bool cond = true;
+  
+  while(cond) {
+// CHECK: %while_body = OpLabel
+    if(a==b) {
+// CHECK: %if_true = OpLabel
+// CHECK-NEXT: OpKill
+      {{discard;}}
+      discard;  // No SPIR-V should be emitted for this statement.
+      break;    // No SPIR-V should be emitted for this statement.
+    } else {
+// CHECK-NEXT: %if_false = OpLabel
+      ++a;
+// CHECK: OpKill
+      discard;
+      continue; // No SPIR-V should be emitted for this statement.
+      --b;      // No SPIR-V should be emitted for this statement.
+    }
+// CHECK-NEXT: %if_merge = OpLabel
+
+  }
+// CHECK: %while_merge = OpLabel
+
+}

+ 88 - 0
tools/clang/test/CodeGenSPIRV/cf.return.early.float4.hlsl

@@ -0,0 +1,88 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK: [[v4f1:%\d+]] = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+// CHECK: [[v4f2:%\d+]] = OpConstantComposite %v4float %float_2 %float_2 %float_2 %float_2
+// CHECK: [[v4f3:%\d+]] = OpConstantComposite %v4float %float_3 %float_3 %float_3 %float_3
+// CHECK: [[v4f4:%\d+]] = OpConstantComposite %v4float %float_4 %float_4 %float_4 %float_4
+// CHECK: [[v4f5:%\d+]] = OpConstantComposite %v4float %float_5 %float_5 %float_5 %float_5
+// CHECK: [[v4f6:%\d+]] = OpConstantComposite %v4float %float_6 %float_6 %float_6 %float_6
+// CHECK: [[v4f7:%\d+]] = OpConstantComposite %v4float %float_7 %float_7 %float_7 %float_7
+// CHECK: [[v4f8:%\d+]] = OpConstantComposite %v4float %float_8 %float_8 %float_8 %float_8
+// CHECK: [[v4f9:%\d+]] = OpConstantComposite %v4float %float_9 %float_9 %float_9 %float_9
+
+float4 myfunc() {
+  int a, b;
+  bool cond = true;
+
+  while(cond) {
+    switch(b) {
+// CHECK: %switch_1 = OpLabel
+      case 1:
+        a = 1;
+// CHECK: OpReturnValue [[v4f1]]
+        return float4(1.0, 1.0, 1.0, 1.0);
+// CHECK-NEXT: %switch_2 = OpLabel
+      case 2: {
+        a = 3;
+// CHECK: OpReturnValue [[v4f2]]
+        {return float4(2.0, 2.0, 2.0, 2.0);}   // Return from function.
+        a = 4;                                 // No SPIR-V should be emitted for this statement.
+        break;                                 // No SPIR-V should be emitted for this statement.
+      }
+// CHECK-NEXT: %switch_5 = OpLabel
+      case 5 : {
+        a = 5;
+// CHECK: OpReturnValue [[v4f3]]
+        {{return float4(3.0, 3.0, 3.0, 3.0);}} // Return from function.
+        a = 6;                                 // No SPIR-V should be emitted for this statement.
+      }
+// CHECK-NEXT: %switch_default = OpLabel
+      default:
+        for (int i=0; i<10; ++i) {
+          if (cond) {
+// CHECK: %if_true = OpLabel
+// CHECK-NEXT: OpReturnValue [[v4f4]]
+            return float4(4.0, 4.0, 4.0, 4.0);    // Return from function.
+            return float4(5.0, 5.0, 5.0, 5.0);    // No SPIR-V should be emitted for this statement.
+            continue;                             // No SPIR-V should be emitted for this statement.
+            break;                                // No SPIR-V should be emitted for this statement.
+            ++a;                                  // No SPIR-V should be emitted for this statement.
+          } else {
+// CHECK-NEXT: %if_false = OpLabel
+// CHECK-NEXT: OpReturnValue [[v4f6]]
+            return float4(6.0, 6.0, 6.0, 6.0);;   // Return from function
+            continue;                             // No SPIR-V should be emitted for this statement.
+            break;                                // No SPIR-V should be emitted for this statement.
+            ++a;                                  // No SPIR-V should be emitted for this statement.
+          }
+        }
+// CHECK: %for_merge = OpLabel
+
+// CHECK-NEXT: OpReturnValue [[v4f7]]
+        // Return from function.
+        // Even though this statement will never be executed [because both "if" and "else" above have return statements],
+        // SPIR-V code should be emitted for it as we do not analyze the logic.
+        return float4(7.0, 7.0, 7.0, 7.0);
+    }
+// CHECK: %switch_merge = OpLabel
+
+// CHECK-NEXT: OpReturnValue [[v4f8]]
+    // Return from function.
+    // Even though this statement will never be executed [because all "case" statements above contain a return statement],
+    // SPIR-V code should be emitted for it as we do not analyze the logic.
+    return float4(8.0, 8.0, 8.0, 8.0);
+  }
+// CHECK: %while_merge = OpLabel
+
+// CHECK-NEXT: OpReturnValue [[v4f9]]
+  // Return from function.
+  // Even though this statement will never be executed [because any iteration of the loop above executes a return statement],
+  // SPIR-V code should be emitted for it as we do not analyze the logic.
+  return float4(9.0, 9.0, 9.0, 9.0);
+
+// CHECK-NEXT: OpFunctionEnd
+}
+
+void main() {
+  float4 result = myfunc();
+}

+ 69 - 0
tools/clang/test/CodeGenSPIRV/cf.return.early.hlsl

@@ -0,0 +1,69 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+  int a, b;
+  bool cond = true;
+
+  while(cond) {
+    switch(b) {
+// CHECK: %switch_1 = OpLabel    
+      case 1:
+        a = 1;
+// CHECK: OpReturn
+        return;
+// CHECK-NEXT: %switch_2 = OpLabel
+      case 2: {
+        a = 3;
+// CHECK: OpReturn
+        {return;}   // Return from function.
+        a = 4;      // No SPIR-V should be emitted for this statement.
+        break;      // No SPIR-V should be emitted for this statement.
+      }
+// CHECK-NEXT: %switch_5 = OpLabel
+      case 5 : {
+        a = 5;
+// CHECK: OpReturn
+        {{return;}} // Return from function.
+        a = 6;      // No SPIR-V should be emitted for this statement.
+      }
+// CHECK-NEXT: %switch_default = OpLabel
+      default:
+        for (int i=0; i<10; ++i) {
+          if (cond) {
+// CHECK: %if_true = OpLabel
+// CHECK-NEXT: OpReturn
+            return;    // Return from function.
+            return;    // No SPIR-V should be emitted for this statement.
+            continue;  // No SPIR-V should be emitted for this statement.
+            break;     // No SPIR-V should be emitted for this statement.
+            ++a;       // No SPIR-V should be emitted for this statement.
+          } else {
+// CHECK-NEXT: %if_false = OpLabel
+// CHECK-NEXT: OpReturn
+            return;   // Return from function
+            continue; // No SPIR-V should be emitted for this statement.
+            break;    // No SPIR-V should be emitted for this statement.
+            ++a;      // No SPIR-V should be emitted for this statement.
+          }
+        }
+// CHECK: %for_merge = OpLabel
+
+// CHECK-NEXT: OpReturn
+        // Return from function.
+        // Even though this statement will never be executed [because both "if" and "else" above have return statements],
+        // SPIR-V code should be emitted for it as we do not analyze the logic.
+        return;
+    }
+// CHECK: %switch_merge = OpLabel
+
+// CHECK-NEXT: OpReturn
+    // Return from function.
+    // Even though this statement will never be executed [because all "case" statements above contain a return statement],
+    // SPIR-V code should be emitted for it as we do not analyze the logic.
+    return;
+  }
+// CHECK: %while_merge = OpLabel
+
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+}

+ 7 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -233,6 +233,13 @@ TEST_F(FileTest, ControlFlowConditionalOp) { runFileTest("cf.cond-op.hlsl"); }
 // For function calls
 TEST_F(FileTest, FunctionCall) { runFileTest("fn.call.hlsl"); }
 
+// For early returns
+TEST_F(FileTest, EarlyReturn) { runFileTest("cf.return.early.hlsl"); }
+TEST_F(FileTest, EarlyReturnFloat4) { runFileTest("cf.return.early.float4.hlsl"); }
+
+// For discard
+TEST_F(FileTest, Discard) { runFileTest("cf.discard.hlsl"); }
+
 // For semantics
 TEST_F(FileTest, SemanticPositionVS) {
   runFileTest("semantic.position.vs.hlsl");