Browse Source

[spirv] Lower Hybrid types into SPIR-V types.

Also, StructType and HybridStructType should share the same enum for
interface type.
Ehsan 6 years ago
parent
commit
7aaa63a714

+ 1 - 0
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -20,6 +20,7 @@ namespace spirv {
 class SpirvFunction;
 class SpirvBasicBlock;
 class SpirvType;
+class SpirvBuilder;
 
 // Provides DenseMapInfo for SpirvLayoutRule so that we can use it as key to
 // DenseMap.

+ 6 - 0
tools/clang/include/clang/SPIRV/LowerTypeVisitor.h

@@ -51,6 +51,12 @@ private:
   /// The lowering is recursive; all the types that the target type depends
   /// on will be created in SpirvContext.
   const SpirvType *lowerType(QualType type, SpirvLayoutRule, SourceLocation);
+  /// Lowers the given Hybrid type into a SPIR-V type.
+  ///
+  /// Uses the above lowerType method to lower the QualType components of hybrid
+  /// types.
+  const SpirvType *lowerType(const HybridType *, SpirvLayoutRule,
+                             SourceLocation);
 
   /// Lowers the given HLSL resource type into its SPIR-V type.
   const SpirvType *lowerResourceType(QualType type, SpirvLayoutRule rule,

+ 9 - 11
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -206,17 +206,15 @@ public:
   const ArrayType *getArrayType(const SpirvType *elemType, uint32_t elemCount);
   const RuntimeArrayType *getRuntimeArrayType(const SpirvType *elemType);
 
-  const StructType *
-  getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
-                llvm::StringRef name, bool isReadOnly = false,
-                StructType::InterfaceType interfaceType =
-                    StructType::InterfaceType::InternalStorage);
-
-  const HybridStructType *
-  getHybridStructType(llvm::ArrayRef<HybridStructType::FieldInfo> fields,
-                      llvm::StringRef name, bool isReadOnly = false,
-                      HybridStructType::InterfaceType interfaceType =
-                          HybridStructType::InterfaceType::InternalStorage);
+  const StructType *getStructType(
+      llvm::ArrayRef<StructType::FieldInfo> fields, llvm::StringRef name,
+      bool isReadOnly = false,
+      StructInterfaceType interfaceType = StructInterfaceType::InternalStorage);
+
+  const HybridStructType *getHybridStructType(
+      llvm::ArrayRef<HybridStructType::FieldInfo> fields, llvm::StringRef name,
+      bool isReadOnly = false,
+      StructInterfaceType interfaceType = StructInterfaceType::InternalStorage);
 
   const SpirvPointerType *getPointerType(const SpirvType *pointee,
                                          spv::StorageClass);

+ 2 - 0
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -69,6 +69,8 @@ public:
   // Returns the result-id of the OpTypeFunction
   uint32_t getFunctionTypeId() const { return fnTypeId; }
 
+  SourceLocation getSourceLocation() const { return functionLoc; }
+
   void setConstainsAliasComponent(bool isAlias) { containsAlias = isAlias; }
   bool constainsAliasComponent() { return containsAlias; }
 

+ 2 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -1539,6 +1539,8 @@ private:
 };
 
 /// \brief OpSampledImage instruction
+/// Result Type must be the OpTypeSampledImage type whose Image Type operand is
+/// the type of Image. We store the QualType for the underlying image as result type.
 class SpirvSampledImage : public SpirvInstruction {
 public:
   SpirvSampledImage(QualType resultType, uint32_t resultId, SourceLocation loc,

+ 20 - 26
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -22,6 +22,12 @@
 namespace clang {
 namespace spirv {
 
+enum class StructInterfaceType : uint32_t {
+  InternalStorage = 0,
+  StorageBuffer = 1,
+  UniformBuffer = 2,
+};
+
 class SpirvType {
 public:
   enum Kind {
@@ -250,12 +256,6 @@ private:
 
 class StructType : public SpirvType {
 public:
-  enum class InterfaceType : uint32_t {
-    InternalStorage = 0,
-    StorageBuffer = 1,
-    UniformBuffer = 2,
-  };
-
   struct FieldInfo {
   public:
     FieldInfo(const SpirvType *type_, llvm::StringRef name_ = "",
@@ -276,16 +276,16 @@ public:
     hlsl::ConstantPacking *packOffsetAttr;
   };
 
-  StructType(llvm::ArrayRef<FieldInfo> fields, llvm::StringRef name,
-             bool isReadOnly,
-             InterfaceType interfaceType = InterfaceType::InternalStorage);
+  StructType(
+      llvm::ArrayRef<FieldInfo> fields, llvm::StringRef name, bool isReadOnly,
+      StructInterfaceType interfaceType = StructInterfaceType::InternalStorage);
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Struct; }
 
   llvm::ArrayRef<FieldInfo> getFields() const { return fields; }
   bool isReadOnly() const { return readOnly; }
   std::string getStructName() const { return structName; }
-  InterfaceType getInterfaceType() const { return interfaceType; }
+  StructInterfaceType getInterfaceType() const { return interfaceType; }
 
   bool operator==(const StructType &that) const;
 
@@ -301,7 +301,7 @@ private:
   // storage buffer shader-interface, it will be decorated with 'BufferBlock'.
   // If this structure is a uniform buffer shader-interface, it will be
   // decorated with 'Block'.
-  InterfaceType interfaceType;
+  StructInterfaceType interfaceType;
 };
 
 class SpirvPointerType : public SpirvType {
@@ -337,7 +337,7 @@ public:
     return returnType == that.returnType && paramTypes == that.paramTypes;
   }
 
-  //void setReturnType(const SpirvType *t) { returnType = t; }
+  // void setReturnType(const SpirvType *t) { returnType = t; }
   const SpirvType *getReturnType() const { return returnType; }
   llvm::ArrayRef<const SpirvType *> getParamTypes() const { return paramTypes; }
 
@@ -363,25 +363,18 @@ protected:
 /// This type uses a mix of SpirvType and QualType for the structure fields.
 class HybridStructType : public HybridType {
 public:
-  enum class InterfaceType : uint32_t {
-    InternalStorage = 0,
-    StorageBuffer = 1,
-    UniformBuffer = 2,
-  };
-
   struct FieldInfo {
   public:
-    FieldInfo(QualType astType_, const SpirvType *type_,
-              llvm::StringRef name_ = "", clang::VKOffsetAttr *offset = nullptr,
+    FieldInfo(QualType astType_, llvm::StringRef name_ = "",
+              clang::VKOffsetAttr *offset = nullptr,
               hlsl::ConstantPacking *packOffset = nullptr)
-        : astType(astType_), spirvType(type_), name(name_),
-          vkOffsetAttr(offset), packOffsetAttr(packOffset) {}
+        : astType(astType_), name(name_), vkOffsetAttr(offset),
+          packOffsetAttr(packOffset) {}
 
     bool operator==(const FieldInfo &that) const;
 
     // The field's type.
     QualType astType;
-    const SpirvType *spirvType;
     // The field's name.
     std::string name;
     // vk::offset attributes associated with this field.
@@ -392,7 +385,7 @@ public:
 
   HybridStructType(
       llvm::ArrayRef<FieldInfo> fields, llvm::StringRef name, bool isReadOnly,
-      InterfaceType interfaceType = InterfaceType::InternalStorage);
+      StructInterfaceType interfaceType = StructInterfaceType::InternalStorage);
 
   static bool classof(const SpirvType *t) {
     return t->getKind() == TK_HybridStruct;
@@ -401,7 +394,7 @@ public:
   llvm::ArrayRef<FieldInfo> getFields() const { return fields; }
   bool isReadOnly() const { return readOnly; }
   std::string getStructName() const { return structName; }
-  InterfaceType getInterfaceType() const { return interfaceType; }
+  StructInterfaceType getInterfaceType() const { return interfaceType; }
 
   bool operator==(const HybridStructType &that) const;
 
@@ -417,7 +410,7 @@ private:
   // storage buffer shader-interface, it will be decorated with 'BufferBlock'.
   // If this structure is a uniform buffer shader-interface, it will be
   // decorated with 'Block'.
-  InterfaceType interfaceType;
+  StructInterfaceType interfaceType;
 };
 
 class HybridPointerType : public HybridType {
@@ -470,6 +463,7 @@ public:
            returnType == that.returnType && paramTypes == that.paramTypes;
   }
 
+  QualType getAstReturnType() const { return astReturnType; }
   void setReturnType(const SpirvType *t) { returnType = t; }
   const SpirvType *getReturnType() const { return returnType; }
   llvm::ArrayRef<const SpirvType *> getParamTypes() const { return paramTypes; }

+ 6 - 5
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -704,7 +704,8 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
     // We don't need it here.
     auto varType = declDecl->getType();
     varType.removeLocalConst();
-    HybridStructType::FieldInfo info(varType, nullptr, declDecl->getName());
+    HybridStructType::FieldInfo info(varType, declDecl->getName());
+    fields.push_back(info);
 
     if (spirvOptions.enable16BitTypes &&
         isOrContains16BitType(varType, spirvOptions.enable16BitTypes)) {
@@ -723,8 +724,8 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
   // tbuffer/TextureBuffers are non-writable SSBOs.
   const SpirvType *resultType = spvContext.getHybridStructType(
       fields, typeName, /*isReadOnly*/ forTBuffer,
-      forTBuffer ? HybridStructType::InterfaceType::StorageBuffer
-                 : HybridStructType::InterfaceType::UniformBuffer);
+      forTBuffer ? StructInterfaceType::StorageBuffer
+                 : StructInterfaceType::UniformBuffer);
 
   // Make an array if requested.
   if (arraySize > 0) {
@@ -968,8 +969,8 @@ void DeclResultIdMapper::createCounterVar(
         spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
   }
 
-  SpirvVariable *counterInstr = spvBuilder.addModuleVar(
-      counterType, spv::StorageClass::Uniform, counterName);
+  SpirvVariable *counterInstr =
+      spvBuilder.addModuleVar(counterType, sc, counterName);
 
   if (!isAlias) {
     // Non-alias counter variables should be put in to resourceVars so that

+ 3 - 3
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -10,10 +10,10 @@
 #include "clang/SPIRV/EmitVisitor.h"
 #include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/SpirvBasicBlock.h"
+#include "clang/SPIRV/SpirvBuilder.h"
 #include "clang/SPIRV/SpirvFunction.h"
 #include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvModule.h"
-#include "clang/SPIRV/SpirvBuilder.h"
 #include "clang/SPIRV/SpirvType.h"
 #include "clang/SPIRV/String.h"
 
@@ -1146,9 +1146,9 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
 
     // Emit Block or BufferBlock decorations if necessary.
     auto interfaceType = structType->getInterfaceType();
-    if (interfaceType == StructType::InterfaceType::StorageBuffer)
+    if (interfaceType == StructInterfaceType::StorageBuffer)
       emitDecoration(id, spv::Decoration::BufferBlock, {});
-    else if (interfaceType == StructType::InterfaceType::UniformBuffer)
+    else if (interfaceType == StructInterfaceType::UniformBuffer)
       emitDecoration(id, spv::Decoration::Block, {});
   }
   // Pointer types

+ 67 - 10
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -27,27 +27,84 @@ bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
 
     // In case the function type is a hybrid type, we should also lower the
     // return type of the SPIR-V function type.
-    if (auto *fnRetType = dyn_cast<HybridFunctionType>(fn->getFunctionType())) {
-      fnRetType->setReturnType(spirvReturnType);
+    if (auto *fnRetType = dyn_cast<HybridType>(fn->getFunctionType())) {
+      fn->setFunctionType(const_cast<SpirvType *>(lowerType(
+          fnRetType, SpirvLayoutRule::Void, fn->getSourceLocation())));
     }
   }
   return true;
 }
 
 bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
-  if (instr->getAstResultType() != QualType({})) {
-    const auto loweredType =
-        lowerType(instr->getAstResultType(), instr->getLayoutRule(),
-                  instr->getSourceLocation());
-
-    instr->setResultType(loweredType);
-
-    return loweredType != nullptr;
+  const QualType astType = instr->getAstResultType();
+  const SpirvType *hybridType = instr->getResultType();
+
+  // Lower QualType to SpirvType
+  if (astType != QualType({})) {
+    const SpirvType *spirvType =
+        lowerType(astType, instr->getLayoutRule(), instr->getSourceLocation());
+    instr->setResultType(spirvType);
+    return spirvType != nullptr;
+  }
+  // Lower Hybrid type to SpirvType
+  else if (hybridType) {
+    if (const auto *hybridType = dyn_cast<HybridType>(instr->getResultType())) {
+      const SpirvType *spirvType = lowerType(hybridType, instr->getLayoutRule(),
+                                             instr->getSourceLocation());
+      instr->setResultType(spirvType);
+    }
   }
 
+  // The instruction does not have a result-type, so nothing to do.
   return true;
 }
 
+const SpirvType *LowerTypeVisitor::lowerType(const HybridType *hybrid,
+                                             SpirvLayoutRule rule,
+                                             SourceLocation loc) {
+  if (const auto *hybridPointer = dyn_cast<HybridPointerType>(hybrid)) {
+    const QualType pointeeType = hybridPointer->getPointeeType();
+    const SpirvType *pointeeSpirvType = lowerType(pointeeType, rule, loc);
+    return spvContext.getPointerType(pointeeSpirvType,
+                                     hybridPointer->getStorageClass());
+  } else if (const auto *hybridSampledImage =
+                 dyn_cast<HybridSampledImageType>(hybrid)) {
+    const QualType imageAstType = hybridSampledImage->getImageType();
+    const SpirvType *imageSpirvType = lowerType(imageAstType, rule, loc);
+    assert(isa<ImageType>(imageSpirvType));
+    return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
+  } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(hybrid)) {
+    // Lower the return type.
+    const QualType astReturnType = hybridFn->getAstReturnType();
+    const SpirvType *spirvReturnType = lowerType(astReturnType, rule, loc);
+
+    // Go over all params. If any of them is hybrid, lower it.
+    std::vector<const SpirvType *> paramTypes;
+    for (auto *paramType : hybridFn->getParamTypes()) {
+      if (const auto *hybridParam = dyn_cast<HybridType>(paramType)) {
+        paramTypes.push_back(lowerType(hybridParam, rule, loc));
+      } else {
+        paramTypes.push_back(paramType);
+      }
+    }
+
+    return spvContext.getFunctionType(spirvReturnType, paramTypes);
+  } else if (const auto *hybridStruct = dyn_cast<HybridStructType>(hybrid)) {
+    // lower all fields of the struct.
+    std::vector<StructType::FieldInfo> structFields;
+    for (auto field : hybridStruct->getFields()) {
+      const SpirvType *fieldSpirvType = lowerType(field.astType, rule, loc);
+      structFields.push_back(StructType::FieldInfo(fieldSpirvType, field.name,
+                                                   field.vkOffsetAttr,
+                                                   field.packOffsetAttr));
+    }
+    return spvContext.getStructType(structFields, hybridStruct->getStructName(),
+                                    hybridStruct->isReadOnly(),
+                                    hybridStruct->getInterfaceType());
+  }
+  llvm_unreachable("lowering of hybrid type not implemented");
+}
+
 const SpirvType *LowerTypeVisitor::lowerType(QualType type,
                                              SpirvLayoutRule rule,
                                              SourceLocation srcLoc) {

+ 2 - 3
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -239,7 +239,7 @@ SpirvContext::getRuntimeArrayType(const SpirvType *elemType) {
 const StructType *
 SpirvContext::getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
                             llvm::StringRef name, bool isReadOnly,
-                            StructType::InterfaceType interfaceType) {
+                            StructInterfaceType interfaceType) {
   // We are creating a temporary struct type here for querying whether the
   // same type was already created. It is a little bit costly, but we can
   // avoid allocating directly from the bump pointer allocator, from which
@@ -262,7 +262,7 @@ SpirvContext::getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
 
 const HybridStructType *SpirvContext::getHybridStructType(
     llvm::ArrayRef<HybridStructType::FieldInfo> fields, llvm::StringRef name,
-    bool isReadOnly, HybridStructType::InterfaceType interfaceType) {
+    bool isReadOnly, StructInterfaceType interfaceType) {
   // We are creating a temporary struct type here for querying whether the
   // same type was already created. It is a little bit costly, but we can
   // avoid allocating directly from the bump pointer allocator, from which
@@ -375,6 +375,5 @@ const StructType *SpirvContext::getACSBufferCounterType() {
   return type;
 }
 
-
 } // end namespace spirv
 } // end namespace clang

+ 9 - 10
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -5750,7 +5750,7 @@ SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
       auto *result = tryToAssignToRWBufferRWTexture(base, newVec);
       assert(result); // Definitely RWBuffer/RWTexture assignment
       (void)result;
-      return rhs;     // TODO: incorrect for compound assignments
+      return rhs; // TODO: incorrect for compound assignments
     } else {
       // Assigning to one normal vector component. Nothing special, just fall
       // back to the normal CodeGen path.
@@ -7354,8 +7354,8 @@ SpirvInstruction *SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
     if (isScalarType(argType) || isVectorType(argType)) {
       // The struct members *must* have the same type.
       const auto modfStructType = spvContext.getHybridStructType(
-          {HybridStructType::FieldInfo(argType, nullptr, "frac"),
-           HybridStructType::FieldInfo(argType, nullptr, "ip")},
+          {HybridStructType::FieldInfo(argType, "frac"),
+           HybridStructType::FieldInfo(argType, "ip")},
           "ModfStructType");
       auto *modf = spvBuilder.createExtInst(modfStructType, glslInstSet,
                                             GLSLstd450::GLSLstd450ModfStruct,
@@ -7377,8 +7377,8 @@ SpirvInstruction *SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
     if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
       const auto colType = astContext.getExtVectorType(elemType, colCount);
       const auto modfStructType = spvContext.getHybridStructType(
-          {HybridStructType::FieldInfo(colType, nullptr, "frac"),
-           HybridStructType::FieldInfo(colType, nullptr, "ip")},
+          {HybridStructType::FieldInfo(colType, "frac"),
+           HybridStructType::FieldInfo(colType, "ip")},
           "ModfStructType");
       llvm::SmallVector<SpirvInstruction *, 4> fracs;
       llvm::SmallVector<SpirvInstruction *, 4> ips;
@@ -7470,8 +7470,8 @@ SPIRVEmitter::processIntrinsicFrexp(const CallExpr *callExpr) {
               ? astContext.IntTy
               : astContext.getExtVectorType(astContext.IntTy, elemCount);
       const auto *frexpStructType = spvContext.getHybridStructType(
-          {HybridStructType::FieldInfo(argType, nullptr, "mantissa"),
-           HybridStructType::FieldInfo(expType, nullptr, "exponent")},
+          {HybridStructType::FieldInfo(argType, "mantissa"),
+           HybridStructType::FieldInfo(expType, "exponent")},
           "FrexpStructType");
       auto *frexp = spvBuilder.createExtInst(frexpStructType, glslInstSet,
                                              GLSLstd450::GLSLstd450FrexpStruct,
@@ -7498,8 +7498,8 @@ SPIRVEmitter::processIntrinsicFrexp(const CallExpr *callExpr) {
       const auto colType =
           astContext.getExtVectorType(astContext.FloatTy, colCount);
       const auto *frexpStructType = spvContext.getHybridStructType(
-          {HybridStructType::FieldInfo(colType, nullptr, "mantissa"),
-           HybridStructType::FieldInfo(expType, nullptr, "exponent")},
+          {HybridStructType::FieldInfo(colType, "mantissa"),
+           HybridStructType::FieldInfo(expType, "exponent")},
           "FrexpStructType");
       llvm::SmallVector<SpirvInstruction *, 4> exponents;
       llvm::SmallVector<SpirvInstruction *, 4> mantissas;
@@ -9508,7 +9508,6 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     // If not explicitly initialized, initialize with their zero values if not
     // resource objects
     else if (!hlsl::IsHLSLResourceType(varDecl->getType())) {
-      const QualType type = varDecl->getType();
       auto *nullValue = spvBuilder.getConstantNull(varDecl->getType());
       spvBuilder.createStore(varInfo, nullValue);
     }

+ 1 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -29,8 +29,8 @@
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/ModuleBuilder.h"
+#include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/SpirvBuilder.h"
-#include "clang/SPIRV/SpirvContext.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 

+ 0 - 3
tools/clang/lib/SPIRV/SpirvModule.cpp

@@ -133,9 +133,6 @@ void SpirvModule::addDecoration(SpirvDecoration *decor) {
 
 void SpirvModule::addConstant(SpirvConstant *constant) {
   assert(constant);
-  if (constants.empty()) {
-    printf("wtf\n");
-  }
   constants.push_back(constant);
 }
 

+ 4 - 4
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -52,7 +52,7 @@ bool ImageType::operator==(const ImageType &that) const {
 
 StructType::StructType(llvm::ArrayRef<StructType::FieldInfo> fieldsVec,
                        llvm::StringRef name, bool isReadOnly,
-                       StructType::InterfaceType iface)
+                       StructInterfaceType iface)
     : SpirvType(TK_Struct), fields(fieldsVec.begin(), fieldsVec.end()),
       structName(name), readOnly(isReadOnly), interfaceType(iface) {}
 
@@ -70,14 +70,14 @@ bool StructType::operator==(const StructType &that) const {
 
 HybridStructType::HybridStructType(
     llvm::ArrayRef<HybridStructType::FieldInfo> fieldsVec, llvm::StringRef name,
-    bool isReadOnly, HybridStructType::InterfaceType iface)
+    bool isReadOnly, StructInterfaceType iface)
     : HybridType(TK_HybridStruct), fields(fieldsVec.begin(), fieldsVec.end()),
       structName(name), readOnly(isReadOnly), interfaceType(iface) {}
 
 bool HybridStructType::FieldInfo::
 operator==(const HybridStructType::FieldInfo &that) const {
-  return astType == that.astType && spirvType == that.spirvType &&
-         name == that.name && vkOffsetAttr == that.vkOffsetAttr &&
+  return astType == that.astType && name == that.name &&
+         vkOffsetAttr == that.vkOffsetAttr &&
          packOffsetAttr == that.packOffsetAttr;
 }