Browse Source

[spirv] More support for matrices and structs in init lists (#548)

Now we can decompose matrices appeared in init lists. Also we can
decompose and construct structs from init lists.
Lei Zhang 8 years ago
parent
commit
35138b401c

+ 122 - 17
tools/clang/lib/SPIRV/InitListHandler.cpp

@@ -13,6 +13,9 @@
 
 
 #include "InitListHandler.h"
 #include "InitListHandler.h"
 
 
+#include <algorithm>
+#include <iterator>
+
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVector.h"
 
 
 namespace clang {
 namespace clang {
@@ -28,12 +31,18 @@ uint32_t InitListHandler::process(const InitListExpr *expr) {
   scalars.clear();
   scalars.clear();
 
 
   flatten(expr);
   flatten(expr);
+  // Reverse the whole initializer list so we can manipulate the list at the
+  // tail of the vector. This is more efficient than using a deque.
+  std::reverse(std::begin(initializers), std::end(initializers));
 
 
   const uint32_t init = createInitForType(expr->getType());
   const uint32_t init = createInitForType(expr->getType());
 
 
-  /// We should have consumed all initializers and scalars extracted from them.
-  assert(initializers.empty());
-  assert(scalars.empty());
+  if (init) {
+    // For successful translation, we should have consumed all initializers and
+    // scalars extracted from them.
+    assert(initializers.empty());
+    assert(scalars.empty());
+  }
 
 
   return init;
   return init;
 }
 }
@@ -64,22 +73,80 @@ void InitListHandler::decompose(const Expr *expr) {
     const uint32_t vec = theEmitter.loadIfGLValue(expr);
     const uint32_t vec = theEmitter.loadIfGLValue(expr);
     const QualType elemType = hlsl::GetHLSLVecElementType(type);
     const QualType elemType = hlsl::GetHLSLVecElementType(type);
     const auto size = hlsl::GetHLSLVecSize(type);
     const auto size = hlsl::GetHLSLVecSize(type);
-    if (size == 1) {
-      // Decomposing of size-1 vector just results in the vector itself.
-      scalars.emplace_back(vec, elemType);
+
+    decomposeVector(vec, elemType, size);
+  } else if (hlsl::IsHLSLMatType(type)) {
+    const uint32_t mat = theEmitter.loadIfGLValue(expr);
+    const QualType elemType = hlsl::GetHLSLMatElementType(type);
+
+    uint32_t rowCount = 0, colCount = 0;
+    hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+
+    if (rowCount == 1 || colCount == 1) {
+      // This also handles the scalar case
+      decomposeVector(mat, elemType, rowCount == 1 ? colCount : rowCount);
     } else {
     } else {
       const uint32_t elemTypeId = typeTranslator.translateType(elemType);
       const uint32_t elemTypeId = typeTranslator.translateType(elemType);
-      for (uint32_t i = 0; i < size; ++i) {
-        const uint32_t element =
-            theBuilder.createCompositeExtract(elemTypeId, vec, {i});
-        scalars.emplace_back(element, elemType);
-      }
+      for (uint32_t i = 0; i < rowCount; ++i)
+        for (uint32_t j = 0; j < colCount; ++j) {
+          const uint32_t element =
+              theBuilder.createCompositeExtract(elemTypeId, mat, {i, j});
+          scalars.emplace_back(element, elemType);
+        }
     }
     }
+  } else if (type->isStructureType()) {
+    llvm_unreachable("struct initializer should already been handled");
   } else {
   } else {
     emitError("decomposing type %0 in initializer list unimplemented") << type;
     emitError("decomposing type %0 in initializer list unimplemented") << type;
   }
   }
 }
 }
 
 
+void InitListHandler::decomposeVector(uint32_t vec, QualType elemType,
+                                      uint32_t size) {
+  if (size == 1) {
+    // Decomposing of size-1 vector just results in the vector itself.
+    scalars.emplace_back(vec, elemType);
+  } else {
+    const uint32_t elemTypeId = typeTranslator.translateType(elemType);
+    for (uint32_t i = 0; i < size; ++i) {
+      const uint32_t element =
+          theBuilder.createCompositeExtract(elemTypeId, vec, {i});
+      scalars.emplace_back(element, elemType);
+    }
+  }
+}
+
+void InitListHandler::tryToSplitStruct() {
+  if (initializers.empty())
+    return;
+
+  auto *init = const_cast<Expr *>(initializers.back());
+  const QualType initType = init->getType();
+  if (!initType->isStructureType())
+    return;
+
+  // We are certain the current intializer will be replaced by now.
+  initializers.pop_back();
+
+  const auto &context = theEmitter.getASTContext();
+  const auto *structDecl = initType->getAsStructureType()->getDecl();
+
+  // Create MemberExpr for each field of the struct
+  llvm::SmallVector<const Expr *, 4> fields;
+  for (auto *field : structDecl->fields()) {
+    fields.push_back(MemberExpr::Create(
+        context, init, /*isarraw*/ false, /*OperatorLoc*/ {},
+        /*QualifierLoc*/ {}, /*TemplateKWLoc*/ {}, field,
+        DeclAccessPair::make(field, AS_none),
+        DeclarationNameInfo(field->getDeclName(), /*NameLoc*/ {}),
+        /*TemplateArgumentListInfo*/ nullptr, field->getType(),
+        init->getValueKind(), OK_Ordinary));
+  }
+
+  // Push in the reverse order
+  initializers.insert(initializers.end(), fields.rbegin(), fields.rend());
+}
+
 uint32_t InitListHandler::createInitForType(QualType type) {
 uint32_t InitListHandler::createInitForType(QualType type) {
   type = type.getCanonicalType();
   type = type.getCanonicalType();
 
 
@@ -98,6 +165,9 @@ uint32_t InitListHandler::createInitForType(QualType type) {
     return createInitForMatrixType(elemType, rowCount, colCount);
     return createInitForMatrixType(elemType, rowCount, colCount);
   }
   }
 
 
+  if (type->isStructureType())
+    return createInitForStructType(type);
+
   emitError("unimplemented initializer for type '%0'") << type;
   emitError("unimplemented initializer for type '%0'") << type;
   return 0;
   return 0;
 }
 }
@@ -111,8 +181,10 @@ uint32_t InitListHandler::createInitForBuiltinType(QualType type) {
     return theEmitter.castToType(init.first, init.second, type);
     return theEmitter.castToType(init.first, init.second, type);
   }
   }
 
 
-  const Expr *init = initializers.front();
-  initializers.pop_front();
+  tryToSplitStruct();
+
+  const Expr *init = initializers.back();
+  initializers.pop_back();
 
 
   if (!init->getType()->isBuiltinType()) {
   if (!init->getType()->isBuiltinType()) {
     decompose(init);
     decompose(init);
@@ -130,11 +202,14 @@ uint32_t InitListHandler::createInitForVectorType(QualType elemType,
   // directly. For all other cases, we need to construct a new vector as the
   // directly. For all other cases, we need to construct a new vector as the
   // initializer.
   // initializer.
   if (scalars.empty()) {
   if (scalars.empty()) {
-    const Expr *init = initializers.front();
+    // A struct may contain a whole vector.
+    tryToSplitStruct();
+
+    const Expr *init = initializers.back();
 
 
     if (hlsl::IsHLSLVecType(init->getType()) &&
     if (hlsl::IsHLSLVecType(init->getType()) &&
         hlsl::GetHLSLVecSize(init->getType()) == count) {
         hlsl::GetHLSLVecSize(init->getType()) == count) {
-      initializers.pop_front();
+      initializers.pop_back();
       /// HLSL vector types are parameterized templates and we cannot
       /// HLSL vector types are parameterized templates and we cannot
       /// construct them. So we construct an ExtVectorType here instead.
       /// construct them. So we construct an ExtVectorType here instead.
       /// This is unfortunate since it means we need to handle ExtVectorType
       /// This is unfortunate since it means we need to handle ExtVectorType
@@ -169,14 +244,17 @@ uint32_t InitListHandler::createInitForMatrixType(QualType elemType,
   // Same as the vector case, first try to see if we already have a matrix at
   // Same as the vector case, first try to see if we already have a matrix at
   // the beginning of the initializer queue.
   // the beginning of the initializer queue.
   if (scalars.empty()) {
   if (scalars.empty()) {
-    const Expr *init = initializers.front();
+    // A struct may contain a whole matrix.
+    tryToSplitStruct();
+
+    const Expr *init = initializers.back();
 
 
     if (hlsl::IsHLSLMatType(init->getType())) {
     if (hlsl::IsHLSLMatType(init->getType())) {
       uint32_t initRowCount = 0, initColCount = 0;
       uint32_t initRowCount = 0, initColCount = 0;
       hlsl::GetHLSLMatRowColCount(init->getType(), initRowCount, initColCount);
       hlsl::GetHLSLMatRowColCount(init->getType(), initRowCount, initColCount);
 
 
       if (rowCount == initRowCount && colCount == initColCount) {
       if (rowCount == initRowCount && colCount == initColCount) {
-        initializers.pop_front();
+        initializers.pop_back();
         // TODO: We only support FP matrices now. Do type cast here after
         // TODO: We only support FP matrices now. Do type cast here after
         // adding more matrix types.
         // adding more matrix types.
         return theEmitter.loadIfGLValue(init);
         return theEmitter.loadIfGLValue(init);
@@ -204,5 +282,32 @@ uint32_t InitListHandler::createInitForMatrixType(QualType elemType,
   return theBuilder.createCompositeConstruct(matType, vectors);
   return theBuilder.createCompositeConstruct(matType, vectors);
 }
 }
 
 
+uint32_t InitListHandler::createInitForStructType(QualType type) {
+  // Same as the vector case, first try to see if we already have a struct at
+  // the beginning of the initializer queue.
+  if (scalars.empty()) {
+    const Expr *init = initializers.back();
+    // We can only avoid decomposing and reconstructing when the type is
+    // exactly the same.
+    if (type.getCanonicalType() == init->getType().getCanonicalType()) {
+      initializers.pop_back();
+      return theEmitter.loadIfGLValue(init);
+    }
+
+    // Otherwise, if the next initializer is a struct, it is not of the same
+    // type as we expected. Split it.
+    tryToSplitStruct();
+  }
+
+  llvm::SmallVector<uint32_t, 4> fields;
+  const RecordDecl *structDecl = type->getAsStructureType()->getDecl();
+  for (const auto *field : structDecl->fields())
+    fields.push_back(createInitForType(field->getType()));
+
+  const uint32_t typeId = typeTranslator.translateType(type);
+  // TODO: use OpConstantComposite when all components are constants
+  return theBuilder.createCompositeConstruct(typeId, fields);
+}
+
 } // end namespace spirv
 } // end namespace spirv
 } // end namespace clang
 } // end namespace clang

+ 13 - 2
tools/clang/lib/SPIRV/InitListHandler.h

@@ -16,6 +16,7 @@
 
 
 #include <deque>
 #include <deque>
 #include <utility>
 #include <utility>
+#include <vector>
 
 
 #include "clang/AST/Expr.h"
 #include "clang/AST/Expr.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/Diagnostic.h"
@@ -96,6 +97,11 @@ private:
   /// Decomposes the given Expr and puts all elements into the end of the
   /// Decomposes the given Expr and puts all elements into the end of the
   /// scalars queue.
   /// scalars queue.
   void decompose(const Expr *expr);
   void decompose(const Expr *expr);
+  void decomposeVector(uint32_t vec, QualType elemType, uint32_t size);
+
+  /// If the next initializer is a struct, replaces it with MemberExprs to all
+  /// its members. Otherwise, does nothing.
+  void tryToSplitStruct();
 
 
   /// Emits the necessary SPIR-V instructions to create a SPIR-V value of the
   /// Emits the necessary SPIR-V instructions to create a SPIR-V value of the
   /// given type. The scalars and initializers queue will be used to fetch the
   /// given type. The scalars and initializers queue will be used to fetch the
@@ -105,6 +111,7 @@ private:
   uint32_t createInitForVectorType(QualType elemType, uint32_t count);
   uint32_t createInitForVectorType(QualType elemType, uint32_t count);
   uint32_t createInitForMatrixType(QualType elemType, uint32_t rowCount,
   uint32_t createInitForMatrixType(QualType elemType, uint32_t rowCount,
                                    uint32_t colCount);
                                    uint32_t colCount);
+  uint32_t createInitForStructType(QualType type);
 
 
 private:
 private:
   SPIRVEmitter &theEmitter;
   SPIRVEmitter &theEmitter;
@@ -112,8 +119,12 @@ private:
   TypeTranslator &typeTranslator;
   TypeTranslator &typeTranslator;
   DiagnosticsEngine &diags;
   DiagnosticsEngine &diags;
 
 
-  /// A queue keeping track of unused AST nodes for initializers
-  std::deque<const Expr *> initializers;
+  /// A queue keeping track of unused AST nodes for initializers. Since we will
+  /// only comsume initializers from the head of the queue and will not add new
+  /// initializers to the tail of the queue, we use a vector (containing the
+  /// reverse of the original intializer list) here and manipulate its tail.
+  /// This is more efficient than using deque.
+  std::vector<const Expr *> initializers;
   /// A queue keeping track of previously extracted but unused scalars.
   /// A queue keeping track of previously extracted but unused scalars.
   /// Each element is a pair, with the first element as the SPIR-V <result-id>
   /// Each element is a pair, with the first element as the SPIR-V <result-id>
   /// and the second element as the AST type of the scalar value.
   /// and the second element as the AST type of the scalar value.

+ 22 - 0
tools/clang/test/CodeGenSPIRV/constant.matrix.hlsl

@@ -0,0 +1,22 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// TODO: actually emit constant SPIR-V instructions for the following tests.
+
+void main() {
+// CHECK:       OpStore %a %float_1
+    float1x1 a = float1x1(1.);
+
+// CHECK-NEXT: [[b:%\d+]] = OpCompositeConstruct %v2float %float_2 %float_3
+// CHECK-NEXT: OpStore %b [[b]]
+    float1x2 b = float1x2(2., 3.);
+
+// CHECK-NEXT: [[c:%\d+]] = OpCompositeConstruct %v2float %float_4 %float_5
+// CHECK-NEXT: OpStore %c [[c]]
+    float2x1 c = float2x1(4., 5.);
+
+// CHECK-NEXT: [[d0:%\d+]] = OpCompositeConstruct %v3float %float_6 %float_7 %float_8
+// CHECK-NEXT: [[d1:%\d+]] = OpCompositeConstruct %v3float %float_9 %float_10 %float_11
+// CHECK-NEXT: [[d:%\d+]] = OpCompositeConstruct %mat2v3float [[d0]] [[d1]]
+// CHECK-NEXT: OpStore %d [[d]]
+    float2x3 d = float2x3(6., 7., 8., 9., 10., 11.);
+}

+ 24 - 0
tools/clang/test/CodeGenSPIRV/constant.struct.hlsl

@@ -0,0 +1,24 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    uint a;
+    bool2 b;
+    float2x2 c;
+};
+
+struct T {
+    S x;
+    int y;
+};
+
+void main() {
+    // TODO: Okay, we are not acutally generating constants here.
+    // We should optimize to use OpConstantComposite for the following.
+// CHECK:      [[b:%\d+]] = OpCompositeConstruct %v2bool %true %false
+// CHECK-NEXT: [[c1:%\d+]] = OpCompositeConstruct %v2float %float_1 %float_2
+// CHECK-NEXT: [[c2:%\d+]] = OpCompositeConstruct %v2float %float_3 %float_4
+// CHECK-NEXT: [[c:%\d+]] = OpCompositeConstruct %mat2v2float [[c1]] [[c2]]
+// CHECK-NEXT: [[s:%\d+]] = OpCompositeConstruct %S %uint_1 [[b]] [[c]]
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %T [[s]] %int_5
+    T t = {1, true, false, 1.0, 2.0, 3.0, 4.0, 5};
+}

+ 6 - 0
tools/clang/test/CodeGenSPIRV/type.struct.hlsl

@@ -1,5 +1,7 @@
 // Run: %dxc -T vs_6_0 -E main
 // Run: %dxc -T vs_6_0 -E main
 
 
+// CHECK:      OpName %N "N"
+
 // CHECK:      OpName %S "S"
 // CHECK:      OpName %S "S"
 // CHECK-NEXT: OpMemberName %S 0 "a"
 // CHECK-NEXT: OpMemberName %S 0 "a"
 // CHECK-NEXT: OpMemberName %S 1 "b"
 // CHECK-NEXT: OpMemberName %S 1 "b"
@@ -10,6 +12,9 @@
 // CHECK-NEXT: OpMemberName %T 1 "y"
 // CHECK-NEXT: OpMemberName %T 1 "y"
 // CHECK-NEXT: OpMemberName %T 2 "z"
 // CHECK-NEXT: OpMemberName %T 2 "z"
 
 
+// CHECK:      %N = OpTypeStruct
+struct N {};
+
 // CHECK:      %S = OpTypeStruct %uint %v4float %mat2v3float
 // CHECK:      %S = OpTypeStruct %uint %v4float %mat2v3float
 struct S {
 struct S {
     uint a;
     uint a;
@@ -25,6 +30,7 @@ struct T {
 };
 };
 
 
 void main() {
 void main() {
+    N n;
     S s;
     S s;
     T t;
     T t;
 }
 }

+ 4 - 0
tools/clang/test/CodeGenSPIRV/var.init.matrix.1x1.hlsl

@@ -18,4 +18,8 @@ void main() {
 // CHECK-NEXT: [[cv:%\d+]] = OpConvertSToF %float [[scalar]]
 // CHECK-NEXT: [[cv:%\d+]] = OpConvertSToF %float [[scalar]]
 // CHECK-NEXT: OpStore %mat5 [[cv]]
 // CHECK-NEXT: OpStore %mat5 [[cv]]
     float1x1 mat5 = {scalar};
     float1x1 mat5 = {scalar};
+
+// CHECK-NEXT: [[mat5:%\d+]] = OpLoad %float %mat5
+// CHECK-NEXT: OpStore %mat6 [[mat5]]
+    float1x1 mat6 = {mat5};
 }
 }

+ 11 - 0
tools/clang/test/CodeGenSPIRV/var.init.matrix.1xn.hlsl

@@ -31,4 +31,15 @@ void main() {
 // CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v4float [[cv0]] [[cv1]] [[cv2]] [[cv3]]
 // CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v4float [[cv0]] [[cv1]] [[cv2]] [[cv3]]
 // CHECK-NEXT: OpStore %mat5 [[cc0]]
 // CHECK-NEXT: OpStore %mat5 [[cc0]]
     float1x4 mat5 = {scalar, vec2, vec1};
     float1x4 mat5 = {scalar, vec2, vec1};
+
+    float1x2 mat6;
+// CHECK-NEXT: [[mat6:%\d+]] = OpLoad %v2float %mat6
+// CHECK-NEXT: [[ce2:%\d+]] = OpCompositeExtract %float [[mat6]] 0
+// CHECK-NEXT: [[ce3:%\d+]] = OpCompositeExtract %float [[mat6]] 1
+// CHECK-NEXT: [[mat6:%\d+]] = OpLoad %v2float %mat6
+// CHECK-NEXT: [[ce4:%\d+]] = OpCompositeExtract %float [[mat6]] 0
+// CHECK-NEXT: [[ce5:%\d+]] = OpCompositeExtract %float [[mat6]] 1
+// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v4float [[ce2]] [[ce3]] [[ce4]] [[ce5]]
+// CHECK-NEXT: OpStore %mat7 [[cc1]]
+    float1x4 mat7 = {mat6, mat6};
 }
 }

+ 11 - 0
tools/clang/test/CodeGenSPIRV/var.init.matrix.mx1.hlsl

@@ -31,4 +31,15 @@ void main() {
 // CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v4float [[cv0]] [[cv1]] [[cv2]] [[cv3]]
 // CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v4float [[cv0]] [[cv1]] [[cv2]] [[cv3]]
 // CHECK-NEXT: OpStore %mat5 [[cc0]]
 // CHECK-NEXT: OpStore %mat5 [[cc0]]
     float4x1 mat5 = {scalar, vec2, vec1};
     float4x1 mat5 = {scalar, vec2, vec1};
+
+    float2x1 mat6;
+// CHECK-NEXT: [[mat6:%\d+]] = OpLoad %v2float %mat6
+// CHECK-NEXT: [[ce2:%\d+]] = OpCompositeExtract %float [[mat6]] 0
+// CHECK-NEXT: [[ce3:%\d+]] = OpCompositeExtract %float [[mat6]] 1
+// CHECK-NEXT: [[mat6:%\d+]] = OpLoad %v2float %mat6
+// CHECK-NEXT: [[ce4:%\d+]] = OpCompositeExtract %float [[mat6]] 0
+// CHECK-NEXT: [[ce5:%\d+]] = OpCompositeExtract %float [[mat6]] 1
+// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v4float [[ce2]] [[ce3]] [[ce4]] [[ce5]]
+// CHECK-NEXT: OpStore %mat7 [[cc1]]
+    float4x1 mat7 = {mat6, mat6};
 }
 }

+ 38 - 0
tools/clang/test/CodeGenSPIRV/var.init.matrix.mxn.hlsl

@@ -110,4 +110,42 @@ void main() {
                      intScalar, boolScalar,         // [1] - 1 scalar
                      intScalar, boolScalar,         // [1] - 1 scalar
                      boolVec3                       // [2]
                      boolVec3                       // [2]
     };
     };
+
+    // Decomposing matrices
+    float2x2 mat8;
+    float2x4 mat9;
+    float4x1 mat10;
+    // TODO: Optimization opportunity. We are extracting all elements in each
+    // vector and then reconstructing the original vector. Optimally we should
+    // extract vectors from matrices directly.
+
+// CHECK-NEXT: [[mat8:%\d+]] = OpLoad %mat2v2float %mat8
+// CHECK-NEXT: [[mat8_00:%\d+]] = OpCompositeExtract %float [[mat8]] 0 0
+// CHECK-NEXT: [[mat8_01:%\d+]] = OpCompositeExtract %float [[mat8]] 0 1
+// CHECK-NEXT: [[mat8_10:%\d+]] = OpCompositeExtract %float [[mat8]] 1 0
+// CHECK-NEXT: [[mat8_11:%\d+]] = OpCompositeExtract %float [[mat8]] 1 1
+// CHECK-NEXT: [[cc21:%\d+]] = OpCompositeConstruct %v4float [[mat8_00]] [[mat8_01]] [[mat8_10]] [[mat8_11]]
+
+// CHECK-NEXT: [[mat9:%\d+]] = OpLoad %mat2v4float %mat9
+// CHECK-NEXT: [[mat9_00:%\d+]] = OpCompositeExtract %float [[mat9]] 0 0
+// CHECK-NEXT: [[mat9_01:%\d+]] = OpCompositeExtract %float [[mat9]] 0 1
+// CHECK-NEXT: [[mat9_02:%\d+]] = OpCompositeExtract %float [[mat9]] 0 2
+// CHECK-NEXT: [[mat9_03:%\d+]] = OpCompositeExtract %float [[mat9]] 0 3
+// CHECK-NEXT: [[mat9_10:%\d+]] = OpCompositeExtract %float [[mat9]] 1 0
+// CHECK-NEXT: [[mat9_11:%\d+]] = OpCompositeExtract %float [[mat9]] 1 1
+// CHECK-NEXT: [[mat9_12:%\d+]] = OpCompositeExtract %float [[mat9]] 1 2
+// CHECK-NEXT: [[mat9_13:%\d+]] = OpCompositeExtract %float [[mat9]] 1 3
+// CHECK-NEXT: [[cc22:%\d+]] = OpCompositeConstruct %v4float [[mat9_00]] [[mat9_01]] [[mat9_02]] [[mat9_03]]
+// CHECK-NEXT: [[cc23:%\d+]] = OpCompositeConstruct %v4float [[mat9_10]] [[mat9_11]] [[mat9_12]] [[mat9_13]]
+
+// CHECK-NEXT: [[mat10:%\d+]] = OpLoad %v4float %mat10
+// CHECK-NEXT: [[mat10_0:%\d+]] = OpCompositeExtract %float [[mat10]] 0
+// CHECK-NEXT: [[mat10_1:%\d+]] = OpCompositeExtract %float [[mat10]] 1
+// CHECK-NEXT: [[mat10_2:%\d+]] = OpCompositeExtract %float [[mat10]] 2
+// CHECK-NEXT: [[mat10_3:%\d+]] = OpCompositeExtract %float [[mat10]] 3
+// CHECK-NEXT: [[cc24:%\d+]] = OpCompositeConstruct %v4float [[mat10_0]] [[mat10_1]] [[mat10_2]] [[mat10_3]]
+
+// CHECK-NEXT: [[cc25:%\d+]] = OpCompositeConstruct %mat4v4float [[cc21]] [[cc22]] [[cc23]] [[cc24]]
+// CHECK-NEXT: OpStore %mat11 [[cc25]]
+    float4x4 mat11 = {mat8, mat9, mat10};
 }
 }

+ 91 - 0
tools/clang/test/CodeGenSPIRV/var.init.struct.hlsl

@@ -0,0 +1,91 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    int3 a;
+    uint b;
+    float2x2 c;
+};
+
+struct T {
+    // Same fields as S
+    int3 h;
+    uint i;
+    float2x2 j;
+
+    // Additional field
+    bool2 k;
+
+    // Embedded S
+    S l;
+
+    // Similar to S but need some casts
+    float3 m;
+    int n;
+    float2x2 o;
+};
+
+struct O {
+    int x;
+};
+
+struct P {
+    O y;
+    float z;
+};
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    // Flat initializer list
+// CHECK:      [[a:%\d+]] = OpCompositeConstruct %v3int %int_1 %int_2 %int_3
+// CHECK-NEXT: [[c0:%\d+]] = OpCompositeConstruct %v2float %float_1 %float_2
+// CHECK-NEXT: [[c1:%\d+]] = OpCompositeConstruct %v2float %float_3 %float_4
+// CHECK-NEXT: [[c:%\d+]] = OpCompositeConstruct %mat2v2float [[c0]] [[c1]]
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %S [[a]] %uint_42 [[c]]
+    S s1 = {1, 2, 3, 42, 1., 2., 3., 4.};
+
+    // Random parentheses
+// CHECK:      [[a:%\d+]] = OpCompositeConstruct %v3int %int_1 %int_2 %int_3
+// CHECK-NEXT: [[c0:%\d+]] = OpCompositeConstruct %v2float %float_1 %float_2
+// CHECK-NEXT: [[c1:%\d+]] = OpCompositeConstruct %v2float %float_3 %float_4
+// CHECK-NEXT: [[c:%\d+]] = OpCompositeConstruct %mat2v2float [[c0]] [[c1]]
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %S [[a]] %uint_42 [[c]]
+    S s2 = {{1, 2}, 3, {{42}, {{1.}}}, {2., {3., 4.}}};
+
+    // Flat initalizer list for nested structs
+// CHECK:      [[y:%\d+]] = OpCompositeConstruct %O %int_1
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %P [[y]] %float_2
+    P p = {1, 2.};
+
+    // Mixed case: use struct as a whole, decomposing struct, type casting
+// CHECK:      [[s1a:%\d+]] = OpAccessChain %_ptr_Function_v3int %s1 %int_0
+// CHECK-NEXT: [[h:%\d+]] = OpLoad %v3int [[s1a]]
+
+// CHECK-NEXT: [[s1b:%\d+]] = OpAccessChain %_ptr_Function_uint %s1 %int_1
+// CHECK-NEXT: [[i:%\d+]] = OpLoad %uint [[s1b]]
+
+// CHECK-NEXT: [[s1c:%\d+]] = OpAccessChain %_ptr_Function_mat2v2float %s1 %int_2
+// CHECK-NEXT: [[j:%\d+]] = OpLoad %mat2v2float [[s1c]]
+
+// CHECK-NEXT: [[k:%\d+]] = OpCompositeConstruct %v2bool %true %false
+
+// CHECK-NEXT: [[l:%\d+]] = OpLoad %S %s2
+
+// CHECK-NEXT: [[s2a:%\d+]] = OpAccessChain %_ptr_Function_v3int %s2 %int_0
+// CHECK-NEXT: [[s2av:%\d+]] = OpLoad %v3int [[s2a]]
+// CHECK-NEXT: [[m:%\d+]] = OpConvertSToF %v3float [[s2av]]
+
+// CHECK-NEXT: [[s2b:%\d+]] = OpAccessChain %_ptr_Function_uint %s2 %int_1
+// CHECK-NEXT: [[s2bv:%\d+]] = OpLoad %uint [[s2b]]
+// CHECK-NEXT: [[n:%\d+]] = OpBitcast %int [[s2bv]]
+
+// CHECK-NEXT: [[s2c:%\d+]] = OpAccessChain %_ptr_Function_mat2v2float %s2 %int_2
+// CHECK-NEXT: [[o:%\d+]] = OpLoad %mat2v2float %65
+
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %T [[h]] [[i]] [[j]] [[k]] [[l]] [[m]] [[n]] [[o]]
+    T t = {s1,          // Decomposing struct
+           true, false, // constructing field from scalar
+           s2,          // Embedded struct
+           s2           // Decomposing struct + type casting
+          };
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -44,6 +44,8 @@ TEST_F(FileTest, TypedefTypes) { runFileTest("type.typedef.hlsl"); }
 // For constants
 // For constants
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
 TEST_F(FileTest, VectorConstants) { runFileTest("constant.vector.hlsl"); }
 TEST_F(FileTest, VectorConstants) { runFileTest("constant.vector.hlsl"); }
+TEST_F(FileTest, MatrixConstants) { runFileTest("constant.matrix.hlsl"); }
+TEST_F(FileTest, StructConstants) { runFileTest("constant.struct.hlsl"); }
 
 
 // For variables
 // For variables
 TEST_F(FileTest, VarInit) { runFileTest("var.init.hlsl"); }
 TEST_F(FileTest, VarInit) { runFileTest("var.init.hlsl"); }
@@ -51,6 +53,7 @@ TEST_F(FileTest, VarInitMatrixMxN) { runFileTest("var.init.matrix.mxn.hlsl"); }
 TEST_F(FileTest, VarInitMatrixMx1) { runFileTest("var.init.matrix.mx1.hlsl"); }
 TEST_F(FileTest, VarInitMatrixMx1) { runFileTest("var.init.matrix.mx1.hlsl"); }
 TEST_F(FileTest, VarInitMatrix1xN) { runFileTest("var.init.matrix.1xn.hlsl"); }
 TEST_F(FileTest, VarInitMatrix1xN) { runFileTest("var.init.matrix.1xn.hlsl"); }
 TEST_F(FileTest, VarInitMatrix1x1) { runFileTest("var.init.matrix.1x1.hlsl"); }
 TEST_F(FileTest, VarInitMatrix1x1) { runFileTest("var.init.matrix.1x1.hlsl"); }
+TEST_F(FileTest, VarInitStruct) { runFileTest("var.init.struct.hlsl"); }
 TEST_F(FileTest, StaticVar) { runFileTest("var.static.hlsl"); }
 TEST_F(FileTest, StaticVar) { runFileTest("var.static.hlsl"); }
 
 
 // For prefix/postfix increment/decrement
 // For prefix/postfix increment/decrement