Browse Source

[spirv] Support function foward declaration (#847)

Lei Zhang 7 years ago
parent
commit
04a5269451

+ 18 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -445,8 +445,8 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
 SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
 SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
   expr = expr->IgnoreParens();
   expr = expr->IgnoreParens();
 
 
-  if (const auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
-    return declIdMapper.getDeclResultId(delRefExpr->getFoundDecl());
+  if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
+    return declIdMapper.getDeclResultId(declRefExpr->getDecl());
   }
   }
 
 
   if (const auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
   if (const auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
@@ -557,6 +557,8 @@ uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
 }
 }
 
 
 void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
 void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
+  assert(decl->isThisDeclarationADefinition());
+
   // A RAII class for maintaining the current function under traversal.
   // A RAII class for maintaining the current function under traversal.
   class FnEnvRAII {
   class FnEnvRAII {
   public:
   public:
@@ -1391,6 +1393,20 @@ SpirvEvalInfo SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
 SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
 SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   const FunctionDecl *callee = callExpr->getDirectCallee();
   const FunctionDecl *callee = callExpr->getDirectCallee();
 
 
+  // If we are calling a forward-declared function, callee will be the
+  // FunctionDecl for the foward-declared function, not the actual
+  // definition. The foward-delcaration and defintion are two completely
+  // different AST nodes.
+  // Note that we always want the defintion because Stmts/Exprs in the
+  // function body references the parameters in the definition.
+  if (!callee->isThisDeclarationADefinition()) {
+    // We need to update callee to the actual definition here
+    if (!callee->isDefined(callee)) {
+      emitError("found undefined function", callExpr->getExprLoc());
+      return 0;
+    }
+  }
+
   if (callee) {
   if (callee) {
     const auto numParams = callee->getNumParams();
     const auto numParams = callee->getNumParams();
     bool isNonStaticMemberCall = false;
     bool isNonStaticMemberCall = false;

+ 19 - 0
tools/clang/test/CodeGenSPIRV/fn.foward-declaration.hlsl

@@ -0,0 +1,19 @@
+// Run: %dxc -T ps_6_0 -E main
+
+float4 foo(float4 input);
+
+float4 main(float4 input: A) : SV_Target0
+{
+    return foo(input);
+}
+
+float4 foo(float4 input)
+{
+    return input;
+}
+
+// CHECK:  %src_main = OpFunction %v4float None {{%\d+}}
+
+// CHECK:              OpFunctionCall %v4float %foo
+
+// CHECK:       %foo = OpFunction %v4float None {{%\d+}}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -367,6 +367,9 @@ TEST_F(FileTest, ControlFlowConditionalOp) { runFileTest("cf.cond-op.hlsl"); }
 TEST_F(FileTest, FunctionCall) { runFileTest("fn.call.hlsl"); }
 TEST_F(FileTest, FunctionCall) { runFileTest("fn.call.hlsl"); }
 TEST_F(FileTest, FunctionDefaultArg) { runFileTest("fn.default-arg.hlsl"); }
 TEST_F(FileTest, FunctionDefaultArg) { runFileTest("fn.default-arg.hlsl"); }
 TEST_F(FileTest, FunctionInOutParam) { runFileTest("fn.param.inout.hlsl"); }
 TEST_F(FileTest, FunctionInOutParam) { runFileTest("fn.param.inout.hlsl"); }
+TEST_F(FileTest, FunctionFowardDeclaration) {
+  runFileTest("fn.foward-declaration.hlsl");
+}
 
 
 // For OO features
 // For OO features
 TEST_F(FileTest, StructMethodCall) { runFileTest("oo.struct.method.hlsl"); }
 TEST_F(FileTest, StructMethodCall) { runFileTest("oo.struct.method.hlsl"); }