فهرست منبع

[spirv] Clean up usage of FunctionType.

Ehsan Nasiri 6 سال پیش
والد
کامیت
429b1bfdee

+ 1 - 4
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -129,9 +129,6 @@ struct FunctionTypeMapInfo {
 /// context is deleted. Therefore, this context should outlive the usages of the
 /// context is deleted. Therefore, this context should outlive the usages of the
 /// the SPIR-V entities allocated in memory.
 /// the SPIR-V entities allocated in memory.
 class SpirvContext {
 class SpirvContext {
-  friend class SpirvBuilder;
-  friend class EmitTypeHandler;
-
 public:
 public:
   SpirvContext();
   SpirvContext();
   ~SpirvContext() = default;
   ~SpirvContext() = default;
@@ -197,7 +194,7 @@ public:
   FunctionType *getFunctionType(const SpirvType *ret,
   FunctionType *getFunctionType(const SpirvType *ret,
                                 llvm::ArrayRef<const SpirvType *> param);
                                 llvm::ArrayRef<const SpirvType *> param);
   HybridFunctionType *getFunctionType(QualType ret,
   HybridFunctionType *getFunctionType(QualType ret,
-                                      llvm::ArrayRef<const SpirvType *> param);
+                                      llvm::ArrayRef<QualType> param);
 
 
   const StructType *getByteAddressBufferType(bool isWritable);
   const StructType *getByteAddressBufferType(bool isWritable);
   const StructType *getACSBufferCounterType();
   const StructType *getACSBufferCounterType();

+ 0 - 2
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -63,8 +63,6 @@ public:
   /// type in the current function and returns its pointer.
   /// type in the current function and returns its pointer.
   SpirvFunctionParameter *addFnParam(QualType ptrType, SourceLocation,
   SpirvFunctionParameter *addFnParam(QualType ptrType, SourceLocation,
                                      llvm::StringRef name = "");
                                      llvm::StringRef name = "");
-  SpirvFunctionParameter *addFnParam(const SpirvType *ptrType, SourceLocation,
-                                     llvm::StringRef name = "");
 
 
   /// \brief Creates a local variable of the given type in the current
   /// \brief Creates a local variable of the given type in the current
   /// function and returns it.
   /// function and returns it.

+ 11 - 10
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -337,6 +337,7 @@ private:
   StructInterfaceType interfaceType;
   StructInterfaceType interfaceType;
 };
 };
 
 
+/// Represents a SPIR-V pointer type.
 class SpirvPointerType : public SpirvType {
 class SpirvPointerType : public SpirvType {
 public:
 public:
   SpirvPointerType(const SpirvType *pointee, spv::StorageClass sc)
   SpirvPointerType(const SpirvType *pointee, spv::StorageClass sc)
@@ -356,11 +357,11 @@ private:
   spv::StorageClass storageClass;
   spv::StorageClass storageClass;
 };
 };
 
 
+/// Represents a SPIR-V function type. None of the parameters nor the return
+/// type is allowed to be a hybrid type.
 class FunctionType : public SpirvType {
 class FunctionType : public SpirvType {
 public:
 public:
-  FunctionType(const SpirvType *ret, llvm::ArrayRef<const SpirvType *> param)
-      : SpirvType(TK_Function), returnType(ret),
-        paramTypes(param.begin(), param.end()) {}
+  FunctionType(const SpirvType *ret, llvm::ArrayRef<const SpirvType *> param);
 
 
   static bool classof(const SpirvType *t) {
   static bool classof(const SpirvType *t) {
     return t->getKind() == TK_Function;
     return t->getKind() == TK_Function;
@@ -484,8 +485,8 @@ private:
 // This class can be extended to also accept QualType vector as param types.
 // This class can be extended to also accept QualType vector as param types.
 class HybridFunctionType : public HybridType {
 class HybridFunctionType : public HybridType {
 public:
 public:
-  HybridFunctionType(QualType ret, llvm::ArrayRef<const SpirvType *> param)
-      : HybridType(TK_HybridFunction), astReturnType(ret),
+  HybridFunctionType(QualType ret, llvm::ArrayRef<QualType> param)
+      : HybridType(TK_HybridFunction), returnType(ret),
         paramTypes(param.begin(), param.end()) {}
         paramTypes(param.begin(), param.end()) {}
 
 
   static bool classof(const SpirvType *t) {
   static bool classof(const SpirvType *t) {
@@ -493,15 +494,15 @@ public:
   }
   }
 
 
   bool operator==(const HybridFunctionType &that) const {
   bool operator==(const HybridFunctionType &that) const {
-    return astReturnType == that.astReturnType && paramTypes == that.paramTypes;
+    return returnType == that.returnType && paramTypes == that.paramTypes;
   }
   }
 
 
-  QualType getAstReturnType() const { return astReturnType; }
-  llvm::ArrayRef<const SpirvType *> getParamTypes() const { return paramTypes; }
+  QualType getReturnType() const { return returnType; }
+  llvm::ArrayRef<QualType> getParamTypes() const { return paramTypes; }
 
 
 private:
 private:
-  QualType astReturnType;
-  llvm::SmallVector<const SpirvType *, 8> paramTypes;
+  QualType returnType;
+  llvm::SmallVector<QualType, 8> paramTypes;
 };
 };
 
 
 } // end namespace spirv
 } // end namespace spirv

+ 1 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -535,11 +535,9 @@ SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl) {
 SpirvFunctionParameter *
 SpirvFunctionParameter *
 DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
 DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
   const auto type = getTypeOrFnRetType(param);
   const auto type = getTypeOrFnRetType(param);
-  const auto *ptrType =
-      spvContext.getPointerType(type, spv::StorageClass::Function);
   const auto loc = param->getLocation();
   const auto loc = param->getLocation();
   SpirvFunctionParameter *fnParamInstr =
   SpirvFunctionParameter *fnParamInstr =
-      spvBuilder.addFnParam(ptrType, loc, param->getName());
+      spvBuilder.addFnParam(type, loc, param->getName());
 
 
   bool isAlias = false;
   bool isAlias = false;
   (void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);
   (void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);

+ 15 - 37
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -83,7 +83,8 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
     }
     }
     break;
     break;
   }
   }
-  // Variables must have a pointer type.
+  // Variables and function parameters must have a pointer type.
+  case spv::Op::OpFunctionParameter:
   case spv::Op::OpVariable: {
   case spv::Op::OpVariable: {
     const SpirvType *pointerType =
     const SpirvType *pointerType =
         spvContext.getPointerType(resultType, instr->getStorageClass());
         spvContext.getPointerType(resultType, instr->getStorageClass());
@@ -141,18 +142,17 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
     return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
     return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
   } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(type)) {
   } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(type)) {
     // Lower the return type.
     // Lower the return type.
-    const QualType astReturnType = hybridFn->getAstReturnType();
+    const QualType astReturnType = hybridFn->getReturnType();
     const SpirvType *spirvReturnType =
     const SpirvType *spirvReturnType =
         lowerType(astReturnType, rule, /*isRowMajor*/ llvm::None, loc);
         lowerType(astReturnType, rule, /*isRowMajor*/ llvm::None, loc);
 
 
-    // Go over all params. If any of them is hybrid, lower it.
+    // Go over all params and lower them.
     std::vector<const SpirvType *> paramTypes;
     std::vector<const SpirvType *> paramTypes;
-    for (auto *paramType : hybridFn->getParamTypes()) {
-      if (const auto *hybridParam = dyn_cast<HybridType>(paramType)) {
-        paramTypes.push_back(lowerType(hybridParam, rule, loc));
-      } else {
-        paramTypes.push_back(paramType);
-      }
+    for (auto paramType : hybridFn->getParamTypes()) {
+      const auto *spirvParamType =
+          lowerType(paramType, rule, /*isRowMajor*/ llvm::None, loc);
+      paramTypes.push_back(spvContext.getPointerType(
+          spirvParamType, spv::StorageClass::Function));
     }
     }
 
 
     return spvContext.getFunctionType(spirvReturnType, paramTypes);
     return spvContext.getFunctionType(spirvReturnType, paramTypes);
@@ -169,9 +169,12 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
   // sampledType in image types can only be numberical type.
   // sampledType in image types can only be numberical type.
   // Sampler types cannot be further lowered.
   // Sampler types cannot be further lowered.
   // SampledImage types cannot be further lowered.
   // SampledImage types cannot be further lowered.
+  // FunctionType is not allowed to contain hybrid parameters or return type.
+  // StructType is not allowed to contain any hybrid types.
   else if (isa<VoidType>(type) || isa<ScalarType>(type) ||
   else if (isa<VoidType>(type) || isa<ScalarType>(type) ||
            isa<MatrixType>(type) || isa<ImageType>(type) ||
            isa<MatrixType>(type) || isa<ImageType>(type) ||
-           isa<SamplerType>(type) || isa<SampledImageType>(type)) {
+           isa<SamplerType>(type) || isa<SampledImageType>(type) ||
+           isa<FunctionType>(type) || isa<StructType>(type)) {
     return type;
     return type;
   }
   }
   // Vectors could contain a hybrid type
   // Vectors could contain a hybrid type
@@ -204,11 +207,6 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
       return raType;
       return raType;
     return spvContext.getRuntimeArrayType(loweredElemType, raType->getStride());
     return spvContext.getRuntimeArrayType(loweredElemType, raType->getStride());
   }
   }
-  // Struct types could contain a hybrid type
-  else if (const auto *structType = dyn_cast<StructType>(type)) {
-    // Struct types can not contain hybrid types.
-    return structType;
-  }
   // Pointer types could point to a hybrid type.
   // Pointer types could point to a hybrid type.
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
     const auto *loweredPointee =
     const auto *loweredPointee =
@@ -220,26 +218,6 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
     return spvContext.getPointerType(loweredPointee,
     return spvContext.getPointerType(loweredPointee,
                                      ptrType->getStorageClass());
                                      ptrType->getStorageClass());
   }
   }
-  // Function types may have a parameter or return type that is hybrid.
-  else if (const auto *fnType = dyn_cast<FunctionType>(type)) {
-    const auto *loweredRetType = lowerType(fnType->getReturnType(), rule, loc);
-    bool wasLowered = fnType->getReturnType() != loweredRetType;
-    llvm::SmallVector<const SpirvType *, 4> loweredParams;
-    const auto &paramTypes = fnType->getParamTypes();
-    for (auto *paramType : paramTypes) {
-      const auto *loweredParamType = lowerType(paramType, rule, loc);
-      loweredParams.push_back(loweredParamType);
-      if (loweredParamType != paramType) {
-        wasLowered = true;
-      }
-    }
-    // If the function type didn't include any hybrid types, return itself.
-    if (!wasLowered) {
-      return fnType;
-    }
-
-    return spvContext.getFunctionType(loweredRetType, loweredParams);
-  }
 
 
   llvm_unreachable("lowering of hybrid type not implemented");
   llvm_unreachable("lowering of hybrid type not implemented");
 }
 }
@@ -319,10 +297,10 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
         // LowerTypeVisitor is invoked. We should error out if we encounter a
         // LowerTypeVisitor is invoked. We should error out if we encounter a
         // literal type.
         // literal type.
         case BuiltinType::LitInt:
         case BuiltinType::LitInt:
-          //emitError("found literal int type when lowering types", srcLoc);
+          // emitError("found literal int type when lowering types", srcLoc);
           return spvContext.getUIntType(64);
           return spvContext.getUIntType(64);
         case BuiltinType::LitFloat: {
         case BuiltinType::LitFloat: {
-          //emitError("found literal float type when lowering types", srcLoc);
+          // emitError("found literal float type when lowering types", srcLoc);
           return spvContext.getFloatType(64);
           return spvContext.getFloatType(64);
 
 
         default:
         default:

+ 1 - 2
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -240,8 +240,7 @@ SpirvContext::getFunctionType(const SpirvType *ret,
 }
 }
 
 
 HybridFunctionType *
 HybridFunctionType *
-SpirvContext::getFunctionType(QualType ret,
-                              llvm::ArrayRef<const SpirvType *> param) {
+SpirvContext::getFunctionType(QualType ret, llvm::ArrayRef<QualType> param) {
   return new (this) HybridFunctionType(ret, param);
   return new (this) HybridFunctionType(ret, param);
 }
 }
 
 

+ 4 - 9
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -964,7 +964,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
       declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
       declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
 
 
   // Construct the function signature.
   // Construct the function signature.
-  llvm::SmallVector<const SpirvType *, 4> paramTypes;
+  llvm::SmallVector<QualType, 4> paramTypes;
 
 
   bool isNonStaticMemberFn = false;
   bool isNonStaticMemberFn = false;
   if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
   if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
@@ -975,18 +975,14 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
       // object on which we are invoking this method.
       // object on which we are invoking this method.
       const QualType valueType =
       const QualType valueType =
           memberFn->getThisType(astContext)->getPointeeType();
           memberFn->getThisType(astContext)->getPointeeType();
-      const SpirvType *ptrType =
-          spvContext.getPointerType(valueType, spv::StorageClass::Function);
-      paramTypes.push_back(ptrType);
+      paramTypes.push_back(valueType);
     }
     }
   }
   }
 
 
   for (const auto *param : decl->params()) {
   for (const auto *param : decl->params()) {
     const QualType valueType =
     const QualType valueType =
         declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
         declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
-    const SpirvType *ptrType =
-        spvContext.getPointerType(valueType, spv::StorageClass::Function);
-    paramTypes.push_back(ptrType);
+    paramTypes.push_back(valueType);
   }
   }
 
 
   auto *funcType = spvContext.getFunctionType(retType, paramTypes);
   auto *funcType = spvContext.getFunctionType(retType, paramTypes);
@@ -8972,8 +8968,7 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   uint32_t outputArraySize = 0;
   uint32_t outputArraySize = 0;
 
 
   // Construct the wrapper function signature.
   // Construct the wrapper function signature.
-  const SpirvType *voidType = spvContext.getVoidType();
-  FunctionType *funcType = spvContext.getFunctionType(voidType, {});
+  auto *funcType = spvContext.getFunctionType(astContext.VoidTy, {});
 
 
   // The wrapper entry function surely does not have pre-assigned <result-id>
   // The wrapper entry function surely does not have pre-assigned <result-id>
   // for it like other functions that got added to the work queue following
   // for it like other functions that got added to the work queue following

+ 0 - 12
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -56,18 +56,6 @@ SpirvFunctionParameter *SpirvBuilder::addFnParam(QualType ptrType,
   return param;
   return param;
 }
 }
 
 
-SpirvFunctionParameter *SpirvBuilder::addFnParam(const SpirvType *ptrType,
-                                                 SourceLocation loc,
-                                                 llvm::StringRef name) {
-  assert(function && "found detached parameter");
-  auto *param =
-      new (context) SpirvFunctionParameter(/*QualType*/ {}, /*id*/ 0, loc);
-  param->setResultType(ptrType);
-  param->setDebugName(name);
-  function->addParameter(param);
-  return param;
-}
-
 SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
 SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
                                       llvm::StringRef name,
                                       llvm::StringRef name,
                                       SpirvInstruction *init) {
                                       SpirvInstruction *init) {

+ 10 - 0
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -243,5 +243,15 @@ bool HybridStructType::operator==(const HybridStructType &that) const {
          readOnly == that.readOnly && interfaceType == that.interfaceType;
          readOnly == that.readOnly && interfaceType == that.interfaceType;
 }
 }
 
 
+FunctionType::FunctionType(const SpirvType *ret,
+                           llvm::ArrayRef<const SpirvType *> param)
+    : SpirvType(TK_Function), returnType(ret),
+      paramTypes(param.begin(), param.end()) {
+  // Make sure
+  assert(!isa<HybridType>(ret));
+  for (auto *paramType : param)
+    assert(!isa<HybridType>(param));
+}
+
 } // namespace spirv
 } // namespace spirv
 } // namespace clang
 } // namespace clang