Răsfoiți Sursa

[spirv] Code refactoring (#886)

* Restructured DeclEvalInfo to contain a SpirvEvalInfo
* Refactored various variable creation methods
* Formatted code
Lei Zhang 7 ani în urmă
părinte
comite
ea76fa189e

+ 46 - 36
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -142,13 +142,13 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
 
   SemanticInfo inheritSemantic = {};
 
-  return createStageVars(
-      sigPoint, decl, /*asInput=*/false, type,
-      /*arraySize=*/0, "out.var", llvm::None, &storedValue,
-      // Write back of stage output variables in GS is manually controlled by
-      // .Append() intrinsic method, implemented in writeBackOutputStream().
-      // So noWriteBack should be set to true for GS.
-      shaderModel.IsGS(), &inheritSemantic);
+  return createStageVars(sigPoint, decl, /*asInput=*/false, type,
+                         /*arraySize=*/0, "out.var", llvm::None, &storedValue,
+                         // Write back of stage output variables in GS is
+                         // manually controlled by .Append() intrinsic method,
+                         // implemented in writeBackOutputStream(). So
+                         // noWriteBack should be set to true for GS.
+                         shaderModel.IsGS(), &inheritSemantic);
 }
 
 bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
@@ -222,15 +222,15 @@ SpirvEvalInfo DeclResultIdMapper::getDeclResultId(const NamedDecl *decl,
           cast<VarDecl>(decl)->getType(),
           // We need to set decorateLayout here to avoid creating SPIR-V
           // instructions for the current type without decorations.
-          info->layoutRule);
+          info->info.getLayoutRule());
 
       const uint32_t elemId = theBuilder.createAccessChain(
-          theBuilder.getPointerType(varType, info->storageClass),
-          info->resultId, {theBuilder.getConstantInt32(info->indexInCTBuffer)});
+          theBuilder.getPointerType(varType, info->info.getStorageClass()),
+          info->info, {theBuilder.getConstantInt32(info->indexInCTBuffer)});
 
       return SpirvEvalInfo(elemId)
-          .setStorageClass(info->storageClass)
-          .setLayoutRule(info->layoutRule);
+          .setStorageClass(info->info.getStorageClass())
+          .setLayoutRule(info->info.getLayoutRule());
     } else {
       return *info;
     }
@@ -243,27 +243,31 @@ SpirvEvalInfo DeclResultIdMapper::getDeclResultId(const NamedDecl *decl,
   return 0;
 }
 
-uint32_t DeclResultIdMapper::createFnParam(uint32_t paramType,
-                                           const ParmVarDecl *param) {
-  const uint32_t id = theBuilder.addFnParam(paramType, param->getName());
-  astDecls[param] = {id, spv::StorageClass::Function};
+uint32_t DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
+  const uint32_t type = typeTranslator.translateType(param->getType());
+  const uint32_t ptrType =
+      theBuilder.getPointerType(type, spv::StorageClass::Function);
+  const uint32_t id = theBuilder.addFnParam(ptrType, param->getName());
+  astDecls[param] = SpirvEvalInfo(id);
 
   return id;
 }
 
-uint32_t DeclResultIdMapper::createFnVar(uint32_t varType, const VarDecl *var,
+uint32_t DeclResultIdMapper::createFnVar(const VarDecl *var,
                                          llvm::Optional<uint32_t> init) {
-  const uint32_t id = theBuilder.addFnVar(varType, var->getName(), init);
-  astDecls[var] = {id, spv::StorageClass::Function};
+  const uint32_t type = typeTranslator.translateType(var->getType());
+  const uint32_t id = theBuilder.addFnVar(type, var->getName(), init);
+  astDecls[var] = SpirvEvalInfo(id);
 
   return id;
 }
 
-uint32_t DeclResultIdMapper::createFileVar(uint32_t varType, const VarDecl *var,
+uint32_t DeclResultIdMapper::createFileVar(const VarDecl *var,
                                            llvm::Optional<uint32_t> init) {
-  const uint32_t id = theBuilder.addModuleVar(
-      varType, spv::StorageClass::Private, var->getName(), init);
-  astDecls[var] = {id, spv::StorageClass::Private};
+  const uint32_t type = typeTranslator.translateType(var->getType());
+  const uint32_t id = theBuilder.addModuleVar(type, spv::StorageClass::Private,
+                                              var->getName(), init);
+  astDecls[var] = SpirvEvalInfo(id).setStorageClass(spv::StorageClass::Private);
 
   return id;
 }
@@ -295,7 +299,8 @@ uint32_t DeclResultIdMapper::createExternVar(const VarDecl *var) {
   const auto varType = typeTranslator.translateType(var->getType(), rule);
   const uint32_t id = theBuilder.addModuleVar(varType, storageClass,
                                               var->getName(), llvm::None);
-  astDecls[var] = {id, storageClass, rule};
+  astDecls[var] =
+      SpirvEvalInfo(id).setStorageClass(storageClass).setLayoutRule(rule);
 
   const auto *regAttr = getResourceBinding(var);
   const auto *bindingAttr = var->getAttr<VKBindingAttr>();
@@ -396,9 +401,11 @@ uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
   int index = 0;
   for (const auto *subDecl : decl->decls()) {
     const auto *varDecl = cast<VarDecl>(subDecl);
-    astDecls[varDecl] = {bufferVar, spv::StorageClass::Uniform,
-                         decl->isCBuffer() ? LayoutRule::GLSLStd140
-                                           : LayoutRule::GLSLStd430,
+    astDecls[varDecl] = {SpirvEvalInfo(bufferVar)
+                             .setStorageClass(spv::StorageClass::Uniform)
+                             .setLayoutRule(decl->isCBuffer()
+                                                ? LayoutRule::GLSLStd140
+                                                : LayoutRule::GLSLStd430),
                          index++};
   }
   resourceVars.emplace_back(
@@ -423,9 +430,11 @@ uint32_t DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
       recordType->getDecl(), usageKind, structName, decl->getName());
 
   // We register the VarDecl here.
-  astDecls[decl] = {bufferVar, spv::StorageClass::Uniform,
-                    context->isCBuffer() ? LayoutRule::GLSLStd140
-                                         : LayoutRule::GLSLStd430};
+  astDecls[decl] =
+      SpirvEvalInfo(bufferVar)
+          .setStorageClass(spv::StorageClass::Uniform)
+          .setLayoutRule(context->isCBuffer() ? LayoutRule::GLSLStd140
+                                              : LayoutRule::GLSLStd430);
   resourceVars.emplace_back(
       bufferVar, ResourceVar::Category::Other, getResourceBinding(context),
       decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
@@ -444,8 +453,9 @@ uint32_t DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
       decl->getName());
 
   // Register the VarDecl
-  astDecls[decl] = {var, spv::StorageClass::PushConstant,
-                    LayoutRule::GLSLStd430};
+  astDecls[decl] = SpirvEvalInfo(var)
+                       .setStorageClass(spv::StorageClass::PushConstant)
+                       .setLayoutRule(LayoutRule::GLSLStd430);
   // Do not push this variable into resourceVars since it does not need
   // descriptor set.
 
@@ -454,10 +464,10 @@ uint32_t DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
 
 uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
   if (const auto *info = getDeclSpirvInfo(fn))
-    return info->resultId;
+    return info->info;
 
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
-  astDecls[fn] = {id, spv::StorageClass::Function};
+  astDecls[fn] = SpirvEvalInfo(id);
 
   return id;
 }
@@ -473,8 +483,8 @@ uint32_t DeclResultIdMapper::createCounterVar(const ValueDecl *decl) {
   const auto *info = getDeclSpirvInfo(decl);
   const uint32_t counterType = typeTranslator.getACSBufferCounter();
   const std::string counterName = "counter.var." + decl->getName().str();
-  const uint32_t counterId =
-      theBuilder.addModuleVar(counterType, info->storageClass, counterName);
+  const uint32_t counterId = theBuilder.addModuleVar(
+      counterType, info->info.getStorageClass(), counterName);
 
   resourceVars.emplace_back(counterId, ResourceVar::Category::Other,
                             getResourceBinding(decl),

+ 11 - 20
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -178,16 +178,14 @@ public:
 
   /// \brief Creates a function-scope paramter in the current function and
   /// returns its <result-id>.
-  uint32_t createFnParam(uint32_t paramType, const ParmVarDecl *param);
+  uint32_t createFnParam(const ParmVarDecl *param);
 
   /// \brief Creates a function-scope variable in the current function and
   /// returns its <result-id>.
-  uint32_t createFnVar(uint32_t varType, const VarDecl *variable,
-                       llvm::Optional<uint32_t> init);
+  uint32_t createFnVar(const VarDecl *var, llvm::Optional<uint32_t> init);
 
   /// \brief Creates a file-scope variable and returns its <result-id>.
-  uint32_t createFileVar(uint32_t varType, const VarDecl *variable,
-                         llvm::Optional<uint32_t> init);
+  uint32_t createFileVar(const VarDecl *var, llvm::Optional<uint32_t> init);
 
   /// \brief Creates an external-visible variable and returns its <result-id>.
   uint32_t createExternVar(const VarDecl *var);
@@ -222,23 +220,16 @@ public:
 private:
   /// The struct containing SPIR-V information of a AST Decl.
   struct DeclSpirvInfo {
-    DeclSpirvInfo(uint32_t result = 0,
-                  spv::StorageClass sc = spv::StorageClass::Function,
-                  LayoutRule lr = LayoutRule::Void, int indexInCTB = -1)
-        : resultId(result), storageClass(sc), layoutRule(lr),
-          indexInCTBuffer(indexInCTB) {}
+    /// Default constructor to satisfy DenseMap
+    DeclSpirvInfo() : info(0), indexInCTBuffer(-1) {}
+
+    DeclSpirvInfo(const SpirvEvalInfo &info_, int index = -1)
+        : info(info_), indexInCTBuffer(index) {}
 
     /// Implicit conversion to SpirvEvalInfo.
-    operator SpirvEvalInfo() const {
-      return SpirvEvalInfo(resultId)
-          .setStorageClass(storageClass)
-          .setLayoutRule(layoutRule);
-    }
-
-    uint32_t resultId;
-    spv::StorageClass storageClass;
-    /// Layout rule for this decl.
-    LayoutRule layoutRule;
+    operator SpirvEvalInfo() const { return info; }
+
+    SpirvEvalInfo info;
     /// Value >= 0 means that this decl is a VarDecl inside a cbuffer/tbuffer
     /// and this is the index; value < 0 means this is just a standalone decl.
     int indexInCTBuffer;

+ 22 - 25
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -168,6 +168,12 @@ const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
   return nullptr;
 }
 
+/// Returns true if the given VarDecl will be translated into a SPIR-V variable
+/// in Private or Function storage class.
+inline bool isNonExternalVar(const VarDecl *var) {
+  return !var->isExternallyVisible() || var->isStaticDataMember();
+}
+
 /// Returns the referenced variable's DeclContext if the given expr is
 /// a DeclRefExpr referencing a ConstantBuffer/TextureBuffer. Otherwise,
 /// returns nullptr.
@@ -726,8 +732,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
   // Create all parameters.
   for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
     const ParmVarDecl *paramDecl = decl->getParamDecl(i);
-    (void)declIdMapper.createFnParam(paramTypes[i + isNonStaticMemberFn],
-                                     paramDecl);
+    (void)declIdMapper.createFnParam(paramDecl);
   }
 
   if (decl->hasBody()) {
@@ -849,20 +854,15 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
   // File scope variables (static "global" and "local" variables) belongs to
   // the Private storage class, while function scope variables (normal "local"
   // variables) belongs to the Function storage class.
-  if (!decl->isExternallyVisible() || decl->isStaticDataMember()) {
-    // Note: cannot move varType outside of this scope because it generates
-    // SPIR-V types without decorations, while external visible variable should
-    // have SPIR-V type with decorations.
-    const uint32_t varType = typeTranslator.translateType(decl->getType());
-
+  if (isNonExternalVar(decl)) {
     // We already know the variable is not externally visible here. If it does
     // not have local storage, it should be file scope variable.
     const bool isFileScopeVar = !decl->hasLocalStorage();
 
     if (isFileScopeVar)
-      varId = declIdMapper.createFileVar(varType, decl, llvm::None);
+      varId = declIdMapper.createFileVar(decl, llvm::None);
     else
-      varId = declIdMapper.createFnVar(varType, decl, llvm::None);
+      varId = declIdMapper.createFnVar(decl, llvm::None);
 
     // Emit OpStore to initialize the variable
     // TODO: revert back to use OpVariable initializer
@@ -2276,7 +2276,7 @@ uint32_t SPIRVEmitter::processTextureGatherRGBACmpRGBA(
   if (numOffsetArgs == 1) {
     // The offset arg is not optional.
     handleOffsetInMethodCall(expr, 2 + isCmp, &constOffset, &varOffset);
-  } else if(numOffsetArgs == 4) {
+  } else if (numOffsetArgs == 4) {
     const auto offset0 = tryToEvaluateAsConst(expr->getArg(2 + isCmp));
     const auto offset1 = tryToEvaluateAsConst(expr->getArg(3 + isCmp));
     const auto offset2 = tryToEvaluateAsConst(expr->getArg(4 + isCmp));
@@ -2337,7 +2337,7 @@ uint32_t SPIRVEmitter::processTextureGatherCmp(const CXXMemberCallExpr *expr) {
   const uint32_t coordinate = doExpr(expr->getArg(1));
   const uint32_t comparator = doExpr(expr->getArg(2));
   uint32_t constOffset = 0, varOffset = 0;
-  if(hasOffsetArg)
+  if (hasOffsetArg)
     handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
 
   const auto retType = typeTranslator.translateType(callee->getReturnType());
@@ -2813,7 +2813,7 @@ uint32_t SPIRVEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr,
   const uint32_t coordinate = doExpr(expr->getArg(1));
   // .Sample()/.Gather() may have a third optional paramter for offset.
   uint32_t constOffset = 0, varOffset = 0;
-  if(hasOffsetArg)
+  if (hasOffsetArg)
     handleOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
 
   const auto retType =
@@ -3296,20 +3296,19 @@ spv::Op SPIRVEmitter::translateOp(BinaryOperator::Opcode op, QualType type) {
   const bool isFloatType = isFloatOrVecMatOfFloatType(type);
 
 #define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp)                      \
-  \
-case BO_##kind : {                                                             \
+                                                                               \
+  case BO_##kind: {                                                            \
     if (isSintType || isUintType) {                                            \
       return spv::Op::Op##intBinOp;                                            \
     }                                                                          \
     if (isFloatType) {                                                         \
       return spv::Op::Op##floatBinOp;                                          \
     }                                                                          \
-  }                                                                            \
-  break
+  } break
 
 #define BIN_OP_CASE_SINT_UINT_FLOAT(kind, sintBinOp, uintBinOp, floatBinOp)    \
-  \
-case BO_##kind : {                                                             \
+                                                                               \
+  case BO_##kind: {                                                            \
     if (isSintType) {                                                          \
       return spv::Op::Op##sintBinOp;                                           \
     }                                                                          \
@@ -3319,20 +3318,18 @@ case BO_##kind : {                                                             \
     if (isFloatType) {                                                         \
       return spv::Op::Op##floatBinOp;                                          \
     }                                                                          \
-  }                                                                            \
-  break
+  } break
 
 #define BIN_OP_CASE_SINT_UINT(kind, sintBinOp, uintBinOp)                      \
-  \
-case BO_##kind : {                                                             \
+                                                                               \
+  case BO_##kind: {                                                            \
     if (isSintType) {                                                          \
       return spv::Op::Op##sintBinOp;                                           \
     }                                                                          \
     if (isUintType) {                                                          \
       return spv::Op::Op##uintBinOp;                                           \
     }                                                                          \
-  }                                                                            \
-  break
+  } break
 
   switch (op) {
   case BO_EQ: {