2
0
Эх сурвалжийг харах

[spirv] Translate non-entry functions and function calls (#487)

Lei Zhang 8 жил өмнө
parent
commit
dc240f47e7

+ 46 - 1
docs/SPIR-V.rst

@@ -149,7 +149,7 @@ Normal scalar types
 Minimal precision scalar types
 ++++++++++++++++++++++++++++++
 
-HLSL also supports various `minimal precision scalar types <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509646(v=vs.85).aspx>`_, which graphics drivers can implement by using any precision greater than or equal to their specified bit precision. 
+HLSL also supports various `minimal precision scalar types <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509646(v=vs.85).aspx>`_, which graphics drivers can implement by using any precision greater than or equal to their specified bit precision.
 
 - ``min16float`` - minimum 16-bit floating point value
 - ``min10float`` - minimum 10-bit floating point value
@@ -260,6 +260,51 @@ Control flows
 
 [TODO]
 
+Functions
+---------
+
+All functions reachable from the entry-point function will be translated into SPIR-V code. Functions not reachable from the entry-point function will be ignored.
+
+Function parameter
+++++++++++++++++++
+
+For a function ``f`` which has a parameter of type ``T``, the generated SPIR-V signature will use type ``T*`` for the parameter. At every call site of ``f``, additional local variables will be allocated to hold the actual arguments. The local variables are passed in as direct function arguments. For example::
+
+  // HLSL source code
+
+  float4 f(float a, int b) { ... }
+
+  void caller(...) {
+    ...
+    float4 result = f(...);
+    ...
+  }
+
+  // SPIR-V code
+
+                ...
+  %i32PtrType = OpTypePointer Function %int
+  %f32PtrType = OpTypePointer Function %float
+      %fnType = OpTypeFunction %v4float %f32PtrType %i32PtrType
+                ...
+
+           %f = OpFunction %v4float None %fnType
+           %a = OpFunctionParameter %f32PtrType
+           %b = OpFunctionParameter %i32PtrType
+                ...
+
+      %caller = OpFunction ...
+                ...
+     %aAlloca = OpVariable %_ptr_Function_float Function
+     %bAlloca = OpVariable %_ptr_Function_int Function
+                ...
+                OpStore %aAlloca ...
+                OpStore %bAlloca ...
+      %result = OpFunctioncall %v4float %f %aAlloca %bAlloca
+                ...
+
+This approach gives us unified handling of function parameters and local variables: both of them are accessed via load/store instructions.
+
 Builtin functions
 -----------------
 

+ 14 - 7
tools/clang/include/clang/SPIRV/DeclResultIdMapper.h

@@ -59,22 +59,29 @@ public:
   /// the given function's parameter.
   void createStageVarFromFnParam(const ParmVarDecl *paramDecl);
 
-  /// \brief Registers a Decl's <result-id> without generating any SPIR-V
-  /// instruction.
+  /// \brief Registers a decl's <result-id> without generating any SPIR-V
+  /// instruction. The given decl will be treated as normal decl.
   void registerDeclResultId(const NamedDecl *symbol, uint32_t resultId);
 
   /// \brief Returns true if the given <result-id> is for a stage variable.
   bool isStageVariable(uint32_t varId) const;
 
-  /// \brief Returns the <result-id> for the given Decl.
+  /// \brief Returns the <result-id> for the given decl.
+  ///
+  /// This method will panic if the given decl is not registered.
   uint32_t getDeclResultId(const NamedDecl *decl) const;
 
-  /// \brief Returns the <result-id> for the given remapped Decl. Returns zero
-  /// if it is not a registered remapped Decl.
+  /// \brief Returns the <result-id> for the given decl if already registered;
+  /// otherwise, treats the given decl as a normal decl and returns a newly
+  /// assigned <result-id> for it.
+  uint32_t getOrRegisterDeclResultId(const NamedDecl *decl);
+
+  /// \brief Returns the <result-id> for the given remapped decl. Returns zero
+  /// if it is not a registered remapped decl.
   uint32_t getRemappedDeclResultId(const NamedDecl *decl) const;
 
-  /// \brief Returns the <result-id> for the given normal Decl. Returns zero if
-  /// it is not a registered normal Decl.
+  /// \brief Returns the <result-id> for the given normal decl. Returns zero if
+  /// it is not a registered normal decl.
   uint32_t getNormalDeclResultId(const NamedDecl *decl) const;
 
   /// \brief Returns all defined stage (builtin/input/ouput) variables in this

+ 26 - 13
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -34,28 +34,31 @@ public:
   /// \brief Constructs a ModuleBuilder with the given SPIR-V context.
   explicit ModuleBuilder(SPIRVContext *);
 
+  /// \brief Returns the associated SPIRVContext.
+  inline SPIRVContext *getSPIRVContext();
+
   /// \brief Takes the SPIR-V module under building. This will consume the
   /// module under construction.
   std::vector<uint32_t> takeModule();
 
   // === Function and Basic Block ===
 
-  /// \brief Begins building a SPIR-V function. At any time, there can only
-  /// exist at most one function under building. Returns the <result-id> for the
+  /// \brief Begins building a SPIR-V function. Returns the <result-id> for the
   /// function on success. Returns zero on failure.
+  ///
+  /// If the resultId supplied is not zero, the created function will use it;
+  /// otherwise, an unused <result-id> will be assgined.
+  /// At any time, there can only exist at most one function under building.
   uint32_t beginFunction(uint32_t funcType, uint32_t returnType,
-                         llvm::StringRef name = "");
+                         llvm::StringRef name = "", uint32_t resultId = 0);
 
-  /// \brief Registers a function parameter of the given type onto the current
-  /// function and returns its <result-id>.
-  uint32_t addFnParameter(uint32_t type, llvm::StringRef name = "");
+  /// \brief Registers a function parameter of the given pointer type in the
+  /// current function and returns its <result-id>.
+  uint32_t addFnParameter(uint32_t ptrType, llvm::StringRef name = "");
 
-  /// \brief Creates a local variable of the given value type in the current
+  /// \brief Creates a local variable of the given pointer type in the current
   /// function and returns its <result-id>.
-  ///
-  /// The corresponding pointer type of the given value type will be constructed
-  /// for the variable itself.
-  uint32_t addFnVariable(uint32_t valueType, llvm::StringRef name = "",
+  uint32_t addFnVariable(uint32_t ptrType, llvm::StringRef name = "",
                          llvm::Optional<uint32_t> init = llvm::None);
 
   /// \brief Ends building of the current function. Returns true of success,
@@ -104,6 +107,11 @@ public:
   /// address.
   void createStore(uint32_t address, uint32_t value);
 
+  /// \brief Creates a function call instruction and returns the <result-id> for
+  /// the return value.
+  uint32_t createFunctionCall(uint32_t returnType, uint32_t functionId,
+                              llvm::ArrayRef<uint32_t> params);
+
   /// \brief Creates an access chain instruction to retrieve the element from
   /// the given base by walking through the given indexes. Returns the
   /// <result-id> for the pointer to the element.
@@ -175,7 +183,7 @@ public:
   uint32_t getPointerType(uint32_t pointeeType, spv::StorageClass);
   uint32_t getStructType(llvm::ArrayRef<uint32_t> fieldTypes);
   uint32_t getFunctionType(uint32_t returnType,
-                           const std::vector<uint32_t> &paramTypes);
+                           llvm::ArrayRef<uint32_t> paramTypes);
 
   // === Constant ===
   uint32_t getConstantBool(bool value);
@@ -205,10 +213,15 @@ private:
   OrderedBasicBlockMap basicBlocks;      ///< The basic blocks under building.
   BasicBlock *insertPoint;               ///< The current insertion point.
 
-  std::vector<uint32_t> constructSite; ///< InstBuilder construction site.
+  /// An InstBuilder associated with the current ModuleBuilder.
+  /// It can be used to contruct instructions on the fly.
+  /// The constructed instruction will appear in constructSite.
   InstBuilder instBuilder;
+  std::vector<uint32_t> constructSite; ///< InstBuilder construction site.
 };
 
+SPIRVContext *ModuleBuilder::getSPIRVContext() { return &theContext; }
+
 bool ModuleBuilder::isCurrentBasicBlockTerminated() const {
   return insertPoint && insertPoint->isTerminated();
 }

+ 14 - 2
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -38,13 +38,25 @@ bool DeclResultIdMapper::isStageVariable(uint32_t varId) const {
 }
 
 uint32_t DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) const {
+  if (const uint32_t id = getNormalDeclResultId(decl))
+    return id;
   if (const uint32_t id = getRemappedDeclResultId(decl))
     return id;
+
+  assert(false && "found unregistered decl");
+  return 0;
+}
+
+uint32_t DeclResultIdMapper::getOrRegisterDeclResultId(const NamedDecl *decl) {
   if (const uint32_t id = getNormalDeclResultId(decl))
     return id;
+  if (const uint32_t id = getRemappedDeclResultId(decl))
+    return id;
 
-  assert(false && "found unregistered Decl in DeclResultIdMapper");
-  return 0;
+  const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
+  registerDeclResultId(decl, id);
+
+  return id;
 }
 
 uint32_t

+ 110 - 30
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -18,6 +18,7 @@
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "clang/SPIRV/TypeTranslator.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
 
 namespace clang {
 namespace spirv {
@@ -112,23 +113,21 @@ public:
 
     TranslationUnitDecl *tu = context.getTranslationUnitDecl();
 
-    // A queue of functions we need to translate.
-    std::deque<FunctionDecl *> workQueue;
-
     // The entry function is the seed of the queue.
     for (auto *decl : tu->decls()) {
       if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
         if (funcDecl->getName() == entryFunctionName) {
-          workQueue.push_back(funcDecl);
+          workQueue.insert(funcDecl);
         }
       }
     }
     // TODO: enlarge the queue upon seeing a function call.
 
     // Translate all functions reachable from the entry function.
-    while (!workQueue.empty()) {
-      doFunctionDecl(workQueue.front());
-      workQueue.pop_front();
+    // The queue can grow in the meanwhile; so need to keep evaluating
+    // workQueue.size().
+    for (uint32_t i = 0; i < workQueue.size(); ++i) {
+      doDecl(workQueue[i]);
     }
 
     theBuilder.addEntryPoint(shaderStage, entryFunctionId, entryFunctionName,
@@ -145,9 +144,11 @@ public:
         reinterpret_cast<const char *>(m.data()), m.size() * 4);
   }
 
-  void doDecl(Decl *decl) {
-    if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
+  void doDecl(const Decl *decl) {
+    if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
       doVarDecl(varDecl);
+    } else if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
+      doFunctionDecl(funcDecl);
     } else {
       // TODO: Implement handling of other Decl types.
       emitWarning("Decl type '%0' is not supported yet.")
@@ -155,50 +156,81 @@ public:
     }
   }
 
-  void doFunctionDecl(FunctionDecl *decl) {
+  void doFunctionDecl(const FunctionDecl *decl) {
     curFunction = decl;
 
     const llvm::StringRef funcName = decl->getName();
 
+    uint32_t funcId;
+
     if (funcName == entryFunctionName) {
       // First create stage variables for the entry point.
       declIdMapper.createStageVarFromFnReturn(decl);
-      for (auto *param : decl->params())
+      for (const auto *param : decl->params())
         declIdMapper.createStageVarFromFnParam(param);
 
       // Construct the function signature.
       const uint32_t voidType = theBuilder.getVoidType();
       const uint32_t funcType = theBuilder.getFunctionType(voidType, {});
-      const uint32_t funcId =
-          theBuilder.beginFunction(funcType, voidType, funcName);
 
-      if (decl->hasBody()) {
-        // The entry basic block.
-        const uint32_t entryLabel = theBuilder.createBasicBlock("bb.entry");
-        theBuilder.setInsertPoint(entryLabel);
+      // The entry function surely does not have pre-assigned <result-id> for
+      // it like other functions that got added to the work queue following
+      // function calls.
+      funcId = theBuilder.beginFunction(funcType, voidType, funcName);
 
-        // Process all statments in the body.
-        doStmt(decl->getBody());
+      // Record the entry function's <result-id>.
+      entryFunctionId = funcId;
+    } else {
+      const uint32_t retType =
+          typeTranslator.translateType(decl->getReturnType());
 
-        // We have processed all Stmts in this function and now in the last
-        // basic block. Make sure we have OpReturn if missing.
-        if (!theBuilder.isCurrentBasicBlockTerminated()) {
-          theBuilder.createReturn();
-        }
+      // Construct the function signature.
+      llvm::SmallVector<uint32_t, 4> paramTypes;
+      for (const auto *param : decl->params()) {
+        const uint32_t valueType =
+            typeTranslator.translateType(param->getType());
+        const uint32_t ptrType =
+            theBuilder.getPointerType(valueType, spv::StorageClass::Function);
+        paramTypes.push_back(ptrType);
       }
+      const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
+
+      // Non-entry functions are added to the work queue following function
+      // calls. We have already assigned <result-id>s for it when translating
+      // its call site. Query it here.
+      funcId = declIdMapper.getDeclResultId(decl);
+      theBuilder.beginFunction(funcType, retType, funcName, funcId);
+
+      // Create all parameters.
+      for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
+        const ParmVarDecl *paramDecl = decl->getParamDecl(i);
+        const uint32_t paramId =
+            theBuilder.addFnParameter(paramTypes[i], paramDecl->getName());
+        declIdMapper.registerDeclResultId(paramDecl, paramId);
+      }
+    }
 
-      theBuilder.endFunction();
+    if (decl->hasBody()) {
+      // The entry basic block.
+      const uint32_t entryLabel = theBuilder.createBasicBlock("bb.entry");
+      theBuilder.setInsertPoint(entryLabel);
 
-      // Record the entry function's <result-id>.
-      entryFunctionId = funcId;
-    } else {
-      emitError("Non-entry functions are not supported yet.");
+      // Process all statments in the body.
+      doStmt(decl->getBody());
+
+      // We have processed all Stmts in this function and now in the last
+      // basic block. Make sure we have OpReturn if missing.
+      if (!theBuilder.isCurrentBasicBlockTerminated()) {
+        theBuilder.createReturn();
+      }
     }
 
+    theBuilder.endFunction();
+
     curFunction = nullptr;
   }
 
-  void doVarDecl(VarDecl *decl) {
+  void doVarDecl(const VarDecl *decl) {
     if (decl->isLocalVarDecl()) {
       const uint32_t ptrType = theBuilder.getPointerType(
           typeTranslator.translateType(decl->getType()),
@@ -526,6 +558,8 @@ public:
       return doBinaryOperator(binOp);
     } else if (auto *unaryOp = dyn_cast<UnaryOperator>(expr)) {
       return doUnaryOperator(unaryOp);
+    } else if (auto *funcCall = dyn_cast<CallExpr>(expr)) {
+      return doCallExpr(funcCall);
     }
 
     emitError("Expr '%0' is not supported yet.") << expr->getStmtClassName();
@@ -639,13 +673,54 @@ public:
       const uint32_t resultType = typeTranslator.translateType(toType);
       return theBuilder.createLoad(resultType, fromValue);
     }
+    case CastKind::CK_FunctionToPointerDecay:
+      // Just need to return the function id
+      return doExpr(subExpr);
     default:
       emitError("ImplictCast Kind '%0' is not supported yet.")
           << expr->getCastKind();
+      expr->dump();
       return 0;
     }
   }
 
+  uint32_t doCallExpr(const CallExpr *callExpr) {
+    const FunctionDecl *callee = callExpr->getDirectCallee();
+
+    if (callee) {
+      const uint32_t returnType =
+          typeTranslator.translateType(callExpr->getType());
+
+      // Get or forward declare the function <result-id>
+      const uint32_t funcId = declIdMapper.getOrRegisterDeclResultId(callee);
+
+      // Evaluate parameters
+      llvm::SmallVector<uint32_t, 4> params;
+      for (const auto *arg : callExpr->arguments()) {
+        // We need to create variables for holding the values to be used as
+        // arguments. The variables themselves are of pointer types.
+        const uint32_t ptrType = theBuilder.getPointerType(
+            typeTranslator.translateType(arg->getType()),
+            spv::StorageClass::Function);
+
+        const uint32_t tempVarId = theBuilder.addFnVariable(ptrType);
+        theBuilder.createStore(tempVarId, doExpr(arg));
+
+        params.push_back(tempVarId);
+      }
+
+      // Push the callee into the work queue if it is not there.
+      if (!workQueue.count(callee)) {
+        workQueue.insert(callee);
+      }
+
+      return theBuilder.createFunctionCall(returnType, funcId, params);
+    }
+
+    emitError("calling non-function unimplemented");
+    return 0;
+  }
+
   /// Translates the given frontend binary operator into its SPIR-V equivalent
   /// taking consideration of the operand type.
   spv::Op translateOp(BinaryOperator::Opcode op, QualType type) {
@@ -833,6 +908,11 @@ private:
   DeclResultIdMapper declIdMapper;
   TypeTranslator typeTranslator;
 
+  /// A queue of decls reachable from the entry function. Decls inserted into
+  /// this queue will persist to avoid duplicated translations. And we'd like
+  /// a deterministic order of iterating the queue for finding the next decl
+  /// to translate. So we need SetVector here.
+  llvm::SetVector<const DeclaratorDecl *> workQueue;
   /// <result-id> for the entry function. Initially it is zero and will be reset
   /// when starting to translate the entry function.
   uint32_t entryFunctionId;

+ 21 - 12
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -33,17 +33,19 @@ std::vector<uint32_t> ModuleBuilder::takeModule() {
   });
 
   theModule.take(&ib);
-  return std::move(binary);
+  return binary;
 }
 
 uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
-                                      llvm::StringRef funcName) {
+                                      llvm::StringRef funcName, uint32_t fId) {
   if (theFunction) {
     assert(false && "found nested function");
     return 0;
   }
 
-  const uint32_t fId = theContext.takeNextId();
+  // If the caller doesn't supply a function <result-id>, we need to get one.
+  if (!fId)
+    fId = theContext.takeNextId();
 
   theFunction = llvm::make_unique<Function>(
       returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
@@ -52,24 +54,22 @@ uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
   return fId;
 }
 
-uint32_t ModuleBuilder::addFnParameter(uint32_t type, llvm::StringRef name) {
+uint32_t ModuleBuilder::addFnParameter(uint32_t ptrType, llvm::StringRef name) {
   assert(theFunction && "found detached parameter");
 
-  const uint32_t pointerType =
-      getPointerType(type, spv::StorageClass::Function);
   const uint32_t paramId = theContext.takeNextId();
-  theFunction->addParameter(pointerType, paramId);
+  theFunction->addParameter(ptrType, paramId);
   theModule.addDebugName(paramId, name);
 
   return paramId;
 }
 
-uint32_t ModuleBuilder::addFnVariable(uint32_t type, llvm::StringRef name,
+uint32_t ModuleBuilder::addFnVariable(uint32_t ptrType, llvm::StringRef name,
                                       llvm::Optional<uint32_t> init) {
   assert(theFunction && "found detached local variable");
 
   const uint32_t varId = theContext.takeNextId();
-  theFunction->addVariable(type, varId, init);
+  theFunction->addVariable(ptrType, varId, init);
   theModule.addDebugName(varId, name);
   return varId;
 }
@@ -152,6 +152,16 @@ void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
   insertPoint->appendInstruction(std::move(constructSite));
 }
 
+uint32_t ModuleBuilder::createFunctionCall(uint32_t returnType,
+                                           uint32_t functionId,
+                                           llvm::ArrayRef<uint32_t> params) {
+  assert(insertPoint && "null insert point");
+  const uint32_t id = theContext.takeNextId();
+  instBuilder.opFunctionCall(returnType, id, functionId, params).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return id;
+}
+
 uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
                                           llvm::ArrayRef<uint32_t> indexes) {
   assert(insertPoint && "null insert point");
@@ -322,9 +332,8 @@ uint32_t ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes) {
   return typeId;
 }
 
-uint32_t
-ModuleBuilder::getFunctionType(uint32_t returnType,
-                               const std::vector<uint32_t> &paramTypes) {
+uint32_t ModuleBuilder::getFunctionType(uint32_t returnType,
+                                        llvm::ArrayRef<uint32_t> paramTypes) {
   const Type *type = Type::getFunction(theContext, returnType, paramTypes);
   const uint32_t typeId = theContext.getResultIdForType(type);
   theModule.addType(type, typeId);

+ 92 - 0
tools/clang/test/CodeGenSPIRV/fn.call.hlsl

@@ -0,0 +1,92 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK-NOT: OpName %fnNoCaller "fnNoCaller"
+
+// CHECK: [[voidf:%\d+]] = OpTypeFunction %void
+// CHECK: [[intfint:%\d+]] = OpTypeFunction %int %_ptr_Function_int
+// CHECK: [[intfintint:%\d+]] = OpTypeFunction %int %_ptr_Function_int %_ptr_Function_int
+
+// CHECK-NOT: %fnNoCaller = OpFunction
+void fnNoCaller() {}
+
+int fnOneParm(int v) { return v; }
+
+int fnTwoParm(int m, int n) { return m + n; }
+
+int fnCallOthers(int v) { return fnOneParm(v); }
+
+// Recursive functions are disallowed by the front end.
+
+// CHECK: %main = OpFunction %void None [[voidf]]
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+// CHECK-NEXT: %v = OpVariable %_ptr_Function_int Function
+    int v;
+// CHECK-NEXT: [[oneParam:%\d+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[twoParam1:%\d+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[twoParam2:%\d+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[nestedParam1:%\d+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[nestedParam2:%\d+]] = OpVariable %_ptr_Function_int Function
+
+// CHECK-NEXT: OpStore [[oneParam]] %int_1
+// CHECK-NEXT: [[call0:%\d+]] = OpFunctionCall %int %fnOneParm [[oneParam]]
+// CHECK-NEXT: OpStore %v [[call0]]
+    v = fnOneParm(1); // Pass in constant; use return value
+
+// CHECK-NEXT: [[v0:%\d+]] = OpLoad %int %v
+// CHECK-NEXT: OpStore [[twoParam1]] [[v0]]
+// CHECK-NEXT: [[v1:%\d+]] = OpLoad %int %v
+// CHECK-NEXT: OpStore [[twoParam2]] [[v1]]
+// CHECK-NEXT: [[call2:%\d+]] = OpFunctionCall %int %fnTwoParm [[twoParam1]] [[twoParam2]]
+    fnTwoParm(v, v);  // Pass in variable; ignore return value
+
+// CHECK-NEXT: OpStore [[nestedParam2]] %int_1
+// CHECK-NEXT: [[call3:%\d+]] = OpFunctionCall %int %fnOneParm [[nestedParam2]]
+// CHECK-NEXT: OpStore [[nestedParam1]] [[call3]]
+// CHECK-NEXT: [[call4:%\d+]] = OpFunctionCall %int %fnCallOthers [[nestedParam1]]
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+    fnCallOthers(fnOneParm(1)); // Nested function calls
+}
+
+// CHECK-NOT: %fnNoCaller = OpFunction
+
+/* For int fnOneParm(int v) { return v; } */
+
+// %fnOneParm = OpFunction %int None [[intfint]]
+// %v_0 = OpFunctionParameter %_ptr_Function_int
+// %bb_entry_0 = OpLabel
+// [[v2:%\d+]] = OpLoad %int %v_0
+// OpReturnValue [[v2]]
+// OpFunctionEnd
+
+
+// CHECK-NOT: %fnNoCaller = OpFunction
+
+/* For int fnTwoParm(int m, int n) { return m + n; } */
+
+// %fnTwoParm = OpFunction %int None %27
+// %m = OpFunctionParameter %_ptr_Function_int
+// %n = OpFunctionParameter %_ptr_Function_int
+// %bb_entry_1 = OpLabel
+// [[m0:%\d+]] = OpLoad %int %m
+// [[n0:%\d+]] = OpLoad %int %n
+// [[add0:%\d+]] = OpIAdd %int [[m0]] [[n0]]
+// OpReturnValue [[add0]]
+// OpFunctionEnd
+
+// CHECK-NOT: %fnNoCaller = OpFunction
+
+/* For int fnCallOthers(int v) { return fnOneParm(v); } */
+
+// %fnCallOthers = OpFunction %int None [[intfintint]]
+// %v_1 = OpFunctionParameter %_ptr_Function_int
+// %bb_entry_2 = OpLabel
+// [[param0:%\d+]] = OpVariable %_ptr_Function_int Function
+// [[v3:%\d+]] = OpLoad %int %v_1
+// OpStore [[param0]] [[v3]]
+// [[call5:%\d+]] = OpFunctionCall %int %fnOneParm [[param0]]
+// OpReturnValue [[call5]]
+// OpFunctionEnd
+
+// CHECK-NOT: %fnNoCaller = OpFunction

+ 2 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -62,4 +62,6 @@ TEST_F(FileTest, ForStmtNestedForStmt) { runFileTest("for-stmt.nested.hlsl"); }
 
 TEST_F(FileTest, ControlFlowNestedIfForStmt) { runFileTest("cf.if.for.hlsl"); }
 
+TEST_F(FileTest, FunctionCall) { runFileTest("fn.call.hlsl"); }
+
 } // namespace