Browse Source

Generate fxc-compatible reflection type info for buffer members (#146)

* Add support for reflection on types of constant-buffer members.

This change performs a best-effort translation from the LLVM types and their DXIL annotations over to the representation exported by the current `ID3D12ShaderReflectionType` interface.

- Added missing implementation of `CShaderReflectionType` members that implement `ID3D12ShaderReflectionType`

- Added a new `CShaderReflectionType::Initialize()` that intializates the instance data from an `llvm::Type*` and its `DxilFieldAnnotation`
  - The main complexity here is a helper routine to extract the appropriate `D3D_SHADER_VARIABLE_TYPE` from an `llvm::StructType*`. The current code copies the logic used in `HLModule::IsHLSLObjectType()`; that bit of copy-paste programming may need to be refactored away.

- Call into the new logic where there had been a `TODO: create reflection type` comment before

- Know issues:
  - Leaks the `CShaderReflectionType` instances. It is probably best to just store the type of a variable directly in the `CShaderReflectionVariable`, but the existing declarations had an unowned pointer, so I expect a tiny bit of discussion is warranted about the right way to proceed.
  - Related: I did not implement the type equality test (`ID3D12ShaderReflectionType::IsEqual()`), because I don't expect it is needed by most users, and it adds complexity. A trivial implementation might store and then compare the underlying `llvm::Type` pointers, but it isn't clear to me whether the type-equality test here is supposed to consider additional properites like offsets for `struct` fields.

Fixes issue #134

* Address small-scope review feedback.

Thanks to @marcelolr for the comments. This change tries to address the simpler style/convention issues. The memory ownership fix will come next.

* Update DXBC reflection comparison tests to comapre type reflection data.

This is a small change. I added a `CompareType()` method to `DxilContainerTest` that can compare the fields of two `ID3D12ShaderReflectionType*`, and then invoked aht routine when comparing the members of reflection contstant buffers.

The comparison applies to all the simple data fields, but does *not* test:

- Anything related to inheritance, interfaces, sub-types, etc.
- Type equality tests (they aren't being implemented right now, and I don't know if most clients of the reflection API care about them)

* Fix computation of names for fxc-compatible reflection types.

Previously I'd used the name from the DXIL *field* annotation, which obviously isn't correct.
For now I just compute type names in a relatively ad hoc fashion, using the information I glean as I destructure the type.
I also don't try to be efficient, and just use `std::string` operations to concatenate on suffixes for matrices/vectors.

Not yet handled here:

- Arrays (do they get suffixed in DXBC reflection data?)
- Object types
- Any demangling required for user-defined `struct` types (is there any?)

* More fixes to fxc-compatible reflection for types.

- A matrix may show up as an LLVM array type, so we need to be careful when unwrapping outer array types to *not* count a matrix as an array.
- Be a little more overt about handling array layouts: include a note that the `Undefined` layout is being treated as column-major (the default).
- Correct name for `uint` type in reflection (apparently it reflects as `dword`)
- Remove any `struct.` prefix from user-define `struct` type names
- Apparently `struct` types are "scalar enough" that they get `Rows` and `Columns` set to 1. Need to double-check whether this should just apply to *all* non-matrix/vector types.
- Store field types with an indirection, so that type name (stored in a `std::string`) doesn't get copied/moved and thus invalidate the pointer we store in `m_Desc.Name`
- Special-case `void` so that it can have a correct name (not sure if this will matter in practice)

* Create fxc-compatible reflection types for `StructuredBuffer` element type

- This was another case where a shader reflection type was just being left `NULL`; it just didn't have a handy `TODO` comment calling it out. :)

* More fixes for fxc-compatible type reflection.

- Reflect scalar `uint` type name as `"dword"`, but not for vector/matrix (e.g., still use `"uint3"`)

- Supress checking of the `Offset` field in type reflection for element type of a structured buffer
  - Note: looking at the existing compiler output for structured buffers, there seems to be an existing (hopefully known) issue that the field offsets for `StructuredBuffer` elements don't match fxc

- Strip of prefixes from user-defined `struct` type names: `"struct."` and `"dx.alignment.legacy."`

- Try to emulate fxc behavior for computing the `Columns` field for reflection on a `struct` type
  - The behavior here doesn't seem to be documented on MSDN (it implies that `Columns` would be zero for a `struct` type)
  - From what I can tell, the desired value is something like "total number of scalar values (not counting objects) recursively contained in this `struct` type"

With these changes, the revamped `DxilContainerTest::ReflectionMatchesDXBC_`{`Checkin`,`Full`} tests pass.

* Implement memory management for fxc-compatible reflection types.

In previous changes the `CShaderReflectionType` instances were allowed to leak.
This change still heap allocates the types, but places `std::unique_ptr`s to them all into a field in the base shader reflection object.

The reason for storing the types together in one master list (rather than having, e.g., a `CShaderReflectionVariable` directly store a `unique_ptr` to its type) is to allow for the possibility of re-using identical types in cases where that is possible (e.g., two `cbuffer` fields that use the same `struct` type should be able to share the `CShaderReflectionType` instances for their nested fields).
Tim Foley 8 years ago
parent
commit
7aa3fbd04b
2 changed files with 611 additions and 16 deletions
  1. 546 16
      lib/HLSL/DxilContainerReflection.cpp
  2. 65 0
      tools/clang/unittests/HLSL/DxilContainerTest.cpp

+ 546 - 16
lib/HLSL/DxilContainerReflection.cpp

@@ -67,6 +67,7 @@ public:
 };
 
 class CShaderReflectionConstantBuffer;
+class CShaderReflectionType;
 class DxilShaderReflection : public ID3D12ShaderReflection {
 private:
   DXC_MICROCOM_REF_FIELD(m_dwRef)
@@ -80,6 +81,7 @@ private:
   std::vector<D3D12_SIGNATURE_PARAMETER_DESC>     m_OutputSignature;
   std::vector<D3D12_SIGNATURE_PARAMETER_DESC>     m_PatchConstantSignature;
   std::vector<std::unique_ptr<char[]>>            m_UpperCaseNames;
+  std::vector<std::unique_ptr<CShaderReflectionType>> m_Types;
   void CreateReflectionObjects();
   void SetCBufferUsage();
   void CreateReflectionObjectForResource(DxilResourceBase *R);
@@ -268,20 +270,22 @@ class CShaderReflectionType : public ID3D12ShaderReflectionType
 {
 protected:
   D3D12_SHADER_TYPE_DESC              m_Desc;
+  std::string                         m_Name;
   std::vector<StringRef>              m_MemberNames;
-  std::vector<CShaderReflectionType>  m_MemberTypes;
+  std::vector<CShaderReflectionType*> m_MemberTypes;
   CShaderReflectionType*              m_pSubType;
   CShaderReflectionType*              m_pBaseClass;
-  std::vector<CShaderReflectionType>  m_Interfaces;
+  std::vector<CShaderReflectionType*> m_Interfaces;
   ULONG_PTR                           m_Identity;
 
 public:
   // Internal
-  CShaderReflectionType();
-  ~CShaderReflectionType();
-
-  HRESULT Initialize(D3D11_INTERNALSHADER_RESOURCE_DEF *pResourceDef,
-                     BYTE *pBase, BYTE *pMax, BYTE *pRawTypeDef);
+  HRESULT Initialize(
+    DxilModule              &M,
+    llvm::Type              *type,
+    DxilFieldAnnotation     &typeAnnotation,
+    unsigned int            baseOffset,
+    std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
 
   // ID3D12ShaderReflectionType
   STDMETHOD(GetDesc)(D3D12_SHADER_TYPE_DESC *pDesc);
@@ -340,8 +344,12 @@ public:
     std::swap(m_Variables, other.m_Variables);
   }
 
-  void Initialize(DxilModule &M, DxilCBuffer &CB);
-  void InitializeStructuredBuffer(DxilModule &M, DxilResource &R);
+  void Initialize(DxilModule &M,
+                  DxilCBuffer &CB,
+                  std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
+  void InitializeStructuredBuffer(DxilModule &M,
+                                  DxilResource &R,
+                                  std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
   LPCSTR GetName() { return m_Desc.Name; }
 
   // ID3D12ShaderReflectionConstantBuffer
@@ -430,7 +438,494 @@ ID3D12ShaderReflectionConstantBuffer *CInvalidSRVariable::GetBuffer() {
   return &g_InvalidSRConstantBuffer;
 }
 
-void CShaderReflectionConstantBuffer::Initialize(DxilModule &M, DxilCBuffer &CB) {
+STDMETHODIMP CShaderReflectionType::GetDesc(D3D12_SHADER_TYPE_DESC *pDesc)
+{
+  if (!pDesc) return E_POINTER;
+  memcpy(pDesc, &m_Desc, sizeof(m_Desc));
+  return S_OK;
+}
+
+STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetMemberTypeByIndex(UINT Index)
+{
+  if (Index >= m_MemberTypes.size()) {
+    return &g_InvalidSRType;
+  }
+  return m_MemberTypes[Index];
+}
+
+STDMETHODIMP_(LPCSTR) CShaderReflectionType::GetMemberTypeName(UINT Index)
+{
+  if (Index >= m_MemberTypes.size()) {
+    return nullptr;
+  }
+  return (LPCSTR) m_MemberNames[Index].bytes_begin();
+}
+
+STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetMemberTypeByName(LPCSTR Name)
+{
+  UINT memberCount = m_Desc.Members;
+  for( UINT mm = 0; mm < memberCount; ++mm ) {
+    if( m_MemberNames[mm] == Name ) {
+      return m_MemberTypes[mm];
+    }
+  }
+  return nullptr;
+}
+
+STDMETHODIMP CShaderReflectionType::IsEqual(THIS_ ID3D12ShaderReflectionType* pType)
+{
+  // TODO: implement this check, if users actually depend on it
+  return S_FALSE;
+}
+
+STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetSubType(THIS)
+{
+  // TODO: implement `class`-related features, if requested
+  return nullptr;
+}
+
+STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetBaseClass(THIS)
+{
+  // TODO: implement `class`-related features, if requested
+  return nullptr;
+}
+
+STDMETHODIMP_(UINT) CShaderReflectionType::GetNumInterfaces(THIS)
+{
+  // HLSL interfaces have been deprecated
+  return 0;
+}
+
+STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetInterfaceByIndex(THIS_ UINT uIndex)
+{
+  // HLSL interfaces have been deprecated
+  return nullptr;
+}
+
+STDMETHODIMP CShaderReflectionType::IsOfType(THIS_ ID3D12ShaderReflectionType* pType)
+{
+  // TODO: implement `class`-related features, if requested
+  return S_FALSE;
+}
+
+STDMETHODIMP CShaderReflectionType::ImplementsInterface(THIS_ ID3D12ShaderReflectionType* pBase)
+{
+  // HLSL interfaces have been deprecated
+  return S_FALSE;
+}
+
+// Helper routine for types that don't have an obvious mapping
+// to the existing shader reflection interface.
+static bool ProcessUnhandledObjectType(
+  llvm::StructType            *structType,
+  D3D_SHADER_VARIABLE_TYPE    *outObjectType)
+{
+  // Don't actually make this a hard error, but instead report the problem using a suitable debug message.
+#ifdef DBG
+  OutputDebugFormatA("DxilContainerReflection.cpp: error: unhandled object type '%s'.\n", structType->getName().str().c_str());
+#endif
+  *outObjectType = D3D_SVT_VOID;
+  return true;
+}
+
+// Helper routine to try to detect if a type represents an HLSL "object" type
+// (a texture, sampler, buffer, etc.), and to extract the coresponding shader
+// reflection type.
+static bool TryToDetectObjectType(
+  llvm::StructType            *structType,
+  D3D_SHADER_VARIABLE_TYPE    *outObjectType)
+{
+  // Note: This logic is largely duplicated from `HLModule::IsHLSLObjectType`
+  // with the addition of returning the appropriate reflection type tag.
+  //
+  // That logic looks error-prone, since it relies on string tests against
+  // type names, including cases that just test against a prefix.
+  // This code doesn't try to be any more robust.
+
+  StringRef name = structType->getName();
+
+  if(name.startswith("dx.types.wave_t") )
+  {
+    return ProcessUnhandledObjectType(structType, outObjectType);
+  }
+
+  // Strip off some prefixes we are likely to see.
+  name = name.ltrim("class.");
+  name = name.ltrim("struct.");
+
+  // Slice types occur as intermediates (they aren not objects)
+  if(name.endswith("_slice_type")) { return false; }
+
+  // We might check for an exact name match, or a prefix match
+#define EXACT_MATCH(NAME, TAG) \
+  else if(name == #NAME) do { *outObjectType = TAG; return true; } while(0)
+#define PREFIX_MATCH(NAME, TAG) \
+  else if(name.startswith(#NAME)) do { *outObjectType = TAG; return true; } while(0)
+
+  if(0) {}
+  EXACT_MATCH(SamplerState,               D3D_SVT_SAMPLER);
+  EXACT_MATCH(SamplerComparisonState,     D3D_SVT_SAMPLER);
+
+  // Note: GS output stream types are supported in the reflection interface.
+  else if(name.startswith("TriangleStream"))    { return ProcessUnhandledObjectType(structType, outObjectType); }
+  else if(name.startswith("PointStream"))       { return ProcessUnhandledObjectType(structType, outObjectType); }
+  else if(name.startswith("LineStream"))        { return ProcessUnhandledObjectType(structType, outObjectType); }
+
+  PREFIX_MATCH(AppendStructuredBuffer,    D3D_SVT_APPEND_STRUCTURED_BUFFER);
+  PREFIX_MATCH(ConsumeStructuredBuffer,   D3D_SVT_CONSUME_STRUCTURED_BUFFER);
+  PREFIX_MATCH(ConstantBuffer,            D3D_SVT_CBUFFER);
+
+  // Note: the `HLModule` code does this trick to avoid checking more names
+  // than it has to, but it doesn't seem 100% correct to do this.
+  // TODO: consider just listing the `RasterizerOrdered` cases explicitly,
+  // just as we do for the `RW` cases already.
+  name = name.ltrim("RasterizerOrdered");
+
+  if(0) {}
+  EXACT_MATCH(ByteAddressBuffer,          D3D_SVT_BYTEADDRESS_BUFFER);
+  EXACT_MATCH(RWByteAddressBuffer,        D3D_SVT_RWBYTEADDRESS_BUFFER);
+  PREFIX_MATCH(Buffer,                    D3D_SVT_BUFFER);
+  PREFIX_MATCH(RWBuffer,                  D3D_SVT_RWBUFFER);
+  PREFIX_MATCH(StructuredBuffer,          D3D_SVT_STRUCTURED_BUFFER);
+  PREFIX_MATCH(RWStructuredBuffer,        D3D_SVT_RWSTRUCTURED_BUFFER);
+  PREFIX_MATCH(Texture1D,                 D3D_SVT_TEXTURE1D);
+  PREFIX_MATCH(RWTexture1D,               D3D_SVT_RWTEXTURE1D);
+  PREFIX_MATCH(Texture1DArray,            D3D_SVT_TEXTURE1DARRAY);
+  PREFIX_MATCH(RWTexture1DArray,          D3D_SVT_RWTEXTURE1DARRAY);
+  PREFIX_MATCH(Texture2D,                 D3D_SVT_TEXTURE2D);
+  PREFIX_MATCH(RWTexture2D,               D3D_SVT_RWTEXTURE2D);
+  PREFIX_MATCH(Texture2DArray,            D3D_SVT_TEXTURE2DARRAY);
+  PREFIX_MATCH(RWTexture2DArray,          D3D_SVT_RWTEXTURE2DARRAY);
+  PREFIX_MATCH(Texture3D,                 D3D_SVT_TEXTURE3D);
+  PREFIX_MATCH(RWTexture3D,               D3D_SVT_RWTEXTURE3D);
+  PREFIX_MATCH(TextureCube,               D3D_SVT_TEXTURECUBE);
+  PREFIX_MATCH(TextureCubeArray,          D3D_SVT_TEXTURECUBEARRAY);
+  PREFIX_MATCH(Texture2DMS,               D3D_SVT_TEXTURE2DMS);
+  PREFIX_MATCH(Texture2DMSArray,          D3D_SVT_TEXTURE2DMSARRAY);
+
+#undef EXACT_MATCH
+#undef PREFIX_MATCH
+
+  // Default: not an object type
+  return false;
+}
+
+// Helper to determine if an LLVM type represents an HLSL
+// object type (uses the `TryToDetectObjectType()` function
+// defined previously).
+static bool IsObjectType(
+  llvm::Type* inType)
+{
+  llvm::Type* type = inType;
+  while(type->isArrayTy())
+  {
+    type = type->getArrayElementType();
+  }
+
+  llvm::StructType* structType = dyn_cast<StructType>(type);
+  if(!structType)
+    return false;
+
+  D3D_SHADER_VARIABLE_TYPE ignored;
+  return TryToDetectObjectType(structType, &ignored);
+}
+
+// Main logic for translating an LLVM type and associated
+// annotations into a D3D shader reflection type.
+HRESULT CShaderReflectionType::Initialize(
+  DxilModule              &M,
+  llvm::Type              *inType,
+  DxilFieldAnnotation     &typeAnnotation,
+  unsigned int            baseOffset,
+  std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes)
+{
+  DXASSERT_NOMSG(inType);
+
+  // Set a bunch of fields to default values, to avoid duplication.
+  m_Desc.Rows = 0;
+  m_Desc.Columns = 0;
+  m_Desc.Elements = 0;
+  m_Desc.Members = 0;
+
+  // Extract offset relative to parent.
+  // Note: the `baseOffset` is used in the case where the type in
+  // question is a field in a constant buffer, since then both the
+  // field and the variable store the same offset information, and
+  // we need to zero out the value in the type to avoid the user
+  // of the reflection interface seeing 2x the correct value.
+  m_Desc.Offset = typeAnnotation.GetCBufferOffset() - baseOffset;
+
+  // Arrays don't seem to be represented directly in the reflection
+  // data, but only as the `Elements` field being non-zero.
+  // We "unwrap" any array type here, and then proceed to look
+  // at the element type.
+  llvm::Type* type = inType;
+
+  while(type->isArrayTy())
+  {
+    llvm::Type* elementType = type->getArrayElementType();
+
+    // Note: At this point an HLSL matrix type may appear as an ordinary
+    // array (not wrapped in a `struct`), so `HLMatrixLower::IsMatrixType()`
+    // is not sufficient. Instead we need to check the field annotation.
+    //
+    // We might have an array of matrices, though, so we only exit if
+    // the field annotation says we have a matrix, and we've bottomed
+    // out and the element type isn't itself an array.
+    if(typeAnnotation.HasMatrixAnnotation() && !elementType->isArrayTy())
+    {
+      break;
+    }
+
+    // Non-array types should have `Elements` be zero, so as soon as we
+    // find that we have our first real array (not a matrix), we initialize `Elements`
+    if(!m_Desc.Elements) m_Desc.Elements = 1;
+
+    // It isn't clear what is the desired behavior for multi-dimensional arrays,
+    // but for now we do the expedient thing of multiplying out all their
+    // dimensions.
+    m_Desc.Elements *= type->getArrayNumElements();
+    type = elementType;
+  }
+
+  // Default to a scalar type, just to avoid some duplication later.
+  m_Desc.Class = D3D_SVC_SCALAR;
+
+  // Look at the annotation to try to determine the basic type of value.
+  //
+  // Note that DXIL supports some types that don't currently have equivalents
+  // in the reflection interface, so we try to muddle through here.
+  D3D_SHADER_VARIABLE_TYPE componentType = D3D_SVT_VOID;
+  switch(typeAnnotation.GetCompType().GetKind())
+  {
+  case hlsl::DXIL::ComponentType::Invalid:
+    break;
+
+  case hlsl::DXIL::ComponentType::I1:
+    componentType = D3D_SVT_BOOL;
+    m_Name = "bool";
+    break;
+
+  case hlsl::DXIL::ComponentType::I16:
+    componentType = D3D_SVT_MIN16INT;
+    m_Name = "min16int";
+    break;
+
+  case hlsl::DXIL::ComponentType::U16:
+    componentType = D3D_SVT_MIN16UINT;
+    m_Name = "min16uint";
+    break;
+
+  case hlsl::DXIL::ComponentType::I64:
+#ifdef DBG
+    OutputDebugStringA("DxilContainerReflection.cpp: warning: component of type 'I64' being reflected as if 'I32'\n");
+#endif
+  case hlsl::DXIL::ComponentType::I32:
+    componentType = D3D_SVT_INT;
+    m_Name = "int";
+    break;
+
+  case hlsl::DXIL::ComponentType::U64:
+#ifdef DBG
+    OutputDebugStringA("DxilContainerReflection.cpp: warning: component of type 'U64' being reflected as if 'U32'\n");
+#endif
+  case hlsl::DXIL::ComponentType::U32:
+    componentType = D3D_SVT_UINT;
+    m_Name = "uint";
+    break;
+
+  case hlsl::DXIL::ComponentType::F16:
+  case hlsl::DXIL::ComponentType::SNormF16:
+  case hlsl::DXIL::ComponentType::UNormF16:
+    componentType = D3D_SVT_MIN16FLOAT;
+    m_Name = "min16float";
+    break;
+
+  case hlsl::DXIL::ComponentType::F32:
+  case hlsl::DXIL::ComponentType::SNormF32:
+  case hlsl::DXIL::ComponentType::UNormF32:
+    componentType = D3D_SVT_FLOAT;
+    m_Name = "float";
+    break;
+
+  case hlsl::DXIL::ComponentType::F64:
+  case hlsl::DXIL::ComponentType::SNormF64:
+  case hlsl::DXIL::ComponentType::UNormF64:
+    componentType = D3D_SVT_DOUBLE;
+    m_Name = "double";
+    break;
+
+  default:
+#ifdef DBG
+    OutputDebugStringA("DxilContainerReflection.cpp: error: unknown component type\n");
+#endif
+    break;
+  }
+  m_Desc.Type = componentType;
+
+  // A matrix type is encoded as a vector type, plus annotations, so we
+  // need to check for this case before other vector cases.
+  if(typeAnnotation.HasMatrixAnnotation())
+  {
+    // We can extract the details from the annotation.
+    DxilMatrixAnnotation const& matrixAnnotation = typeAnnotation.GetMatrixAnnotation();
+
+    switch(matrixAnnotation.Orientation)
+    {
+    default:
+#ifdef DBG
+      OutputDebugStringA("DxilContainerReflection.cpp: error: unknown matrix orientation\n");
+#endif
+    // Note: column-major layout is the default
+    case hlsl::MatrixOrientation::Undefined:
+    case hlsl::MatrixOrientation::ColumnMajor:
+      m_Desc.Class = D3D_SVC_MATRIX_COLUMNS;
+      break;
+
+    case hlsl::MatrixOrientation::RowMajor:
+      m_Desc.Class = D3D_SVC_MATRIX_ROWS;
+      break;
+    }
+
+    m_Desc.Rows = matrixAnnotation.Rows;
+    m_Desc.Columns = matrixAnnotation.Cols;
+    m_Name += std::to_string(matrixAnnotation.Rows) + "x" + std::to_string(matrixAnnotation.Cols);
+  }
+  else if( type->isVectorTy() )
+  {
+    // We assume that LLVM vectors either represent matrices (handled above)
+    // or HLSL vectors.
+    //
+    // Note: the reflection interface encodes an N-vector as if it had 1 row
+    // and N columns.
+    m_Desc.Class = D3D_SVC_VECTOR;
+    m_Desc.Rows = 1;
+    m_Desc.Columns = type->getVectorNumElements();
+
+    m_Name += std::to_string(type->getVectorNumElements());
+  }
+  else if( type->isStructTy() )
+  {
+    // A struct type might be an ordinary user-defined `struct`,
+    // or one of the builtin in HLSL "object" types.
+    StructType *structType = cast<StructType>(type);
+
+    // We use our function to try to detect an object type
+    // based on its name.
+    if(TryToDetectObjectType(structType, &m_Desc.Type))
+    {
+      m_Desc.Class = D3D_SVC_OBJECT;
+    }
+    else
+    {
+      // Otherwise we have a struct and need to recurse on its fields.
+      m_Desc.Class = D3D_SVC_STRUCT;
+      m_Desc.Rows = 1;
+
+      // Try to "clean" the type name for use in reflection data
+      llvm::StringRef name = structType->getName();
+      name = name.ltrim("dx.alignment.legacy.");
+      name = name.ltrim("struct.");
+      m_Name = name;
+
+      unsigned int fieldCount = type->getStructNumElements();
+
+      // Fields may have annotations, and we need to look at these
+      // in order to decode their types properly.
+      DxilTypeSystem &typeSys = M.GetTypeSystem();
+      DxilStructAnnotation *structAnnotation = typeSys.GetStructAnnotation(structType);
+      DXASSERT(structAnnotation, "else type system is missing annotations for user-defined struct");
+
+      // The DXBC reflection info computes `Columns` for a
+      // `struct` type from the fields (see below)
+      UINT columnCounter = 0;
+
+      for(unsigned int ff = 0; ff < fieldCount; ++ff)
+      {
+        DxilFieldAnnotation& fieldAnnotation = structAnnotation->GetFieldAnnotation(ff);
+        llvm::Type* fieldType = structType->getStructElementType(ff);
+
+        // Skip fields with object types, since applications may not expect to see them here.
+        //
+        // TODO: should skipping be context-dependent, since we might not be inside
+        // a constant buffer?
+        if( IsObjectType(fieldType) )
+        {
+          continue;
+        }
+
+        CShaderReflectionType *fieldReflectionType = new CShaderReflectionType();
+        allTypes.push_back(std::unique_ptr<CShaderReflectionType>(fieldReflectionType));
+
+        fieldReflectionType->Initialize(M, fieldType, fieldAnnotation, 0, allTypes);
+
+        m_MemberTypes.push_back(fieldReflectionType);
+        m_MemberNames.push_back(fieldAnnotation.GetFieldName().c_str());
+
+        // Effectively, we want to add one to `Columns` for every scalar nested recursively
+        // inside this `struct` type (ignoring objects, which we filtered above). We should
+        // be able to compute this as the product of the `Columns`, `Rows` and `Elements`
+        // of each field, with the caveat that some of these may be zero, but shoud be
+        // treated as one.
+        columnCounter +=
+            (fieldReflectionType->m_Desc.Columns  ? fieldReflectionType->m_Desc.Columns  : 1)
+          * (fieldReflectionType->m_Desc.Rows     ? fieldReflectionType->m_Desc.Rows     : 1)
+          * (fieldReflectionType->m_Desc.Elements ? fieldReflectionType->m_Desc.Elements : 1);
+      }
+
+      m_Desc.Columns = columnCounter;
+
+      // Because we might have skipped fields during enumeration,
+      // the `Members` count in the description might not be the same
+      // as the field count of the original LLVM type.
+      m_Desc.Members = m_MemberTypes.size();
+    }
+  }
+  else if( type->isPointerTy() )
+  {
+#ifdef DBG
+      OutputDebugStringA("DxilContainerReflection.cpp: error: cannot reflect pointer type\n");
+#endif
+  }
+  else if( type->isVoidTy() )
+  {
+    // Name for `void` wasn't handle in the component-type `switch` above
+    m_Name = "void";
+    m_Desc.Class = D3D_SVC_SCALAR;
+    m_Desc.Rows = 1;
+    m_Desc.Columns = 1;
+  }
+  else
+  {
+    // Assume we have a scalar at this point.
+    m_Desc.Class = D3D_SVC_SCALAR;
+    m_Desc.Rows = 1;
+    m_Desc.Columns = 1;
+
+    // Special-case naming
+    switch(m_Desc.Type)
+    {
+    default:
+      break;
+
+    case D3D_SVT_UINT:
+      // Scalar `uint` gets reflected as `dword`, while vectors/matrices use `uint`...
+      m_Name = "dword";
+      break;
+    }
+  }
+  // TODO: are there other cases to be handled?
+
+  m_Desc.Name = m_Name.c_str();
+
+  return S_OK;
+}
+
+
+void CShaderReflectionConstantBuffer::Initialize(
+  DxilModule &M,
+  DxilCBuffer &CB,
+  std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes) {
   ZeroMemory(&m_Desc, sizeof(m_Desc));
   m_Desc.Name = CB.GetGlobalName().c_str();
   m_Desc.Size = CB.GetSize() / CB.GetRangeSize();
@@ -461,8 +956,11 @@ void CShaderReflectionConstantBuffer::Initialize(DxilModule &M, DxilCBuffer &CB)
     ZeroMemory(&VarDesc, sizeof(VarDesc));
     VarDesc.uFlags |= D3D_SVF_USED; // Will update in SetCBufferUsage.
     CShaderReflectionVariable Var;
-    // TODO: create reflection type.
-    CShaderReflectionType *pVarType = nullptr;
+    //Create reflection type.
+    CShaderReflectionType *pVarType = new CShaderReflectionType();
+    allTypes.push_back(std::unique_ptr<CShaderReflectionType>(pVarType));
+    pVarType->Initialize(M, ST->getContainedType(i), fieldAnnotation, fieldAnnotation.GetCBufferOffset(), allTypes);
+
     BYTE *pDefaultValue = nullptr;
 
     VarDesc.Name = fieldAnnotation.GetFieldName().c_str();
@@ -515,7 +1013,10 @@ static unsigned CalcResTypeSize(DxilModule &M, DxilResource &R) {
   return CalcTypeSize(Ty);
 }
 
-void CShaderReflectionConstantBuffer::InitializeStructuredBuffer(DxilModule &M, DxilResource &R) {
+void CShaderReflectionConstantBuffer::InitializeStructuredBuffer(
+  DxilModule &M,
+  DxilResource &R,
+  std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes) {
   ZeroMemory(&m_Desc, sizeof(m_Desc));
   m_Desc.Name = R.GetGlobalName().c_str();
   //m_Desc.Size = R.GetSize();
@@ -532,6 +1033,35 @@ void CShaderReflectionConstantBuffer::InitializeStructuredBuffer(DxilModule &M,
   VarDesc.uFlags |= D3D_SVF_USED; // TODO: not necessarily true
   CShaderReflectionVariable Var;
   CShaderReflectionType *pVarType = nullptr;
+
+  // Create reflection type, if we have the necessary annotation info
+
+  // Extract the `struct` that wraps element type of the buffer resource
+  Constant *GV = R.GetGlobalSymbol();
+  Type *Ty = GV->getType()->getPointerElementType();
+  if(Ty->isArrayTy())
+      Ty = Ty->getArrayElementType();
+  StructType *ST = cast<StructType>(Ty);
+
+  // Look up struct type annotation on the element type
+  DxilTypeSystem &typeSys = M.GetTypeSystem();
+  DxilStructAnnotation *annotation =
+    typeSys.GetStructAnnotation(cast<StructType>(ST));
+
+  // Dxil from dxbc doesn't have annotation.
+  if(annotation)
+  {
+    // Actually create the reflection type.
+    pVarType = new CShaderReflectionType();
+    allTypes.push_back(std::unique_ptr<CShaderReflectionType>(pVarType));
+
+    // The user-visible element type is the first field of the wrapepr `struct`
+    Type *fieldType = ST->getElementType(0);
+    DxilFieldAnnotation &fieldAnnotation = annotation->GetFieldAnnotation(0);
+
+    pVarType->Initialize(M, fieldType, fieldAnnotation, fieldAnnotation.GetCBufferOffset(), allTypes);
+  }
+
   BYTE *pDefaultValue = nullptr;
   Var.Initialize(this, &VarDesc, pVarType, pDefaultValue);
   m_Variables.push_back(Var);
@@ -884,7 +1414,7 @@ void DxilShaderReflection::CreateReflectionObjects() {
   // Create constant buffers, resources and signatures.
   for (auto && cb : m_pDxilModule->GetCBuffers()) {
     CShaderReflectionConstantBuffer rcb;
-    rcb.Initialize(*m_pDxilModule, *(cb.get()));
+    rcb.Initialize(*m_pDxilModule, *(cb.get()), m_Types);
     m_CBs.push_back(std::move(rcb));
   }
   // Set cbuf usage.
@@ -896,7 +1426,7 @@ void DxilShaderReflection::CreateReflectionObjects() {
       continue;
     }
     CShaderReflectionConstantBuffer rcb;
-    rcb.InitializeStructuredBuffer(*m_pDxilModule, *(uav.get()));
+    rcb.InitializeStructuredBuffer(*m_pDxilModule, *(uav.get()), m_Types);
     m_CBs.push_back(std::move(rcb));
   }
   for (auto && srv : m_pDxilModule->GetSRVs()) {
@@ -904,7 +1434,7 @@ void DxilShaderReflection::CreateReflectionObjects() {
       continue;
     }
     CShaderReflectionConstantBuffer rcb;
-    rcb.InitializeStructuredBuffer(*m_pDxilModule, *(srv.get()));
+    rcb.InitializeStructuredBuffer(*m_pDxilModule, *(srv.get()), m_Types);
     m_CBs.push_back(std::move(rcb));
   }
 

+ 65 - 0
tools/clang/unittests/HLSL/DxilContainerTest.cpp

@@ -140,6 +140,42 @@ public:
     VERIFY_ARE_EQUAL(pTestDesc->SystemValueType, pBaseDesc->SystemValueType);
   }
 
+  void CompareType(ID3D12ShaderReflectionType *pTest,
+                   ID3D12ShaderReflectionType *pBase,
+                   bool shouldSuppressOffsetChecks = false)
+  {
+    D3D12_SHADER_TYPE_DESC testDesc, baseDesc;
+    VERIFY_SUCCEEDED(pTest->GetDesc(&testDesc));
+    VERIFY_SUCCEEDED(pBase->GetDesc(&baseDesc));
+
+    VERIFY_ARE_EQUAL(testDesc.Class,    baseDesc.Class);
+    VERIFY_ARE_EQUAL(testDesc.Type,     baseDesc.Type);
+    VERIFY_ARE_EQUAL(testDesc.Rows,     baseDesc.Rows);
+    VERIFY_ARE_EQUAL(testDesc.Columns,  baseDesc.Columns);
+    VERIFY_ARE_EQUAL(testDesc.Elements, baseDesc.Elements);
+    VERIFY_ARE_EQUAL(testDesc.Members,  baseDesc.Members);
+
+    if(!shouldSuppressOffsetChecks)
+    {
+      VERIFY_ARE_EQUAL(testDesc.Offset,   baseDesc.Offset);
+    }
+
+    VERIFY_ARE_EQUAL(0, strcmp(testDesc.Name, baseDesc.Name));
+
+    for (UINT i = 0; i < baseDesc.Members; ++i) {
+      ID3D12ShaderReflectionType* testMemberType = pTest->GetMemberTypeByIndex(i);
+      ID3D12ShaderReflectionType* baseMemberType = pBase->GetMemberTypeByIndex(i);
+      VERIFY_IS_NOT_NULL(testMemberType);
+      VERIFY_IS_NOT_NULL(baseMemberType);
+
+      CompareType(testMemberType, baseMemberType, shouldSuppressOffsetChecks);
+
+      LPCSTR testMemberName = pTest->GetMemberTypeName(i);
+      LPCSTR baseMemberName = pBase->GetMemberTypeName(i);
+      VERIFY_ARE_EQUAL(0, strcmp(testMemberName, baseMemberName));
+    }
+  }
+
   typedef HRESULT (__stdcall ID3D12ShaderReflection::*GetParameterDescFn)(UINT, D3D12_SIGNATURE_PARAMETER_DESC*);
 
   void SortNameIdxVector(std::vector<std::tuple<LPCSTR, UINT, UINT>> &value) {
@@ -204,12 +240,17 @@ public:
         VERIFY_ARE_EQUAL(testCB.uFlags, baseCB.uFlags);
 
         llvm::StringMap<D3D12_SHADER_VARIABLE_DESC> variableMap;
+        llvm::StringMap<ID3D12ShaderReflectionType*> variableTypeMap;
         for (UINT vi = 0; vi < testCB.Variables; ++vi) {
           ID3D12ShaderReflectionVariable *pBaseConst;
           D3D12_SHADER_VARIABLE_DESC baseConst;
           pBaseConst = pBaseCB->GetVariableByIndex(vi);
           VERIFY_SUCCEEDED(pBaseConst->GetDesc(&baseConst));
           variableMap[baseConst.Name] = baseConst;
+
+          ID3D12ShaderReflectionType* pBaseType = pBaseConst->GetType();
+          VERIFY_IS_NOT_NULL(pBaseType);
+          variableTypeMap[baseConst.Name] = pBaseType;
         }
         for (UINT vi = 0; vi < testCB.Variables; ++vi) {
           ID3D12ShaderReflectionVariable *pTestConst;
@@ -222,6 +263,30 @@ public:
           VERIFY_ARE_EQUAL(testConst.StartOffset, baseConst.StartOffset);
           // TODO: enalbe size cmp.
           //VERIFY_ARE_EQUAL(testConst.Size, baseConst.Size);
+
+          ID3D12ShaderReflectionType* pTestType = pTestConst->GetType();
+          VERIFY_IS_NOT_NULL(pTestType);
+          VERIFY_ARE_EQUAL(variableTypeMap.count(testConst.Name), 1);
+          ID3D12ShaderReflectionType* pBaseType = variableTypeMap[testConst.Name];
+
+          // Note: we suppress comparing offsets for structured buffers, because dxc and fxc don't
+          // seem to agree in that case.
+          //
+          // The information in the `D3D12_SHADER_BUFFER_DESC` doesn't give us enough to
+          // be able to isolate structured buffers, so we do the test negatively: suppress
+          // offset checks *unless* we are looking at a `cbuffer` or `tbuffer`.
+          bool shouldSuppressOffsetChecks = true;
+          switch( baseCB.Type )
+          {
+          default:
+            break;
+
+          case D3D_CT_CBUFFER:
+          case D3D_CT_TBUFFER:
+            shouldSuppressOffsetChecks = false;
+            break;
+          }
+          CompareType(pTestType, pBaseType, shouldSuppressOffsetChecks);
         }
       }
     }