Pārlūkot izejas kodu

[spirv] Cleanup: Remove SPIRV hybrid function type. (#2620)

this is a cleanup that will help us with debug type representation
later.
Ehsan 5 gadi atpakaļ
vecāks
revīzija
b0731e7f2d

+ 2 - 1
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -54,7 +54,8 @@ public:
   /// on failure.
   ///
   /// At any time, there can only exist at most one function under building.
-  SpirvFunction *beginFunction(QualType returnType, SpirvType *functionType,
+  SpirvFunction *beginFunction(QualType returnType,
+                               llvm::ArrayRef<QualType> paramTypes,
                                SourceLocation, llvm::StringRef name = "",
                                bool isPrecise = false,
                                SpirvFunction *func = nullptr);

+ 0 - 3
tools/clang/include/clang/SPIRV/SpirvContext.h

@@ -202,9 +202,6 @@ public:
 
   const HybridPointerType *getPointerType(QualType pointee, spv::StorageClass);
 
-  HybridFunctionType *getFunctionType(QualType ret,
-                                      llvm::ArrayRef<QualType> param);
-
   /// Functions to get/set current entry point ShaderModelKind.
   ShaderModelKind getCurrentShaderModelKind() { return curShaderModelKind; }
   void setCurrentShaderModelKind(ShaderModelKind smk) {

+ 16 - 5
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -24,8 +24,9 @@ class SpirvVisitor;
 /// The class representing a SPIR-V function in memory.
 class SpirvFunction {
 public:
-  SpirvFunction(QualType astReturnType, SpirvType *fnSpirvType, SourceLocation,
-                llvm::StringRef name = "", bool precise = false);
+  SpirvFunction(QualType astReturnType, llvm::ArrayRef<QualType> astParamTypes,
+                SourceLocation, llvm::StringRef name = "",
+                bool precise = false);
   ~SpirvFunction() = default;
 
   // Forbid copy construction and assignment
@@ -47,9 +48,18 @@ public:
   // Returns the lowered (SPIR-V) return type.
   SpirvType *getReturnType() const { return returnType; }
 
+  // Sets the function AST return type
   void setAstReturnType(QualType type) { astReturnType = type; }
+  // Gets the function AST return type
   QualType getAstReturnType() const { return astReturnType; }
 
+  // Sets the vector of parameter QualTypes.
+  void setAstParamTypes(llvm::ArrayRef<QualType> paramTypes) {
+    astParamTypes.append(paramTypes.begin(), paramTypes.end());
+  }
+  // Gets the vector of parameter QualTypes.
+  llvm::ArrayRef<QualType> getAstParamTypes() const { return astParamTypes; }
+
   // Sets the SPIR-V type of the function
   void setFunctionType(SpirvType *type) { fnType = type; }
   // Returns the SPIR-V type of the function
@@ -88,9 +98,10 @@ public:
 private:
   uint32_t functionId; ///< This function's <result-id>
 
-  QualType astReturnType; ///< The return type
-  SpirvType *returnType;  ///< The lowered return type
-  SpirvType *fnType;      ///< The SPIR-V function type
+  QualType astReturnType;                       ///< The return type
+  llvm::SmallVector<QualType, 4> astParamTypes; ///< The paratemer types in AST
+  SpirvType *returnType;                        ///< The lowered return type
+  SpirvType *fnType;                            ///< The SPIR-V function type
 
   bool relaxedPrecision; ///< Whether the return type is at relaxed precision
   bool precise;          ///< Whether the return value is 'precise'

+ 0 - 23
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -503,29 +503,6 @@ private:
   QualType imageType;
 };
 
-// This class can be extended to also accept QualType vector as param types.
-class HybridFunctionType : public HybridType {
-public:
-  HybridFunctionType(QualType ret, llvm::ArrayRef<QualType> param)
-      : HybridType(TK_HybridFunction), returnType(ret),
-        paramTypes(param.begin(), param.end()) {}
-
-  static bool classof(const SpirvType *t) {
-    return t->getKind() == TK_HybridFunction;
-  }
-
-  bool operator==(const HybridFunctionType &that) const {
-    return returnType == that.returnType && paramTypes == that.paramTypes;
-  }
-
-  QualType getReturnType() const { return returnType; }
-  llvm::ArrayRef<QualType> getParamTypes() const { return paramTypes; }
-
-private:
-  QualType returnType;
-  llvm::SmallVector<QualType, 8> paramTypes;
-};
-
 //
 // Function Definition for templated functions
 //

+ 6 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -1015,8 +1015,13 @@ SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
   (void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias);
 
   const bool isPrecise = fn->hasAttr<HLSLPreciseAttr>();
+  // Note: we do not need to worry about function parameter types at this point
+  // as this is used when function declarations are seen. When function
+  // definition is seen, the parameter types will be set properly and take into
+  // account whether the function is a member function of a class/struct (in
+  // which case a 'this' parameter is added at the beginnig).
   SpirvFunction *spirvFunction = new (spvContext)
-      SpirvFunction(fn->getReturnType(), /*functionType*/ nullptr,
+      SpirvFunction(fn->getReturnType(), /* param QualTypes */ {},
                     fn->getLocation(), fn->getName(), isPrecise);
 
   // No need to dereference to get the pointer. Function returns that are

+ 14 - 21
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -44,10 +44,18 @@ bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
                   /*SourceLocation*/ {});
     fn->setReturnType(const_cast<SpirvType *>(spirvReturnType));
 
-    // Lower the SPIR-V function type if necessary.
-    fn->setFunctionType(const_cast<SpirvType *>(
-        lowerType(fn->getFunctionType(), SpirvLayoutRule::Void,
-                  fn->getSourceLocation())));
+    // Lower the function parameter types.
+    auto paramQualTypes = fn->getAstParamTypes();
+    llvm::SmallVector<const SpirvType *, 4> spirvParamTypes;
+    for (auto qualtype : paramQualTypes) {
+      const auto *spirvParamType =
+          lowerType(qualtype, SpirvLayoutRule::Void,
+                    /*isRowMajor*/ llvm::None, fn->getSourceLocation());
+      spirvParamTypes.push_back(spvContext.getPointerType(
+          spirvParamType, spv::StorageClass::Function));
+    }
+    fn->setFunctionType(
+        spvContext.getFunctionType(spirvReturnType, spirvParamTypes));
   }
   return true;
 }
@@ -154,22 +162,6 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
         lowerType(imageAstType, rule, /*isRowMajor*/ llvm::None, loc);
     assert(isa<ImageType>(imageSpirvType));
     return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
-  } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(type)) {
-    // Lower the return type.
-    const QualType astReturnType = hybridFn->getReturnType();
-    const SpirvType *spirvReturnType =
-        lowerType(astReturnType, rule, /*isRowMajor*/ llvm::None, loc);
-
-    // Go over all params and lower them.
-    std::vector<const SpirvType *> paramTypes;
-    for (auto paramType : hybridFn->getParamTypes()) {
-      const auto *spirvParamType =
-          lowerType(paramType, rule, /*isRowMajor*/ llvm::None, loc);
-      paramTypes.push_back(spvContext.getPointerType(
-          spirvParamType, spv::StorageClass::Function));
-    }
-
-    return spvContext.getFunctionType(spirvReturnType, paramTypes);
   } else if (const auto *hybridStruct = dyn_cast<HybridStructType>(type)) {
     // lower all fields of the struct.
     auto loweredFields =
@@ -553,7 +545,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
 
     // We have a runtime array of structures. So:
     // The stride of the runtime array is the size of the struct.
-    const auto *raType = spvContext.getRuntimeArrayType(structType, arrayStride);
+    const auto *raType =
+        spvContext.getRuntimeArrayType(structType, arrayStride);
     const bool isReadOnly = (name == "StructuredBuffer");
 
     // Attach matrix stride decorations if this is a *StructuredBuffer<matrix>.

+ 8 - 6
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -26,21 +26,23 @@ SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx,
   module = new (context) SpirvModule;
 }
 
-SpirvFunction *
-SpirvBuilder::beginFunction(QualType returnType, SpirvType *functionType,
-                            SourceLocation loc, llvm::StringRef funcName,
-                            bool isPrecise, SpirvFunction *func) {
+SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
+                                           llvm::ArrayRef<QualType> paramTypes,
+                                           SourceLocation loc,
+                                           llvm::StringRef funcName,
+                                           bool isPrecise,
+                                           SpirvFunction *func) {
   assert(!function && "found nested function");
   if (func) {
     function = func;
     function->setAstReturnType(returnType);
-    function->setFunctionType(functionType);
+    function->setAstParamTypes(paramTypes);
     function->setSourceLocation(loc);
     function->setFunctionName(funcName);
     function->setPrecise(isPrecise);
   } else {
     function = new (context)
-        SpirvFunction(returnType, functionType, loc, funcName, isPrecise);
+        SpirvFunction(returnType, paramTypes, loc, funcName, isPrecise);
   }
 
   return function;

+ 0 - 5
tools/clang/lib/SPIRV/SpirvContext.cpp

@@ -242,11 +242,6 @@ SpirvContext::getFunctionType(const SpirvType *ret,
   return *inserted.first;
 }
 
-HybridFunctionType *
-SpirvContext::getFunctionType(QualType ret, llvm::ArrayRef<QualType> param) {
-  return new (this) HybridFunctionType(ret, param);
-}
-
 const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
   // Create a uint RuntimeArray.
   const auto *raType =

+ 4 - 7
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1043,8 +1043,7 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
     paramTypes.push_back(valueType);
   }
 
-  auto *funcType = spvContext.getFunctionType(retType, paramTypes);
-  spvBuilder.beginFunction(retType, funcType, decl->getLocStart(), funcName,
+  spvBuilder.beginFunction(retType, paramTypes, decl->getLocStart(), funcName,
                            decl->hasAttr<HLSLPreciseAttr>(), func);
 
   if (isNonStaticMemberFn) {
@@ -10606,14 +10605,12 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   uint32_t inputArraySize = 0;
   uint32_t outputArraySize = 0;
 
-  // Construct the wrapper function signature.
-  auto *funcType = spvContext.getFunctionType(astContext.VoidTy, {});
-
   // The wrapper entry function surely does not have pre-assigned <result-id>
   // for it like other functions that got added to the work queue following
   // function calls. And the wrapper is the entry function.
-  entryFunction = spvBuilder.beginFunction(
-      astContext.VoidTy, funcType, decl->getLocStart(), decl->getName());
+  entryFunction =
+      spvBuilder.beginFunction(astContext.VoidTy, /* param QualTypes */ {},
+                               decl->getLocStart(), decl->getName());
   // Note this should happen before using declIdMapper for other tasks.
   declIdMapper.setEntryFunction(entryFunction);
 

+ 5 - 3
tools/clang/lib/SPIRV/SpirvFunction.cpp

@@ -15,11 +15,13 @@
 namespace clang {
 namespace spirv {
 
-SpirvFunction::SpirvFunction(QualType returnType, SpirvType *functionType,
+SpirvFunction::SpirvFunction(QualType returnType,
+                             llvm::ArrayRef<QualType> paramTypes,
                              SourceLocation loc, llvm::StringRef name,
                              bool isPrecise)
-    : functionId(0), astReturnType(returnType), returnType(nullptr),
-      fnType(functionType), relaxedPrecision(false), precise(isPrecise),
+    : functionId(0), astReturnType(returnType),
+      astParamTypes(paramTypes.begin(), paramTypes.end()), returnType(nullptr),
+      fnType(nullptr), relaxedPrecision(false), precise(isPrecise),
       containsAlias(false), rvalue(false), functionLoc(loc),
       functionName(name) {}