2
0
Эх сурвалжийг харах

[spirv] Refactoring: split files and use LLVM libraries (#450)

* [spirv] Refactoring: move classes to their own files

* [spirv] Refactoring: use llvm libraries and cosmetic improvements

* Use llvm StringRef and ArrayRef in utils
* Sort methods according to their declaration order
* Qualify parameters with const when possible
* Put test harness in clang::spirv namespace
Lei Zhang 8 жил өмнө
parent
commit
6c64875343

+ 154 - 0
tools/clang/include/clang/SPIRV/DeclResultIdMapper.h

@@ -0,0 +1,154 @@
+//===--- DeclResultIdMapper.h - AST Decl to SPIR-V <result-id> mapper ------==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_SPIRV_DECLRESULTIDMAPPER_H
+#define LLVM_CLANG_SPIRV_DECLRESULTIDMAPPER_H
+
+#include <string>
+#include <vector>
+
+#include "spirv/1.0/spirv.hpp11"
+#include "clang/SPIRV/ModuleBuilder.h"
+#include "clang/SPIRV/TypeTranslator.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace clang {
+namespace spirv {
+
+/// \brief The class containing mappings from Clang frontend Decls to their
+/// corresponding SPIR-V <result-id>s.
+///
+/// All symbols defined in the AST should be "defined" or registered in this
+/// class and have their <result-id>s queried from this class. In the process
+/// of defining a Decl, the SPIR-V module builder passed into the constructor
+/// will be used to generate all SPIR-V instructions required.
+///
+/// This class acts as a middle layer to handle the mapping between HLSL
+/// semantics and Vulkan stage (builtin/input/output) variables. Such mapping
+/// is required because of the semantic differences between DirectX and
+/// Vulkan and the essence of HLSL as the front-end language for DirectX.
+/// A normal variable attached with some semantic will be translated into a
+/// single stage variables if it is of non-struct type. If it is of struct
+/// type, the fields with attached semantics will need to be translated into
+/// stage variables per Vulkan's requirements.
+///
+/// In the following class, we call a Decl as *remapped* when it is translated
+/// into a stage variable; otherwise, we call it as *normal*. Remapped decls
+/// include:
+/// * FunctionDecl if the return value is attached with a semantic
+/// * ParmVarDecl if the parameter is attached with a semantic
+/// * FieldDecl if the field is attached with a semantic.
+class DeclResultIdMapper {
+public:
+  inline DeclResultIdMapper(spv::ExecutionModel stage, ModuleBuilder &builder,
+                            DiagnosticsEngine &diag);
+
+  /// \brief Creates the stage variables by parsing the semantics attached to
+  /// the given function's return value.
+  void createStageVarFromFnReturn(const FunctionDecl *funcDecl);
+
+  /// \brief Creates the stage variables by parsing the semantics attached to
+  /// the given function's parameter.
+  void createStageVarFromFnParam(const ParmVarDecl *paramDecl);
+
+  /// \brief Registers a Decl's <result-id> without generating any SPIR-V
+  /// instruction.
+  void registerDeclResultId(const NamedDecl *symbol, uint32_t resultId);
+
+  /// \brief Returns true if the given <result-id> is for a stage variable.
+  bool isStageVariable(uint32_t varId) const;
+
+  /// \brief Returns the <result-id> for the given Decl.
+  uint32_t getDeclResultId(const NamedDecl *decl) const;
+
+  /// \brief Returns the <result-id> for the given remapped Decl. Returns zero
+  /// if it is not a registered remapped Decl.
+  uint32_t getRemappedDeclResultId(const NamedDecl *decl) const;
+
+  /// \brief Returns the <result-id> for the given normal Decl. Returns zero if
+  /// it is not a registered normal Decl.
+  uint32_t getNormalDeclResultId(const NamedDecl *decl) const;
+
+  /// \brief Returns all defined stage (builtin/input/ouput) variables in this
+  /// mapper.
+  std::vector<uint32_t> collectStageVariables() const;
+
+  /// \brief Decorates all stage input and output variables with proper
+  /// location.
+  ///
+  /// This method will writes the location assignment into the module under
+  /// construction.
+  void finalizeStageIOLocations();
+
+private:
+  /// \brief Stage variable kind.
+  ///
+  /// Stage variables include builtin, input, and output variables.
+  /// They participate in interface matching in Vulkan pipelines.
+  enum class StageVarKind {
+    None,
+    Arbitary,
+    Position,
+    Color,
+    Target,
+    // TODO: other possible kinds
+  };
+
+  using StageVarIdSemanticPair = std::pair<uint32_t, std::string>;
+
+  /// Returns the type of the given decl. If the given decl is a FunctionDecl,
+  /// returns its result type.
+  QualType getFnParamOrRetType(const DeclaratorDecl *decl) const;
+
+  /// Creates all the stage variables mapped from semantics on the given decl.
+  ///
+  /// Assumes the decl has semantic attached to itself or to its fields.
+  void createStageVariables(const DeclaratorDecl *decl, bool actAsInput);
+
+  /// \brief Returns the stage variable's kind for the given semantic.
+  StageVarKind getStageVarKind(llvm::StringRef semantic) const;
+
+  /// \brief Returns the stage variable's semantic for the given Decl.
+  std::string getStageVarSemantic(const NamedDecl *decl) const;
+
+private:
+  const spv::ExecutionModel shaderStage;
+  ModuleBuilder &theBuilder;
+  TypeTranslator typeTranslator;
+
+  /// Mapping of all remapped decls to their <result-id>s.
+  llvm::DenseMap<const NamedDecl *, uint32_t> remappedDecls;
+  /// Mapping of all normal decls to their <result-id>s.
+  llvm::DenseMap<const NamedDecl *, uint32_t> normalDecls;
+  /// <result-id>s of all defined stage variables.
+  ///
+  /// We need to keep a separate list here to avoid looping through the
+  /// remappedDecls to find whether an <result-id> is for a stage variable.
+  llvm::SmallSet<uint32_t, 16> stageVars;
+
+  /// Stage input/oupt/builtin variables and their kinds.
+  ///
+  /// We need to keep a separate list here in order to sort them at the end
+  /// of the module building.
+  llvm::SmallVector<StageVarIdSemanticPair, 8> stageInputs;
+  llvm::SmallVector<StageVarIdSemanticPair, 8> stageOutputs;
+  llvm::SmallVector<StageVarIdSemanticPair, 8> stageBuiltins;
+};
+
+DeclResultIdMapper::DeclResultIdMapper(spv::ExecutionModel stage,
+                                       ModuleBuilder &builder,
+                                       DiagnosticsEngine &diag)
+    : shaderStage(stage), theBuilder(builder), typeTranslator(builder, diag) {}
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif

+ 6 - 17
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -107,8 +107,8 @@ public:
                             llvm::ArrayRef<uint32_t> interfaces);
 
   /// \brief Adds an execution mode to the module under construction.
-  inline void addExecutionMode(uint32_t entryPointId, spv::ExecutionMode em,
-                               const std::vector<uint32_t> &params);
+  void addExecutionMode(uint32_t entryPointId, spv::ExecutionMode em,
+                        const std::vector<uint32_t> &params);
 
   /// \brief Adds a stage input/ouput variable whose value is of the given type.
   ///
@@ -128,9 +128,9 @@ public:
   // === Type ===
 
   uint32_t getVoidType();
-  uint32_t getUint32Type();
   uint32_t getInt32Type();
-  uint32_t getFloatType();
+  uint32_t getUint32Type();
+  uint32_t getFloat32Type();
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
   uint32_t getPointerType(uint32_t pointeeType, spv::StorageClass);
   uint32_t getStructType(llvm::ArrayRef<uint32_t> fieldTypes);
@@ -138,9 +138,9 @@ public:
                            const std::vector<uint32_t> &paramTypes);
 
   // === Constant ===
-  uint32_t getConstantFloat32(float value);
   uint32_t getConstantInt32(int32_t value);
   uint32_t getConstantUint32(uint32_t value);
+  uint32_t getConstantFloat32(float value);
   uint32_t getConstantComposite(uint32_t typeId,
                                 llvm::ArrayRef<uint32_t> constituents);
 
@@ -180,18 +180,7 @@ void ModuleBuilder::requireCapability(spv::Capability cap) {
 void ModuleBuilder::addEntryPoint(spv::ExecutionModel em, uint32_t targetId,
                                   std::string targetName,
                                   llvm::ArrayRef<uint32_t> interfaces) {
-  theModule.addEntryPoint(em, targetId, targetName, interfaces);
-}
-
-void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
-                                     spv::ExecutionMode em,
-                                     const std::vector<uint32_t> &params) {
-  instBuilder.opExecutionMode(entryPointId, em);
-  for (const auto &param : params) {
-    instBuilder.literalInteger(param);
-  }
-  instBuilder.x();
-  theModule.addExecutionMode(std::move(constructSite));
+  theModule.addEntryPoint(em, targetId, std::move(targetName), interfaces);
 }
 
 } // end namespace spirv

+ 5 - 2
tools/clang/include/clang/SPIRV/String.h

@@ -12,18 +12,21 @@
 #include <string>
 #include <vector>
 
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+
 namespace clang {
 namespace spirv {
 namespace string {
 
 /// \brief Reinterprets a given string as sequence of words. It follows the
 /// SPIR-V string encoding requirements.
-std::vector<uint32_t> encodeSPIRVString(std::string s);
+std::vector<uint32_t> encodeSPIRVString(llvm::StringRef strChars);
 
 /// \brief Reinterprets the given vector of 32-bit words as a string.
 /// Expectes that the words represent a NULL-terminated string.
 /// It follows the SPIR-V string encoding requirements.
-std::string decodeSPIRVString(const std::vector<uint32_t> &vec);
+std::string decodeSPIRVString(llvm::ArrayRef<uint32_t> strWords);
 
 } // end namespace string
 } // end namespace spirv

+ 57 - 0
tools/clang/include/clang/SPIRV/TypeTranslator.h

@@ -0,0 +1,57 @@
+//===--- TypeTranslator.h - AST type to SPIR-V type translator ---*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_SPIRV_TYPETRANSLATOR_H
+#define LLVM_CLANG_SPIRV_TYPETRANSLATOR_H
+
+#include "clang/AST/Type.h"
+#include "clang/Basic/Diagnostic.h"
+#include "clang/SPIRV/ModuleBuilder.h"
+
+namespace clang {
+namespace spirv {
+
+/// The class responsible to translate Clang frontend types into SPIR-V type
+/// instructions.
+///
+/// SPIR-V type instructions generated during translation will be emitted to
+/// the SPIR-V module builder passed into the constructor.
+/// Warnings and errors during the translation will be reported to the
+/// DiagnosticEngine passed into the constructor.
+class TypeTranslator {
+public:
+  TypeTranslator(ModuleBuilder &builder, DiagnosticsEngine &diag)
+      : theBuilder(builder), diags(diag) {}
+
+  /// \brief Generates the corresponding SPIR-V type for the given Clang
+  /// frontend type and returns the type's <result-id>. On failure, reports
+  /// the error and returns 0.
+  ///
+  /// The translation is recursive; all the types that the target type depends
+  /// on will be generated.
+  uint32_t translateType(QualType type);
+
+private:
+  /// \brief Wrapper method to create an error message and report it
+  /// in the diagnostic engine associated with this consumer.
+  template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Error, message);
+    return diags.Report(diagId);
+  }
+
+private:
+  ModuleBuilder &theBuilder;
+  DiagnosticsEngine &diags;
+};
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif

+ 2 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -4,6 +4,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_clang_library(clangSPIRV
   Constant.cpp
+  DeclResultIdMapper.cpp
   Decoration.cpp
   EmitSPIRVAction.cpp
   InstBuilderAuto.cpp
@@ -13,6 +14,7 @@ add_clang_library(clangSPIRV
   String.cpp
   Structure.cpp
   Type.cpp
+  TypeTranslator.cpp
 
   LINK_LIBS
   clangAST

+ 185 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -0,0 +1,185 @@
+//===--- DeclResultIdMapper.cpp - DeclResultIdMapper impl --------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/SPIRV/DeclResultIdMapper.h"
+
+#include "clang/AST/HlslTypes.h"
+#include "llvm/ADT/StringSwitch.h"
+
+namespace clang {
+namespace spirv {
+
+void DeclResultIdMapper::createStageVarFromFnReturn(
+    const FunctionDecl *funcDecl) {
+  // SemanticDecl for the return value is attached to the FunctionDecl.
+  createStageVariables(funcDecl, false);
+}
+
+void DeclResultIdMapper::createStageVarFromFnParam(
+    const ParmVarDecl *paramDecl) {
+  // TODO: We cannot treat all parameters as stage inputs because of
+  // out/input modifiers.
+  createStageVariables(paramDecl, true);
+}
+
+void DeclResultIdMapper::registerDeclResultId(const NamedDecl *symbol,
+                                              uint32_t resultId) {
+  normalDecls[symbol] = resultId;
+}
+
+bool DeclResultIdMapper::isStageVariable(uint32_t varId) const {
+  return stageVars.count(varId) != 0;
+}
+
+uint32_t DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) const {
+  if (const uint32_t id = getRemappedDeclResultId(decl))
+    return id;
+  if (const uint32_t id = getNormalDeclResultId(decl))
+    return id;
+
+  assert(false && "found unregistered Decl in DeclResultIdMapper");
+  return 0;
+}
+
+uint32_t
+DeclResultIdMapper::getRemappedDeclResultId(const NamedDecl *decl) const {
+  auto it = remappedDecls.find(decl);
+  if (it != remappedDecls.end())
+    return it->second;
+  return 0;
+}
+
+uint32_t
+DeclResultIdMapper::getNormalDeclResultId(const NamedDecl *decl) const {
+  auto it = normalDecls.find(decl);
+  if (it != normalDecls.end())
+    return it->second;
+  return 0;
+}
+
+std::vector<uint32_t> DeclResultIdMapper::collectStageVariables() const {
+  std::vector<uint32_t> stageVars;
+
+  for (const auto &builtin : stageBuiltins) {
+    stageVars.push_back(builtin.first);
+  }
+  for (const auto &input : stageInputs) {
+    stageVars.push_back(input.first);
+  }
+  for (const auto &output : stageOutputs) {
+    stageVars.push_back(output.first);
+  }
+
+  return stageVars;
+}
+
+void DeclResultIdMapper::finalizeStageIOLocations() {
+  uint32_t nextInputLocation = 0;
+  uint32_t nextOutputLocation = 0;
+
+  // TODO: sort the variables according to some criteria first, e.g.,
+  // alphabetical order of semantic names.
+  for (const auto &input : stageInputs) {
+    theBuilder.decorateLocation(input.first, nextInputLocation++);
+  }
+  for (const auto &output : stageOutputs) {
+    theBuilder.decorateLocation(output.first, nextOutputLocation++);
+  }
+}
+
+QualType
+DeclResultIdMapper::getFnParamOrRetType(const DeclaratorDecl *decl) const {
+  if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
+    return funcDecl->getReturnType();
+  }
+  return decl->getType();
+}
+
+void DeclResultIdMapper::createStageVariables(const DeclaratorDecl *decl,
+                                              bool actAsInput) {
+  QualType type = getFnParamOrRetType(decl);
+
+  if (type->isVoidType()) {
+    // No stage variables will be created for void type.
+    return;
+  }
+
+  const std::string semantic = getStageVarSemantic(decl);
+  if (!semantic.empty()) {
+    // Found semantic attached directly to this Decl. This means we need to
+    // map this decl to a single stage variable.
+    const uint32_t typeId = typeTranslator.translateType(type);
+    const auto kind = getStageVarKind(semantic);
+
+    if (actAsInput) {
+      // Stage (builtin) input variable cases
+      const uint32_t varId =
+          theBuilder.addStageIOVariable(typeId, spv::StorageClass::Input);
+
+      stageInputs.push_back(std::make_pair(varId, semantic));
+      remappedDecls[decl] = varId;
+      stageVars.insert(varId);
+    } else {
+      // Handle output builtin variables first
+      if (shaderStage == spv::ExecutionModel::Vertex &&
+          kind == StageVarKind::Position) {
+        const uint32_t varId =
+            theBuilder.addStageBuiltinVariable(typeId, spv::BuiltIn::Position);
+
+        stageBuiltins.push_back(std::make_pair(varId, semantic));
+        remappedDecls[decl] = varId;
+        stageVars.insert(varId);
+      } else {
+        // The rest are normal stage output variables
+        const uint32_t varId =
+            theBuilder.addStageIOVariable(typeId, spv::StorageClass::Output);
+
+        stageOutputs.push_back(std::make_pair(varId, semantic));
+        remappedDecls[decl] = varId;
+        stageVars.insert(varId);
+      }
+    }
+  } else {
+    // If the decl itself doesn't have semantic, it should be a struct having
+    // all its fields with semantics.
+    assert(type->isStructureType() &&
+           "found non-struct decls without semantics");
+
+    const auto *structDecl = cast<RecordType>(type.getTypePtr())->getDecl();
+
+    // Recursively handle all the fields.
+    for (const auto *field : structDecl->fields()) {
+      createStageVariables(field, actAsInput);
+    }
+  }
+}
+
+DeclResultIdMapper::StageVarKind
+DeclResultIdMapper::getStageVarKind(llvm::StringRef semantic) const {
+  return llvm::StringSwitch<StageVarKind>(semantic)
+      .Case("", StageVarKind::None)
+      .StartsWith("COLOR", StageVarKind::Color)
+      .StartsWith("POSITION", StageVarKind::Position)
+      .StartsWith("SV_POSITION", StageVarKind::Position)
+      .StartsWith("SV_TARGET", StageVarKind::Target)
+      .Default(StageVarKind::Arbitary);
+}
+
+std::string
+DeclResultIdMapper::getStageVarSemantic(const NamedDecl *decl) const {
+  for (auto *annotation : decl->getUnusualAnnotations()) {
+    if (auto *semantic = dyn_cast<hlsl::SemanticDecl>(annotation)) {
+      return semantic->SemanticName.upper();
+    }
+  }
+  return "";
+}
+
+} // end namespace spirv
+} // end namespace clang

+ 8 - 338
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -12,345 +12,15 @@
 #include "clang/AST/AST.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTContext.h"
-#include "clang/AST/HlslTypes.h"
-#include "clang/AST/RecordLayout.h"
 #include "clang/Basic/Diagnostic.h"
-#include "clang/Basic/FileManager.h"
-#include "clang/Basic/SourceManager.h"
 #include "clang/Frontend/CompilerInstance.h"
+#include "clang/SPIRV/DeclResultIdMapper.h"
 #include "clang/SPIRV/ModuleBuilder.h"
-#include "llvm/Support/Path.h"
-#include "llvm/Support/raw_ostream.h"
+#include "clang/SPIRV/TypeTranslator.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace clang {
-namespace {
-
-/// The class responsible to translate Clang frontend types into SPIR-V type
-/// instructions.
-///
-/// SPIR-V type instructions generated during translation will be emitted to
-/// the SPIR-V module builder passed into the constructor.
-/// Warnings and errors during the translation will be reported to the
-/// DiagnosticEngine passed into the constructor.
-class TypeTranslator {
-public:
-  TypeTranslator(spirv::ModuleBuilder &builder, DiagnosticsEngine &diag)
-      : theBuilder(builder), diags(diag) {}
-
-  /// \brief Generates the corresponding SPIR-V type for the given Clang
-  /// frontend type and returns the type's <result-id>. On failure, reports
-  /// the error and returns 0.
-  ///
-  /// The translation is recursive; all the types that the target type depends
-  /// on will be generated.
-  uint32_t translateType(QualType type) {
-    const auto *typePtr = type.getTypePtr();
-
-    // Primitive types
-    if (const auto *builtinType = dyn_cast<BuiltinType>(typePtr)) {
-      switch (builtinType->getKind()) {
-      case BuiltinType::Void:
-        return theBuilder.getVoidType();
-      case BuiltinType::Float:
-        return theBuilder.getFloatType();
-      default:
-        emitError("Primitive type '%0' is not supported yet.")
-            << builtinType->getTypeClassName();
-        return 0;
-      }
-    }
-
-    // In AST, vector types are TypedefType of TemplateSpecializationType.
-    // We handle them via HLSL type inspection functions.
-    if (hlsl::IsHLSLVecType(type)) {
-      const auto elemType = hlsl::GetHLSLVecElementType(type);
-      const auto elemCount = hlsl::GetHLSLVecSize(type);
-      return theBuilder.getVecType(translateType(elemType), elemCount);
-    }
-
-    // Struct type
-    if (const auto *structType = dyn_cast<RecordType>(typePtr)) {
-      const auto *decl = structType->getDecl();
-
-      // Collect all fields' types.
-      std::vector<uint32_t> fieldTypes;
-      for (const auto *field : decl->fields()) {
-        fieldTypes.push_back(translateType(field->getType()));
-      }
-
-      return theBuilder.getStructType(fieldTypes);
-    }
-
-    emitError("Type '%0' is not supported yet.") << type->getTypeClassName();
-    return 0;
-  }
-
-private:
-  /// \brief Wrapper method to create an error message and report it
-  /// in the diagnostic engine associated with this consumer.
-  template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
-    const auto diagId =
-        diags.getCustomDiagID(clang::DiagnosticsEngine::Error, message);
-    return diags.Report(diagId);
-  }
-
-private:
-  spirv::ModuleBuilder &theBuilder;
-  DiagnosticsEngine &diags;
-};
-
-/// \brief The class containing mappings from Clang frontend Decls to their
-/// corresponding SPIR-V <result-id>s.
-///
-/// All symbols defined in the AST should be "defined" or registered in this
-/// class and have their <result-id>s queried from this class. In the process
-/// of defining a Decl, the SPIR-V module builder passed into the constructor
-/// will be used to generate all SPIR-V instructions required.
-///
-/// This class acts as a middle layer to handle the mapping between HLSL
-/// semantics and Vulkan stage (builtin/input/output) variables. Such mapping
-/// is required because of the semantic differences between DirectX and
-/// Vulkan and the essence of HLSL as the front-end language for DirectX.
-/// A normal variable attached with some semantic will be translated into a
-/// single stage variables if it is of non-struct type. If it is of struct
-/// type, the fields with attached semantics will need to be translated into
-/// stage variables per Vulkan's requirements.
-///
-/// In the following class, we call a Decl as *remapped* when it is translated
-/// into a stage variable; otherwise, we call it as *normal*. Remapped decls
-/// include:
-/// * FunctionDecl if the return value is attached with a semantic
-/// * ParmVarDecl if the parameter is attached with a semantic
-/// * FieldDecl if the field is attached with a semantic.
-class DeclResultIdMapper {
-public:
-  DeclResultIdMapper(spv::ExecutionModel stage, spirv::ModuleBuilder &builder,
-                     DiagnosticsEngine &diag)
-      : shaderStage(stage), theBuilder(builder), typeTranslator(builder, diag) {
-  }
-
-  /// \brief Creates the stage variables by parsing the semantics attached to
-  /// the given function's return value.
-  void createStageVarFromFnReturn(FunctionDecl *funcDecl) {
-    // SemanticDecl for the return value is attached to the FunctionDecl.
-    createStageVariables(funcDecl, false);
-  }
-
-  /// \brief Creates the stage variables by parsing the semantics attached to
-  /// the given function's parameter.
-  void createStageVarFromFnParam(ParmVarDecl *paramDecl) {
-    // TODO: We cannot treat all parameters as stage inputs because of
-    // out/input modifiers.
-    createStageVariables(paramDecl, true);
-  }
-
-  /// \brief Registers a Decl's <result-id> without generating any SPIR-V
-  /// instruction.
-  void registerDeclResultId(const NamedDecl *symbol, uint32_t resultId) {
-    normalDecls[symbol] = resultId;
-  }
-
-  /// \brief Returns true if the given <result-id> is for a stage variable.
-  bool isStageVariable(uint32_t varId) const {
-    return stageVars.count(varId) != 0;
-  }
-
-  /// \brief Returns the <result-id> for the given Decl.
-  uint32_t getDeclResultId(const NamedDecl *decl) const {
-    if (const uint32_t id = getRemappedDeclResultId(decl))
-      return id;
-    if (const uint32_t id = getNormalDeclResultId(decl))
-      return id;
-
-    assert(false && "found unregistered Decl in DeclResultIdMapper");
-    return 0;
-  }
-
-  /// \brief Returns the <result-id> for the given remapped Decl. Returns zero
-  /// if it is not a registered remapped Decl.
-  uint32_t getRemappedDeclResultId(const NamedDecl *decl) const {
-    auto it = remappedDecls.find(decl);
-    if (it != remappedDecls.end())
-      return it->second;
-    return 0;
-  }
-
-  /// \brief Returns the <result-id> for the given normal Decl. Returns zero if
-  /// it is not a registered normal Decl.
-  uint32_t getNormalDeclResultId(const NamedDecl *decl) const {
-    auto it = normalDecls.find(decl);
-    if (it != normalDecls.end())
-      return it->second;
-    return 0;
-  }
-
-  /// \brief Returns all defined stage (builtin/input/ouput) variables in this
-  /// mapper.
-  std::vector<uint32_t> collectStageVariables() const {
-    std::vector<uint32_t> stageVars;
-
-    for (const auto &builtin : stageBuiltins) {
-      stageVars.push_back(builtin.first);
-    }
-    for (const auto &input : stageInputs) {
-      stageVars.push_back(input.first);
-    }
-    for (const auto &output : stageOutputs) {
-      stageVars.push_back(output.first);
-    }
-
-    return stageVars;
-  }
-
-  /// \brief Decorates all stage input and output variables with proper
-  /// location.
-  ///
-  /// This method will writes the location assignment into the module under
-  /// construction.
-  void finalizeStageIOLocations() {
-    uint32_t nextInputLocation = 0;
-    uint32_t nextOutputLocation = 0;
-
-    // TODO: sort the variables according to some criteria first, e.g.,
-    // alphabetical order of semantic names.
-    for (const auto &input : stageInputs) {
-      theBuilder.decorateLocation(input.first, nextInputLocation++);
-    }
-    for (const auto &output : stageOutputs) {
-      theBuilder.decorateLocation(output.first, nextOutputLocation++);
-    }
-  }
-
-private:
-  /// \brief Stage variable kind.
-  ///
-  /// Stage variables include builtin, input, and output variables.
-  /// They participate in interface matching in Vulkan pipelines.
-  enum class StageVarKind {
-    None,
-    Arbitary,
-    Position,
-    Color,
-    Target,
-    // TODO: other possible kinds
-  };
-
-  using StageVarIdSemanticPair = std::pair<uint32_t, std::string>;
-
-  /// Returns the type of the given decl. If the given decl is a FunctionDecl,
-  /// returns its result type.
-  QualType getFnParamOrRetType(const DeclaratorDecl *decl) const {
-    if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
-      return funcDecl->getReturnType();
-    }
-    return decl->getType();
-  }
-
-  /// Creates all the stage variables mapped from semantics on the given decl.
-  ///
-  /// Assumes the decl has semantic attached to itself or to its fields.
-  void createStageVariables(const DeclaratorDecl *decl, bool actAsInput) {
-    QualType type = getFnParamOrRetType(decl);
-
-    if (type->isVoidType()) {
-      // No stage variables will be created for void type.
-      return;
-    }
-
-    const std::string semantic = getStageVarSemantic(decl);
-    if (!semantic.empty()) {
-      // Found semantic attached directly to this Decl. This means we need to
-      // map this decl to a single stage variable.
-      const uint32_t typeId = typeTranslator.translateType(type);
-      const auto kind = getStageVarKind(semantic);
-
-      if (actAsInput) {
-        // Stage (builtin) input variable cases
-        const uint32_t varId =
-            theBuilder.addStageIOVariable(typeId, spv::StorageClass::Input);
-
-        stageInputs.push_back(std::make_pair(varId, semantic));
-        remappedDecls[decl] = varId;
-        stageVars.insert(varId);
-      } else {
-        // Handle output builtin variables first
-        if (shaderStage == spv::ExecutionModel::Vertex &&
-            kind == StageVarKind::Position) {
-          const uint32_t varId = theBuilder.addStageBuiltinVariable(
-              typeId, spv::BuiltIn::Position);
-
-          stageBuiltins.push_back(std::make_pair(varId, semantic));
-          remappedDecls[decl] = varId;
-          stageVars.insert(varId);
-        } else {
-          // The rest are normal stage output variables
-          const uint32_t varId =
-              theBuilder.addStageIOVariable(typeId, spv::StorageClass::Output);
-
-          stageOutputs.push_back(std::make_pair(varId, semantic));
-          remappedDecls[decl] = varId;
-          stageVars.insert(varId);
-        }
-      }
-    } else {
-      // If the decl itself doesn't have semantic, it should be a struct having
-      // all its fields with semantics.
-      assert(type->isStructureType() &&
-             "found non-struct decls without semantics");
-
-      const auto *structDecl = cast<RecordType>(type.getTypePtr())->getDecl();
-
-      // Recursively handle all the fields.
-      for (const auto *field : structDecl->fields()) {
-        createStageVariables(field, actAsInput);
-      }
-    }
-  }
-
-  /// \brief Returns the stage variable's kind for the given semantic.
-  StageVarKind getStageVarKind(llvm::StringRef semantic) const {
-    return llvm::StringSwitch<StageVarKind>(semantic)
-        .Case("", StageVarKind::None)
-        .StartsWith("COLOR", StageVarKind::Color)
-        .StartsWith("POSITION", StageVarKind::Position)
-        .StartsWith("SV_POSITION", StageVarKind::Position)
-        .StartsWith("SV_TARGET", StageVarKind::Target)
-        .Default(StageVarKind::Arbitary);
-  }
-
-  /// \brief Returns the stage variable's semantic for the given Decl.
-  std::string getStageVarSemantic(const NamedDecl *decl) const {
-    for (auto *annotation : decl->getUnusualAnnotations()) {
-      if (auto *semantic = dyn_cast<hlsl::SemanticDecl>(annotation)) {
-        return semantic->SemanticName.upper();
-      }
-    }
-    return "";
-  }
-
-private:
-  const spv::ExecutionModel shaderStage;
-  spirv::ModuleBuilder &theBuilder;
-  TypeTranslator typeTranslator;
-
-  /// Mapping of all remapped decls to their <result-id>s.
-  llvm::DenseMap<const NamedDecl *, uint32_t> remappedDecls;
-  /// Mapping of all normal decls to their <result-id>s.
-  llvm::DenseMap<const NamedDecl *, uint32_t> normalDecls;
-  /// <result-id>s of all defined stage variables.
-  ///
-  /// We need to keep a separate list here to avoid looping through the
-  /// remappedDecls to find whether an <result-id> is for a stage variable.
-  llvm::SmallSet<uint32_t, 16> stageVars;
-
-  /// Stage input/oupt/builtin variables and their kinds.
-  ///
-  /// We need to keep a separate list here in order to sort them at the end
-  /// of the module building.
-  llvm::SmallVector<StageVarIdSemanticPair, 8> stageInputs;
-  llvm::SmallVector<StageVarIdSemanticPair, 8> stageOutputs;
-  llvm::SmallVector<StageVarIdSemanticPair, 8> stageBuiltins;
-};
+namespace spirv {
 
 /// SPIR-V emitter class. It consumes the HLSL AST and emits SPIR-V words.
 ///
@@ -717,8 +387,8 @@ private:
   const llvm::StringRef entryFunctionName;
   const spv::ExecutionModel shaderStage;
 
-  spirv::SPIRVContext theContext;
-  spirv::ModuleBuilder theBuilder;
+  SPIRVContext theContext;
+  ModuleBuilder theBuilder;
   DeclResultIdMapper declIdMapper;
   TypeTranslator typeTranslator;
 
@@ -729,10 +399,10 @@ private:
   const FunctionDecl *curFunction;
 };
 
-} // namespace
+} // end namespace spirv
 
 std::unique_ptr<ASTConsumer>
 EmitSPIRVAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
-  return llvm::make_unique<SPIRVEmitter>(CI);
+  return llvm::make_unique<spirv::SPIRVEmitter>(CI);
 }
 } // end namespace clang

+ 96 - 103
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -24,6 +24,18 @@ ModuleBuilder::ModuleBuilder(SPIRVContext *C)
   });
 }
 
+std::vector<uint32_t> ModuleBuilder::takeModule() {
+  theModule.setBound(theContext.getNextId());
+
+  std::vector<uint32_t> binary;
+  auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
+    binary.insert(binary.end(), words.begin(), words.end());
+  });
+
+  theModule.take(&ib);
+  return std::move(binary);
+}
+
 uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
                                       std::string funcName) {
   if (theFunction) {
@@ -142,68 +154,72 @@ void ModuleBuilder::createReturnValue(uint32_t value) {
   insertPoint->appendInstruction(std::move(constructSite));
 }
 
-uint32_t
-ModuleBuilder::getConstantComposite(uint32_t typeId,
-                                    llvm::ArrayRef<uint32_t> constituents) {
-  const Constant *constant =
-      Constant::getComposite(theContext, typeId, constituents);
-  const uint32_t constId = theContext.getResultIdForConstant(constant);
-  theModule.addConstant(constant, constId);
-  return constId;
+void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
+                                     spv::ExecutionMode em,
+                                     const std::vector<uint32_t> &params) {
+  instBuilder.opExecutionMode(entryPointId, em);
+  for (const auto &param : params) {
+    instBuilder.literalInteger(param);
+  }
+  instBuilder.x();
+  theModule.addExecutionMode(std::move(constructSite));
 }
 
-uint32_t ModuleBuilder::getConstantFloat32(float value) {
-  const uint32_t floatTypeId = getFloatType();
-  const Constant *constant =
-      Constant::getFloat32(theContext, floatTypeId, value);
-  const uint32_t constId = theContext.getResultIdForConstant(constant);
-  theModule.addConstant(constant, constId);
-  return constId;
+uint32_t ModuleBuilder::addStageIOVariable(uint32_t type,
+                                           spv::StorageClass storageClass) {
+  const uint32_t pointerType = getPointerType(type, storageClass);
+  const uint32_t varId = theContext.takeNextId();
+  instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
+  theModule.addVariable(std::move(constructSite));
+  return varId;
 }
 
-uint32_t ModuleBuilder::getConstantInt32(int32_t value) {
-  const uint32_t intTypeId = getInt32Type();
-  const Constant *constant = Constant::getInt32(theContext, intTypeId, value);
-  const uint32_t constId = theContext.getResultIdForConstant(constant);
-  theModule.addConstant(constant, constId);
-  return constId;
-}
+uint32_t ModuleBuilder::addStageBuiltinVariable(uint32_t type,
+                                                spv::BuiltIn builtin) {
+  spv::StorageClass sc = spv::StorageClass::Input;
+  switch (builtin) {
+  case spv::BuiltIn::Position:
+  case spv::BuiltIn::PointSize:
+    // TODO: add the rest output builtins
+    sc = spv::StorageClass::Output;
+    break;
+  default:
+    break;
+  }
+  const uint32_t pointerType = getPointerType(type, sc);
+  const uint32_t varId = theContext.takeNextId();
+  instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
+  theModule.addVariable(std::move(constructSite));
 
-uint32_t ModuleBuilder::getConstantUint32(uint32_t value) {
-  const uint32_t uintTypeId = getUint32Type();
-  const Constant *constant = Constant::getUint32(theContext, uintTypeId, value);
-  const uint32_t constId = theContext.getResultIdForConstant(constant);
-  theModule.addConstant(constant, constId);
-  return constId;
-}
+  // Decorate with the specified Builtin
+  const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
+  theModule.addDecoration(*d, varId);
 
-uint32_t ModuleBuilder::getVoidType() {
-  const Type *type = Type::getVoid(theContext);
-  const uint32_t typeId = theContext.getResultIdForType(type);
-  theModule.addType(type, typeId);
-  return typeId;
+  return varId;
 }
 
-uint32_t ModuleBuilder::getUint32Type() {
-  const Type *type = Type::getUint32(theContext);
-  const uint32_t typeId = theContext.getResultIdForType(type);
-  theModule.addType(type, typeId);
-  return typeId;
+void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
+  const Decoration *d =
+      Decoration::getLocation(theContext, location, llvm::None);
+  theModule.addDecoration(*d, targetId);
 }
 
-uint32_t ModuleBuilder::getInt32Type() {
-  const Type *type = Type::getInt32(theContext);
-  const uint32_t typeId = theContext.getResultIdForType(type);
-  theModule.addType(type, typeId);
-  return typeId;
+#define IMPL_GET_PRIMITIVE_TYPE(ty)                                            \
+  \
+uint32_t ModuleBuilder::get##ty##Type() {                                      \
+    const Type *type = Type::get##ty(theContext);                              \
+    const uint32_t typeId = theContext.getResultIdForType(type);               \
+    theModule.addType(type, typeId);                                           \
+    return typeId;                                                             \
+  \
 }
 
-uint32_t ModuleBuilder::getFloatType() {
-  const Type *type = Type::getFloat32(theContext);
-  const uint32_t typeId = theContext.getResultIdForType(type);
-  theModule.addType(type, typeId);
-  return typeId;
-}
+IMPL_GET_PRIMITIVE_TYPE(Void)
+IMPL_GET_PRIMITIVE_TYPE(Int32)
+IMPL_GET_PRIMITIVE_TYPE(Uint32)
+IMPL_GET_PRIMITIVE_TYPE(Float32)
+
+#undef IMPL_GET_PRIMITIVE_TYPE
 
 uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
   const Type *type = nullptr;
@@ -229,6 +245,14 @@ uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
   return typeId;
 }
 
+uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
+                                       spv::StorageClass storageClass) {
+  const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
+}
+
 uint32_t ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes) {
   const Type *type = Type::getStruct(theContext, fieldTypes);
   const uint32_t typeId = theContext.getResultIdForType(type);
@@ -245,63 +269,32 @@ ModuleBuilder::getFunctionType(uint32_t returnType,
   return typeId;
 }
 
-uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
-                                       spv::StorageClass storageClass) {
-  const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
-  const uint32_t typeId = theContext.getResultIdForType(type);
-  theModule.addType(type, typeId);
-  return typeId;
+#define IMPL_GET_PRIMITIVE_VALUE(builderTy, cppTy)                             \
+  \
+uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) {                  \
+    const uint32_t typeId = get##builderTy##Type();                            \
+    const Constant *constant =                                                 \
+        Constant::get##builderTy(theContext, typeId, value);                   \
+    const uint32_t constId = theContext.getResultIdForConstant(constant);      \
+    theModule.addConstant(constant, constId);                                  \
+    return constId;                                                            \
+  \
 }
 
-uint32_t ModuleBuilder::addStageIOVariable(uint32_t type,
-                                           spv::StorageClass storageClass) {
-  const uint32_t pointerType = getPointerType(type, storageClass);
-  const uint32_t varId = theContext.takeNextId();
-  instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
-  theModule.addVariable(std::move(constructSite));
-  return varId;
-}
+IMPL_GET_PRIMITIVE_VALUE(Int32, int32_t)
+IMPL_GET_PRIMITIVE_VALUE(Uint32, uint32_t)
+IMPL_GET_PRIMITIVE_VALUE(Float32, float)
 
-uint32_t ModuleBuilder::addStageBuiltinVariable(uint32_t type,
-                                                spv::BuiltIn builtin) {
-  spv::StorageClass sc = spv::StorageClass::Input;
-  switch (builtin) {
-  case spv::BuiltIn::Position:
-  case spv::BuiltIn::PointSize:
-    // TODO: add the rest output builtins
-    sc = spv::StorageClass::Output;
-    break;
-  default:
-    break;
-  }
-  const uint32_t pointerType = getPointerType(type, sc);
-  const uint32_t varId = theContext.takeNextId();
-  instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
-  theModule.addVariable(std::move(constructSite));
-
-  // Decorate with the specified Builtin
-  const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
-  theModule.addDecoration(*d, varId);
-
-  return varId;
-}
-
-void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
-  const Decoration *d =
-      Decoration::getLocation(theContext, location, llvm::None);
-  theModule.addDecoration(*d, targetId);
-}
+#undef IMPL_GET_PRIMITIVE_VALUE
 
-std::vector<uint32_t> ModuleBuilder::takeModule() {
-  theModule.setBound(theContext.getNextId());
-
-  std::vector<uint32_t> binary;
-  auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
-    binary.insert(binary.end(), words.begin(), words.end());
-  });
-
-  theModule.take(&ib);
-  return std::move(binary);
+uint32_t
+ModuleBuilder::getConstantComposite(uint32_t typeId,
+                                    llvm::ArrayRef<uint32_t> constituents) {
+  const Constant *constant =
+      Constant::getComposite(theContext, typeId, constituents);
+  const uint32_t constId = theContext.getResultIdForConstant(constant);
+  theModule.addConstant(constant, constId);
+  return constId;
 }
 
 } // end namespace spirv

+ 8 - 8
tools/clang/lib/SPIRV/String.cpp

@@ -15,9 +15,9 @@ namespace spirv {
 namespace string {
 
 /// \brief Reinterprets a given string as sequence of words.
-std::vector<uint32_t> encodeSPIRVString(std::string s) {
+std::vector<uint32_t> encodeSPIRVString(llvm::StringRef strChars) {
   // Initialize all words to 0.
-  size_t numChars = s.size();
+  size_t numChars = strChars.size();
   std::vector<uint32_t> result(numChars / 4 + 1, 0);
 
   // From the SPIR-V spec, literal string is
@@ -32,19 +32,19 @@ std::vector<uint32_t> encodeSPIRVString(std::string s) {
   //
   // So the following works on little endian machines.
   char *strDest = reinterpret_cast<char *>(result.data());
-  strncpy(strDest, s.c_str(), numChars);
+  strncpy(strDest, strChars.data(), numChars);
   return result;
 }
 
 /// \brief Reinterprets the given vector of 32-bit words as a string.
 /// Expectes that the words represent a NULL-terminated string.
 /// Assumes Little Endian architecture.
-std::string decodeSPIRVString(const std::vector<uint32_t> &vec) {
-  std::string result;
-  if (!vec.empty()) {
-    result = std::string(reinterpret_cast<const char *>(vec.data()));
+std::string decodeSPIRVString(llvm::ArrayRef<uint32_t> strWords) {
+  if (!strWords.empty()) {
+    return reinterpret_cast<const char *>(strWords.data());
   }
-  return result;
+
+  return "";
 }
 
 } // end namespace string

+ 59 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -0,0 +1,59 @@
+//===--- TypeTranslator.cpp - TypeTranslator implementation ------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/SPIRV/TypeTranslator.h"
+#include "clang/AST/HlslTypes.h"
+
+namespace clang {
+namespace spirv {
+
+uint32_t TypeTranslator::translateType(QualType type) {
+  const auto *typePtr = type.getTypePtr();
+
+  // Primitive types
+  if (const auto *builtinType = dyn_cast<BuiltinType>(typePtr)) {
+    switch (builtinType->getKind()) {
+    case BuiltinType::Void:
+      return theBuilder.getVoidType();
+    case BuiltinType::Float:
+      return theBuilder.getFloat32Type();
+    default:
+      emitError("Primitive type '%0' is not supported yet.")
+          << builtinType->getTypeClassName();
+      return 0;
+    }
+  }
+
+  // In AST, vector types are TypedefType of TemplateSpecializationType.
+  // We handle them via HLSL type inspection functions.
+  if (hlsl::IsHLSLVecType(type)) {
+    const auto elemType = hlsl::GetHLSLVecElementType(type);
+    const auto elemCount = hlsl::GetHLSLVecSize(type);
+    return theBuilder.getVecType(translateType(elemType), elemCount);
+  }
+
+  // Struct type
+  if (const auto *structType = dyn_cast<RecordType>(typePtr)) {
+    const auto *decl = structType->getDecl();
+
+    // Collect all fields' types.
+    std::vector<uint32_t> fieldTypes;
+    for (const auto *field : decl->fields()) {
+      fieldTypes.push_back(translateType(field->getType()));
+    }
+
+    return theBuilder.getStructType(fieldTypes);
+  }
+
+  emitError("Type '%0' is not supported yet.") << type->getTypeClassName();
+  return 0;
+}
+
+} // end namespace spirv
+} // end namespace clang

+ 5 - 4
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -9,10 +9,10 @@
 
 #include "WholeFileCheck.h"
 
+namespace {
+using clang::spirv::WholeFileTest;
+
 TEST_F(WholeFileTest, EmptyVoidMain) {
-  // Ideally all generated SPIR-V must be valid, but this currently fails with
-  // this error message: "No OpEntryPoint instruction was found...".
-  // TODO: change this test such that it does run validation.
   runWholeFileTest("empty-void-main.hlsl2spv",
                    /*generateHeader*/ true,
                    /*runValidation*/ true);
@@ -33,5 +33,6 @@ TEST_F(WholeFileTest, PassThruVertexShader) {
 TEST_F(WholeFileTest, ConstantPixelShader) {
   runWholeFileTest("constant-ps.hlsl2spv",
                    /*generateHeader*/ true,
-                   /*runValidation*/ false);
+                   /*runValidation*/ true);
+}
 }

+ 14 - 3
tools/clang/unittests/SPIRV/WholeFileCheck.cpp

@@ -12,6 +12,14 @@
 #include "WholeFileCheck.h"
 #include "gtest/gtest.h"
 
+namespace clang {
+namespace spirv {
+
+namespace {
+const char hlslStartLabel[] = "// Run:";
+const char spirvStartLabel[] = "// CHECK-WHOLE-SPIR-V:";
+}
+
 WholeFileTest::WholeFileTest() : spirvTools(SPV_ENV_UNIVERSAL_1_0) {
   spirvTools.SetMessageConsumer(
       [](spv_message_level_t, const char *, const spv_position_t &,
@@ -185,8 +193,7 @@ void WholeFileTest::convertIDxcBlobToUint32(const CComPtr<IDxcBlob> &blob) {
   memcpy(generatedBinary.data(), binaryStr.data(), binaryStr.size());
 }
 
-std::string
-WholeFileTest::getAbsPathOfInputDataFile(const std::string &filename) {
+std::string WholeFileTest::getAbsPathOfInputDataFile(llvm::StringRef filename) {
   std::string path = clang::spirv::testOptions::inputDataDir;
 
 #ifdef _WIN32
@@ -203,7 +210,8 @@ WholeFileTest::getAbsPathOfInputDataFile(const std::string &filename) {
   return path;
 }
 
-void WholeFileTest::runWholeFileTest(std::string filename, bool generateHeader,
+void WholeFileTest::runWholeFileTest(llvm::StringRef filename,
+                                     bool generateHeader,
                                      bool runSpirvValidation) {
   inputFilePath = getAbsPathOfInputDataFile(filename);
 
@@ -224,3 +232,6 @@ void WholeFileTest::runWholeFileTest(std::string filename, bool generateHeader,
     EXPECT_TRUE(validateSpirvBinary());
   }
 }
+
+} // end namespace spirv
+} // end namespace clang

+ 13 - 6
tools/clang/unittests/SPIRV/WholeFileCheck.h

@@ -7,6 +7,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#ifndef LLVM_CLANG_UNITTESTS_SPIRV_WHOLEFILECHECK_H
+#define LLVM_CLANG_UNITTESTS_SPIRV_WHOLEFILECHECK_H
+
 #include <algorithm>
 #include <fstream>
 
@@ -14,14 +17,13 @@
 #include "dxc/Support/WinIncludes.h"
 #include "dxc/Support/dxcapi.use.h"
 #include "spirv-tools/libspirv.hpp"
+#include "llvm/ADT/StringRef.h"
 #include "gtest/gtest.h"
 
 #include "SpirvTestOptions.h"
 
-namespace {
-const char hlslStartLabel[] = "// Run:";
-const char spirvStartLabel[] = "// CHECK-WHOLE-SPIR-V:";
-}
+namespace clang {
+namespace spirv {
 
 /// \brief The purpose of the this test class is to take in an input file with
 /// the following format:
@@ -55,7 +57,7 @@ public:
   /// It is also important that all generated SPIR-V code is valid. Users of
   /// WholeFileTest may choose not to run the SPIR-V Validator (for cases where
   /// a certain feature has not been added to the Validator yet).
-  void runWholeFileTest(std::string path, bool generateHeader = false,
+  void runWholeFileTest(llvm::StringRef path, bool generateHeader = false,
                         bool runSpirvValidation = true);
 
 private:
@@ -86,7 +88,7 @@ private:
   void convertIDxcBlobToUint32(const CComPtr<IDxcBlob> &blob);
 
   /// \brief Returns the absolute path to the input file of the test.
-  std::string getAbsPathOfInputDataFile(const std::string &filename);
+  std::string getAbsPathOfInputDataFile(llvm::StringRef filename);
 
   std::string targetProfile;             ///< Target profile (argument of -T)
   std::string entryPoint;                ///< Entry point name (argument of -E)
@@ -96,3 +98,8 @@ private:
   std::string generatedSpirvAsm;         ///< Disassembled binary (SPIR-V code)
   spvtools::SpirvTools spirvTools;       ///< SPIR-V Tools used by the test
 };
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif