Ver código fonte

[spirv] Call SPIR-V legalization passes from SPIRV-Tools (#655)

When seeing opaque types within structs in function parameter,
function return, and variable definition, invoke SPIRV-Tools
legalization passes.

Also refreshed external projects
Lei Zhang 8 anos atrás
pai
commit
62629d982a

+ 1 - 1
external/CMakeLists.txt

@@ -28,7 +28,7 @@ if (${ENABLE_SPIRV_CODEGEN})
   if (NOT TARGET SPIRV-Tools)
     message(FATAL_ERROR "SPIRV-Tools was not found - required for SPIR-V codegen")
   else()
-    set(SPIRV_TOOLS_INCLUDE_DIR ${SPIRV-Tools_SOURCE_DIR}/include PARENT_SCOPE)
+    set(SPIRV_TOOLS_INCLUDE_DIR ${spirv-tools_SOURCE_DIR}/include PARENT_SCOPE)
   endif()
 
   set(SPIRV_DEP_TARGETS

+ 1 - 1
external/SPIRV-Tools

@@ -1 +1 @@
-Subproject commit 768d9b42d38c7562bd42dbc29b22c61046848ee8
+Subproject commit dcf42433a63c9779cf1269a4e5f1caea3a887b63

+ 1 - 1
external/googletest

@@ -1 +1 @@
-Subproject commit b7e8a993b4125d1083cb431d91407d8ee4dba2ad
+Subproject commit f1a87d73fc604c5ab8fbb0cc6fa9a86ffd845530

+ 1 - 1
external/re2

@@ -1 +1 @@
-Subproject commit 971f917a35125c6dcfabf099d5fe9a1e5c383265
+Subproject commit d2b639578a17f459ff90f4bf9b904f66c3ebb93d

+ 2 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -24,6 +24,8 @@ add_clang_library(clangSPIRV
   clangBasic
   clangFrontend
   clangLex
+  SPIRV-Tools-opt
   )
 
 target_include_directories(clangSPIRV PUBLIC ${SPIRV_HEADER_INCLUDE_DIR})
+target_include_directories(clangSPIRV PRIVATE ${SPIRV_TOOLS_INCLUDE_DIR})

+ 52 - 9
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -14,6 +14,7 @@
 #include "SPIRVEmitter.h"
 
 #include "dxc/HlslIntrinsicOp.h"
+#include "spirv-tools/optimizer.hpp"
 #include "llvm/ADT/StringExtras.h"
 
 #include "InitListHandler.h"
@@ -148,15 +149,33 @@ const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
   return nullptr;
 }
 
-/// \brief Returns the statement that is the immediate parent AST node of the
-/// given statement. Returns nullptr if there are no parents nodes.
-const Stmt *getImmediateParent(ASTContext &astContext, const Stmt *stmt) {
-  const auto &parents = astContext.getParents(*stmt);
-  return parents.empty() ? nullptr : parents[0].get<Stmt>();
-}
+bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
+  spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
+
+  optimizer.SetMessageConsumer(
+      [messages](spv_message_level_t /*level*/, const char * /*source*/,
+                 const spv_position_t & /*position*/,
+                 const char *message) { *messages += message; });
+
+  optimizer.RegisterPass(spvtools::CreateInlineExhaustivePass());
+  optimizer.RegisterPass(spvtools::CreateLocalAccessChainConvertPass());
+  optimizer.RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass());
+  optimizer.RegisterPass(spvtools::CreateLocalSingleStoreElimPass());
+  optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
+  optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
+
+  optimizer.RegisterPass(spvtools::CreateDeadBranchElimPass());
+  optimizer.RegisterPass(spvtools::CreateBlockMergePass());
+  optimizer.RegisterPass(spvtools::CreateLocalMultiStoreElimPass());
+  optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
+  optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
 
-bool isLoopStmt(const Stmt *stmt) {
-  return isa<ForStmt>(stmt) || isa<WhileStmt>(stmt) || isa<DoStmt>(stmt);
+  optimizer.RegisterPass(spvtools::CreateEliminateDeadFunctionsPass());
+  optimizer.RegisterPass(spvtools::CreateEliminateDeadConstantPass());
+
+  optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
+
+  return optimizer.Run(module->data(), module->size(), module);
 }
 
 } // namespace
@@ -171,7 +190,7 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
       theContext(), theBuilder(&theContext),
       declIdMapper(shaderModel, astContext, theBuilder, diags, spirvOptions),
       typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
-      curFunction(nullptr), curThis(0) {
+      curFunction(nullptr), curThis(0), needsLegalization(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0") << shaderModel.GetName();
 }
@@ -230,6 +249,19 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
 
   // Output the constructed module.
   std::vector<uint32_t> m = theBuilder.takeModule();
+
+  const auto optLevel = theCompilerInstance.getCodeGenOpts().OptimizationLevel;
+  if (needsLegalization || optLevel > 0) {
+    if (needsLegalization && optLevel == 0)
+      emitWarning("-O0 ignored since SPIR-V legalization required");
+
+    std::string messages;
+    if (!spirvToolsOptimize(&m, &messages)) {
+      emitFatalError("failed to legalize/optimize SPIR-V: %0") << messages;
+      return;
+    }
+  }
+
   theCompilerInstance.getOutStream()->write(
       reinterpret_cast<const char *>(m.data()), m.size() * 4);
 }
@@ -425,6 +457,10 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
     funcId = declIdMapper.getDeclResultId(decl);
   }
 
+  if (!needsLegalization &&
+      TypeTranslator::isOpaqueStructType(decl->getReturnType()))
+    needsLegalization = true;
+
   const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
 
   // Construct the function signature.
@@ -454,6 +490,10 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
     const uint32_t ptrType =
         theBuilder.getPointerType(valueType, spv::StorageClass::Function);
     paramTypes.push_back(ptrType);
+
+    if (!needsLegalization &&
+        TypeTranslator::isOpaqueStructType(param->getType()))
+      needsLegalization = true;
   }
 
   const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
@@ -555,6 +595,9 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
   if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
     theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
   }
+
+  if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType()))
+    needsLegalization = true;
 }
 
 spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) {

+ 18 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -500,6 +500,15 @@ private:
       const CXXMemberCallExpr *);
 
 private:
+  /// \brief Wrapper method to create a fatal error message and report it
+  /// in the diagnostic engine associated with this consumer.
+  template <unsigned N>
+  DiagnosticBuilder emitFatalError(const char (&message)[N]) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Fatal, message);
+    return diags.Report(diagId);
+  }
+
   /// \brief Wrapper method to create an error message and report it
   /// in the diagnostic engine associated with this consumer.
   template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
@@ -548,6 +557,15 @@ private:
   /// The SPIR-V function parameter for the current this object.
   uint32_t curThis;
 
+  /// Whether the translated SPIR-V binary needs legalization.
+  ///
+  /// The following cases will require legalization:
+  /// * Opaque types (textures, samplers) within structs
+  ///
+  /// If this is true, SPIRV-Tools legalization passes will be executed after
+  /// the translation to legalize the generated SPIR-V binary.
+  bool needsLegalization;
+
   /// Global variables that should be initialized once at the begining of the
   /// entry function.
   llvm::SmallVector<const VarDecl *, 4> toInitGloalVars;

+ 47 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -61,6 +61,53 @@ bool TypeTranslator::isRelaxedPrecisionType(QualType type) {
   return false;
 }
 
+bool TypeTranslator::isOpaqueType(QualType type) {
+  if (const auto *recordType = type->getAs<RecordType>()) {
+    const auto name = recordType->getDecl()->getName();
+
+    if (name == "Texture1D" || name == "RWTexture1D")
+      return true;
+    if (name == "Texture2D" || name == "RWTexture2D")
+      return true;
+    if (name == "Texture2DMS" || name == "RWTexture2DMS")
+      return true;
+    if (name == "Texture3D" || name == "RWTexture3D")
+      return true;
+    if (name == "TextureCube" || name == "RWTextureCube")
+      return true;
+
+    if (name == "Texture1DArray" || name == "RWTexture1DArray")
+      return true;
+    if (name == "Texture2DArray" || name == "RWTexture2DArray")
+      return true;
+    if (name == "Texture2DMSArray" || name == "RWTexture2DMSArray")
+      return true;
+    if (name == "TextureCubeArray" || name == "RWTextureCubeArray")
+      return true;
+
+    if (name == "Buffer" || name == "RWBuffer")
+      return true;
+
+    if (name == "SamplerState" || name == "SamplerComparisonState")
+      return true;
+  }
+  return false;
+}
+
+bool TypeTranslator::isOpaqueStructType(QualType type) {
+  if (isOpaqueType(type))
+    return false;
+
+  if (const auto *recordType = type->getAs<RecordType>())
+    for (const auto *field : recordType->getDecl()->decls())
+      if (const auto *fieldDecl = dyn_cast<FieldDecl>(field))
+        if (isOpaqueType(fieldDecl->getType()) ||
+            isOpaqueStructType(fieldDecl->getType()))
+          return true;
+
+  return false;
+}
+
 uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
                                        bool isRowMajor) {
   // We can only apply row_major to matrices or arrays of matrices.

+ 8 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -133,6 +133,14 @@ public:
   /// operated on with a relaxed precision.
   static bool isRelaxedPrecisionType(QualType);
 
+  /// Returns true if the given type will be translated into a SPIR-V image,
+  /// sampler or struct containing images or samplers.
+  static bool isOpaqueType(QualType type);
+
+  /// Returns true if the given type is a struct type who has an opaque field
+  /// (in a recursive away).
+  static bool isOpaqueStructType(QualType tye);
+
   /// \brief Returns the the element type for the given scalar/vector/matrix
   /// type. Returns empty QualType for other cases.
   QualType getElementType(QualType type);

+ 3 - 2
tools/clang/unittests/SPIRV/FileTestUtils.cpp

@@ -22,7 +22,7 @@ namespace utils {
 bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
                             std::string *generatedSpirvAsm,
                             bool generateHeader) {
-  spvtools::SpirvTools spirvTools(SPV_ENV_UNIVERSAL_1_0);
+  spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
   spirvTools.SetMessageConsumer(
       [](spv_message_level_t, const char *, const spv_position_t &,
          const char *message) { fprintf(stdout, "%s\n", message); });
@@ -33,7 +33,7 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
 }
 
 bool validateSpirvBinary(std::vector<uint32_t> &binary) {
-  spvtools::SpirvTools spirvTools(SPV_ENV_UNIVERSAL_1_0);
+  spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
   spirvTools.SetMessageConsumer(
       [](spv_message_level_t, const char *, const spv_position_t &,
          const char *message) { fprintf(stdout, "%s\n", message); });
@@ -134,6 +134,7 @@ bool runCompilerWithSpirvGeneration(const llvm::StringRef inputFilePath,
     flags.push_back(L"-T");
     flags.push_back(profile.c_str());
     flags.push_back(L"-spirv");
+    flags.push_back(L"-O0"); // Disable optimization for testing
     flags.push_back(rest.c_str());
 
     IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));