瀏覽代碼

[spirv] Support texture and sampler types (#572)

This commit add support for the following types:
* Texture1D, Texture2D, Texture3D, TextureCube
* Texture1DArray, Texture2DArray, TextureCubeArray
* SamplerState, SamplerComparisonState, sampler

Buffer, Texture2DMS, and Texture2DMSArray is not supported yet,
also setting sampler states.

Texture types will be translated OpTypeImage, and sampler types
will be translated into OpTypeSampler.
Lei Zhang 8 年之前
父節點
當前提交
f284c982c9

+ 29 - 4
docs/SPIR-V.rst

@@ -189,10 +189,35 @@ User-defined types
 are type aliases introduced by typedef. No new types are introduced and we can
 are type aliases introduced by typedef. No new types are introduced and we can
 rely on Clang to resolve to the original types.
 rely on Clang to resolve to the original types.
 
 
-Samplers and textures
-+++++++++++++++++++++
-
-[TODO]
+Samplers
+++++++++
+
+All `sampler types <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509644(v=vs.85).aspx>`_
+will be translated into SPIR-V ``OpTypeSampler``.
+
+SPIR-V ``OpTypeSampler`` is an opaque type that cannot be parameterized;
+therefore state assignments on sampler types is not supported (yet).
+
+Textures
+++++++++
+
+`Texture types <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509700(v=vs.85).aspx>`_
+are translated into SPIR-V ``OpTypeImage``, with parameters:
+
+====================   ==== ===== ======= == ======= ============
+HLSL Texture Type      Dim  Depth Arrayed MS Sampled Image Format
+====================   ==== ===== ======= == ======= ============
+``Texture1D``          1D    0       0    0    1       Unknown
+``Texture2D``          2D    0       0    0    1       Unknown
+``Texture3D``          3D    0       0    0    1       Unknown
+``TextureCube``        Cube  0       0    0    1       Unknown
+``Texture1DArray``     1D    0       1    0    1       Unknown
+``Texture2DArray``     2D    0       1    0    1       Unknown
+``TextureCubeArray``   3D    0       1    0    1       Unknown
+====================   ==== ===== ======= == ======= ============
+
+The meanings of the headers in the above table is explained in ``OpTypeImage``
+of the SPIR-V spec.
 
 
 Buffers
 Buffers
 +++++++
 +++++++

+ 7 - 4
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -226,13 +226,14 @@ public:
   uint32_t addStageBuiltinVar(uint32_t type, spv::StorageClass storageClass,
   uint32_t addStageBuiltinVar(uint32_t type, spv::StorageClass storageClass,
                               spv::BuiltIn);
                               spv::BuiltIn);
 
 
-  /// \brief Adds a file/module visible variable. This variable will have
-  /// Private storage class.
+  /// \brief Adds a module variable. This variable should not have the Function
+  /// storage class.
   ///
   ///
   /// The corresponding pointer type of the given type will be constructed in
   /// The corresponding pointer type of the given type will be constructed in
   /// this method for the variable itself.
   /// this method for the variable itself.
-  uint32_t addFileVar(uint32_t valueType, llvm::StringRef name = "",
-                      llvm::Optional<uint32_t> init = llvm::None);
+  uint32_t addModuleVar(uint32_t valueType, spv::StorageClass storageClass,
+                        llvm::StringRef name = "",
+                        llvm::Optional<uint32_t> init = llvm::None);
 
 
   /// \brief Decorates the given target <result-id> with the given location.
   /// \brief Decorates the given target <result-id> with the given location.
   void decorateLocation(uint32_t targetId, uint32_t location);
   void decorateLocation(uint32_t targetId, uint32_t location);
@@ -257,6 +258,8 @@ public:
   uint32_t getArrayType(uint32_t elemType, uint32_t count);
   uint32_t getArrayType(uint32_t elemType, uint32_t count);
   uint32_t getFunctionType(uint32_t returnType,
   uint32_t getFunctionType(uint32_t returnType,
                            llvm::ArrayRef<uint32_t> paramTypes);
                            llvm::ArrayRef<uint32_t> paramTypes);
+  uint32_t getImageType(uint32_t sampledType, spv::Dim, bool isArray);
+  uint32_t getSamplerType();
 
 
   // === Constant ===
   // === Constant ===
   uint32_t getConstantBool(bool value);
   uint32_t getConstantBool(bool value);

+ 38 - 20
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -33,23 +33,6 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
   return createStageVars(paramDecl, loadedValue, true, "in.var");
   return createStageVars(paramDecl, loadedValue, true, "in.var");
 }
 }
 
 
-void DeclResultIdMapper::registerDeclResultId(const NamedDecl *symbol,
-                                              uint32_t resultId) {
-  auto sc = spv::StorageClass::Function;
-  if (const auto *varDecl = dyn_cast<VarDecl>(symbol)) {
-    if (varDecl->isExternallyVisible()) {
-      // TODO: Global variables are by default constant. But the default
-      // behavior can be changed via command line option. So Uniform may
-      // not be the correct storage class.
-      sc = spv::StorageClass::Uniform;
-    } else if (!varDecl->hasLocalStorage()) {
-      // File scope variables
-      sc = spv::StorageClass::Private;
-    }
-  }
-  astDecls[symbol] = {resultId, sc};
-}
-
 const DeclResultIdMapper::DeclSpirvInfo *
 const DeclResultIdMapper::DeclSpirvInfo *
 DeclResultIdMapper::getDeclSpirvInfo(const NamedDecl *decl) const {
 DeclResultIdMapper::getDeclSpirvInfo(const NamedDecl *decl) const {
   auto it = astDecls.find(decl);
   auto it = astDecls.find(decl);
@@ -67,12 +50,47 @@ uint32_t DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) const {
   return 0;
   return 0;
 }
 }
 
 
-uint32_t DeclResultIdMapper::getOrRegisterDeclResultId(const NamedDecl *decl) {
-  if (const auto *info = getDeclSpirvInfo(decl))
+uint32_t DeclResultIdMapper::createFnParam(uint32_t paramType,
+                                           const ParmVarDecl *param) {
+  const uint32_t id = theBuilder.addFnParam(paramType, param->getName());
+  astDecls[param] = {id, spv::StorageClass::Function};
+
+  return id;
+}
+
+uint32_t DeclResultIdMapper::createFnVar(uint32_t varType, const VarDecl *var,
+                                         llvm::Optional<uint32_t> init) {
+  const uint32_t id = theBuilder.addFnVar(varType, var->getName(), init);
+  astDecls[var] = {id, spv::StorageClass::Function};
+
+  return id;
+}
+
+uint32_t DeclResultIdMapper::createFileVar(uint32_t varType, const VarDecl *var,
+                                           llvm::Optional<uint32_t> init) {
+  const uint32_t id = theBuilder.addModuleVar(
+      varType, spv::StorageClass::Private, var->getName(), init);
+  astDecls[var] = {id, spv::StorageClass::Private};
+
+  return id;
+}
+
+uint32_t DeclResultIdMapper::createExternVar(uint32_t varType,
+                                             const VarDecl *var) {
+  // TODO: storage class can also be Uniform
+  const uint32_t id = theBuilder.addModuleVar(
+      varType, spv::StorageClass::UniformConstant, var->getName(), llvm::None);
+  astDecls[var] = {id, spv::StorageClass::UniformConstant};
+
+  return id;
+}
+
+uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
+  if (const auto *info = getDeclSpirvInfo(fn))
     return info->resultId;
     return info->resultId;
 
 
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
-  registerDeclResultId(decl, id);
+  astDecls[fn] = {id, spv::StorageClass::Function};
 
 
   return id;
   return id;
 }
 }

+ 24 - 10
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -103,7 +103,7 @@ StageVar::StageVar(const hlsl::SigPoint *sig, llvm::StringRef semaStr,
 /// stage variables per Vulkan's requirements.
 /// stage variables per Vulkan's requirements.
 class DeclResultIdMapper {
 class DeclResultIdMapper {
 public:
 public:
-  inline DeclResultIdMapper(const hlsl::ShaderModel &stage,
+  inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
                             ModuleBuilder &builder, DiagnosticsEngine &diag);
                             ModuleBuilder &builder, DiagnosticsEngine &diag);
 
 
   /// \brief Creates the stage output variables by parsing the semantics
   /// \brief Creates the stage output variables by parsing the semantics
@@ -119,10 +119,23 @@ public:
   /// variables and composite them into one and write to *loadedValue.
   /// variables and composite them into one and write to *loadedValue.
   bool createStageInputVar(const ParmVarDecl *paramDecl, uint32_t *loadedValue);
   bool createStageInputVar(const ParmVarDecl *paramDecl, uint32_t *loadedValue);
 
 
-  /// \brief Registers a decl's <result-id> without generating any SPIR-V
-  /// instruction. The given decl will be treated as normal decl.
-  void registerDeclResultId(const NamedDecl *symbol, uint32_t resultId);
+  /// \brief Creates a function-scope paramter in the current function and
+  /// returns its <result-id>.
+  uint32_t createFnParam(uint32_t paramType, const ParmVarDecl *param);
 
 
+  /// \brief Creates a function-scope variable in the current function and
+  /// returns its <result-id>.
+  uint32_t createFnVar(uint32_t varType, const VarDecl *variable,
+                       llvm::Optional<uint32_t> init);
+
+  /// \brief Creates a file-scope variable and returns its <result-id>.
+  uint32_t createFileVar(uint32_t varType, const VarDecl *variable,
+                         llvm::Optional<uint32_t> init);
+
+  /// \brief Creates an external-visible variable and returns its <result-id>.
+  uint32_t createExternVar(uint32_t varType, const VarDecl *var);
+
+  /// \brief Sets the <result-id> of the entry function.
   void setEntryFunctionId(uint32_t id) { entryFunctionId = id; }
   void setEntryFunctionId(uint32_t id) { entryFunctionId = id; }
 
 
 public:
 public:
@@ -141,10 +154,10 @@ public:
   /// This method will panic if the given decl is not registered.
   /// This method will panic if the given decl is not registered.
   uint32_t getDeclResultId(const NamedDecl *decl) const;
   uint32_t getDeclResultId(const NamedDecl *decl) const;
 
 
-  /// \brief Returns the <result-id> for the given decl if already registered;
-  /// otherwise, treats the given decl as a normal decl and returns a newly
-  /// assigned <result-id> for it.
-  uint32_t getOrRegisterDeclResultId(const NamedDecl *decl);
+  /// \brief Returns the <result-id> for the given function if already
+  /// registered; otherwise, treats the given function as a normal decl and
+  /// returns a newly assigned <result-id> for it.
+  uint32_t getOrRegisterFnResultId(const FunctionDecl *fn);
 
 
   /// Returns the storage class for the given expression. The expression is
   /// Returns the storage class for the given expression. The expression is
   /// expected to be an lvalue. Otherwise this method may panic.
   /// expected to be an lvalue. Otherwise this method may panic.
@@ -220,10 +233,11 @@ private:
 };
 };
 
 
 DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
 DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
+                                       ASTContext &context,
                                        ModuleBuilder &builder,
                                        ModuleBuilder &builder,
                                        DiagnosticsEngine &diag)
                                        DiagnosticsEngine &diag)
-    : shaderModel(model), theBuilder(builder), typeTranslator(builder, diag),
-      diags(diag), entryFunctionId(0) {}
+    : shaderModel(model), theBuilder(builder),
+      typeTranslator(context, builder, diag), diags(diag), entryFunctionId(0) {}
 
 
 bool DeclResultIdMapper::decorateStageIOLocations() {
 bool DeclResultIdMapper::decorateStageIOLocations() {
   // Try both input and output even if input location assignment failed
   // Try both input and output even if input location assignment failed

+ 57 - 5
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -341,12 +341,15 @@ uint32_t ModuleBuilder::addStageBuiltinVar(uint32_t type, spv::StorageClass sc,
   return varId;
   return varId;
 }
 }
 
 
-uint32_t ModuleBuilder::addFileVar(uint32_t type, llvm::StringRef name,
-                                   llvm::Optional<uint32_t> init) {
-  const uint32_t pointerType = getPointerType(type, spv::StorageClass::Private);
+uint32_t ModuleBuilder::addModuleVar(uint32_t type, spv::StorageClass sc,
+                                     llvm::StringRef name,
+                                     llvm::Optional<uint32_t> init) {
+  assert(sc != spv::StorageClass::Function);
+
+  // TODO: basically duplicated code of addFileVar()
+  const uint32_t pointerType = getPointerType(type, sc);
   const uint32_t varId = theContext.takeNextId();
   const uint32_t varId = theContext.takeNextId();
-  instBuilder.opVariable(pointerType, varId, spv::StorageClass::Private, init)
-      .x();
+  instBuilder.opVariable(pointerType, varId, sc, init).x();
   theModule.addVariable(std::move(constructSite));
   theModule.addVariable(std::move(constructSite));
   theModule.addDebugName(varId, name);
   theModule.addDebugName(varId, name);
   return varId;
   return varId;
@@ -474,6 +477,55 @@ uint32_t ModuleBuilder::getFunctionType(uint32_t returnType,
   return typeId;
   return typeId;
 }
 }
 
 
+uint32_t ModuleBuilder::getImageType(uint32_t sampledType, spv::Dim dim,
+                                     bool isArray) {
+  const Type *type = Type::getImage(theContext, sampledType, dim,
+                                    /*depth*/ 0, isArray, /*ms*/ 0,
+                                    /*sampled*/ 1, spv::ImageFormat::Unknown);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+
+  const char *dimStr = "";
+  switch (dim) {
+  case spv::Dim::Dim1D:
+    dimStr = "1d.";
+    break;
+  case spv::Dim::Dim2D:
+    dimStr = "2d.";
+    break;
+  case spv::Dim::Dim3D:
+    dimStr = "3d.";
+    break;
+  case spv::Dim::Cube:
+    dimStr = "cube.";
+    break;
+  case spv::Dim::Rect:
+    dimStr = "rect.";
+    break;
+  case spv::Dim::Buffer:
+    dimStr = "buffer.";
+    break;
+  case spv::Dim::SubpassData:
+    dimStr = "subpass.";
+    break;
+  default:
+    break;
+  }
+  std::string name =
+      std::string("type.") + dimStr + "image" + (isArray ? ".array" : "");
+  theModule.addDebugName(typeId, name);
+
+  return typeId;
+}
+
+uint32_t ModuleBuilder::getSamplerType() {
+  const Type *type = Type::getSampler(theContext);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  theModule.addDebugName(typeId, "type.sampler");
+  return typeId;
+}
+
 uint32_t ModuleBuilder::getConstantBool(bool value) {
 uint32_t ModuleBuilder::getConstantBool(bool value) {
   const uint32_t typeId = getBoolType();
   const uint32_t typeId = getBoolType();
   const Constant *constant = value ? Constant::getTrue(theContext, typeId)
   const Constant *constant = value ? Constant::getTrue(theContext, typeId)

+ 16 - 18
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -142,8 +142,8 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
       shaderModel(*hlsl::ShaderModel::GetByName(
       shaderModel(*hlsl::ShaderModel::GetByName(
           ci.getCodeGenOpts().HLSLProfile.c_str())),
           ci.getCodeGenOpts().HLSLProfile.c_str())),
       theContext(), theBuilder(&theContext),
       theContext(), theBuilder(&theContext),
-      declIdMapper(shaderModel, theBuilder, diags),
-      typeTranslator(theBuilder, diags), entryFunctionId(0),
+      declIdMapper(shaderModel, astContext, theBuilder, diags),
+      typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
       curFunction(nullptr) {
       curFunction(nullptr) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0") << shaderModel.GetName();
     emitError("unknown shader module: %0") << shaderModel.GetName();
@@ -247,9 +247,7 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
 uint32_t SPIRVEmitter::doExpr(const Expr *expr) {
 uint32_t SPIRVEmitter::doExpr(const Expr *expr) {
   if (const auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
   if (const auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
     // Returns the <result-id> of the referenced Decl.
     // Returns the <result-id> of the referenced Decl.
-    const NamedDecl *referredDecl = delRefExpr->getFoundDecl();
-    assert(referredDecl && "found non-NamedDecl referenced");
-    return declIdMapper.getDeclResultId(referredDecl);
+    return declIdMapper.getDeclResultId(delRefExpr->getFoundDecl());
   }
   }
 
 
   if (const auto *parenExpr = dyn_cast<ParenExpr>(expr)) {
   if (const auto *parenExpr = dyn_cast<ParenExpr>(expr)) {
@@ -396,9 +394,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
   // Create all parameters.
   // Create all parameters.
   for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
   for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
     const ParmVarDecl *paramDecl = decl->getParamDecl(i);
     const ParmVarDecl *paramDecl = decl->getParamDecl(i);
-    const uint32_t paramId =
-        theBuilder.addFnParam(paramTypes[i], paramDecl->getName());
-    declIdMapper.registerDeclResultId(paramDecl, paramId);
+    (void)declIdMapper.createFnParam(paramTypes[i], paramDecl);
   }
   }
 
 
   if (decl->hasBody()) {
   if (decl->hasBody()) {
@@ -422,6 +418,8 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
 }
 }
 
 
 void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
 void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
+  const uint32_t varType = typeTranslator.translateType(decl->getType());
+
   // The contents in externally visible variables can be updated via the
   // The contents in externally visible variables can be updated via the
   // pipeline. They should be handled differently from file and function scope
   // pipeline. They should be handled differently from file and function scope
   // variables.
   // variables.
@@ -432,7 +430,6 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
     // We already know the variable is not externally visible here. If it does
     // We already know the variable is not externally visible here. If it does
     // not have local storage, it should be file scope variable.
     // not have local storage, it should be file scope variable.
     const bool isFileScopeVar = !decl->hasLocalStorage();
     const bool isFileScopeVar = !decl->hasLocalStorage();
-    const uint32_t varType = typeTranslator.translateType(decl->getType());
 
 
     // Handle initializer. SPIR-V requires that "initializer must be an <id>
     // Handle initializer. SPIR-V requires that "initializer must be an <id>
     // from a constant instruction or a global (module scope) OpVariable
     // from a constant instruction or a global (module scope) OpVariable
@@ -452,11 +449,11 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       constInit = llvm::Optional<uint32_t>(theBuilder.getConstantNull(varType));
       constInit = llvm::Optional<uint32_t>(theBuilder.getConstantNull(varType));
     }
     }
 
 
-    const uint32_t varId =
-        isFileScopeVar
-            ? theBuilder.addFileVar(varType, decl->getName(), constInit)
-            : theBuilder.addFnVar(varType, decl->getName(), constInit);
-    declIdMapper.registerDeclResultId(decl, varId);
+    uint32_t varId;
+    if (isFileScopeVar)
+      varId = declIdMapper.createFileVar(varType, decl, constInit);
+    else
+      varId = declIdMapper.createFnVar(varType, decl, constInit);
 
 
     // If we cannot evaluate the initializer as a constant expression, we'll
     // If we cannot evaluate the initializer as a constant expression, we'll
     // need to use OpStore to write the initializer to the variable.
     // need to use OpStore to write the initializer to the variable.
@@ -475,7 +472,7 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       }
       }
     }
     }
   } else {
   } else {
-    emitError("Global variables are not supported yet.");
+    (void)declIdMapper.createExternVar(varType, decl);
   }
   }
 }
 }
 
 
@@ -1078,7 +1075,7 @@ uint32_t SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
 
 
     const uint32_t retType = typeTranslator.translateType(callExpr->getType());
     const uint32_t retType = typeTranslator.translateType(callExpr->getType());
     // Get or forward declare the function <result-id>
     // Get or forward declare the function <result-id>
-    const uint32_t funcId = declIdMapper.getOrRegisterDeclResultId(callee);
+    const uint32_t funcId = declIdMapper.getOrRegisterFnResultId(callee);
 
 
     const uint32_t retVal =
     const uint32_t retVal =
         theBuilder.createFunctionCall(retType, funcId, params);
         theBuilder.createFunctionCall(retType, funcId, params);
@@ -1788,8 +1785,9 @@ void SPIRVEmitter::initOnce(std::string varName, uint32_t varPtr,
   varName = "init.done." + varName;
   varName = "init.done." + varName;
 
 
   // Create a file/module visible variable to hold the initialization state.
   // Create a file/module visible variable to hold the initialization state.
-  const uint32_t initDoneVar = theBuilder.addFileVar(
-      boolType, varName, theBuilder.getConstantBool(false));
+  const uint32_t initDoneVar =
+      theBuilder.addModuleVar(boolType, spv::StorageClass::Private, varName,
+                              theBuilder.getConstantBool(false));
 
 
   const uint32_t condition = theBuilder.createLoad(boolType, initDoneVar);
   const uint32_t condition = theBuilder.createLoad(boolType, initDoneVar);
 
 

+ 50 - 10
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -9,6 +9,7 @@
 
 
 #include "TypeTranslator.h"
 #include "TypeTranslator.h"
 
 
+#include "dxc/HLSL/DxilConstants.h"
 #include "clang/AST/HlslTypes.h"
 #include "clang/AST/HlslTypes.h"
 
 
 namespace clang {
 namespace clang {
@@ -23,8 +24,8 @@ uint32_t TypeTranslator::translateType(QualType type) {
   // Primitive types
   // Primitive types
   {
   {
     QualType ty = {};
     QualType ty = {};
-    if (isScalarType(type, &ty)) {
-      if (const auto *builtinType = cast<BuiltinType>(ty.getTypePtr())) {
+    if (isScalarType(type, &ty))
+      if (const auto *builtinType = cast<BuiltinType>(ty.getTypePtr()))
         switch (builtinType->getKind()) {
         switch (builtinType->getKind()) {
         case BuiltinType::Void:
         case BuiltinType::Void:
           return theBuilder.getVoidType();
           return theBuilder.getVoidType();
@@ -41,19 +42,15 @@ uint32_t TypeTranslator::translateType(QualType type) {
               << builtinType->getTypeClassName();
               << builtinType->getTypeClassName();
           return 0;
           return 0;
         }
         }
-      }
-    }
   }
   }
 
 
-  const auto *typePtr = type.getTypePtr();
-
   // Typedefs
   // Typedefs
-  if (const auto *typedefType = dyn_cast<TypedefType>(typePtr)) {
+  if (const auto *typedefType = type->getAs<TypedefType>()) {
     return translateType(typedefType->desugar());
     return translateType(typedefType->desugar());
   }
   }
 
 
   // Reference types
   // Reference types
-  if (const auto *refType = dyn_cast<ReferenceType>(typePtr)) {
+  if (const auto *refType = type->getAs<ReferenceType>()) {
     // Note: Pointer/reference types are disallowed in HLSL source code.
     // Note: Pointer/reference types are disallowed in HLSL source code.
     // Although developers cannot use them directly, they are generated into
     // Although developers cannot use them directly, they are generated into
     // the AST by out/inout parameter modifiers in function signatures.
     // the AST by out/inout parameter modifiers in function signatures.
@@ -108,9 +105,16 @@ uint32_t TypeTranslator::translateType(QualType type) {
   }
   }
 
 
   // Struct type
   // Struct type
-  if (const auto *structType = dyn_cast<RecordType>(typePtr)) {
+  if (const auto *structType = type->getAs<RecordType>()) {
     const auto *decl = structType->getDecl();
     const auto *decl = structType->getDecl();
 
 
+    // HLSL resource types are also represented as RecordType in the AST.
+    // (ClassTemplateSpecializationDecl is a subclass of CXXRecordDecl, which is
+    // then a subclass of RecordDecl.) So we need to check them before checking
+    // the general struct type.
+    if (const auto id = translateResourceType(type))
+      return id;
+
     // Collect all fields' types and names.
     // Collect all fields' types and names.
     llvm::SmallVector<uint32_t, 4> fieldTypes;
     llvm::SmallVector<uint32_t, 4> fieldTypes;
     llvm::SmallVector<llvm::StringRef, 4> fieldNames;
     llvm::SmallVector<llvm::StringRef, 4> fieldNames;
@@ -122,8 +126,9 @@ uint32_t TypeTranslator::translateType(QualType type) {
     return theBuilder.getStructType(fieldTypes, decl->getName(), fieldNames);
     return theBuilder.getStructType(fieldTypes, decl->getName(), fieldNames);
   }
   }
 
 
-  if (const auto *arrayType = dyn_cast<ConstantArrayType>(typePtr)) {
+  if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
     const uint32_t elemType = translateType(arrayType->getElementType());
     const uint32_t elemType = translateType(arrayType->getElementType());
+    // TODO: handle extra large array size?
     const auto size =
     const auto size =
         static_cast<uint32_t>(arrayType->getSize().getZExtValue());
         static_cast<uint32_t>(arrayType->getSize().getZExtValue());
     return theBuilder.getArrayType(elemType,
     return theBuilder.getArrayType(elemType,
@@ -297,5 +302,40 @@ uint32_t TypeTranslator::getComponentVectorType(QualType matrixType) {
   return theBuilder.getVecType(elemType, colCount);
   return theBuilder.getVecType(elemType, colCount);
 }
 }
 
 
+uint32_t TypeTranslator::translateResourceType(QualType type) {
+  const auto *recordType = type->getAs<RecordType>();
+  assert(recordType);
+  const llvm::StringRef name = recordType->getDecl()->getName();
+
+  // TODO: avoid string comparison once hlsl::IsHLSLResouceType() does that.
+
+  { // Texture types
+    spv::Dim dim = {};
+    bool isArray = {};
+
+    if ((dim = spv::Dim::Dim1D, isArray = false, name == "Texture1D") ||
+        (dim = spv::Dim::Dim2D, isArray = false, name == "Texture2D") ||
+        (dim = spv::Dim::Dim3D, isArray = false, name == "Texture3D") ||
+        (dim = spv::Dim::Cube, isArray = false, name == "TextureCube") ||
+        (dim = spv::Dim::Dim1D, isArray = true, name == "Texture1DArray") ||
+        (dim = spv::Dim::Dim2D, isArray = true, name == "Texture2DArray") ||
+        // There is no Texture3DArray
+        (dim = spv::Dim::Cube, isArray = true, name == "TextureCubeArray")) {
+      if (dim == spv::Dim::Dim1D)
+        theBuilder.requireCapability(spv::Capability::Sampled1D);
+      const auto sampledType = hlsl::GetHLSLResourceResultType(type);
+      return theBuilder.getImageType(translateType(getElementType(sampledType)),
+                                     dim, isArray);
+    }
+  }
+
+  // Sampler types
+  if (name == "SamplerState" || name == "SamplerComparisonState") {
+    return theBuilder.getSamplerType();
+  }
+
+  return 0;
+}
+
 } // end namespace spirv
 } // end namespace spirv
 } // end namespace clang
 } // end namespace clang

+ 8 - 2
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -26,8 +26,9 @@ namespace spirv {
 /// DiagnosticEngine passed into the constructor.
 /// DiagnosticEngine passed into the constructor.
 class TypeTranslator {
 class TypeTranslator {
 public:
 public:
-  TypeTranslator(ModuleBuilder &builder, DiagnosticsEngine &diag)
-      : theBuilder(builder), diags(diag) {}
+  TypeTranslator(ASTContext &context, ModuleBuilder &builder,
+                 DiagnosticsEngine &diag)
+      : astContext(context), theBuilder(builder), diags(diag) {}
 
 
   /// \brief Generates the corresponding SPIR-V type for the given Clang
   /// \brief Generates the corresponding SPIR-V type for the given Clang
   /// frontend type and returns the type's <result-id>. On failure, reports
   /// frontend type and returns the type's <result-id>. On failure, reports
@@ -101,7 +102,12 @@ private:
     return diags.Report(diagId);
     return diags.Report(diagId);
   }
   }
 
 
+  /// \brief Translates the given HLSL resource type into its SPIR-V
+  /// instructions and returns the <result-id>. Returns 0 on failure.
+  uint32_t translateResourceType(QualType type);
+
 private:
 private:
+  ASTContext &astContext;
   ModuleBuilder &theBuilder;
   ModuleBuilder &theBuilder;
   DiagnosticsEngine &diags;
   DiagnosticsEngine &diags;
 };
 };

+ 15 - 0
tools/clang/test/CodeGenSPIRV/type.sampler.hlsl

@@ -0,0 +1,15 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: %type_sampler = OpTypeSampler
+// CHECK: %_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler
+
+// CHECK: %s1 = OpVariable %_ptr_UniformConstant_type_sampler UniformConstant
+SamplerState           s1 : register(s1);
+// CHECK: %s2 = OpVariable %_ptr_UniformConstant_type_sampler UniformConstant
+SamplerComparisonState s2 : register(s2);
+// CHECK: %s3 = OpVariable %_ptr_UniformConstant_type_sampler UniformConstant
+sampler                s3 : register(s3);
+
+void main() {
+// CHECK-LABEL: %main = OpFunction
+}

+ 45 - 0
tools/clang/test/CodeGenSPIRV/type.texture.hlsl

@@ -0,0 +1,45 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: OpCapability Sampled1D
+
+// CHECK: %type_1d_image = OpTypeImage %float 1D 0 0 0 1 Unknown
+// CHECK: %_ptr_UniformConstant_type_1d_image = OpTypePointer UniformConstant %type_1d_image
+
+// CHECK: %type_2d_image = OpTypeImage %int 2D 0 0 0 1 Unknown
+// CHECK: %_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image
+
+// CHECK: %type_3d_image = OpTypeImage %uint 3D 0 0 0 1 Unknown
+// CHECK: %_ptr_UniformConstant_type_3d_image = OpTypePointer UniformConstant %type_3d_image
+
+// CHECK: %type_cube_image = OpTypeImage %float Cube 0 0 0 1 Unknown
+// CHECK: %_ptr_UniformConstant_type_cube_image = OpTypePointer UniformConstant %type_cube_image
+
+// CHECK: %type_1d_image_array = OpTypeImage %float 1D 0 1 0 1 Unknown
+// CHECK: %_ptr_UniformConstant_type_1d_image_array = OpTypePointer UniformConstant %type_1d_image_array
+
+// CHECK: %type_2d_image_array = OpTypeImage %int 2D 0 1 0 1 Unknown
+// CHECK: %_ptr_UniformConstant_type_2d_image_array = OpTypePointer UniformConstant %type_2d_image_array
+
+// CHECK: %type_cube_image_array = OpTypeImage %float Cube 0 1 0 1 Unknown
+// CHECK: %_ptr_UniformConstant_type_cube_image_array = OpTypePointer UniformConstant %type_cube_image_array
+
+// CHECK: %t1 = OpVariable %_ptr_UniformConstant_type_1d_image UniformConstant
+Texture1D   <float4> t1 : register(t1);
+// CHECK: %t2 = OpVariable %_ptr_UniformConstant_type_2d_image UniformConstant
+Texture2D   <int4>   t2 : register(t2);
+// CHECK: %t3 = OpVariable %_ptr_UniformConstant_type_3d_image UniformConstant
+Texture3D   <uint4>  t3 : register(t3);
+// CHECK: %t4 = OpVariable %_ptr_UniformConstant_type_cube_image UniformConstant
+TextureCube <float4> t4 : register(t4);
+
+
+// CHECK: %t5 = OpVariable %_ptr_UniformConstant_type_1d_image_array UniformConstant
+Texture1DArray   <float4> t5 : register(t5);
+// CHECK: %t6 = OpVariable %_ptr_UniformConstant_type_2d_image_array UniformConstant
+Texture2DArray   <int4>   t6 : register(t6);
+// CHECK: %t7 = OpVariable %_ptr_UniformConstant_type_cube_image_array UniformConstant
+TextureCubeArray <float4> t7 : register(t7);
+
+void main() {
+// CHECK-LABEL: %main = OpFunction
+}

+ 2 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -41,6 +41,8 @@ TEST_F(FileTest, MatrixTypes) { runFileTest("type.matrix.hlsl"); }
 TEST_F(FileTest, StructTypes) { runFileTest("type.struct.hlsl"); }
 TEST_F(FileTest, StructTypes) { runFileTest("type.struct.hlsl"); }
 TEST_F(FileTest, ArrayTypes) { runFileTest("type.array.hlsl"); }
 TEST_F(FileTest, ArrayTypes) { runFileTest("type.array.hlsl"); }
 TEST_F(FileTest, TypedefTypes) { runFileTest("type.typedef.hlsl"); }
 TEST_F(FileTest, TypedefTypes) { runFileTest("type.typedef.hlsl"); }
+TEST_F(FileTest, SamplerTypes) { runFileTest("type.sampler.hlsl"); }
+TEST_F(FileTest, TextureTypes) { runFileTest("type.texture.hlsl"); }
 
 
 // For constants
 // For constants
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }