Browse Source

[spirv] Emit UserTypeGOOGLE decoration on variables with binding. (#2289)

Ehsan 6 years ago
parent
commit
ed86c82a8f

+ 6 - 1
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -206,6 +206,11 @@ bool isSubpassInput(QualType);
 /// \brief Returns true if the given type is SubpassInputMS.
 bool isSubpassInputMS(QualType);
 
+/// \brief If the given QualType is an HLSL resource type (or array of
+/// resources), returns its HLSL type name. e.g. "RWTexture2D". Returns an empty
+/// string otherwise.
+std::string getHlslResourceTypeName(QualType);
+
 /// Returns true if the given type will be translated into a SPIR-V image,
 /// sampler or struct containing images or samplers.
 ///
@@ -220,7 +225,7 @@ bool isOpaqueArrayType(QualType type);
 /// (in a recursive away).
 ///
 /// Note: legalization specific code
-bool isOpaqueStructType(QualType tye);
+bool isOpaqueStructType(QualType type);
 
 /// \brief Returns true if the given type can use relaxed precision
 /// decoration. Integer and float types with lower than 32 bits can be

+ 1 - 0
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -42,6 +42,7 @@ enum class Extension {
   AMD_gpu_shader_half_float,
   AMD_shader_explicit_vertex_parameter,
   GOOGLE_hlsl_functionality1,
+  GOOGLE_user_type,
   NV_ray_tracing,
   Unknown,
 };

+ 1 - 1
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -479,7 +479,7 @@ public:
 
   /// \brief Decorates the given target with the given descriptor set and
   /// binding number.
-  void decorateDSetBinding(SpirvInstruction *target, uint32_t setNumber,
+  void decorateDSetBinding(SpirvVariable *target, uint32_t setNumber,
                            uint32_t bindingNumber);
 
   /// \brief Decorates the given target with the given SpecId.

+ 9 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -461,9 +461,18 @@ public:
 
   bool hasInitializer() const { return initializer != nullptr; }
   SpirvInstruction *getInitializer() const { return initializer; }
+  bool hasBinding() const { return descriptorSet >= 0 || binding >= 0; }
+  llvm::StringRef getHlslUserType() const { return hlslUserType; }
+
+  void setDescriptorSetNo(int32_t dset) { descriptorSet = dset; }
+  void setBindingNo(int32_t b) { binding = b; }
+  void setHlslUserType(llvm::StringRef userType) { hlslUserType = userType; }
 
 private:
   SpirvInstruction *initializer;
+  int32_t descriptorSet;
+  int32_t binding;
+  std::string hlslUserType;
 };
 
 class SpirvFunctionParameter : public SpirvInstruction {

+ 29 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -893,6 +893,35 @@ bool isOpaqueType(QualType type) {
   return false;
 }
 
+std::string getHlslResourceTypeName(QualType type) {
+  if (type.isNull())
+    return "";
+
+  // Strip outer arrayness first
+  while (type->isArrayType())
+    type = type->getAsArrayTypeUnsafe()->getElementType();
+
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
+        name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
+        name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer" ||
+        name == "Texture1D" || name == "Texture2D" || name == "Texture3D" ||
+        name == "TextureCube" || name == "Texture1DArray" ||
+        name == "Texture2DArray" || name == "Texture2DMS" ||
+        name == "Texture2DMSArray" || name == "TextureCubeArray" ||
+        name == "RWTexture1D" || name == "RWTexture2D" ||
+        name == "RWTexture3D" || name == "RWTexture1DArray" ||
+        name == "RWTexture2DArray" || name == "Buffer" || name == "RWBuffer" ||
+        name == "SubpassInput" || name == "SubpassInputMS" ||
+        name == "InputPatch" || name == "OutputPatch") {
+      return name;
+    }
+  }
+
+  return "";
+}
+
 bool isOpaqueStructType(QualType type) {
   if (isOpaqueType(type))
     return false;

+ 9 - 0
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -463,6 +463,15 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
   case spv::Op::OpGroupNonUniformQuadSwap:
     addCapability(spv::Capability::GroupNonUniformQuad);
     break;
+  case spv::Op::OpVariable: {
+    if (spvOptions.enableReflect &&
+        !cast<SpirvVariable>(instr)->getHlslUserType().empty()) {
+      addExtension(Extension::GOOGLE_user_type, "HLSL User Type", loc);
+      addExtension(Extension::GOOGLE_hlsl_functionality1, "HLSL User Type",
+                   loc);
+    }
+    break;
+  }
   default:
     break;
   }

+ 2 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -770,6 +770,7 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
           : (forTBuffer ? spirvOptions.tBufferLayoutRule
                         : spirvOptions.sBufferLayoutRule);
 
+  var->setHlslUserType(forCBuffer ? "cbuffer" : forTBuffer ? "tbuffer" : "");
   var->setLayoutRule(layoutRule);
   return var;
 }
@@ -1501,7 +1502,7 @@ bool DeclResultIdMapper::decorateResourceBindings() {
 
   // Decorates the given varId of the given category with set number
   // setNo, binding number bindingNo. Ignores overlaps.
-  const auto tryToDecorate = [this, &bindingSet](SpirvInstruction *var,
+  const auto tryToDecorate = [this, &bindingSet](SpirvVariable *var,
                                                  const uint32_t setNo,
                                                  const uint32_t bindingNo) {
     bindingSet.useBinding(bindingNo, setNo);

+ 12 - 0
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -479,6 +479,13 @@ bool EmitVisitor::visit(SpirvVariable *inst) {
   finalizeInstruction();
   emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
                               inst->getDebugName());
+  if (spvOptions.enableReflect && inst->hasBinding() &&
+      !inst->getHlslUserType().empty()) {
+    typeHandler.emitDecoration(
+        getOrAssignResultId<SpirvInstruction>(inst),
+        spv::Decoration::UserTypeGOOGLE,
+        string::encodeSPIRVString(inst->getHlslUserType().lower()));
+  }
   return true;
 }
 
@@ -1601,6 +1608,11 @@ void EmitTypeHandler::emitDecoration(uint32_t typeResultId,
 
   spv::Op op =
       memberIndex.hasValue() ? spv::Op::OpMemberDecorate : spv::Op::OpDecorate;
+  if (decoration == spv::Decoration::UserTypeGOOGLE) {
+    op = memberIndex.hasValue() ? spv::Op::OpMemberDecorateString
+                                : spv::Op::OpDecorateString;
+  }
+
   assert(curDecorationInst.empty());
   curDecorationInst.push_back(static_cast<uint32_t>(op));
   curDecorationInst.push_back(typeResultId);

+ 4 - 0
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -115,6 +115,8 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
             Extension::AMD_shader_explicit_vertex_parameter)
       .Case("SPV_GOOGLE_hlsl_functionality1",
             Extension::GOOGLE_hlsl_functionality1)
+      .Case("SPV_GOOGLE_user_type",
+            Extension::GOOGLE_user_type)
       .Case("SPV_KHR_post_depth_coverage", Extension::KHR_post_depth_coverage)
       .Case("SPV_NV_ray_tracing", Extension::NV_ray_tracing)
       .Default(Extension::Unknown);
@@ -150,6 +152,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
     return "SPV_AMD_shader_explicit_vertex_parameter";
   case Extension::GOOGLE_hlsl_functionality1:
     return "SPV_GOOGLE_hlsl_functionality1";
+  case Extension::GOOGLE_user_type:
+    return "SPV_GOOGLE_user_type";
   case Extension::NV_ray_tracing:
     return "SPV_NV_ray_tracing";
   default:

+ 5 - 0
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -86,6 +86,11 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
   // Variables and function parameters must have a pointer type.
   case spv::Op::OpFunctionParameter:
   case spv::Op::OpVariable: {
+    if (auto *var = dyn_cast<SpirvVariable>(instr)) {
+      if (var->hasBinding() && var->getHlslUserType().empty()) {
+        var->setHlslUserType(getHlslResourceTypeName(var->getAstResultType()));
+      }
+    }
     const SpirvType *pointerType =
         spvContext.getPointerType(resultType, instr->getStorageClass());
     instr->setResultType(pointerType);

+ 5 - 1
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -869,7 +869,7 @@ void SpirvBuilder::decorateIndex(SpirvInstruction *target, uint32_t index,
   module->addDecoration(decor);
 }
 
-void SpirvBuilder::decorateDSetBinding(SpirvInstruction *target,
+void SpirvBuilder::decorateDSetBinding(SpirvVariable *target,
                                        uint32_t setNumber,
                                        uint32_t bindingNumber) {
   const SourceLocation srcLoc = target->getSourceLocation();
@@ -879,6 +879,10 @@ void SpirvBuilder::decorateDSetBinding(SpirvInstruction *target,
 
   auto *binding = new (context) SpirvDecoration(
       srcLoc, target, spv::Decoration::Binding, {bindingNumber});
+
+  target->setDescriptorSetNo(setNumber);
+  target->setBindingNo(bindingNumber);
+
   module->addDecoration(binding);
 }
 

+ 4 - 2
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -226,7 +226,8 @@ SpirvDecoration::SpirvDecoration(SourceLocation loc,
 
 spv::Op SpirvDecoration::getDecorateOpcode(
     spv::Decoration decoration, const llvm::Optional<uint32_t> &memberIndex) {
-  if (decoration == spv::Decoration::HlslSemanticGOOGLE)
+  if (decoration == spv::Decoration::HlslSemanticGOOGLE ||
+      decoration == spv::Decoration::UserTypeGOOGLE)
     return memberIndex.hasValue() ? spv::Op::OpMemberDecorateStringGOOGLE
                                   : spv::Op::OpDecorateStringGOOGLE;
 
@@ -245,7 +246,8 @@ SpirvVariable::SpirvVariable(QualType resultType, SourceLocation loc,
                              spv::StorageClass sc, bool precise,
                              SpirvInstruction *initializerInst)
     : SpirvInstruction(IK_Variable, spv::Op::OpVariable, resultType, loc),
-      initializer(initializerInst) {
+      initializer(initializerInst), descriptorSet(-1), binding(-1),
+      hlslUserType("") {
   setStorageClass(sc);
   setPrecise(precise);
 }

+ 86 - 0
tools/clang/test/CodeGenSPIRV/decoration.user-type.hlsl

@@ -0,0 +1,86 @@
+// Run: %dxc -T ps_6_0 -E main -fspv-reflect
+
+// CHECK: OpDecorateString %a UserTypeGOOGLE "structuredbuffer"
+StructuredBuffer<float> a;
+// CHECK: OpDecorateString %b UserTypeGOOGLE "rwstructuredbuffer"
+RWStructuredBuffer<float> b;
+// CHECK: OpDecorateString %c UserTypeGOOGLE "appendstructuredbuffer"
+AppendStructuredBuffer<float> c;
+// CHECK: OpDecorateString %d UserTypeGOOGLE "consumestructuredbuffer"
+ConsumeStructuredBuffer<float> d;
+// CHECK: OpDecorateString %e UserTypeGOOGLE "texture1d"
+Texture1D<float> e;
+// CHECK: OpDecorateString %f UserTypeGOOGLE "texture2d"
+Texture2D<float> f;
+// CHECK: OpDecorateString %g UserTypeGOOGLE "texture3d"
+Texture3D<float> g;
+// CHECK: OpDecorateString %h UserTypeGOOGLE "texturecube"
+TextureCube<float> h;
+// CHECK: OpDecorateString %i UserTypeGOOGLE "texture1darray"
+Texture1DArray<float> i;
+// CHECK: OpDecorateString %j UserTypeGOOGLE "texture2darray"
+Texture2DArray<float> j;
+// CHECK: OpDecorateString %k UserTypeGOOGLE "texture2dms"
+Texture2DMS<float> k;
+// CHECK: OpDecorateString %l UserTypeGOOGLE "texture2dmsarray"
+Texture2DMSArray<float> l;
+// CHECK: OpDecorateString %m UserTypeGOOGLE "texturecubearray"
+TextureCubeArray<float> m;
+// CHECK: OpDecorateString %n UserTypeGOOGLE "rwtexture1d"
+RWTexture1D<float> n;
+// CHECK: OpDecorateString %o UserTypeGOOGLE "rwtexture2d"
+RWTexture2D<float> o;
+// CHECK: OpDecorateString %p UserTypeGOOGLE "rwtexture3d"
+RWTexture3D<float> p;
+// CHECK: OpDecorateString %q UserTypeGOOGLE "rwtexture1darray"
+RWTexture1DArray<float> q;
+// CHECK: OpDecorateString %r UserTypeGOOGLE "rwtexture2darray"
+RWTexture2DArray<float> r;
+// CHECK: OpDecorateString %s UserTypeGOOGLE "buffer"
+Buffer<float> s;
+// CHECK: OpDecorateString %t UserTypeGOOGLE "rwbuffer"
+RWBuffer<float> t;
+
+// CHECK: OpDecorateString %eArr UserTypeGOOGLE "texture1d"
+Texture1D<float> eArr[5];
+// CHECK: OpDecorateString %fArr UserTypeGOOGLE "texture2d"
+Texture2D<float> fArr[5];
+// CHECK: OpDecorateString %gArr UserTypeGOOGLE "texture3d"
+Texture3D<float> gArr[5];
+// CHECK: OpDecorateString %hArr UserTypeGOOGLE "texturecube"
+TextureCube<float> hArr[5];
+// CHECK: OpDecorateString %iArr UserTypeGOOGLE "texture1darray"
+Texture1DArray<float> iArr[5];
+// CHECK: OpDecorateString %jArr UserTypeGOOGLE "texture2darray"
+Texture2DArray<float> jArr[5];
+// CHECK: OpDecorateString %kArr UserTypeGOOGLE "texture2dms"
+Texture2DMS<float> kArr[5];
+// CHECK: OpDecorateString %lArr UserTypeGOOGLE "texture2dmsarray"
+Texture2DMSArray<float> lArr[5];
+// CHECK: OpDecorateString %mArr UserTypeGOOGLE "texturecubearray"
+TextureCubeArray<float> mArr[5];
+// CHECK: OpDecorateString %nArr UserTypeGOOGLE "rwtexture1d"
+RWTexture1D<float> nArr[5];
+// CHECK: OpDecorateString %oArr UserTypeGOOGLE "rwtexture2d"
+RWTexture2D<float> oArr[5];
+// CHECK: OpDecorateString %pArr UserTypeGOOGLE "rwtexture3d"
+RWTexture3D<float> pArr[5];
+// CHECK: OpDecorateString %qArr UserTypeGOOGLE "rwtexture1darray"
+RWTexture1DArray<float> qArr[5];
+// CHECK: OpDecorateString %rArr UserTypeGOOGLE "rwtexture2darray"
+RWTexture2DArray<float> rArr[5];
+// CHECK: OpDecorateString %sArr UserTypeGOOGLE "buffer"
+Buffer<float> sArr[5];
+// CHECK: OpDecorateString %tArr UserTypeGOOGLE "rwbuffer"
+RWBuffer<float> tArr[5];
+
+// CHECK: OpDecorateString %MyCBuffer UserTypeGOOGLE "cbuffer"
+cbuffer MyCBuffer { float x; };
+
+// CHECK: OpDecorateString %MyTBuffer UserTypeGOOGLE "tbuffer"
+tbuffer MyTBuffer { float y; };
+
+float4 main() : SV_Target{
+    return 0.0.xxxx;
+}
+

+ 5 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1977,6 +1977,11 @@ TEST_F(FileTest, DecorationNoContractionStageVars) {
   runFileTest("decoration.no-contraction.stage-vars.hlsl");
 }
 
+// For UserTypeGOOGLE decorations
+TEST_F(FileTest, DecorationUserTypeGOOGLE) {
+  runFileTest("decoration.user-type.hlsl");
+}
+
 // For pragmas
 TEST_F(FileTest, PragmaPackMatrix) { runFileTest("pragma.pack_matrix.hlsl"); }