Browse Source

[spirv] Variable/Resource in namespaces. (#1208)

Ehsan 7 years ago
parent
commit
e24f4b7a65

+ 6 - 10
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -500,17 +500,16 @@ uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
   const auto *blockDec = forTBuffer ? Decoration::getBufferBlock(context)
   const auto *blockDec = forTBuffer ? Decoration::getBufferBlock(context)
                                     : Decoration::getBlock(context);
                                     : Decoration::getBlock(context);
 
 
-  auto decorations = typeTranslator.getLayoutDecorations(decl, layoutRule);
+  const llvm::SmallVector<const Decl *, 4> &declGroup =
+      typeTranslator.collectDeclsInDeclContext(decl);
+  auto decorations = typeTranslator.getLayoutDecorations(declGroup, layoutRule);
   decorations.push_back(blockDec);
   decorations.push_back(blockDec);
 
 
   // Collect the type and name for each field
   // Collect the type and name for each field
   llvm::SmallVector<uint32_t, 4> fieldTypes;
   llvm::SmallVector<uint32_t, 4> fieldTypes;
   llvm::SmallVector<llvm::StringRef, 4> fieldNames;
   llvm::SmallVector<llvm::StringRef, 4> fieldNames;
   uint32_t fieldIndex = 0;
   uint32_t fieldIndex = 0;
-  for (const auto *subDecl : decl->decls()) {
-    if (TypeTranslator::shouldSkipInStructLayout(subDecl))
-      continue;
-
+  for (const auto *subDecl : declGroup) {
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
     // HLSLBufferDecls).
     // HLSLBufferDecls).
     assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
     assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
@@ -648,7 +647,7 @@ void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
   if (astDecls.count(var) != 0)
   if (astDecls.count(var) != 0)
     return;
     return;
 
 
-  const auto *context = var->getDeclContext();
+  const auto *context = var->getTranslationUnitDecl();
   const uint32_t globals = createStructOrStructArrayVarOfExplicitLayout(
   const uint32_t globals = createStructOrStructArrayVarOfExplicitLayout(
       context, /*arraySize*/ 0, ContextUsageKind::Globals, "type.$Globals",
       context, /*arraySize*/ 0, ContextUsageKind::Globals, "type.$Globals",
       "$Globals");
       "$Globals");
@@ -657,11 +656,8 @@ void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
                             nullptr, nullptr);
                             nullptr, nullptr);
 
 
   uint32_t index = 0;
   uint32_t index = 0;
-  for (const auto *decl : context->decls())
+  for (const auto *decl : typeTranslator.collectDeclsInDeclContext(context))
     if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
     if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
-      if (TypeTranslator::shouldSkipInStructLayout(varDecl))
-        continue;
-
       if (const auto *attr = varDecl->getAttr<VKBindingAttr>()) {
       if (const auto *attr = varDecl->getAttr<VKBindingAttr>()) {
         emitError("variable '%0' will be placed in $Globals so cannot have "
         emitError("variable '%0' will be placed in $Globals so cannot have "
                   "vk::binding attribute",
                   "vk::binding attribute",

+ 24 - 19
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -510,6 +510,29 @@ spv::Capability getCapabilityForGroupNonUniform(spv::Op opcode) {
   return spv::Capability::Max;
   return spv::Capability::Max;
 }
 }
 
 
+std::string getNamespacePrefix(const Decl* decl) {
+  std::string nsPrefix = "";
+  const DeclContext *dc = decl->getDeclContext();
+  while (dc && !dc->isTranslationUnit()) {
+    if (const NamespaceDecl *ns = dyn_cast<NamespaceDecl>(dc)) {
+      if (!ns->isAnonymousNamespace()) {
+        nsPrefix = ns->getName().str() + "::" + nsPrefix;
+      }
+    }
+    dc = dc->getParent();
+  }
+  return nsPrefix;
+}
+
+std::string getFnName(const FunctionDecl *fn) {
+  // Prefix the function name with the struct name if necessary
+  std::string classOrStructName = "";
+  if (const auto *memberFn = dyn_cast<CXXMethodDecl>(fn))
+    if (const auto *st = dyn_cast<CXXRecordDecl>(memberFn->getDeclContext()))
+      classOrStructName = st->getName().str() + ".";
+  return getNamespacePrefix(fn) + classOrStructName + fn->getName().str();
+}
+
 } // namespace
 } // namespace
 
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci, EmitSPIRVOptions &options)
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci, EmitSPIRVOptions &options)
@@ -979,23 +1002,9 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
   breakStack = std::stack<uint32_t>();
   breakStack = std::stack<uint32_t>();
   continueStack = std::stack<uint32_t>();
   continueStack = std::stack<uint32_t>();
 
 
-  std::string funcName = decl->getName();
-  std::string nsPrefix = "";
-
-  // Add namespace name as prefix of function name (if any).
-  const DeclContext *dc = decl->getEnclosingNamespaceContext();
-  while (dc && !dc->isTranslationUnit()) {
-    if (const NamespaceDecl *ns = dyn_cast<NamespaceDecl>(dc)) {
-      if (!ns->isAnonymousNamespace()) {
-        nsPrefix = ns->getName().str() + "::" + nsPrefix;
-      }
-    }
-    dc = dc->getParent();
-  }
-
   // This will allow the entry-point name to be something like
   // This will allow the entry-point name to be something like
   // myNamespace::myEntrypointFunc.
   // myNamespace::myEntrypointFunc.
-  funcName = nsPrefix + funcName;
+  std::string funcName = getFnName(decl);
 
 
   uint32_t funcId = 0;
   uint32_t funcId = 0;
 
 
@@ -1035,10 +1044,6 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
           theBuilder.getPointerType(valueType, spv::StorageClass::Function);
           theBuilder.getPointerType(valueType, spv::StorageClass::Function);
       paramTypes.push_back(ptrType);
       paramTypes.push_back(ptrType);
     }
     }
-
-    // Prefix the function name with the struct name
-    if (const auto *st = dyn_cast<CXXRecordDecl>(memberFn->getDeclContext()))
-      funcName = nsPrefix + st->getName().str() + "." + decl->getName().str();
   }
   }
 
 
   for (const auto *param : decl->params()) {
   for (const auto *param : decl->params()) {

+ 51 - 24
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -592,7 +592,7 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule) {
 
 
     llvm::SmallVector<const Decoration *, 4> decorations;
     llvm::SmallVector<const Decoration *, 4> decorations;
     if (rule != LayoutRule::Void) {
     if (rule != LayoutRule::Void) {
-      decorations = getLayoutDecorations(decl, rule);
+      decorations = getLayoutDecorations(collectDeclsInDeclContext(decl), rule);
     }
     }
 
 
     return theBuilder.getStructType(fieldTypes, decl->getName(), fieldNames,
     return theBuilder.getStructType(fieldTypes, decl->getName(), fieldNames,
@@ -1154,39 +1154,32 @@ bool TypeTranslator::shouldSkipInStructLayout(const Decl *decl) {
       decl->getDeclContext()->getLexicalParent()->isTranslationUnit())
       decl->getDeclContext()->getLexicalParent()->isTranslationUnit())
     return true;
     return true;
 
 
-  // For others we can check their DeclContext directly.
-  if (decl->getDeclContext()->isTranslationUnit()) {
-    // External visibility
-    if (const auto *declDecl = dyn_cast<DeclaratorDecl>(decl))
-      if (!declDecl->hasExternalFormalLinkage())
-        return true;
-
-    // cbuffer/tbuffer
-    if (isa<HLSLBufferDecl>(decl))
+  // External visibility
+  if (const auto *declDecl = dyn_cast<DeclaratorDecl>(decl))
+    if (!declDecl->hasExternalFormalLinkage())
       return true;
       return true;
 
 
-    // Other resource types
-    if (const auto *valueDecl = dyn_cast<ValueDecl>(decl))
-      if (isResourceType(valueDecl))
-        return true;
-  }
+  // cbuffer/tbuffer
+  if (isa<HLSLBufferDecl>(decl))
+    return true;
+
+  // Other resource types
+  if (const auto *valueDecl = dyn_cast<ValueDecl>(decl))
+    if (isResourceType(valueDecl))
+      return true;
 
 
   return false;
   return false;
 }
 }
 
 
-llvm::SmallVector<const Decoration *, 4>
-TypeTranslator::getLayoutDecorations(const DeclContext *decl, LayoutRule rule) {
+llvm::SmallVector<const Decoration *, 4> TypeTranslator::getLayoutDecorations(
+    const llvm::SmallVector<const Decl *, 4> &decls, LayoutRule rule) {
   const auto spirvContext = theBuilder.getSPIRVContext();
   const auto spirvContext = theBuilder.getSPIRVContext();
   llvm::SmallVector<const Decoration *, 4> decorations;
   llvm::SmallVector<const Decoration *, 4> decorations;
   uint32_t offset = 0, index = 0;
   uint32_t offset = 0, index = 0;
-
-  for (const auto *field : decl->decls()) {
-    if (shouldSkipInStructLayout(field))
-      continue;
-
+  for (const auto *decl : decls) {
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
     // HLSLBufferDecls).
     // HLSLBufferDecls).
-    const auto *declDecl = cast<DeclaratorDecl>(field);
+    const auto *declDecl = cast<DeclaratorDecl>(decl);
     auto fieldType = declDecl->getType();
     auto fieldType = declDecl->getType();
 
 
     uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
     uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
@@ -1205,7 +1198,7 @@ TypeTranslator::getLayoutDecorations(const DeclContext *decl, LayoutRule rule) {
       offset = roundToPow2(offset, memberAlignment);
       offset = roundToPow2(offset, memberAlignment);
 
 
     // The vk::offset attribute takes precedence over all.
     // The vk::offset attribute takes precedence over all.
-    if (const auto *offsetAttr = field->getAttr<VKOffsetAttr>()) {
+    if (const auto *offsetAttr = decl->getAttr<VKOffsetAttr>()) {
       offset = offsetAttr->getOffset();
       offset = offsetAttr->getOffset();
     }
     }
     // The :packoffset() annotation takes precedence over normal layout
     // The :packoffset() annotation takes precedence over normal layout
@@ -1266,6 +1259,40 @@ TypeTranslator::getLayoutDecorations(const DeclContext *decl, LayoutRule rule) {
   return decorations;
   return decorations;
 }
 }
 
 
+void TypeTranslator::collectDeclsInNamespace(
+    const NamespaceDecl *nsDecl, llvm::SmallVector<const Decl *, 4> *decls) {
+  for (const auto *decl : nsDecl->decls()) {
+    collectDeclsInField(decl, decls);
+  }
+}
+
+void TypeTranslator::collectDeclsInField(
+    const Decl *field, llvm::SmallVector<const Decl *, 4> *decls) {
+
+  // Case of nested namespaces.
+  if (const auto *nsDecl = dyn_cast<NamespaceDecl>(field)) {
+    collectDeclsInNamespace(nsDecl, decls);
+  }
+
+  if (shouldSkipInStructLayout(field))
+    return;
+
+  if (!isa<DeclaratorDecl>(field)) {
+    return;
+  }
+
+  (*decls).push_back(field);
+}
+
+const llvm::SmallVector<const Decl *, 4>
+TypeTranslator::collectDeclsInDeclContext(const DeclContext *declContext) {
+  llvm::SmallVector<const Decl *, 4> decls;
+  for (const auto *field : declContext->decls()) {
+    collectDeclsInField(field, &decls);
+  }
+  return decls;
+}
+
 uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule,
 uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule,
                                                bool isDepthCmp) {
                                                bool isDepthCmp) {
   // Resource types are either represented like C struct or C++ class in the
   // Resource types are either represented like C struct or C++ class in the

+ 22 - 4
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -257,18 +257,36 @@ public:
   static bool shouldSkipInStructLayout(const Decl *decl);
   static bool shouldSkipInStructLayout(const Decl *decl);
 
 
   /// \brief Generates layout decorations (Offset, MatrixStride, RowMajor,
   /// \brief Generates layout decorations (Offset, MatrixStride, RowMajor,
-  /// ColMajor) for the given type.
+  /// ColMajor) for the given decl group.
   ///
   ///
-  /// This method is not recursive; it only handles the top-level member/field
-  /// of the given DeclContext. Besides, it does not handle ArrayStride, which
+  /// This method is not recursive; it only handles the top-level members/fields
+  /// of the given Decl group. Besides, it does not handle ArrayStride, which
   /// according to the spec, must be attached to the array type itself instead
   /// according to the spec, must be attached to the array type itself instead
   /// of a struct member.
   /// of a struct member.
   llvm::SmallVector<const Decoration *, 4>
   llvm::SmallVector<const Decoration *, 4>
-  getLayoutDecorations(const DeclContext *decl, LayoutRule rule);
+  getLayoutDecorations(const llvm::SmallVector<const Decl *, 4> &declGroup,
+                       LayoutRule rule);
 
 
   /// \brief Returns how many sequential locations are consumed by a given type.
   /// \brief Returns how many sequential locations are consumed by a given type.
   uint32_t getLocationCount(QualType type);
   uint32_t getLocationCount(QualType type);
 
 
+  /// \brief Collects and returns all member/field declarations inside the given
+  /// DeclContext. If it sees a NamespaceDecl, it recursively dives in and
+  /// collects decls in the correct order.
+  /// Utilizes collectDeclsInNamespace and collectDeclsInField private methods.
+  const llvm::SmallVector<const Decl *, 4>
+  collectDeclsInDeclContext(const DeclContext *declContext);
+
+private:
+  /// \brief Appends any member/field decls found inside the given namespace
+  /// into the give decl vector.
+  void collectDeclsInNamespace(const NamespaceDecl *nsDecl,
+                               llvm::SmallVector<const Decl *, 4> *decls);
+
+  /// \brief Appends the given member/field decl into the given decl vector.
+  void collectDeclsInField(const Decl *field,
+                           llvm::SmallVector<const Decl *, 4> *decls);
+
 private:
 private:
   /// \brief Wrapper method to create an error message and report it
   /// \brief Wrapper method to create an error message and report it
   /// in the diagnostic engine associated with this consumer.
   /// in the diagnostic engine associated with this consumer.

+ 9 - 13
tools/clang/test/CodeGenSPIRV/namespace.functions.hlsl

@@ -71,29 +71,25 @@ float4 main(float4 PosCS : SV_Position) : SV_Target
 float3 A::B::AddBlue() { return float3(1, 1, 1); }
 float3 A::B::AddBlue() { return float3(1, 1, 1); }
 float3 A::AddGreen() { return float3(3, 3, 3); }
 float3 A::AddGreen() { return float3(3, 3, 3); }
 
 
-// CHECK: %AddRed = OpFunction %v3float None {{%\d+}}
+// CHECK: %AddRed = OpFunction %v3float None
 // CHECK: OpReturnValue [[v3f2]]
 // CHECK: OpReturnValue [[v3f2]]
 
 
-// CHECK: %A__AddRed = OpFunction %v3float None {{%\d+}}
+// CHECK: %A__AddRed = OpFunction %v3float None
 // CHECK: OpReturnValue [[v3f0]]
 // CHECK: OpReturnValue [[v3f0]]
 
 
-// CHECK: %A__B__AddRed = OpFunction %v3float None {{%\d+}}
+// CHECK: %A__B__AddRed = OpFunction %v3float None
 // CHECK: OpReturnValue [[v3f1]]
 // CHECK: OpReturnValue [[v3f1]]
 
 
-// CHECK: %A__B__AddBlue = OpFunction %v3float None {{%\d+}}
+// CHECK: %A__B__AddBlue = OpFunction %v3float None
 // CHECK: OpReturnValue [[v3f1]]
 // CHECK: OpReturnValue [[v3f1]]
 
 
-// CHECK: %A__AddGreen = OpFunction %v3float None {{%\d+}}
+// CHECK: %A__AddGreen = OpFunction %v3float None
 // CHECK: OpReturnValue [[v3f3]]
 // CHECK: OpReturnValue [[v3f3]]
 
 
 // TODO: struct name should also be updated to A::myStruct
 // TODO: struct name should also be updated to A::myStruct
-// CHECK: %A__createMyStruct = OpFunction %myStruct None {{%\d+}}
+// CHECK: %A__createMyStruct = OpFunction %myStruct None
 
 
-// CHECK: %A__myStruct_add = OpFunction %int None {{%\d+}}
+// CHECK: %A__myStruct_add = OpFunction %int None
 // CHECK: %param_this = OpFunctionParameter %_ptr_Function_myStruct
 // CHECK: %param_this = OpFunctionParameter %_ptr_Function_myStruct
-// CHECK: {{%\d+}} = OpAccessChain %_ptr_Function_int %param_this %int_0
-// CHECK: {{%\d+}} = OpLoad
-// CHECK: {{%\d+}} = OpAccessChain %_ptr_Function_int %param_this %int_1
-// CHECK: {{%\d+}} = OpLoad
-// CHECK: {{%\d+}} = OpIAdd
-// CHECK: OpReturnValue {{%\d+}}
+// CHECK: OpAccessChain %_ptr_Function_int %param_this %int_0
+// CHECK: OpAccessChain %_ptr_Function_int %param_this %int_1

+ 31 - 0
tools/clang/test/CodeGenSPIRV/namespace.globals.hlsl

@@ -0,0 +1,31 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK: OpMemberName %type__Globals 0 "a"
+// CHECK: OpMemberName %type__Globals 1 "b"
+// CHECK: OpMemberName %type__Globals 2 "c"
+// CHECK: OpName %_Globals "$Globals"
+
+// CHECK: OpDecorate %_Globals DescriptorSet 0
+// CHECK: OpDecorate %_Globals Binding 0
+
+// CHECK: %type__Globals = OpTypeStruct %int %int %int
+
+namespace A {
+  int a;
+
+  namespace B {
+    int b;
+  }  // end namespace B
+
+}  // end namespace A
+
+int c;
+
+float4 main(float4 PosCS : SV_Position) : SV_Target
+{
+// CHECK: OpAccessChain %_ptr_Uniform_int %_Globals %int_1
+// CHECK: OpAccessChain %_ptr_Uniform_int %_Globals %int_0
+// CHECK: OpAccessChain %_ptr_Uniform_int %_Globals %int_2
+  int newInt = A::B::b + A::a + c;
+  return float4(0,0,0,0);
+}

+ 44 - 0
tools/clang/test/CodeGenSPIRV/namespace.resources.hlsl

@@ -0,0 +1,44 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK: OpMemberDecorate %type_RWStructuredBuffer_v4float 0 Offset 0
+// CHECK: OpDecorate %type_RWStructuredBuffer_v4float BufferBlock
+
+// CHECK: OpMemberDecorate %type__Globals 0 Offset 0
+// CHECK: OpDecorate %type__Globals Block
+
+
+// CHECK: OpDecorate %rw1 DescriptorSet 0
+// CHECK: OpDecorate %rw1 Binding 0
+// CHECK: OpDecorate %counter_var_rw1 DescriptorSet 0
+// CHECK: OpDecorate %counter_var_rw1 Binding 1
+// CHECK: OpDecorate %rw2 DescriptorSet 0
+// CHECK: OpDecorate %rw2 Binding 2
+// CHECK: OpDecorate %counter_var_rw2 DescriptorSet 0
+// CHECK: OpDecorate %counter_var_rw2 Binding 3
+// CHECK: OpDecorate %rw3 DescriptorSet 0
+// CHECK: OpDecorate %rw3 Binding 4
+// CHECK: OpDecorate %counter_var_rw3 DescriptorSet 0
+// CHECK: OpDecorate %counter_var_rw3 Binding 5
+
+RWStructuredBuffer<float4> rw1;
+
+namespace A {
+  RWStructuredBuffer<float4> rw2;
+  
+  namespace B {
+    RWStructuredBuffer<float4> rw3;
+  }  // end namespace B
+
+}  // end namespace A
+
+// Check that resources are not added to the globals struct.
+// CHECK: %type__Globals = OpTypeStruct %int
+int c;
+
+float4 main(float4 PosCS : SV_Position) : SV_Target
+{
+// CHECK: OpAccessChain %_ptr_Uniform_v4float %rw1 %int_0 %uint_0
+// CHECK: OpAccessChain %_ptr_Uniform_v4float %rw2 %int_0 %uint_1
+// CHECK: OpAccessChain %_ptr_Uniform_v4float %rw3 %int_0 %uint_2
+  return rw1[0] + A::rw2[1] + A::B::rw3[2];
+}

+ 4 - 3
tools/clang/test/CodeGenSPIRV/type.struct.hlsl

@@ -57,12 +57,13 @@ void main() {
   S s;
   S s;
   T t;
   T t;
 
 
-// CHECK: %R = OpTypeStruct %v2float
-// CHECK: %r0 = OpVariable %_ptr_Function_R Function
+// CHECK: %_ptr_Function__struct_[[num]] = OpTypePointer Function %_struct_[[num]]
+
+// CHECK: %r0 = OpVariable %_ptr_Function__struct_[[num]] Function
   struct R {
   struct R {
     float2 rVal;
     float2 rVal;
   } r0;
   } r0;
 
 
-// CHECK: %r1 = OpVariable %_ptr_Function_R Function
+// CHECK: %r1 = OpVariable %_ptr_Function__struct_[[num]] Function
   R r1;
   R r1;
 }
 }

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1440,6 +1440,12 @@ TEST_F(FileTest, NonFpColMajorError) {
 TEST_F(FileTest, NamespaceFunctions) {
 TEST_F(FileTest, NamespaceFunctions) {
   runFileTest("namespace.functions.hlsl");
   runFileTest("namespace.functions.hlsl");
 }
 }
+TEST_F(FileTest, NamespaceGlobals) {
+  runFileTest("namespace.globals.hlsl");
+}
+TEST_F(FileTest, NamespaceResources) {
+  runFileTest("namespace.resources.hlsl");
+}
 
 
 // HS: for different Patch Constant Functions
 // HS: for different Patch Constant Functions
 TEST_F(FileTest, HullShaderPCFVoid) { runFileTest("hs.pcf.void.hlsl"); }
 TEST_F(FileTest, HullShaderPCFVoid) { runFileTest("hs.pcf.void.hlsl"); }