Procházet zdrojové kódy

[spirv] More v2, including more hybrid types.

Ehsan před 6 roky
rodič
revize
2bae115d7f

+ 8 - 0
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -202,9 +202,12 @@ public:
 
   const SpirvPointerType *getPointerType(const SpirvType *pointee,
                                          spv::StorageClass);
+  const HybridPointerType *getPointerType(QualType pointee, spv::StorageClass);
 
   const FunctionType *getFunctionType(const SpirvType *ret,
                                       llvm::ArrayRef<const SpirvType *> param);
+  const HybridFunctionType *
+  getFunctionType(QualType ret, llvm::ArrayRef<const SpirvType *> param);
 
   const StructType *getByteAddressBufferType(bool isWritable);
   const StructType *getACSBufferCounterType();
@@ -316,6 +319,9 @@ private:
   using SCToPtrTyMap =
       llvm::DenseMap<spv::StorageClass, const SpirvPointerType *,
                      StorageClassDenseMapInfo>;
+  using SCToHybridPtrTyMap =
+      llvm::DenseMap<spv::StorageClass, const HybridPointerType *,
+                     StorageClassDenseMapInfo>;
 
   // Vector/matrix types for each possible element count.
   // Type at index is for vector of index components. Index 0/1 is unused.
@@ -334,8 +340,10 @@ private:
   llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;
 
   llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
+  llvm::DenseMap<QualType, SCToHybridPtrTyMap> hybridPointerTypes;
 
   llvm::SmallVector<const FunctionType *, 8> functionTypes;
+  llvm::SmallVector<const HybridFunctionType *, 8> hybridFunctionTypes;
 
   // Unique constants
   // We currently do a linear search to find an existing constant (if any). This

+ 7 - 3
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -55,19 +55,23 @@ public:
   /// on failure.
   ///
   /// At any time, there can only exist at most one function under building.
-  SpirvFunction *beginFunction(QualType returnType, SourceLocation,
+  SpirvFunction *beginFunction(QualType returnType,
+                               const SpirvType *functionType, SourceLocation,
                                llvm::StringRef name = "");
 
   /// \brief Creates and registers a function parameter of the given pointer
   /// type in the current function and returns its pointer.
   SpirvFunctionParameter *addFnParam(QualType ptrType, SourceLocation,
                                      llvm::StringRef name = "");
+  SpirvFunctionParameter *addFnParam(const SpirvType *ptrType, SourceLocation,
+                                     llvm::StringRef name = "");
 
   /// \brief Creates a SpirvFunction object and adds it to the list of module
   /// functions. This does not change the current function under construction.
   /// The handle can be used to create function call instructions for functions
-  /// that we have not yet discovered in the source code.
-  SpirvFunction *createFunction(QualType returnType, SourceLocation,
+  /// that we have not yet been discovered in the source code.
+  SpirvFunction *createFunction(QualType returnType,
+                                const SpirvType *functionType, SourceLocation,
                                 llvm::StringRef name = "",
                                 bool isAlias = false);
 

+ 6 - 5
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -24,8 +24,9 @@ class SpirvVisitor;
 /// The class representing a SPIR-V function in memory.
 class SpirvFunction {
 public:
-  SpirvFunction(QualType type, uint32_t id, spv::FunctionControlMask,
-                SourceLocation, llvm::StringRef name = "");
+  SpirvFunction(QualType returnType, const SpirvType *fnSpirvType, uint32_t id,
+                spv::FunctionControlMask, SourceLocation,
+                llvm::StringRef name = "");
   ~SpirvFunction() = default;
 
   // Forbid copy construction and assignment
@@ -59,7 +60,7 @@ public:
   // Sets the SPIR-V type of the function
   void setFunctionType(FunctionType *type) { fnType = type; }
   // Returns the SPIR-V type of the function
-  FunctionType *getFunctionType() const { return fnType; }
+  const FunctionType *getFunctionType() const { return fnType; }
 
   // Sets the result-id of the OpTypeFunction
   void setFunctionTypeId(uint32_t id) { fnTypeId = id; }
@@ -86,8 +87,8 @@ private:
   SpirvType *returnType;  ///< The lowered return type
   uint32_t returnTypeId;  ///< result-id for the return type
 
-  FunctionType *fnType; ///< The SPIR-V function type
-  uint32_t fnTypeId;    ///< result-id for the SPIR-V function type
+  const SpirvType *fnType; ///< The SPIR-V function type
+  uint32_t fnTypeId;       ///< result-id for the SPIR-V function type
 
   bool containsAlias; ///< Whether function return type is aliased
   bool rvalue;        ///< Whether the return value is an rvalue

+ 60 - 2
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -37,9 +37,11 @@ public:
     TK_Array,
     TK_RuntimeArray,
     TK_Struct,
-    TK_HybridStruct, // TODO: Remove once HybridStrcut type is removed.
     TK_Pointer,
     TK_Function,
+    TK_HybridStruct,
+    TK_HybridPointer,
+    TK_HybridFunction,
   };
 
   virtual ~SpirvType() = default;
@@ -310,6 +312,10 @@ public:
   const SpirvType *getPointeeType() const { return pointeeType; }
   spv::StorageClass getStorageClass() const { return storageClass; }
 
+  bool operator==(const SpirvPointerType &that) const {
+    return pointeeType == that.pointeeType && storageClass == that.storageClass;
+  }
+
 private:
   const SpirvType *pointeeType;
   spv::StorageClass storageClass;
@@ -337,12 +343,22 @@ private:
   llvm::SmallVector<const SpirvType *, 8> paramTypes;
 };
 
+class HybridType : public SpirvType {
+public:
+  static bool classof(const SpirvType *t) {
+    return t->getKind() >= TK_HybridStruct && t->getKind() <= TK_HybridFunction;
+  }
+
+protected:
+  HybridType(Kind k) : SpirvType(k) {}
+};
+
 /// **NOTE**: This type is created in order to facilitate transition of old
 /// infrastructure to the new infrastructure. Using this type should be avoided
 /// as much as possible.
 ///
 /// This type uses a mix of SpirvType and QualType for the structure fields.
-class HybridStructType : public SpirvType {
+class HybridStructType : public HybridType {
 public:
   enum class InterfaceType : uint32_t {
     InternalStorage = 0,
@@ -401,6 +417,48 @@ private:
   InterfaceType interfaceType;
 };
 
+class HybridPointerType : public HybridType {
+public:
+  HybridPointerType(QualType pointee, spv::StorageClass sc)
+      : HybridType(TK_HybridPointer), pointeeType(pointee), storageClass(sc) {}
+
+  static bool classof(const SpirvType *t) { return t->getKind() == TK_Pointer; }
+
+  QualType getPointeeType() const { return pointeeType; }
+  spv::StorageClass getStorageClass() const { return storageClass; }
+
+  bool operator==(const HybridPointerType &that) const {
+    return pointeeType == that.pointeeType && storageClass == that.storageClass;
+  }
+
+private:
+  QualType pointeeType;
+  spv::StorageClass storageClass;
+};
+
+// This class can be extended to also accept QualType vector as param types.
+class HybridFunctionType : public HybridType {
+public:
+  HybridFunctionType(QualType ret, llvm::ArrayRef<const SpirvType *> param)
+      : HybridType(TK_HybridFunction), returnType(ret),
+        paramTypes(param.begin(), param.end()) {}
+
+  static bool classof(const SpirvType *t) {
+    return t->getKind() == TK_Function;
+  }
+
+  bool operator==(const HybridFunctionType &that) const {
+    return returnType == that.returnType && paramTypes == that.paramTypes;
+  }
+
+  QualType getReturnType() const { return returnType; }
+  llvm::ArrayRef<const SpirvType *> getParamTypes() const { return paramTypes; }
+
+private:
+  QualType returnType;
+  llvm::SmallVector<const SpirvType *, 8> paramTypes;
+};
+
 } // end namespace spirv
 } // end namespace clang
 

+ 36 - 0
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -290,6 +290,22 @@ const SpirvPointerType *SpirvContext::getPointerType(const SpirvType *pointee,
   return pointerTypes[pointee][sc] = new (this) SpirvPointerType(pointee, sc);
 }
 
+const HybridPointerType *SpirvContext::getPointerType(QualType pointee,
+                                                      spv::StorageClass sc) {
+  auto foundPointee = hybridPointerTypes.find(pointee);
+
+  if (foundPointee != hybridPointerTypes.end()) {
+    auto &pointeeMap = foundPointee->second;
+    auto foundSC = pointeeMap.find(sc);
+
+    if (foundSC != pointeeMap.end())
+      return foundSC->second;
+  }
+
+  return hybridPointerTypes[pointee][sc] =
+             new (this) HybridPointerType(pointee, sc);
+}
+
 const FunctionType *
 SpirvContext::getFunctionType(const SpirvType *ret,
                               llvm::ArrayRef<const SpirvType *> param) {
@@ -308,6 +324,26 @@ SpirvContext::getFunctionType(const SpirvType *ret,
   return functionTypes.back();
 }
 
+const HybridFunctionType *
+SpirvContext::getFunctionType(QualType ret,
+                              llvm::ArrayRef<const SpirvType *> param) {
+  // Create a temporary object for finding in the vector.
+  HybridFunctionType type(ret, param);
+
+  auto found =
+      std::find_if(hybridFunctionTypes.begin(), hybridFunctionTypes.end(),
+                   [&type](const HybridFunctionType *cachedType) {
+                     return type == *cachedType;
+                   });
+
+  if (found != hybridFunctionTypes.end())
+    return *found;
+
+  hybridFunctionTypes.push_back(new (this) HybridFunctionType(ret, param));
+
+  return hybridFunctionTypes.back();
+}
+
 const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
   // Create a uint RuntimeArray.
   const auto *raType = getRuntimeArrayType(getUIntType(32));

+ 95 - 90
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1087,37 +1087,39 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
 
   // We are about to start translation for a new function. Clear the break stack
   // and the continue stack.
-  breakStack = std::stack<uint32_t>();
-  continueStack = std::stack<uint32_t>();
+  breakStack = std::stack<SpirvBasicBlock *>();
+  continueStack = std::stack<SpirvBasicBlock *>();
 
   // This will allow the entry-point name to be something like
   // myNamespace::myEntrypointFunc.
   std::string funcName = getFnName(decl);
 
-  uint32_t funcId = 0;
+  SpirvFunction *func = nullptr;
 
   if (funcName == entryFunctionName) {
     // The entry function surely does not have pre-assigned <result-id> for
     // it like other functions that got added to the work queue following
     // function calls.
-    funcId = theContext.takeNextId();
+    func = declIdMapper.getOrRegisterFn(decl);
     funcName = "src." + funcName;
 
     // Create wrapper for the entry function
-    if (!emitEntryFunctionWrapper(decl, funcId))
+    if (!emitEntryFunctionWrapper(decl, func))
       return;
   } else {
     // Non-entry functions are added to the work queue following function
     // calls. We have already assigned <result-id>s for it when translating
     // its call site. Query it here.
-    funcId = declIdMapper.getDeclEvalInfo(decl);
+    // TODO(ehsan): just call getOrRegisterFn in both cases.
+    func = declIdMapper.getOrRegisterFn(decl);
+    // funcId = declIdMapper.getDeclEvalInfo(decl);
   }
 
-  const uint32_t retType =
+  const QualType retType =
       declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
 
   // Construct the function signature.
-  llvm::SmallVector<uint32_t, 4> paramTypes;
+  llvm::SmallVector<const SpirvType *, 4> paramTypes;
 
   bool isNonStaticMemberFn = false;
   if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
@@ -1126,29 +1128,29 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
     if (isNonStaticMemberFn) {
       // For non-static member function, the first parameter should be the
       // object on which we are invoking this method.
-      const uint32_t valueType = typeTranslator.translateType(
-          memberFn->getThisType(astContext)->getPointeeType());
-      const uint32_t ptrType =
-          theBuilder.getPointerType(valueType, spv::StorageClass::Function);
+      const QualType valueType =
+          memberFn->getThisType(astContext)->getPointeeType();
+      const SpirvType *ptrType =
+          spvContext.getPointerType(valueType, spv::StorageClass::Function);
       paramTypes.push_back(ptrType);
     }
   }
 
   for (const auto *param : decl->params()) {
-    const uint32_t valueType =
+    const QualType valueType =
         declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
-    const uint32_t ptrType =
-        theBuilder.getPointerType(valueType, spv::StorageClass::Function);
+    const SpirvType *ptrType =
+        spvContext.getPointerType(valueType, spv::StorageClass::Function);
     paramTypes.push_back(ptrType);
   }
 
-  const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
-  theBuilder.beginFunction(funcType, retType, funcName, funcId);
+  const auto *funcType = spvContext.getFunctionType(retType, paramTypes);
+  spvBuilder.beginFunction(retType, funcType, decl->getLocation(), funcName);
 
   if (isNonStaticMemberFn) {
     // Remember the parameter for the this object so later we can handle
     // CXXThisExpr correctly.
-    curThis = theBuilder.addFnParam(paramTypes[0], "param.this");
+    curThis = spvBuilder.addFnParam(paramTypes[0], "param.this");
   }
 
   // Create all parameters.
@@ -1159,25 +1161,24 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
 
   if (decl->hasBody()) {
     // The entry basic block.
-    const uint32_t entryLabel = theBuilder.createBasicBlock("bb.entry");
-    theBuilder.setInsertPoint(entryLabel);
+    auto *entryLabel = spvBuilder.createBasicBlock("bb.entry");
+    spvBuilder.setInsertPoint(entryLabel);
 
     // Process all statments in the body.
     doStmt(decl->getBody());
 
     // We have processed all Stmts in this function and now in the last
     // basic block. Make sure we have a termination instruction.
-    if (!theBuilder.isCurrentBasicBlockTerminated()) {
+    if (!spvBuilder.isCurrentBasicBlockTerminated()) {
       const auto retType = decl->getReturnType();
 
       if (retType->isVoidType()) {
-        theBuilder.createReturn();
+        spvBuilder.createReturn();
       } else {
         // If the source code does not provide a proper return value for some
         // control flow path, it's undefined behavior. We just return null
         // value here.
-        theBuilder.createReturnValue(
-            theBuilder.getConstantNull(typeTranslator.translateType(retType)));
+        spvBuilder.createReturnValue(spvContext.getConstantNull(retType));
       }
     }
   }
@@ -6776,7 +6777,8 @@ SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
 #undef INTRINSIC_OP_CASE
 #undef INTRINSIC_OP_CASE_INT_FLOAT
 
-  return SpirvEvalInfo(retVal).setRValue();
+  retVal->setRValue();
+  return retVal;
 }
 
 SpirvInstruction *
@@ -9387,13 +9389,15 @@ bool SPIRVEmitter::processTessellationShaderAttributes(
 }
 
 bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
-                                            const uint32_t entryFuncId) {
+                                            SpirvFunction *entryFuncInstr) {
   // HS specific attributes
   uint32_t numOutputControlPoints = 0;
-  uint32_t outputControlPointIdVal = 0; // SV_OutputControlPointID value
-  uint32_t primitiveIdVar = 0;          // SV_PrimitiveID variable
-  uint32_t viewIdVar = 0;               // SV_ViewID variable
-  uint32_t hullMainInputPatchParam = 0; // Temporary parameter for InputPatch<>
+  SpirvInstruction *outputControlPointIdVal =
+      nullptr;                                // SV_OutputControlPointID value
+  SpirvInstruction *primitiveIdVar = nullptr; // SV_PrimitiveID variable
+  SpirvInstruction *viewIdVar = nullptr;      // SV_ViewID variable
+  SpirvInstruction *hullMainInputPatchParam =
+      nullptr; // Temporary parameter for InputPatch<>
 
   // The array size of per-vertex input/output variables
   // Used by HS/DS/GS for the additional arrayness, zero means not an array.
@@ -9401,14 +9405,14 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   uint32_t outputArraySize = 0;
 
   // Construct the wrapper function signature.
-  const uint32_t voidType = theBuilder.getVoidType();
-  const uint32_t funcType = theBuilder.getFunctionType(voidType, {});
+  const SpirvType *voidType = spvContext.getVoidType();
+  const FunctionType *funcType = spvContext.getFunctionType(voidType, {});
 
   // The wrapper entry function surely does not have pre-assigned <result-id>
   // for it like other functions that got added to the work queue following
   // function calls. And the wrapper is the entry function.
   entryFunction = spvBuilder.beginFunction(
-      astContext.VoidTy, /*SourceLocation*/ {}, decl->getName());
+      astContext.VoidTy, funcType, /*SourceLocation*/ {}, decl->getName());
   // Note this should happen before using declIdMapper for other tasks.
   declIdMapper.setEntryFunction(entryFunction);
 
@@ -9489,8 +9493,8 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   declIdMapper.glPerVertex.requireCapabilityIfNecessary();
 
   // The entry basic block.
-  const uint32_t entryLabel = theBuilder.createBasicBlock();
-  theBuilder.setInsertPoint(entryLabel);
+  auto *entryLabel = spvBuilder.createBasicBlock();
+  spvBuilder.setInsertPoint(entryLabel);
 
   // Initialize all global variables at the beginning of the wrapper
   for (const VarDecl *varDecl : toInitGloalVars) {
@@ -9504,18 +9508,19 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     // If not explicitly initialized, initialize with their zero values if not
     // resource objects
     else if (!hlsl::IsHLSLResourceType(varDecl->getType())) {
-      const auto typeId = typeTranslator.translateType(varDecl->getType());
-      theBuilder.createStore(varInfo, theBuilder.getConstantNull(typeId));
+      const QualType type = varDecl->getType();
+      auto *nullValue = spvContext.getConstantNull(varDecl->getType());
+      spvBuilder.createStore(varInfo, nullValue);
     }
   }
 
   // Create temporary variables for holding function call arguments
-  llvm::SmallVector<uint32_t, 4> params;
+  llvm::SmallVector<SpirvInstruction *, 4> params;
   for (const auto *param : decl->params()) {
     const auto paramType = param->getType();
-    const uint32_t typeId = typeTranslator.translateType(paramType);
     std::string tempVarName = "param.var." + param->getNameAsString();
-    const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
+    auto *tempVar =
+        spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName);
 
     params.push_back(tempVar);
 
@@ -9530,14 +9535,14 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
         hullMainInputPatchParam = tempVar;
       }
 
-      uint32_t loadedValue = 0;
+      SpirvInstruction *loadedValue = nullptr;
 
       if (!declIdMapper.createStageInputVar(param, &loadedValue, false))
         return false;
 
       // Only initialize the temporary variable if the parameter is indeed used.
       if (param->isUsed()) {
-        theBuilder.createStore(tempVar, loadedValue);
+        spvBuilder.createStore(tempVar, loadedValue);
       }
 
       // Record the temporary variable holding SV_OutputControlPointID,
@@ -9553,9 +9558,8 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   }
 
   // Call the original entry function
-  const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
-  const uint32_t retVal =
-      theBuilder.createFunctionCall(retType, entryFuncId, params);
+  const QualType retType = decl->getReturnType();
+  auto *retVal = spvBuilder.createFunctionCall(retType, entryFuncInstr, params);
 
   // Create and write stage output variables for return value. Special case for
   // Hull shaders since they operate differently in 2 ways:
@@ -9584,8 +9588,7 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     const auto *param = decl->getParamDecl(i);
     if (canActAsOutParmVar(param)) {
       // Load the value from the parameter after function call
-      const uint32_t typeId = typeTranslator.translateType(param->getType());
-      uint32_t loadedParam = 0;
+      SpirvInstruction *loadedParam = nullptr;
 
       // No need to write back the value if the parameter is not used at all in
       // the original entry function.
@@ -9594,15 +9597,15 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
       // .Append() intrinsic method. No need to load the parameter since we
       // won't need to write back here.
       if (param->isUsed() && !shaderModel.IsGS())
-        loadedParam = theBuilder.createLoad(typeId, params[i]);
+        loadedParam = spvBuilder.createLoad(param->getType(), params[i]);
 
       if (!declIdMapper.createStageOutputVar(param, loadedParam, false))
         return false;
     }
   }
 
-  theBuilder.createReturn();
-  theBuilder.endFunction();
+  spvBuilder.createReturn();
+  spvBuilder.endFunction();
 
   // For Hull shaders, there is no explicit call to the PCF in the HLSL source.
   // We should invoke a translation of the PCF manually.
@@ -9613,9 +9616,10 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
 }
 
 bool SPIRVEmitter::processHSEntryPointOutputAndPCF(
-    const FunctionDecl *hullMainFuncDecl, uint32_t retType, uint32_t retVal,
-    uint32_t numOutputControlPoints, uint32_t outputControlPointId,
-    uint32_t primitiveId, uint32_t viewId, uint32_t hullMainInputPatch) {
+    const FunctionDecl *hullMainFuncDecl, QualType retType,
+    SpirvInstruction *retVal, uint32_t numOutputControlPoints,
+    SpirvInstruction *outputControlPointId, SpirvInstruction *primitiveId,
+    SpirvInstruction *viewId, SpirvInstruction *hullMainInputPatch) {
   // This method may only be called for Hull shaders.
   assert(shaderModel.IsHS());
 
@@ -9641,19 +9645,21 @@ bool SPIRVEmitter::processHSEntryPointOutputAndPCF(
     return false;
   }
 
-  uint32_t hullMainOutputPatch = 0;
+  SpirvInstruction *hullMainOutputPatch = nullptr;
   // If the patch constant function (PCF) takes the result of the Hull main
   // entry point, create a temporary function-scope variable and write the
   // results to it, so it can be passed to the PCF.
   if (patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
-    const uint32_t hullMainRetType = theBuilder.getArrayType(
-        retType, theBuilder.getConstantUint32(numOutputControlPoints));
-    hullMainOutputPatch =
-        theBuilder.addFnVar(hullMainRetType, "temp.var.hullMainRetVal");
-    const auto tempLocation = theBuilder.createAccessChain(
-        theBuilder.getPointerType(retType, spv::StorageClass::Function),
-        hullMainOutputPatch, {outputControlPointId});
-    theBuilder.createStore(tempLocation, retVal);
+    // ehsan was here.
+    const QualType hullMainRetType = astContext.getConstantArrayType(
+        retType, llvm::APInt(32, numOutputControlPoints),
+        clang::ArrayType::Normal, 0);
+    hullMainOutputPatch = spvBuilder.addFnVar(
+        hullMainRetType, /*SourceLocation*/ {}, "temp.var.hullMainRetVal");
+    // Note (ehsan): Using value type rather than pointer type in access chain.
+    auto *tempLocation = spvBuilder.createAccessChain(
+        retType, hullMainOutputPatch, {outputControlPointId});
+    spvBuilder.createStore(tempLocation, retVal);
   }
 
   // Now create a barrier before calling the Patch Constant Function (PCF).
@@ -9661,32 +9667,30 @@ bool SPIRVEmitter::processHSEntryPointOutputAndPCF(
   // Execution Barrier scope = Workgroup (2)
   // Memory Barrier scope = Invocation (4)
   // Memory Semantics Barrier scope = None (0)
-  const auto constZero = theBuilder.getConstantUint32(0);
-  const auto constFour = theBuilder.getConstantUint32(4);
-  const auto constTwo = theBuilder.getConstantUint32(2);
-  theBuilder.createBarrier(constTwo, constFour, constZero);
+  spvBuilder.createBarrier(spv::Scope::Invocation,
+                           spv::MemorySemanticsMask::MaskNone,
+                           spv::Scope::Workgroup);
 
   // The PCF should be called only once. Therefore, we check the invocationID,
   // and we only allow ID 0 to call the PCF.
-  const uint32_t condition = theBuilder.createBinaryOp(
-      spv::Op::OpIEqual, theBuilder.getBoolType(), outputControlPointId,
-      theBuilder.getConstantUint32(0));
-  const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
-  const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
-  theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
-  theBuilder.addSuccessor(thenBB);
-  theBuilder.addSuccessor(mergeBB);
-  theBuilder.setMergeTarget(mergeBB);
-
-  theBuilder.setInsertPoint(thenBB);
+  auto *condition = spvBuilder.createBinaryOp(
+      spv::Op::OpIEqual, astContext.BoolTy, outputControlPointId,
+      spvContext.getConstantUint32(0));
+  auto *thenBB = spvBuilder.createBasicBlock("if.true");
+  auto *mergeBB = spvBuilder.createBasicBlock("if.merge");
+  spvBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
+  spvBuilder.addSuccessor(thenBB);
+  spvBuilder.addSuccessor(mergeBB);
+  spvBuilder.setMergeTarget(mergeBB);
+
+  spvBuilder.setInsertPoint(thenBB);
 
   // Call the PCF. Since the function is not explicitly called, we must first
   // register an ID for it.
-  const uint32_t pcfId = declIdMapper.getOrRegisterFnResultId(patchConstFunc);
-  const uint32_t pcfRetType =
-      typeTranslator.translateType(patchConstFunc->getReturnType());
+  SpirvFunction *pcfId = declIdMapper.getOrRegisterFn(patchConstFunc);
+  const QualType pcfRetType = patchConstFunc->getReturnType();
 
-  std::vector<uint32_t> pcfParams;
+  std::vector<SpirvInstruction *> pcfParams;
 
   // A lambda for creating a stage input variable and its associated temporary
   // variable for function call. Also initializes the temporary variable using
@@ -9694,12 +9698,13 @@ bool SPIRVEmitter::processHSEntryPointOutputAndPCF(
   // of the temporary variable.
   const auto createParmVarAndInitFromStageInputVar =
       [this](const ParmVarDecl *param) {
-        const uint32_t typeId = typeTranslator.translateType(param->getType());
+        const QualType type = param->getType();
         std::string tempVarName = "param.var." + param->getNameAsString();
-        const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
-        uint32_t loadedValue = 0;
+        auto *tempVar =
+            spvBuilder.addFnVar(type, param->getLocation(), tempVarName);
+        SpirvInstruction *loadedValue = nullptr;
         declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true);
-        theBuilder.createStore(tempVar, loadedValue);
+        spvBuilder.createStore(tempVar, loadedValue);
         return tempVar;
       };
 
@@ -9728,15 +9733,15 @@ bool SPIRVEmitter::processHSEntryPointOutputAndPCF(
           << param->getName();
     }
   }
-  const uint32_t pcfResultId =
-      theBuilder.createFunctionCall(pcfRetType, pcfId, {pcfParams});
+  auto *pcfResultId =
+      spvBuilder.createFunctionCall(pcfRetType, pcfId, {pcfParams});
   if (!declIdMapper.createStageOutputVar(patchConstFunc, pcfResultId,
                                          /*forPCF*/ true))
     return false;
 
-  theBuilder.createBranch(mergeBB);
-  theBuilder.addSuccessor(mergeBB);
-  theBuilder.setInsertPoint(mergeBB);
+  spvBuilder.createBranch(mergeBB);
+  spvBuilder.addSuccessor(mergeBB);
+  spvBuilder.setInsertPoint(mergeBB);
   return true;
 }
 

+ 6 - 7
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -628,7 +628,7 @@ private:
   /// The wrapper function is also responsible for initializing global static
   /// variables for some cases.
   bool emitEntryFunctionWrapper(const FunctionDecl *entryFunction,
-                                uint32_t entryFuncId);
+                                SpirvFunction *entryFuncId);
 
   /// \brief Performs the following operations for the Hull shader:
   /// * Creates an output variable which is an Array containing results for all
@@ -652,12 +652,11 @@ private:
   ///
   /// The method panics if it is called for any shader kind other than Hull
   /// shaders.
-  bool processHSEntryPointOutputAndPCF(const FunctionDecl *hullMainFuncDecl,
-                                       uint32_t retType, uint32_t retVal,
-                                       uint32_t numOutputControlPoints,
-                                       uint32_t outputControlPointId,
-                                       uint32_t primitiveId, uint32_t viewId,
-                                       uint32_t hullMainInputPatch);
+  bool processHSEntryPointOutputAndPCF(
+      const FunctionDecl *hullMainFuncDecl, QualType retType,
+      SpirvInstruction *retVal, uint32_t numOutputControlPoints,
+      SpirvInstruction *outputControlPointId, SpirvInstruction *primitiveId,
+      SpirvInstruction *viewId, SpirvInstruction *hullMainInputPatch);
 
 private:
   /// \brief Returns true iff *all* the case values in the given switch

+ 20 - 4
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -21,20 +21,24 @@ SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx,
 }
 
 SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
+                                           const SpirvType *functionType,
                                            SourceLocation loc,
                                            llvm::StringRef funcName) {
   assert(!function && "found nested function");
-  function = new (context) SpirvFunction(
-      returnType, /*id*/ 0, spv::FunctionControlMask::MaskNone, loc, funcName);
+  function = new (context)
+      SpirvFunction(returnType, functionType, /*id*/ 0,
+                    spv::FunctionControlMask::MaskNone, loc, funcName);
   return function;
 }
 
 SpirvFunction *SpirvBuilder::createFunction(QualType returnType,
+                                            const SpirvType *functionType,
                                             SourceLocation loc,
                                             llvm::StringRef funcName,
                                             bool isAlias) {
-  function = new (context) SpirvFunction(
-      returnType, /*id*/ 0, spv::FunctionControlMask::MaskNone, loc, funcName);
+  SpirvFunction *fn = new (context)
+      SpirvFunction(returnType, functionType, /*id*/ 0,
+                    spv::FunctionControlMask::MaskNone, loc, funcName);
   function->setConstainsAliasComponent(isAlias);
   module->addFunction(function);
   return function;
@@ -50,6 +54,18 @@ SpirvFunctionParameter *SpirvBuilder::addFnParam(QualType ptrType,
   return param;
 }
 
+SpirvFunctionParameter *SpirvBuilder::addFnParam(const SpirvType *ptrType,
+                                                 SourceLocation loc,
+                                                 llvm::StringRef name) {
+  assert(function && "found detached parameter");
+  auto *param =
+      new (context) SpirvFunctionParameter(/*QualType*/ {}, /*id*/ 0, loc);
+  param->setResultType(ptrType);
+  param->setDebugName(name);
+  function->addParameter(param);
+  return param;
+}
+
 SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
                                       llvm::StringRef name,
                                       SpirvInstruction *init) {

+ 5 - 5
tools/clang/lib/SPIRV/SpirvFunction.cpp

@@ -13,12 +13,12 @@
 namespace clang {
 namespace spirv {
 
-SpirvFunction::SpirvFunction(QualType type, uint32_t id,
-                             spv::FunctionControlMask control,
+SpirvFunction::SpirvFunction(QualType returnType, const SpirvType *functionType,
+                             uint32_t id, spv::FunctionControlMask control,
                              SourceLocation loc, llvm::StringRef name)
-    : functionId(id), astReturnType(type), returnType(nullptr), returnTypeId(0),
-      fnType(nullptr), fnTypeId(0), functionControl(control), functionLoc(loc),
-      functionName(name) {}
+    : functionId(id), astReturnType(returnType), returnType(nullptr),
+      returnTypeId(0), fnType(functionType), fnTypeId(0),
+      functionControl(control), functionLoc(loc), functionName(name) {}
 
 bool SpirvFunction::invokeVisitor(Visitor *visitor) {
   if (!visitor->visit(this, Visitor::Phase::Init))

+ 1 - 1
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -71,7 +71,7 @@ bool StructType::operator==(const StructType &that) const {
 HybridStructType::HybridStructType(
     llvm::ArrayRef<HybridStructType::FieldInfo> fieldsVec, llvm::StringRef name,
     bool isReadOnly, HybridStructType::InterfaceType iface)
-    : SpirvType(TK_HybridStruct), fields(fieldsVec.begin(), fieldsVec.end()),
+    : HybridType(TK_HybridStruct), fields(fieldsVec.begin(), fieldsVec.end()),
       structName(name), readOnly(isReadOnly), interfaceType(iface) {}
 
 bool HybridStructType::FieldInfo::