Explorar o código

Merge pull request #1955 from tex3d/19h1-fixes

merge argument/addrspace cast/HLMatrixLower fixes to 19h1-rel
Tex Riddell %!s(int64=6) %!d(string=hai) anos
pai
achega
bc3beec39f
Modificáronse 48 ficheiros con 2940 adicións e 2880 borrados
  1. 2 0
      include/dxc/DXIL/DxilUtil.h
  2. 1 31
      include/dxc/HLSL/HLMatrixLowerHelper.h
  3. 97 0
      include/dxc/HLSL/HLMatrixType.h
  4. 48 28
      lib/DXIL/DxilUtil.cpp
  5. 3 0
      lib/HLSL/CMakeLists.txt
  6. 5 3
      lib/HLSL/DxilCondenseResources.cpp
  7. 0 1
      lib/HLSL/DxilGenerationPass.cpp
  8. 251 0
      lib/HLSL/HLMatrixBitcastLowerPass.cpp
  9. 1254 2558
      lib/HLSL/HLMatrixLowerPass.cpp
  10. 284 0
      lib/HLSL/HLMatrixSubscriptUseReplacer.cpp
  11. 67 0
      lib/HLSL/HLMatrixSubscriptUseReplacer.h
  12. 179 0
      lib/HLSL/HLMatrixType.cpp
  13. 85 57
      lib/HLSL/HLOperationLower.cpp
  14. 0 1
      lib/HLSL/HLOperationLowerExtension.cpp
  15. 43 53
      lib/HLSL/HLSignatureLower.cpp
  16. 7 1
      lib/Transforms/IPO/PassManagerBuilder.cpp
  17. 1 1
      lib/Transforms/Scalar/DxilLoopUnroll.cpp
  18. 110 47
      lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
  19. 1 2
      tools/clang/include/clang/AST/HlslTypes.h
  20. 6 4
      tools/clang/lib/AST/HlslTypes.cpp
  21. 2 1
      tools/clang/lib/CodeGen/CGClass.cpp
  22. 1 1
      tools/clang/lib/CodeGen/CGDecl.cpp
  23. 15 2
      tools/clang/lib/CodeGen/CGExpr.cpp
  24. 2 2
      tools/clang/lib/CodeGen/CGExprScalar.cpp
  25. 104 60
      tools/clang/lib/CodeGen/CGHLSLMS.cpp
  26. 1 6
      tools/clang/lib/CodeGen/CodeGenModule.cpp
  27. 1 1
      tools/clang/lib/SPIRV/SPIRVEmitter.cpp
  28. 4 1
      tools/clang/lib/Sema/SemaHLSL.cpp
  29. 2 1
      tools/clang/test/CodeGenHLSL/RValSubscript.hlsl
  30. 15 0
      tools/clang/test/CodeGenHLSL/declarations/functions/inout_derived_struct_no_crash.hlsl
  31. 28 0
      tools/clang/test/CodeGenHLSL/quick-test/addrspacecast.hlsl
  32. 29 0
      tools/clang/test/CodeGenHLSL/quick-test/empty_struct3.hlsl
  33. 50 0
      tools/clang/test/CodeGenHLSL/quick-test/groupshared-base-cast.hlsl
  34. 47 0
      tools/clang/test/CodeGenHLSL/quick-test/groupshared-member-matrix-subscript-col.hlsl
  35. 46 0
      tools/clang/test/CodeGenHLSL/quick-test/groupshared-member-matrix-subscript.hlsl
  36. 45 0
      tools/clang/test/CodeGenHLSL/quick-test/groupshared-member-matrix-subscript2.hlsl
  37. 12 0
      tools/clang/test/CodeGenHLSL/quick-test/mat_init_splat.hlsl
  38. 1 1
      tools/clang/test/CodeGenHLSL/quick-test/static_global_copy3.hlsl
  39. 28 0
      tools/clang/test/CodeGenHLSL/quick-test/struct_param_in_mod.hlsl
  40. 38 0
      tools/clang/test/CodeGenHLSL/quick-test/struct_param_in_mod2.hlsl
  41. 8 0
      tools/clang/test/CodeGenHLSL/quick-test/unused_matrix_input_regression.hlsl
  42. 2 1
      tools/clang/test/CodeGenHLSL/shader-compat-suite/lib_arg_flatten/lib_arg_flatten.hlsl
  43. 1 1
      tools/clang/test/CodeGenHLSL/shader-compat-suite/lib_arg_flatten/lib_arg_flatten3.hlsl
  44. 2 2
      tools/clang/test/CodeGenHLSL/shader-compat-suite/lib_arg_flatten/lib_empty_struct_arg.hlsl
  45. 3 6
      tools/clang/test/CodeGenHLSL/share_mem_dbg.hlsl
  46. 1 1
      tools/clang/test/CodeGenHLSL/staticGlobals.hlsl
  47. 4 1
      tools/clang/test/CodeGenHLSL/static_matrix.hlsl
  48. 4 5
      tools/clang/tools/dxcompiler/dxcdisassembler.cpp

+ 2 - 0
include/dxc/DXIL/DxilUtil.h

@@ -102,8 +102,10 @@ namespace dxilutil {
   std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
     llvm::LLVMContext &Ctx, std::string &DiagStr);
   void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context);
+  bool IsIntegerOrFloatingPointType(llvm::Type *Ty);
   // Returns true if type contains HLSL Object type (resource)
   bool ContainsHLSLObjectType(llvm::Type *Ty);
+  bool IsHLSLResourceType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLMatrixType(llvm::Type *Ty);
   bool IsSplat(llvm::ConstantDataVector *cdv);

+ 1 - 31
include/dxc/HLSL/HLMatrixLowerHelper.h

@@ -26,41 +26,11 @@ class DxilFieldAnnotation;
 class DxilTypeSystem;
 
 namespace HLMatrixLower {
-// TODO: use type annotation.
-DxilFieldAnnotation *FindAnnotationFromMatUser(llvm::Value *Mat,
-                                               DxilTypeSystem &typeSys);
-// Translate matrix type to vector type.
-llvm::Type *LowerMatrixType(llvm::Type *Ty, bool forMem = false);
-// TODO: use type annotation.
-llvm::Type *GetMatrixInfo(llvm::Type *Ty, unsigned &col, unsigned &row);
-// TODO: use type annotation.
-bool IsMatrixArrayPointer(llvm::Type *Ty);
-// Translate matrix array pointer type to vector array pointer type.
-llvm::Type *LowerMatrixArrayPointer(llvm::Type *Ty, bool forMem = false);
 
-llvm::Value *BuildVector(llvm::Type *EltTy, unsigned size,
+llvm::Value *BuildVector(llvm::Type *EltTy,
                          llvm::ArrayRef<llvm::Value *> elts,
                          llvm::IRBuilder<> &Builder);
 
-llvm::Value *VecMatrixMemToReg(llvm::Value *VecVal, llvm::Type *MatType,
-                               llvm::IRBuilder<> &Builder);
-llvm::Value *VecMatrixRegToMem(llvm::Value* VecVal, llvm::Type *MatType,
-                               llvm::IRBuilder<> &Builder);
-llvm::Instruction *CreateVecMatrixLoad(llvm::Value *VecPtr,
-                                       llvm::Type *MatType, llvm::IRBuilder<> &Builder);
-llvm::Instruction *CreateVecMatrixStore(llvm::Value* VecVal, llvm::Value *VecPtr,
-                                        llvm::Type *MatType, llvm::IRBuilder<> &Builder);
-
-// For case like mat[i][j].
-// IdxList is [i][0], [i][1], [i][2],[i][3].
-// Idx is j.
-// return [i][j] not mat[i][j] because resource ptr and temp ptr need different
-// code gen.
-llvm::Value *
-LowerGEPOnMatIndexListToIndex(llvm::GetElementPtrInst *GEP,
-                              llvm::ArrayRef<llvm::Value *> IdxList);
-unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row);
-unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col);
 } // namespace HLMatrixLower
 
 } // namespace hlsl

+ 97 - 0
include/dxc/HLSL/HLMatrixType.h

@@ -0,0 +1,97 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixType.h                                                            //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+
+#include "llvm/IR/IRBuilder.h"
+
+namespace llvm {
+  template<typename T>
+  class ArrayRef;
+  class Type;
+  class Value;
+  class Constant;
+  class StructType;
+  class VectorType;
+  class StoreInst;
+}
+
+namespace hlsl {
+
+class DxilFieldAnnotation;
+class DxilTypeSystem;
+
+// A high-level matrix type in LLVM IR.
+//
+// Matrices are represented by an llvm struct type of the following form:
+// { [RowCount x <ColCount x RegReprTy>] }
+// Note that the element type is always in its register representation (ie bools are i1s).
+// This allows preserving the original type and is okay since matrix types are only
+// manipulated in an opaque way, through intrinsics.
+//
+// During matrix lowering, matrices are converted to vectors of the following form:
+// <RowCount*ColCount x Ty>
+// At this point, register vs memory representation starts to matter and we have to
+// imitate the codegen for scalar and vector bools: i1s when in llvm registers,
+// and i32s when in memory (allocas, pointers, or in structs/lists, which are always in memory).
+//
+// This class is designed to resemble a llvm::Type-derived class.
+class HLMatrixType
+{
+public:
+  static constexpr const char* StructNamePrefix = "class.matrix.";
+
+  HLMatrixType() : RegReprElemTy(nullptr), NumRows(0), NumColumns(0) {}
+  HLMatrixType(llvm::Type *RegReprElemTy, unsigned NumRows, unsigned NumColumns);
+
+  // We allow default construction to an invalid state to support the dynCast pattern.
+  // This tests whether we have a legit object.
+  operator bool() const { return RegReprElemTy != nullptr; }
+
+  llvm::Type *getElementType(bool MemRepr) const;
+  llvm::Type *getElementTypeForReg() const { return getElementType(false); }
+  llvm::Type *getElementTypeForMem() const { return getElementType(true); }
+  unsigned getNumRows() const { return NumRows; }
+  unsigned getNumColumns() const { return NumColumns; }
+  unsigned getNumElements() const { return NumRows * NumColumns; }
+  unsigned getRowMajorIndex(unsigned RowIdx, unsigned ColIdx) const;
+  unsigned getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx) const;
+  static unsigned getRowMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns);
+  static unsigned getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns);
+
+  llvm::VectorType *getLoweredVectorType(bool MemRepr) const;
+  llvm::VectorType *getLoweredVectorTypeForReg() const { return getLoweredVectorType(false); }
+  llvm::VectorType *getLoweredVectorTypeForMem() const { return getLoweredVectorType(true); }
+
+  llvm::Value *emitLoweredMemToReg(llvm::Value *Val, llvm::IRBuilder<> &Builder) const;
+  llvm::Value *emitLoweredRegToMem(llvm::Value *Val, llvm::IRBuilder<> &Builder) const;
+  llvm::Value *emitLoweredLoad(llvm::Value *Ptr, llvm::IRBuilder<> &Builder) const;
+  llvm::StoreInst *emitLoweredStore(llvm::Value *Val, llvm::Value *Ptr, llvm::IRBuilder<> &Builder) const;
+
+  llvm::Value *emitLoweredVectorRowToCol(llvm::Value *VecVal, llvm::IRBuilder<> &Builder) const;
+  llvm::Value *emitLoweredVectorColToRow(llvm::Value *VecVal, llvm::IRBuilder<> &Builder) const;
+
+  static bool isa(llvm::Type *Ty);
+  static bool isMatrixPtr(llvm::Type *Ty);
+  static bool isMatrixArray(llvm::Type *Ty);
+  static bool isMatrixArrayPtr(llvm::Type *Ty);
+  static bool isMatrixPtrOrArrayPtr(llvm::Type *Ty);
+  static bool isMatrixOrPtrOrArrayPtr(llvm::Type *Ty);
+
+  static llvm::Type *getLoweredType(llvm::Type *Ty, bool MemRepr = false);
+
+  static HLMatrixType cast(llvm::Type *Ty);
+  static HLMatrixType dyn_cast(llvm::Type *Ty);
+
+private:
+  llvm::Type *RegReprElemTy;
+  unsigned NumRows, NumColumns;
+};
+
+} // namespace hlsl

+ 48 - 28
lib/DXIL/DxilUtil.cpp

@@ -396,16 +396,9 @@ llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Function* F) {
   return SkipAllocas(FindAllocaInsertionPt(F));
 }
 
-bool IsHLSLObjectType(llvm::Type *Ty) {
+bool IsHLSLResourceType(llvm::Type *Ty) {
   if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
     StringRef name = ST->getName();
-    // TODO: don't check names.
-    if (name.startswith("dx.types.wave_t"))
-      return true;
-
-    if (name.endswith("_slice_type"))
-      return false;
-
     name = name.ltrim("class.");
     name = name.ltrim("struct.");
 
@@ -414,13 +407,6 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
     if (name == "SamplerComparisonState")
       return true;
 
-    if (name.startswith("TriangleStream<"))
-      return true;
-    if (name.startswith("PointStream<"))
-      return true;
-    if (name.startswith("LineStream<"))
-      return true;
-
     if (name.startswith("AppendStructuredBuffer<"))
       return true;
     if (name.startswith("ConsumeStructuredBuffer<"))
@@ -441,23 +427,53 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
       return true;
     if (name.startswith("StructuredBuffer<"))
       return true;
-    if (name.startswith("Texture1D<"))
-      return true;
-    if (name.startswith("Texture1DArray<"))
-      return true;
-    if (name.startswith("Texture2D<"))
-      return true;
-    if (name.startswith("Texture2DArray<"))
-      return true;
-    if (name.startswith("Texture3D<"))
+
+    if (name.startswith("Texture")) {
+      name = name.ltrim("Texture");
+      if (name.startswith("1D<"))
+        return true;
+      if (name.startswith("1DArray<"))
+        return true;
+      if (name.startswith("2D<"))
+        return true;
+      if (name.startswith("2DArray<"))
+        return true;
+      if (name.startswith("3D<"))
+        return true;
+      if (name.startswith("Cube<"))
+        return true;
+      if (name.startswith("CubeArray<"))
+        return true;
+      if (name.startswith("2DMS<"))
+        return true;
+      if (name.startswith("2DMSArray<"))
+        return true;
+    }
+  }
+  return false;
+}
+
+bool IsHLSLObjectType(llvm::Type *Ty) {
+  if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
+    StringRef name = ST->getName();
+    // TODO: don't check names.
+    if (name.startswith("dx.types.wave_t"))
       return true;
-    if (name.startswith("TextureCube<"))
+
+    if (name.endswith("_slice_type"))
+      return false;
+
+    if (IsHLSLResourceType(Ty))
       return true;
-    if (name.startswith("TextureCubeArray<"))
+
+    name = name.ltrim("class.");
+    name = name.ltrim("struct.");
+
+    if (name.startswith("TriangleStream<"))
       return true;
-    if (name.startswith("Texture2DMS<"))
+    if (name.startswith("PointStream<"))
       return true;
-    if (name.startswith("Texture2DMSArray<"))
+    if (name.startswith("LineStream<"))
       return true;
   }
   return false;
@@ -477,6 +493,10 @@ bool IsHLSLMatrixType(Type *Ty) {
   return false;
 }
 
+bool IsIntegerOrFloatingPointType(llvm::Type *Ty) {
+  return Ty->isIntegerTy() || Ty->isFloatingPointTy();
+}
+
 bool ContainsHLSLObjectType(llvm::Type *Ty) {
   // Unwrap pointer/array
   while (llvm::isa<llvm::PointerType>(Ty))

+ 3 - 0
lib/HLSL/CMakeLists.txt

@@ -23,7 +23,10 @@ add_llvm_library(LLVMHLSL
   DxilExportMap.cpp
   DxilValidation.cpp
   DxcOptimizer.cpp
+  HLMatrixBitcastLowerPass.cpp
   HLMatrixLowerPass.cpp
+  HLMatrixSubscriptUseReplacer.cpp
+  HLMatrixType.cpp
   HLModule.cpp
   HLOperations.cpp
   HLOperationLower.cpp

+ 5 - 3
lib/HLSL/DxilCondenseResources.cpp

@@ -17,7 +17,7 @@
 #include "dxc/DXIL/DxilTypeSystem.h"
 #include "dxc/DXIL/DxilInstructions.h"
 #include "dxc/HLSL/DxilSpanAllocator.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/HLSL/HLModule.h"
 
@@ -1536,8 +1536,10 @@ Type *UpdateFieldTypeForLegacyLayout(Type *Ty, bool IsCBuf,
       return ArrayType::get(UpdatedTy, Ty->getArrayNumElements());
   } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     DXASSERT(annotation.HasMatrixAnnotation(), "must a matrix");
-    unsigned rows, cols;
-    Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, cols, rows);
+    HLMatrixType MatTy = HLMatrixType::cast(Ty);
+    unsigned rows = MatTy.getNumRows();
+    unsigned cols = MatTy.getNumColumns();
+    Type *EltTy = MatTy.getElementTypeForReg();
 
     // Get cols and rows from annotation.
     const DxilMatrixAnnotation &matrix = annotation.GetMatrixAnnotation();

+ 0 - 1
lib/HLSL/DxilGenerationPass.cpp

@@ -15,7 +15,6 @@
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLOperations.h"
 #include "dxc/DXIL/DxilInstructions.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/Support/Global.h"
 #include "dxc/DXIL/DxilTypeSystem.h"

+ 251 - 0
lib/HLSL/HLMatrixBitcastLowerPass.cpp

@@ -0,0 +1,251 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixBitcastLowerPass.cpp                                              //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/HLSL/HLMatrixLowerPass.h"
+#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
+#include "dxc/DXIL/DxilUtil.h"
+#include "dxc/Support/Global.h"
+#include "dxc/DXIL/DxilOperations.h"
+#include "dxc/DXIL/DxilModule.h"
+#include "dxc/HLSL/DxilGenerationPass.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+#include <unordered_set>
+#include <vector>
+
+using namespace llvm;
+using namespace hlsl;
+using namespace hlsl::HLMatrixLower;
+
+// Matrix Bitcast lower.
+// After linking Lower matrix bitcast patterns like:
+//  %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
+//  %conv.i = fptoui float %164 to i32
+//  %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
+//  %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
+
+namespace {
+
+// Translate matrix type to array type.
+Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
+  if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
+    Type *EltTy = MatTy.getElementTypeForReg();
+    return ArrayType::get(EltTy, MatTy.getNumElements());
+  }
+  else {
+    return Ty;
+  }
+}
+
+Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
+  unsigned addrSpace = Ty->getPointerAddressSpace();
+  Ty = Ty->getPointerElementType();
+
+  unsigned arraySize = 1;
+  while (Ty->isArrayTy()) {
+    arraySize *= Ty->getArrayNumElements();
+    Ty = Ty->getArrayElementType();
+  }
+
+  HLMatrixType MatTy = HLMatrixType::cast(Ty);
+  arraySize *= MatTy.getNumElements();
+
+  Ty = ArrayType::get(MatTy.getElementTypeForReg(), arraySize);
+  return PointerType::get(Ty, addrSpace);
+}
+
+Type *TryLowerMatTy(Type *Ty) {
+  Type *VecTy = nullptr;
+  if (HLMatrixType::isMatrixArrayPtr(Ty)) {
+    VecTy = LowerMatrixArrayPointerToOneDimArray(Ty);
+  } else if (isa<PointerType>(Ty) &&
+             dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
+    VecTy = LowerMatrixTypeToOneDimArray(
+        Ty->getPointerElementType());
+    VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
+  }
+  return VecTy;
+}
+
+class MatrixBitcastLowerPass : public FunctionPass {
+
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
+
+  const char *getPassName() const override { return "Matrix Bitcast lower"; }
+  bool runOnFunction(Function &F) override {
+    bool bUpdated = false;
+    std::unordered_set<BitCastInst*> matCastSet;
+    for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
+      BasicBlock *BB = blkIt;
+      for (auto iIt = BB->begin(); iIt != BB->end(); ) {
+        Instruction *I = (iIt++);
+        if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
+          // Mutate mat to vec.
+          Type *ToTy = BCI->getType();
+          if (TryLowerMatTy(ToTy)) {
+            matCastSet.insert(BCI);
+            bUpdated = true;
+          }
+        }
+      }
+    }
+
+    DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
+    // Remove bitcast which has CallInst user.
+    if (DM.GetShaderModel()->IsLib()) {
+      for (auto it = matCastSet.begin(); it != matCastSet.end();) {
+        BitCastInst *BCI = *(it++);
+        if (hasCallUser(BCI)) {
+          matCastSet.erase(BCI);
+        }
+      }
+    }
+
+    // Lower matrix first.
+    for (BitCastInst *BCI : matCastSet) {
+      lowerMatrix(BCI, BCI->getOperand(0));
+    }
+    return bUpdated;
+  }
+private:
+  void lowerMatrix(Instruction *M, Value *A);
+  bool hasCallUser(Instruction *M);
+};
+
+}
+
+bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
+  for (auto it = M->user_begin(); it != M->user_end();) {
+    User *U = *(it++);
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      Type *EltTy = GEP->getType()->getPointerElementType();
+      if (dxilutil::IsHLSLMatrixType(EltTy)) {
+        if (hasCallUser(GEP))
+          return true;
+      } else {
+        DXASSERT(0, "invalid GEP for matrix");
+      }
+    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
+      if (hasCallUser(BCI))
+        return true;
+    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
+      if (isa<VectorType>(LI->getType())) {
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
+      Value *V = ST->getValueOperand();
+      if (isa<VectorType>(V->getType())) {
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (isa<CallInst>(U)) {
+      return true;
+    } else {
+      DXASSERT(0, "invalid use of matrix");
+    }
+  }
+  return false;
+}
+
+namespace {
+Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
+                    IRBuilder<> &Builder) {
+  Value *GEP = nullptr;
+  if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
+    // A should be gep oneDimArray, 0, index * matSize
+    // Here add eltIdx to index * matSize foreach elt.
+    Instruction *EltGEP = GEPA->clone();
+    unsigned eltIdx = EltGEP->getNumOperands() - 1;
+    Value *NewIdx =
+        Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
+    EltGEP->setOperand(eltIdx, NewIdx);
+    Builder.Insert(EltGEP);
+    GEP = EltGEP;
+  } else {
+    GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
+  }
+  return GEP;
+}
+} // namespace
+
+void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
+  for (auto it = M->user_begin(); it != M->user_end();) {
+    User *U = *(it++);
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      Type *EltTy = GEP->getType()->getPointerElementType();
+      if (dxilutil::IsHLSLMatrixType(EltTy)) {
+        // Change gep matrixArray, 0, index
+        // into
+        //   gep oneDimArray, 0, index * matSize
+        IRBuilder<> Builder(GEP);
+        SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
+        DXASSERT(idxList.size() == 2,
+                 "else not one dim matrix array index to matrix");
+
+        HLMatrixType MatTy = HLMatrixType::cast(EltTy);
+        Value *matSize = Builder.getInt32(MatTy.getNumElements());
+        idxList.back() = Builder.CreateMul(idxList.back(), matSize);
+        Value *NewGEP = Builder.CreateGEP(A, idxList);
+        lowerMatrix(GEP, NewGEP);
+        DXASSERT(GEP->user_empty(), "else lower matrix fail");
+        GEP->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid GEP for matrix");
+      }
+    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
+      lowerMatrix(BCI, A);
+      DXASSERT(BCI->user_empty(), "else lower matrix fail");
+      BCI->eraseFromParent();
+    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
+      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
+        IRBuilder<> Builder(LI);
+        Value *zeroIdx = Builder.getInt32(0);
+        unsigned vecSize = Ty->getNumElements();
+        Value *NewVec = UndefValue::get(LI->getType());
+        for (unsigned i = 0; i < vecSize; i++) {
+          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
+          Value *Elt = Builder.CreateLoad(GEP);
+          NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
+        }
+        LI->replaceAllUsesWith(NewVec);
+        LI->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
+      Value *V = ST->getValueOperand();
+      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
+        IRBuilder<> Builder(LI);
+        Value *zeroIdx = Builder.getInt32(0);
+        unsigned vecSize = Ty->getNumElements();
+        for (unsigned i = 0; i < vecSize; i++) {
+          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
+          Value *Elt = Builder.CreateExtractElement(V, i);
+          Builder.CreateStore(Elt, GEP);
+        }
+        ST->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else {
+      DXASSERT(0, "invalid use of matrix");
+    }
+  }
+}
+
+char MatrixBitcastLowerPass::ID = 0;
+FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
+
+INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)

+ 1254 - 2558
lib/HLSL/HLMatrixLowerPass.cpp

@@ -9,8 +9,9 @@
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
 #include "dxc/HLSL/HLMatrixLowerPass.h"
+#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HLSL/HLOperations.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/DXIL/DxilUtil.h"
@@ -19,6 +20,7 @@
 #include "dxc/DXIL/DxilOperations.h"
 #include "dxc/DXIL/DxilTypeSystem.h"
 #include "dxc/DXIL/DxilModule.h"
+#include "HLMatrixSubscriptUseReplacer.h"
 
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
@@ -36,356 +38,146 @@ using namespace hlsl::HLMatrixLower;
 namespace hlsl {
 namespace HLMatrixLower {
 
-// If user is function call, return param annotation to get matrix major.
-DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
-                                               DxilTypeSystem &typeSys) {
-  for (User *U : Mat->users()) {
-    if (CallInst *CI = dyn_cast<CallInst>(U)) {
-      Function *F = CI->getCalledFunction();
-      if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
-        for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
-          if (CI->getArgOperand(i) == Mat) {
-            return &Anno->GetParameterAnnotation(i);
-          }
-        }
-      }
-    }
-  }
-  return nullptr;
-}
-
-// Translate matrix type to vector type.
-Type *LowerMatrixType(Type *Ty, bool forMem) {
-  // Only translate matrix type and function type which use matrix type.
-  // Not translate struct has matrix or matrix pointer.
-  // Struct should be flattened before.
-  // Pointer could cover by matldst which use vector as value type.
-  if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
-    Type *RetTy = LowerMatrixType(FT->getReturnType());
-    SmallVector<Type *, 4> params;
-    for (Type *param : FT->params()) {
-      params.emplace_back(LowerMatrixType(param));
-    }
-    return FunctionType::get(RetTy, params, false);
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
-    unsigned row, col;
-    Type *EltTy = GetMatrixInfo(Ty, col, row);
-    if (forMem && EltTy->isIntegerTy(1))
-      EltTy = Type::getInt32Ty(Ty->getContext());
-    return VectorType::get(EltTy, row * col);
-  } else {
-    return Ty;
-  }
-}
-
-// Translate matrix type to array type.
-Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
-  if (dxilutil::IsHLSLMatrixType(Ty)) {
-    unsigned row, col;
-    Type *EltTy = GetMatrixInfo(Ty, col, row);
-    return ArrayType::get(EltTy, row * col);
-  } else {
-    return Ty;
-  }
-}
-
-
-Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
-  DXASSERT(dxilutil::IsHLSLMatrixType(Ty), "not matrix type");
-  StructType *ST = cast<StructType>(Ty);
-  Type *EltTy = ST->getElementType(0);
-  Type *RowTy = EltTy->getArrayElementType();
-  row = EltTy->getArrayNumElements();
-  col = RowTy->getVectorNumElements();
-  return RowTy->getVectorElementType();
-}
-
-bool IsMatrixArrayPointer(llvm::Type *Ty) {
-  if (!Ty->isPointerTy())
-    return false;
-  Ty = Ty->getPointerElementType();
-  if (!Ty->isArrayTy())
-    return false;
-  while (Ty->isArrayTy())
-    Ty = Ty->getArrayElementType();
-  return dxilutil::IsHLSLMatrixType(Ty);
-}
-Type *LowerMatrixArrayPointer(Type *Ty, bool forMem) {
-  unsigned addrSpace = Ty->getPointerAddressSpace();
-  Ty = Ty->getPointerElementType();
-  std::vector<unsigned> arraySizeList;
-  while (Ty->isArrayTy()) {
-    arraySizeList.push_back(Ty->getArrayNumElements());
-    Ty = Ty->getArrayElementType();
-  }
-  Ty = LowerMatrixType(Ty, forMem);
-
-  for (auto arraySize = arraySizeList.rbegin();
-       arraySize != arraySizeList.rend(); arraySize++)
-    Ty = ArrayType::get(Ty, *arraySize);
-  return PointerType::get(Ty, addrSpace);
-}
-
-Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
-  unsigned addrSpace = Ty->getPointerAddressSpace();
-  Ty = Ty->getPointerElementType();
-
-  unsigned arraySize = 1;
-  while (Ty->isArrayTy()) {
-    arraySize *= Ty->getArrayNumElements();
-    Ty = Ty->getArrayElementType();
-  }
-  unsigned row, col;
-  Type *EltTy = GetMatrixInfo(Ty, col, row);
-  arraySize *= row*col;
-
-  Ty = ArrayType::get(EltTy, arraySize);
-  return PointerType::get(Ty, addrSpace);
-}
-Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
-  IRBuilder<> &Builder) {
-  Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
-  for (unsigned i = 0; i < size; i++)
+Value *BuildVector(Type *EltTy, ArrayRef<llvm::Value *> elts, IRBuilder<> &Builder) {
+  Value *Vec = UndefValue::get(VectorType::get(EltTy, static_cast<unsigned>(elts.size())));
+  for (unsigned i = 0; i < elts.size(); i++)
     Vec = Builder.CreateInsertElement(Vec, elts[i], i);
   return Vec;
 }
 
-llvm::Value *VecMatrixMemToReg(llvm::Value *VecVal, llvm::Type *MatType,
-  llvm::IRBuilder<> &Builder)
-{
-  llvm::Type *VecMatRegTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/false);
-  if (VecVal->getType() == VecMatRegTy) {
-    return VecVal;
-  }
-
-  DXASSERT(VecMatRegTy->getVectorElementType()->isIntegerTy(1),
-    "Vector matrix mem to reg type mismatch should only happen for bools.");
-  llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
-  return Builder.CreateICmpNE(VecVal, Constant::getNullValue(VecMatMemTy));
-}
-
-llvm::Value *VecMatrixRegToMem(llvm::Value* VecVal, llvm::Type *MatType,
-  llvm::IRBuilder<> &Builder)
-{
-  llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
-  if (VecVal->getType() == VecMatMemTy) {
-    return VecVal;
-  }
-
-  DXASSERT(VecVal->getType()->getVectorElementType()->isIntegerTy(1),
-    "Vector matrix reg to mem type mismatch should only happen for bools.");
-  return Builder.CreateZExt(VecVal, VecMatMemTy);
-}
-
-llvm::Instruction *CreateVecMatrixLoad(
-  llvm::Value *VecPtr, llvm::Type *MatType, llvm::IRBuilder<> &Builder)
-{
-  llvm::Instruction *VecVal = Builder.CreateLoad(VecPtr);
-  return cast<llvm::Instruction>(VecMatrixMemToReg(VecVal, MatType, Builder));
-}
-
-llvm::Instruction *CreateVecMatrixStore(llvm::Value* VecVal, llvm::Value *VecPtr,
-  llvm::Type *MatType, llvm::IRBuilder<> &Builder)
-{
-  llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
-  if (VecVal->getType() == VecMatMemTy) {
-    return Builder.CreateStore(VecVal, VecPtr);
-  }
-
-  // We need to convert to the memory representation, and we want to return
-  // the conversion instruction rather than the store since that's what
-  // accepts the register-typed i1 values.
-
-  // Do not use VecMatrixRegToMem as it may constant fold the conversion
-  // instruction, which is what we want to return.
-  DXASSERT(VecVal->getType()->getVectorElementType()->isIntegerTy(1),
-    "Vector matrix reg to mem type mismatch should only happen for bools.");
-
-  llvm::Instruction *ConvInst = Builder.Insert(new ZExtInst(VecVal, VecMatMemTy));
-  Builder.CreateStore(ConvInst, VecPtr);
-  return ConvInst;
-}
-
-Value *LowerGEPOnMatIndexListToIndex(
-    llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
-  IRBuilder<> Builder(GEP);
-  Value *zero = Builder.getInt32(0);
-  DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
-  Value *baseIdx = (GEP->idx_begin())->get();
-  DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
-  Value *Idx = (GEP->idx_begin() + 1)->get();
-
-  if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
-    return IdxList[immIdx->getSExtValue()];
-  } else {
-    IRBuilder<> AllocaBuilder(
-        GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
-    unsigned size = IdxList.size();
-    // Store idxList to temp array.
-    ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
-    Value *tempArray = AllocaBuilder.CreateAlloca(AT);
-
-    for (unsigned i = 0; i < size; i++) {
-      Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
-      Builder.CreateStore(IdxList[i], EltPtr);
-    }
-    // Load the idx.
-    Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
-    return Builder.CreateLoad(GEPOffset);
-  }
-}
-
-
-unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
-  return c * row + r;
-}
-unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
-  return r * col + c;
-}
-
 } // namespace HLMatrixLower
 } // namespace hlsl
 
 namespace {
 
-class HLMatrixLowerPass : public ModulePass {
+// Creates and manages a set of temporary overloaded functions keyed on the function type,
+// and which should be destroyed when the pool gets out of scope.
+class TempOverloadPool {
+public:
+  TempOverloadPool(llvm::Module &Module, const char* BaseName)
+    : Module(Module), BaseName(BaseName) {}
+  ~TempOverloadPool() { clear(); }
+
+  Function *get(FunctionType *Ty);
+  bool contains(FunctionType *Ty) const { return Funcs.count(Ty) != 0; }
+  bool contains(Function *Func) const;
+  void clear();
 
+private:
+  llvm::Module &Module;
+  const char* BaseName;
+  llvm::DenseMap<FunctionType*, Function*> Funcs;
+};
+
+Function *TempOverloadPool::get(FunctionType *Ty) {
+  auto It = Funcs.find(Ty);
+  if (It != Funcs.end()) return It->second;
+
+  std::string MangledName;
+  raw_string_ostream MangledNameStream(MangledName);
+  MangledNameStream << BaseName;
+  MangledNameStream << '.';
+  Ty->print(MangledNameStream);
+  MangledNameStream.flush();
+
+  Function* Func = cast<Function>(Module.getOrInsertFunction(MangledName, Ty));
+  Funcs.insert(std::make_pair(Ty, Func));
+  return Func;
+}
+
+bool TempOverloadPool::contains(Function *Func) const {
+  auto It = Funcs.find(Func->getFunctionType());
+  return It != Funcs.end() && It->second == Func;
+}
+
+void TempOverloadPool::clear() {
+  for (auto Entry : Funcs) {
+    DXASSERT(Entry.second->use_empty(), "Temporary function still used during pool destruction.");
+    Entry.second->removeFromParent();
+  }
+  Funcs.clear();
+}
+
+// High-level matrix lowering pass.
+//
+// This pass converts matrices to their lowered vector representations,
+// including global variables, local variables and operations,
+// but not function signatures (arguments and return types) - left to HLSignatureLower and HLMatrixBitcastLower,
+// nor matrices obtained from resources or constant - left to HLOperationLower.
+//
+// Algorithm overview:
+// 1. Find all matrix and matrix array global variables and lower them to vectors.
+//    Walk any GEPs and insert vec-to-mat translation stubs so that consuming
+//    instructions keep dealing with matrix types for the moment.
+// 2. For each function
+// 2a. Lower all matrix and matrix array allocas, just like global variables.
+// 2b. Lower all other instructions producing or consuming matrices
+//
+// Conversion stubs are used to allow converting instructions in isolation,
+// and in an order-independent manner:
+//
+// Initial: MatInst1(MatInst2(MatInst3))
+// After lowering MatInst2: MatInst1(VecToMat(VecInst2(MatToVec(MatInst3))))
+// After lowering MatInst1: VecInst1(VecInst2(MatToVec(MatInst3)))
+// After lowering MatInst3: VecInst1(VecInst2(VecInst3))
+class HLMatrixLowerPass : public ModulePass {
 public:
   static char ID; // Pass identification, replacement for typeid
   explicit HLMatrixLowerPass() : ModulePass(ID) {}
 
   const char *getPassName() const override { return "HL matrix lower"; }
+  bool runOnModule(Module &M) override;
 
-  bool runOnModule(Module &M) override {
-    m_pModule = &M;
-    m_pHLModule = &m_pModule->GetOrCreateHLModule();
-    // Load up debug information, to cross-reference values and the instructions
-    // used to load them.
-    m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
-
-    for (Function &F : M.functions()) {
-
-      if (F.isDeclaration())
-        continue;
-      runOnFunction(F);
-    }
-    std::vector<GlobalVariable*> staticGVs;
-    for (GlobalVariable &GV : M.globals()) {
-      if (dxilutil::IsStaticGlobal(&GV) ||
-          dxilutil::IsSharedMemoryGlobal(&GV)) {
-        staticGVs.emplace_back(&GV);
-      }
-    }
-
-    for (GlobalVariable *GV : staticGVs)
-      runOnGlobal(GV);
-
-    return true;
-  }
+private:
+  void runOnFunction(Function &Func);
+  void addToDeadInsts(Instruction *Inst) { m_deadInsts.emplace_back(Inst); }
+  void deleteDeadInsts();
+
+  void getMatrixAllocasAndOtherInsts(Function &Func,
+    std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts);
+  Value *getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub = false);
+  Value *tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub = false);
+  Value *bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder);
+  void replaceAllUsesByLoweredValue(Instruction *MatInst, Value *VecVal);
+  void replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr);
+  void replaceAllVariableUses(SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr);
+
+  void lowerGlobal(GlobalVariable *Global);
+  Constant *lowerConstInitVal(Constant *Val);
+  AllocaInst *lowerAlloca(AllocaInst *MatAlloca);
+  void lowerInstruction(Instruction* Inst);
+  void lowerReturn(ReturnInst* Return);
+  Value *lowerCall(CallInst *Call);
+  Value *lowerNonHLCall(CallInst *Call);
+  Value *lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup);
+  Value *lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode);
+  Value *lowerHLMulIntrinsic(Value* Lhs, Value *Rhs, bool Unsigned, IRBuilder<> &Builder);
+  Value *lowerHLTransposeIntrinsic(Value *MatVal, IRBuilder<> &Builder);
+  Value *lowerHLDeterminantIntrinsic(Value *MatVal, IRBuilder<> &Builder);
+  Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder);
+  Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder);
+  Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
+  Value *lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
+  Value *lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
+  Value *lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
+  Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
+  Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
+  Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
+  void lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices);
+  Value *lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
+  Value *lowerHLInit(CallInst *Call);
+  Value *lowerHLSelect(CallInst *Call);
 
 private:
   Module *m_pModule;
   HLModule *m_pHLModule;
   bool m_HasDbgInfo;
+
+  // Pools for the translation stubs
+  TempOverloadPool *m_matToVecStubs = nullptr;
+  TempOverloadPool *m_vecToMatStubs = nullptr;
+
   std::vector<Instruction *> m_deadInsts;
-  // For instruction like matrix array init.
-  // May use more than 1 matrix alloca inst.
-  // This set is here to avoid put it into deadInsts more than once.
-  std::unordered_set<Instruction *> m_inDeadInstsSet;
-  // For most matrix insturction users, it will only have one matrix use.
-  // Use vector so save deadInsts because vector is cheap.
-  void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
-  // In case instruction has more than one matrix use.
-  // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
-  void AddToDeadInstsWithDups(Instruction *I) {
-    if (m_inDeadInstsSet.count(I) == 0) {
-      // Only add to deadInsts when it's not inside m_inDeadInstsSet.
-      m_inDeadInstsSet.insert(I);
-      AddToDeadInsts(I);
-    }
-  }
-  void runOnFunction(Function &F);
-  void runOnGlobal(GlobalVariable *GV);
-  void runOnGlobalMatrixArray(GlobalVariable *GV);
-  Instruction *MatCastToVec(CallInst *CI);
-  Instruction *MatLdStToVec(CallInst *CI);
-  Instruction *MatSubscriptToVec(CallInst *CI);
-  Instruction *MatFrExpToVec(CallInst *CI);
-  Instruction *MatIntrinsicToVec(CallInst *CI);
-  Instruction *TrivialMatUnOpToVec(CallInst *CI);
-  // Replace matVal with vecVal on matUseInst.
-  void TrivialMatUnOpReplace(Value *matVal, Value *vecVal,
-                            CallInst *matUseInst);
-  Instruction *TrivialMatBinOpToVec(CallInst *CI);
-  // Replace matVal with vecVal on matUseInst.
-  void TrivialMatBinOpReplace(Value *matVal, Value *vecVal,
-                             CallInst *matUseInst);
-  // Replace matVal with vecVal on mulInst.
-  void TranslateMatMatMul(Value *matVal, Value *vecVal,
-                          CallInst *mulInst, bool isSigned);
-  void TranslateMatVecMul(Value *matVal, Value *vecVal,
-                          CallInst *mulInst, bool isSigned);
-  void TranslateVecMatMul(Value *matVal, Value *vecVal,
-                          CallInst *mulInst, bool isSigned);
-  void TranslateMul(Value *matVal, Value *vecVal, CallInst *mulInst,
-                    bool isSigned);
-  // Replace matVal with vecVal on transposeInst.
-  void TranslateMatTranspose(Value *matVal, Value *vecVal,
-                             CallInst *transposeInst);
-  void TranslateMatDeterminant(Value *matVal, Value *vecVal,
-                             CallInst *determinantInst);
-  void MatIntrinsicReplace(Value *matVal, Value *vecVal,
-                           CallInst *matUseInst);
-  // Replace matVal with vecVal on castInst.
-  void TranslateMatMatCast(Value *matVal, Value *vecVal,
-                           CallInst *castInst);
-  void TranslateMatToOtherCast(Value *matVal, Value *vecVal,
-                               CallInst *castInst);
-  void TranslateMatCast(Value *matVal, Value *vecVal,
-                        CallInst *castInst);
-  void TranslateMatMajorCast(Value *matVal, Value *vecVal,
-                        CallInst *castInst, bool rowToCol, bool transpose);
-  // Replace matVal with vecVal in matSubscript
-  void TranslateMatSubscript(Value *matVal, Value *vecVal,
-                             CallInst *matSubInst);
-  // Replace matInitInst using matToVecMap
-  void TranslateMatInit(CallInst *matInitInst);
-  // Replace matSelectInst using matToVecMap
-  void TranslateMatSelect(CallInst *matSelectInst);
-  // Replace matVal with vecVal on matInitInst.
-  void TranslateMatArrayGEP(Value *matVal, Value *vecVal,
-                            GetElementPtrInst *matGEP);
-  void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
-                             CallInst *matLdStInst);
-  void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
-                             CallInst *matLdStInst);
-  void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
-  void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
-
-  // Get new matrix value corresponding to vecVal
-  Value *GetMatrixForVec(Value *vecVal, Type *matTy);
-
-  // Translate library function input/output to preserve function signatures
-  void TranslateArgForLibFunc(CallInst *CI);
-  void TranslateArgsForLibFunc(Function &F);
-
-  // Replace matVal with vecVal on matUseInst.
-  void TrivialMatReplace(Value *matVal, Value *vecVal,
-                        CallInst *matUseInst);
-  // Lower a matrix type instruction to a vector type instruction.
-  void lowerToVec(Instruction *matInst);
-  // Lower users of a matrix type instruction.
-  void replaceMatWithVec(Value *matVal, Value *vecVal);
-  // Translate user library function call arguments
-  void castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI);
-  // Translate mat inst which need all operands ready.
-  void finalMatTranslation(Value *matVal);
-  // Delete dead insts in m_deadInsts.
-  void DeleteDeadInsts();
-  // Map from matrix value to its vector version.
-  DenseMap<Value *, Value *> matToVecMap;
-  // Map from new vector version to matrix version needed by user call or return.
-  DenseMap<Value *, Value *> vecToMatMap;
 };
 }
 
@@ -395,2445 +187,1349 @@ ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
 
 INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
 
-static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
-                                   IRBuilder<> Builder) {
-  Type *srcTy = src->getType();
-
-  // Conversions between equivalent types are no-ops,
-  // even between signed/unsigned variants.
-  if (srcTy == toTy) return cast<Instruction>(src);
+bool HLMatrixLowerPass::runOnModule(Module &M) {
+  TempOverloadPool matToVecStubs(M, "hlmatrixlower.mat2vec");
+  TempOverloadPool vecToMatStubs(M, "hlmatrixlower.vec2mat");
+
+  m_pModule = &M;
+  m_pHLModule = &m_pModule->GetOrCreateHLModule();
+  // Load up debug information, to cross-reference values and the instructions
+  // used to load them.
+  m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
+  m_matToVecStubs = &matToVecStubs;
+  m_vecToMatStubs = &vecToMatStubs;
+
+  // First, lower static global variables.
+  // We need to accumulate them locally because we'll be creating new ones as we lower them.
+  std::vector<GlobalVariable*> Globals;
+  for (GlobalVariable &Global : M.globals()) {
+    if ((dxilutil::IsStaticGlobal(&Global) || dxilutil::IsSharedMemoryGlobal(&Global))
+      && HLMatrixType::isMatrixPtrOrArrayPtr(Global.getType())) {
+      Globals.emplace_back(&Global);
+    }
+  }
 
-  bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
-                      castOp == HLCastOpcode::UnsignedUnsignedCast;
-  bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
-                    castOp == HLCastOpcode::UnsignedUnsignedCast;
+  for (GlobalVariable *Global : Globals)
+    lowerGlobal(Global);
 
-  // Conversions to bools are comparisons
-  if (toTy->getScalarSizeInBits() == 1) {
-    // fcmp une is what regular clang uses in C++ for (bool)f;
-    return cast<Instruction>(srcTy->isIntOrIntVectorTy()
-      ? Builder.CreateICmpNE(src, llvm::Constant::getNullValue(srcTy), "tobool")
-      : Builder.CreateFCmpUNE(src, llvm::Constant::getNullValue(srcTy), "tobool"));
+  for (Function &F : M.functions()) {
+    if (F.isDeclaration()) continue;
+    runOnFunction(F);
   }
 
-  // Cast necessary
-  auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
-    srcTy, fromUnsigned, toTy, toUnsigned));
-  return cast<Instruction>(Builder.CreateCast(CastOp, src, toTy));
+  m_pModule = nullptr;
+  m_pHLModule = nullptr;
+  m_matToVecStubs = nullptr;
+  m_vecToMatStubs = nullptr;
+
+  // If you hit an assert during TempOverloadPool destruction,
+  // it means that either a matrix producer was lowered,
+  // causing a translation stub to be created,
+  // but the consumer of that matrix was never (properly) lowered.
+  // Or the opposite: a matrix consumer was lowered and not its producer.
+
+  return true;
 }
 
-Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
-  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
-
-  bool ToMat = dxilutil::IsHLSLMatrixType(CI->getType());
-  bool FromMat = dxilutil::IsHLSLMatrixType(op->getType());
-  if (ToMat && !FromMat) {
-    // Translate OtherToMat here.
-    // Rest will translated when replace.
-    unsigned col, row;
-    Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
-    unsigned toSize = col * row;
-    Instruction *sizeCast = nullptr;
-    Type *FromTy = op->getType();
-    Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
-    if (FromTy->isVectorTy()) {
-      std::vector<Constant *> MaskVec(toSize);
-      for (size_t i = 0; i != toSize; ++i)
-        MaskVec[i] = ConstantInt::get(I32Ty, i);
-
-      Value *castMask = ConstantVector::get(MaskVec);
-
-      sizeCast = new ShuffleVectorInst(op, op, castMask);
-      Builder.Insert(sizeCast);
-
-    } else {
-      op = Builder.CreateInsertElement(
-          UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
-      Constant *zero = ConstantInt::get(I32Ty, 0);
-      std::vector<Constant *> MaskVec(toSize, zero);
-      Value *castMask = ConstantVector::get(MaskVec);
-
-      sizeCast = new ShuffleVectorInst(op, op, castMask);
-      Builder.Insert(sizeCast);
-    }
-    Instruction *typeCast = sizeCast;
-    if (EltTy != FromTy->getScalarType()) {
-      typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
-                                sizeCast, Builder);
-    }
-    return typeCast;
-  } else if (FromMat && ToMat) {
-    if (isa<Argument>(op)) {
-      // Cast From mat to mat for arugment.
-      IRBuilder<> Builder(CI);
-
-      // Here only lower the return type to vector.
-      Type *RetTy = LowerMatrixType(CI->getType());
-      SmallVector<Type *, 4> params;
-      for (Value *operand : CI->arg_operands()) {
-        params.emplace_back(operand->getType());
+void HLMatrixLowerPass::runOnFunction(Function &Func) {
+  // Skip hl function definition (like createhandle)
+  if (hlsl::GetHLOpcodeGroupByName(&Func) != HLOpcodeGroup::NotHL)
+    return;
+
+  // Save the matrix instructions first since the translation process
+  // will temporarily create other instructions consuming/producing matrix types.
+  std::vector<AllocaInst*> MatAllocas;
+  std::vector<Instruction*> MatInsts;
+  getMatrixAllocasAndOtherInsts(Func, MatAllocas, MatInsts);
+
+  // First lower all allocas and take care of their GEP chains
+  for (AllocaInst* MatAlloca : MatAllocas) {
+    AllocaInst* LoweredAlloca = lowerAlloca(MatAlloca);
+    replaceAllVariableUses(MatAlloca, LoweredAlloca);
+    addToDeadInsts(MatAlloca);
+  }
+
+  // Now lower all other matrix instructions
+  for (Instruction *MatInst : MatInsts)
+    lowerInstruction(MatInst);
+
+  deleteDeadInsts();
+}
+
+void HLMatrixLowerPass::deleteDeadInsts() {
+  while (!m_deadInsts.empty()) {
+    Instruction *Inst = m_deadInsts.back();
+    m_deadInsts.pop_back();
+
+    DXASSERT_NOMSG(Inst->use_empty());
+    for (Value *Operand : Inst->operand_values()) {
+      Instruction *OperandInst = dyn_cast<Instruction>(Operand);
+      if (OperandInst && ++OperandInst->user_begin() == OperandInst->user_end()) {
+        // We were its only user, erase recursively.
+        // This will get rid of translation stubs:
+        // Original: MatConsumer(MatProducer)
+        // Producer lowered: MatConsumer(VecToMat(VecProducer)), MatProducer dead
+        // Consumer lowered: VecConsumer(VecProducer)), MatConsumer(VecToMat) dead
+        // Only by recursing on MatConsumer's operand do we delete the VecToMat stub.
+        DXASSERT_NOMSG(*OperandInst->user_begin() == Inst);
+        m_deadInsts.emplace_back(OperandInst);
       }
+    }
+
+    Inst->eraseFromParent();
+  }
+}
 
-      Type *FT = FunctionType::get(RetTy, params, false);
+// Find all instructions consuming or producing matrices,
+// directly or through pointers/arrays.
+void HLMatrixLowerPass::getMatrixAllocasAndOtherInsts(Function &Func,
+    std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts){
+  for (BasicBlock &BasicBlock : Func) {
+    for (Instruction &Inst : BasicBlock) {
+      // Don't lower GEPs directly, we'll handle them as we lower the root pointer,
+      // typically a global variable or alloca.
+      if (isa<GetElementPtrInst>(&Inst)) continue;
 
-      HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-      unsigned opcode = GetHLOpcode(CI);
+      if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&Inst)) {
+        if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Alloca->getType())) {
+          MatAllocas.emplace_back(Alloca);
+        }
+        continue;
+      }
+      
+      if (CallInst *Call = dyn_cast<CallInst>(&Inst)) {
+        // Lowering of global variables will have introduced
+        // vec-to-mat translation stubs, which we deal with indirectly,
+        // as we lower the instructions consuming them.
+        if (m_vecToMatStubs->contains(Call->getCalledFunction()))
+          continue;
+
+        // Mat-to-vec stubs should only be introduced during instruction lowering.
+        // Globals lowering won't introduce any because their only operand is
+        // their initializer, which we can fully lower without stubbing since it is constant.
+        DXASSERT(!m_matToVecStubs->contains(Call->getCalledFunction()),
+          "Unexpected mat-to-vec stubbing before function instruction lowering.");
+
+        // Match matrix producers
+        if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Inst.getType())) {
+          MatInsts.emplace_back(Call);
+          continue;
+        }
 
-      Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
-                                             group, opcode);
+        // Match matrix consumers
+        for (Value *Operand : Inst.operand_values()) {
+          if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Operand->getType())) {
+            MatInsts.emplace_back(Call);
+            break;
+          }
+        }
+
+        continue;
+      }
 
-      SmallVector<Value *, 4> argList;
-      for (Value *arg : CI->arg_operands()) {
-        argList.emplace_back(arg);
+      if (ReturnInst *Return = dyn_cast<ReturnInst>(&Inst)) {
+        Value *ReturnValue = Return->getReturnValue();
+        if (ReturnValue != nullptr && HLMatrixType::isMatrixOrPtrOrArrayPtr(ReturnValue->getType()))
+          MatInsts.emplace_back(Return);
+        continue;
       }
 
-      return Builder.CreateCall(vecF, argList);
+      // Nothing else should produce or consume matrices
     }
   }
-
-  return MatIntrinsicToVec(CI);
 }
 
-// Return GEP if value is Matrix resulting GEP from UDT alloca
-// UDT alloca must be there for library function args
-static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
-  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
-    if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
-      Value *ptr = GEP->getPointerOperand();
-      if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
-        Type *ATy = AI->getAllocatedType();
-        if (ATy->isStructTy() && !dxilutil::IsHLSLMatrixType(ATy)) {
-          return GEP;
-        }
+// Gets the matrix-lowered representation of a value, potentially adding a translation stub.
+// DiscardStub causes any vec-to-mat translation stubs to be deleted,
+// it should be true only if the original instruction will be modified and kept alive.
+// If a new instruction is created and the original marked as dead,
+// then the remove dead instructions pass will take care of removing the stub.
+Value* HLMatrixLowerPass::getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub) {
+  Type *Ty = Val->getType();
+
+  // We're only lowering byval matrices.
+  // Since structs and arrays are always accessed by pointer,
+  // we do not need to worry about a matrix being hidden inside a more complex type.
+  DXASSERT(!Ty->isPointerTy(), "Value cannot be a pointer.");
+  HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty);
+  if (!MatTy) return Val;
+
+  Type *LoweredTy = MatTy.getLoweredVectorTypeForReg();
+  
+  // Check if the value is already a vec-to-mat translation stub
+  if (CallInst *Call = dyn_cast<CallInst>(Val)) {
+    if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
+      if (DiscardStub && Call->getNumUses() == 1) {
+        Call->use_begin()->set(UndefValue::get(Call->getType()));
+        addToDeadInsts(Call);
       }
+
+      Value *LoweredVal = Call->getArgOperand(0);
+      DXASSERT(LoweredVal->getType() == LoweredTy, "Unexpected already-lowered value type.");
+      return LoweredVal;
     }
   }
-  return nullptr;
+
+  // Return a mat-to-vec translation stub
+  FunctionType *TranslationStubTy = FunctionType::get(LoweredTy, { Ty }, /* isVarArg */ false);
+  Function *TranslationStub = m_matToVecStubs->get(TranslationStubTy);
+  return Builder.CreateCall(TranslationStub, { Val });
 }
 
-// Return GEP if value is Matrix resulting GEP from UDT argument of
-// none-graphics functions.
-static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
-  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
-    if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
-      Value *ptr = GEP->getPointerOperand();
-      if (Argument *Arg = dyn_cast<Argument>(ptr)) {
-        if (!HM.IsGraphicsShader(Arg->getParent()))
-          return GEP;
+// Attempts to retrieve the lowered vector pointer equivalent to a matrix pointer.
+// Returns nullptr if the pointed-to matrix lives in memory that cannot be lowered at this time,
+// for example a buffer or shader inputs/outputs, which are lowered during signature lowering.
+Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub) {
+  if (!HLMatrixType::isMatrixPtrOrArrayPtr(Ptr->getType()))
+    return nullptr;
+
+  // Matrix pointers can only be derived from Allocas, GlobalVariables or resource accesses.
+  // The first two cases are what this pass must be able to lower, and we should already
+  // have replaced their uses by vector to matrix pointer translation stubs.
+  if (CallInst *Call = dyn_cast<CallInst>(Ptr)) {
+    if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
+      if (DiscardStub && Call->getNumUses() == 1) {
+        Call->use_begin()->set(UndefValue::get(Call->getType()));
+        addToDeadInsts(Call);
       }
+      return Call->getArgOperand(0);
     }
   }
-  return nullptr;
-}
 
-Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  unsigned opcode = GetHLOpcode(CI);
-  HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
-  Instruction *result = nullptr;
-  switch (matOpcode) {
-  case HLMatLoadStoreOpcode::ColMatLoad:
-  case HLMatLoadStoreOpcode::RowMatLoad: {
-    Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
-    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
-        GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
-      Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
-      result = CreateVecMatrixLoad(vecPtr, matPtr->getType()->getPointerElementType(), Builder);
-    } else
-      result = MatIntrinsicToVec(CI);
-  } break;
-  case HLMatLoadStoreOpcode::ColMatStore:
-  case HLMatLoadStoreOpcode::RowMatStore: {
-    Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
-    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
-        GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
-      Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
-      Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-      Value *vecVal = UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
-      result = CreateVecMatrixStore(vecVal, vecPtr, matVal->getType(), Builder);
-    } else
-      result = MatIntrinsicToVec(CI);
-  } break;
+  // There's one more case to handle.
+  // When compiling shader libraries, signatures won't have been lowered yet.
+  // So we can have a matrix in a struct as an argument,
+  // or an alloca'd struct holding the return value of a call and containing a matrix.
+  Value *RootPtr = Ptr;
+  while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
+    RootPtr = GEP->getPointerOperand();
+
+  Argument *Arg = dyn_cast<Argument>(RootPtr);
+  bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsGraphicsShader(Arg->getParent());
+  if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
+    // Bitcast the matrix pointer to its lowered equivalent.
+    // The HLMatrixBitcast pass will take care of this later.
+    return Builder.CreateBitCast(Ptr, HLMatrixType::getLoweredType(Ptr->getType()));
   }
-  return result;
+
+  // The pointer must be derived from a resource, we don't handle it in this pass.
+  return nullptr;
 }
 
-Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-  if (isa<AllocaInst>(matPtr)) {
-    // Here just create a new matSub call which use vec ptr.
-    // Later in TranslateMatSubscript will do the real translation.
-    std::vector<Value *> args(CI->getNumArgOperands());
-    for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
-      args[i] = CI->getArgOperand(i);
-    }
-    // Change mat ptr into vec ptr.
-    args[HLOperandIndex::kMatSubscriptMatOpIdx] =
-        matToVecMap[cast<Instruction>(matPtr)];
-    std::vector<Type *> paramTyList(CI->getNumArgOperands());
-    for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
-      paramTyList[i] = args[i]->getType();
-    }
+// Bitcasts a value from matrix to vector or vice-versa.
+// This is used to convert to/from arguments/return values since we don't
+// lower signatures in this pass. The later HLMatrixBitcastLower pass fixes this.
+Value *HLMatrixLowerPass::bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder) {
+  Type *SrcTy = SrcVal->getType();
+  DXASSERT_NOMSG(!SrcTy->isPointerTy());
 
-    FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
-    unsigned opcode = GetHLOpcode(CI);
-    Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
-    return Builder.CreateCall(opFunc, args);
-  } else
-    return MatIntrinsicToVec(CI);
+  // We store and load from a temporary alloca, bitcasting either on the store pointer
+  // or on the load pointer.
+  IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
+  Value *Alloca = AllocaBuilder.CreateAlloca(DstTyAlloca ? DstTy : SrcTy);
+  Value *BitCastedAlloca = Builder.CreateBitCast(Alloca, (DstTyAlloca ? SrcTy : DstTy)->getPointerTo());
+  Builder.CreateStore(SrcVal, DstTyAlloca ? BitCastedAlloca : Alloca);
+  return Builder.CreateLoad(DstTyAlloca ? Alloca : BitCastedAlloca);
 }
 
-Instruction *HLMatrixLowerPass::MatFrExpToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  FunctionType *FT = CI->getCalledFunction()->getFunctionType();
-  Type *RetTy = LowerMatrixType(FT->getReturnType());
-  SmallVector<Type *, 4> params;
-  for (Type *param : FT->params()) {
-    if (!param->isPointerTy()) {
-      params.emplace_back(LowerMatrixType(param));
-    } else {
-      // Lower pointer type for frexp.
-      Type *EltTy = LowerMatrixType(param->getPointerElementType());
-      params.emplace_back(
-          PointerType::get(EltTy, param->getPointerAddressSpace()));
-    }
-  }
+// Replaces all uses of a matrix value by its lowered vector form,
+// inserting translation stubs for users which still expect a matrix value.
+void HLMatrixLowerPass::replaceAllUsesByLoweredValue(Instruction* MatInst, Value* VecVal) {
+  if (VecVal == nullptr || VecVal == MatInst) return;
 
-  Type *VecFT = FunctionType::get(RetTy, params, false);
+  DXASSERT(HLMatrixType::getLoweredType(MatInst->getType()) == VecVal->getType(),
+    "Unexpected lowered value type.");
 
-  HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-  Function *vecF =
-      GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(VecFT), group,
-                            static_cast<unsigned>(IntrinsicOp::IOP_frexp));
+  Instruction *VecToMatStub = nullptr;
 
-  SmallVector<Value *, 4> argList;
-  auto paramTyIt = params.begin();
-  for (Value *arg : CI->arg_operands()) {
-    Type *Ty = arg->getType();
-    Type *ParamTy = *(paramTyIt++);
+  while (!MatInst->use_empty()) {
+    Use &ValUse = *MatInst->use_begin();
 
-    if (Ty != ParamTy)
-      argList.emplace_back(UndefValue::get(ParamTy));
-    else
-      argList.emplace_back(arg);
-  }
+    // Handle non-matrix cases, just point to the new value.
+    if (MatInst->getType() == VecVal->getType()) {
+      ValUse.set(VecVal);
+      continue;
+    }
 
-  return Builder.CreateCall(vecF, argList);
-}
+    // If the user is already a matrix-to-vector translation stub,
+    // we can now replace it by the proper vector value.
+    if (CallInst *Call = dyn_cast<CallInst>(ValUse.getUser())) {
+      if (m_matToVecStubs->contains(Call->getCalledFunction())) {
+        Call->replaceAllUsesWith(VecVal);
+        ValUse.set(UndefValue::get(MatInst->getType()));
+        addToDeadInsts(Call);
+        continue;
+      }
+    }
 
-Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  unsigned opcode = GetHLOpcode(CI);
+    // Otherwise, the user should point to a vector-to-matrix translation
+    // stub of the new vector value.
+    if (VecToMatStub == nullptr) {
+      FunctionType *TranslationStubTy = FunctionType::get(
+        MatInst->getType(), { VecVal->getType() }, /* isVarArg */ false);
+      Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
 
-  if (opcode == static_cast<unsigned>(IntrinsicOp::IOP_frexp))
-    return MatFrExpToVec(CI);
+      Instruction *PrevInst = dyn_cast<Instruction>(VecVal);
+      if (PrevInst == nullptr) PrevInst = MatInst;
 
-  Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
+      IRBuilder<> Builder(dxilutil::SkipAllocas(PrevInst->getNextNode()));
+      VecToMatStub = Builder.CreateCall(TranslationStub, { VecVal });
+    }
 
-  HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+    ValUse.set(VecToMatStub);
+  }
+}
+
+// Replaces all uses of a matrix or matrix array alloca or global variable by its lowered equivalent.
+// This doesn't lower the users, but will insert a translation stub from the lowered value pointer
+// back to the matrix value pointer, and recreate any GEPs around the new pointer.
+// Before: User(GEP(MatrixArrayAlloca))
+// After: User(VecToMatPtrStub(GEP'(VectorArrayAlloca)))
+void HLMatrixLowerPass::replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr) {
+  DXASSERT_NOMSG(HLMatrixType::isMatrixPtrOrArrayPtr(MatPtr->getType()));
+  DXASSERT_NOMSG(LoweredPtr->getType() == HLMatrixType::getLoweredType(MatPtr->getType()));
+
+  SmallVector<Value*, 4> GEPIdxStack;
+  GEPIdxStack.emplace_back(ConstantInt::get(Type::getInt32Ty(MatPtr->getContext()), 0));
+  replaceAllVariableUses(GEPIdxStack, MatPtr, LoweredPtr);
+}
+
+void HLMatrixLowerPass::replaceAllVariableUses(
+    SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr) {
+  while (!StackTopPtr->use_empty()) {
+    llvm::Use &Use = *StackTopPtr->use_begin();
+    if (GEPOperator *GEP = dyn_cast<GEPOperator>(Use.getUser())) {
+      DXASSERT(GEP->getNumIndices() >= 1, "Unexpected degenerate GEP.");
+      DXASSERT(cast<ConstantInt>(*GEP->idx_begin())->isZero(), "Unexpected non-zero first GEP index.");
+
+      // Recurse in GEP to find actual users
+      for (auto It = GEP->idx_begin() + 1; It != GEP->idx_end(); ++It)
+        GEPIdxStack.emplace_back(*It);
+      replaceAllVariableUses(GEPIdxStack, GEP, LoweredPtr);
+      GEPIdxStack.erase(GEPIdxStack.end() - (GEP->getNumIndices() - 1), GEPIdxStack.end());
+      
+      // Discard the GEP
+      DXASSERT_NOMSG(GEP->use_empty());
+      if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP)) {
+        Use.set(UndefValue::get(Use->getType()));
+        addToDeadInsts(GEPInst);
+      } else {
+        // constant GEP
+        cast<Constant>(GEP)->destroyConstant();
+      }
+      continue;
+    }
 
-  Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Use.getUser())) {
+      DXASSERT(CE->getOpcode() == Instruction::AddrSpaceCast,
+               "Unexpected constant user");
+      replaceAllVariableUses(GEPIdxStack, CE, LoweredPtr);
+      DXASSERT_NOMSG(CE->use_empty());
+      CE->destroyConstant();
+      continue;
+    }
 
-  SmallVector<Value *, 4> argList;
-  for (Value *arg : CI->arg_operands()) {
-    Type *Ty = arg->getType();
-    if (dxilutil::IsHLSLMatrixType(Ty)) {
-      argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
-    } else
-      argList.emplace_back(arg);
-  }
+    if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(Use.getUser())) {
+      replaceAllVariableUses(GEPIdxStack, CI, LoweredPtr);
+      Use.set(UndefValue::get(Use->getType()));
+      addToDeadInsts(CI);
+      continue;
+    }
 
-  return Builder.CreateCall(vecF, argList);
-}
+    // Recreate the same GEP sequence, if any, on the lowered pointer
+    IRBuilder<> Builder(cast<Instruction>(Use.getUser()));
+    Value *LoweredStackTopPtr = GEPIdxStack.size() == 1
+      ? LoweredPtr : Builder.CreateGEP(LoweredPtr, GEPIdxStack);
 
-Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
-  Type *ResultTy = LowerMatrixType(CI->getType());
-  UndefValue *tmp = UndefValue::get(ResultTy);
-  IRBuilder<> Builder(CI);
-  HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
-  bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
-
-  Constant *one = isFloat
-    ? ConstantFP::get(ResultTy->getVectorElementType(), 1)
-    : ConstantInt::get(ResultTy->getVectorElementType(), 1);
-  Constant *oneVec = ConstantVector::getSplat(ResultTy->getVectorNumElements(), one);
-
-  Instruction *Result = nullptr;
-  switch (opcode) {
-  case HLUnaryOpcode::Plus: {
-    // This is actually a no-op, but the structure of the code here requires
-    // that we create an instruction.
-    Constant *zero = Constant::getNullValue(ResultTy);
-    if (isFloat)
-      Result = BinaryOperator::CreateFAdd(tmp, zero);
-    else
-      Result = BinaryOperator::CreateAdd(tmp, zero);
-  } break;
-  case HLUnaryOpcode::Minus: {
-    Constant *zero = Constant::getNullValue(ResultTy);
-    if (isFloat)
-      Result = BinaryOperator::CreateFSub(zero, tmp);
-    else
-      Result = BinaryOperator::CreateSub(zero, tmp);
-  } break;
-  case HLUnaryOpcode::LNot: {
-    Constant *zero = Constant::getNullValue(ResultTy);
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UEQ, tmp, zero);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, zero);
-  } break;
-  case HLUnaryOpcode::Not: {
-    Constant *allOneBits = Constant::getAllOnesValue(ResultTy);
-    Result = BinaryOperator::CreateXor(tmp, allOneBits);
-  } break;
-  case HLUnaryOpcode::PostInc:
-  case HLUnaryOpcode::PreInc:
-    if (isFloat)
-      Result = BinaryOperator::CreateFAdd(tmp, oneVec);
-    else
-      Result = BinaryOperator::CreateAdd(tmp, oneVec);
-    break;
-  case HLUnaryOpcode::PostDec:
-  case HLUnaryOpcode::PreDec:
-    if (isFloat)
-      Result = BinaryOperator::CreateFSub(tmp, oneVec);
-    else
-      Result = BinaryOperator::CreateSub(tmp, oneVec);
-    break;
-  default:
-    DXASSERT(0, "not implement");
-    return nullptr;
+    // Generate a stub translating the vector pointer back to a matrix pointer,
+    // such that consuming instructions are unaffected.
+    FunctionType *TranslationStubTy = FunctionType::get(
+      StackTopPtr->getType(), { LoweredStackTopPtr->getType() }, /* isVarArg */ false);
+    Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
+    Use.set(Builder.CreateCall(TranslationStub, { LoweredStackTopPtr }));
   }
-  Builder.Insert(Result);
-  return Result;
 }
 
-Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
-  Type *ResultTy = LowerMatrixType(CI->getType());
-  IRBuilder<> Builder(CI);
-  HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
-  Type *OpTy = LowerMatrixType(
-      CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
-  UndefValue *tmp = UndefValue::get(OpTy);
-  bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
+void HLMatrixLowerPass::lowerGlobal(GlobalVariable *Global) {
+  if (Global->user_empty()) return;
 
-  Instruction *Result = nullptr;
+  PointerType *LoweredPtrTy = cast<PointerType>(HLMatrixType::getLoweredType(Global->getType()));
+  DXASSERT_NOMSG(LoweredPtrTy != Global->getType());
 
-  switch (opcode) {
-  case HLBinaryOpcode::Add:
-    if (isFloat)
-      Result = BinaryOperator::CreateFAdd(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateAdd(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Sub:
-    if (isFloat)
-      Result = BinaryOperator::CreateFSub(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateSub(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Mul:
-    if (isFloat)
-      Result = BinaryOperator::CreateFMul(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateMul(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Div:
-    if (isFloat)
-      Result = BinaryOperator::CreateFDiv(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateSDiv(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Rem:
-    if (isFloat)
-      Result = BinaryOperator::CreateFRem(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateSRem(tmp, tmp);
-    break;
-  case HLBinaryOpcode::And:
-    Result = BinaryOperator::CreateAnd(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Or:
-    Result = BinaryOperator::CreateOr(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Xor:
-    Result = BinaryOperator::CreateXor(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Shl: {
-    Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
-                      "must be matrix type here");
-    Result = BinaryOperator::CreateShl(tmp, tmp);
-  } break;
-  case HLBinaryOpcode::Shr: {
-    Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
-                      "must be matrix type here");
-    Result = BinaryOperator::CreateAShr(tmp, tmp);
-  } break;
-  case HLBinaryOpcode::LT:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
-    break;
-  case HLBinaryOpcode::GT:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
-    break;
-  case HLBinaryOpcode::LE:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::GE:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::EQ:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
-    break;
-  case HLBinaryOpcode::NE:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::UDiv:
-    Result = BinaryOperator::CreateUDiv(tmp, tmp);
-    break;
-  case HLBinaryOpcode::URem:
-    Result = BinaryOperator::CreateURem(tmp, tmp);
-    break;
-  case HLBinaryOpcode::UShr: {
-    Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
-                      "must be matrix type here");
-    Result = BinaryOperator::CreateLShr(tmp, tmp);
-  } break;
-  case HLBinaryOpcode::ULT:
-    Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
-    break;
-  case HLBinaryOpcode::UGT:
-    Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
-    break;
-  case HLBinaryOpcode::ULE:
-    Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::UGE:
-    Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::LAnd:
-  case HLBinaryOpcode::LOr: {
-    Value *vecZero = Constant::getNullValue(ResultTy);
-    Instruction *cmpL;
-    if (isFloat)
-      cmpL = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, vecZero);
-    else
-      cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, vecZero);
-    Builder.Insert(cmpL);
-
-    Instruction *cmpR;
-    if (isFloat)
-      cmpR =
-          CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, vecZero);
-    else
-      cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, vecZero);
-    Builder.Insert(cmpR);
-
-    // How to map l, r back? Need check opcode
-    if (opcode == HLBinaryOpcode::LOr)
-      Result = BinaryOperator::CreateOr(cmpL, cmpR);
-    else
-      Result = BinaryOperator::CreateAnd(cmpL, cmpR);
-    break;
-  }
-  default:
-    DXASSERT(0, "not implement");
-    return nullptr;
-  }
-  Builder.Insert(Result);
-  return Result;
-}
+  Constant *LoweredInitVal = Global->hasInitializer()
+    ? lowerConstInitVal(Global->getInitializer()) : nullptr;
+  GlobalVariable *LoweredGlobal = new GlobalVariable(*m_pModule, LoweredPtrTy->getElementType(),
+    Global->isConstant(), Global->getLinkage(), LoweredInitVal,
+    Global->getName() + ".v", /*InsertBefore*/ nullptr, Global->getThreadLocalMode(),
+    Global->getType()->getAddressSpace());
 
-// Create BitCast if ptr, otherwise, create alloca of new type, write to bitcast of alloca, and return load from alloca
-// If bOrigAllocaTy is true: create alloca of old type instead, write to alloca, and return load from bitcast of alloca
-static Instruction *BitCastValueOrPtr(Value* V, Instruction *Insert, Type *Ty, bool bOrigAllocaTy = false, const Twine &Name = "") {
-  IRBuilder<> Builder(Insert);
-  if (Ty->isPointerTy()) {
-    // If pointer, we can bitcast directly
-    return cast<Instruction>(Builder.CreateBitCast(V, Ty, Name));
-  } else {
-    // If value, we have to alloca, store to bitcast ptr, and load
-    IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Insert));
-    Type *allocaTy = bOrigAllocaTy ? V->getType() : Ty;
-    Type *otherTy = bOrigAllocaTy ? Ty : V->getType();
-    Instruction *allocaInst = AllocaBuilder.CreateAlloca(allocaTy);
-    Instruction *bitCast = cast<Instruction>(Builder.CreateBitCast(allocaInst, otherTy->getPointerTo()));
-    Builder.CreateStore(V, bOrigAllocaTy ? allocaInst : bitCast);
-    return Builder.CreateLoad(bOrigAllocaTy ? bitCast : allocaInst, Name);
+  // Add debug info.
+  if (m_HasDbgInfo) {
+    DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
+    HLModule::UpdateGlobalVariableDebugInfo(Global, Finder, LoweredGlobal);
   }
+
+  replaceAllVariableUses(Global, LoweredGlobal);
+  Global->removeDeadConstantUsers();
+  Global->eraseFromParent();
 }
 
-void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
-  Value *vecVal = nullptr;
-
-  if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
-    hlsl::HLOpcodeGroup group =
-        hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
-    switch (group) {
-    case HLOpcodeGroup::HLIntrinsic: {
-      vecVal = MatIntrinsicToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLSelect: {
-      vecVal = MatIntrinsicToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLBinOp: {
-      vecVal = TrivialMatBinOpToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLUnOp: {
-      vecVal = TrivialMatUnOpToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLCast: {
-      vecVal = MatCastToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLInit: {
-      vecVal = MatIntrinsicToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLMatLoadStore: {
-      vecVal = MatLdStToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLSubscript: {
-      vecVal = MatSubscriptToVec(CI);
-    } break;
-    case HLOpcodeGroup::NotHL: {
-      // Translate user function return
-      vecVal = BitCastValueOrPtr( matInst,
-                                  matInst->getNextNode(),
-                                  HLMatrixLower::LowerMatrixType(matInst->getType()),
-                                  /*bOrigAllocaTy*/ false,
-                                  matInst->getName());
-      // matrix equivalent of this new vector will be the original, retained user call
-      vecToMatMap[vecVal] = matInst;
-    } break;
-    default:
-      DXASSERT(0, "invalid inst");
-    }
-  } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
-    Type *Ty = AI->getAllocatedType();
-    Type *matTy = Ty;
-    
-    IRBuilder<> AllocaBuilder(AI);
-    if (Ty->isArrayTy()) {
-      Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType(), /*forMem*/ true);
-      vecTy = vecTy->getPointerElementType();
-      vecVal = AllocaBuilder.CreateAlloca(vecTy, nullptr, AI->getName());
-    } else {
-      Type *vecTy = HLMatrixLower::LowerMatrixType(matTy, /*forMem*/ true);
-      vecVal = AllocaBuilder.CreateAlloca(vecTy, nullptr, AI->getName());
-    }
-    // Update debug info.
-    DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
-    if (DDI) {
-      LLVMContext &Context = AI->getContext();
-      Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
-      Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
-      Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecVal));
-      IRBuilder<> debugBuilder(DDI);
-      debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
+Constant *HLMatrixLowerPass::lowerConstInitVal(Constant *Val) {
+  Type *Ty = Val->getType();
+
+  // If it's an array of matrices, recurse for each element or nested array
+  if (ArrayType *ArrayTy = dyn_cast<ArrayType>(Ty)) {
+    SmallVector<Constant*, 4> LoweredElems;
+    unsigned NumElems = ArrayTy->getNumElements();
+    LoweredElems.reserve(NumElems);
+    for (unsigned ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
+      Constant *ArrayElem = Val->getAggregateElement(ElemIdx);
+      LoweredElems.emplace_back(lowerConstInitVal(ArrayElem));
     }
 
-    if (HLModule::HasPreciseAttributeWithMetadata(AI))
-      HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
-
-  } else if (GetIfMatrixGEPOfUDTAlloca(matInst) ||
-             GetIfMatrixGEPOfUDTArg(matInst, *m_pHLModule)) {
-    // If GEP from alloca of non-matrix UDT, bitcast
-    IRBuilder<> Builder(matInst->getNextNode());
-    vecVal = Builder.CreateBitCast(matInst,
-      HLMatrixLower::LowerMatrixType(
-        matInst->getType()->getPointerElementType() )->getPointerTo());
-    // matrix equivalent of this new vector will be the original, retained GEP
-    vecToMatMap[vecVal] = matInst;
-  } else {
-    DXASSERT(0, "invalid inst");
+    Type *LoweredElemTy = HLMatrixType::getLoweredType(ArrayTy->getElementType());
+    ArrayType *LoweredArrayTy = ArrayType::get(LoweredElemTy, NumElems);
+    return ConstantArray::get(LoweredArrayTy, LoweredElems);
   }
-  if (vecVal) {
-    matToVecMap[matInst] = vecVal;
-  }
-}
 
-// Replace matInst with vecVal on matUseInst.
-void HLMatrixLowerPass::TrivialMatUnOpReplace(Value *matVal,
-                                             Value *vecVal,
-                                             CallInst *matUseInst) {
-  (void)(matVal); // Unused
-  HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
-  switch (opcode) {
-  case HLUnaryOpcode::Plus: // add(x, 0)
-    // Ideally we'd get completely rid of the instruction for +mat,
-    // but matToVecMap needs to point to some instruction.
-  case HLUnaryOpcode::Not: // xor(x, -1)
-  case HLUnaryOpcode::LNot: // cmpeq(x, 0)
-  case HLUnaryOpcode::PostInc:
-  case HLUnaryOpcode::PreInc:
-  case HLUnaryOpcode::PostDec:
-  case HLUnaryOpcode::PreDec:
-    vecUseInst->setOperand(0, vecVal);
-    break;
-  case HLUnaryOpcode::Minus: // sub(0, x)
-    vecUseInst->setOperand(1, vecVal);
-    break;
-  case HLUnaryOpcode::Invalid:
-  case HLUnaryOpcode::NumOfUO:
-    DXASSERT(false, "Unexpected HL unary opcode.");
-    break;
-  }
-}
+  // Otherwise it's a matrix, lower it to a vector
+  HLMatrixType MatTy = HLMatrixType::cast(Ty);
+  DXASSERT_NOMSG(isa<StructType>(Ty));
+  Constant *RowArrayVal = Val->getAggregateElement((unsigned)0);
 
-// Replace matInst with vecVal on matUseInst.
-void HLMatrixLowerPass::TrivialMatBinOpReplace(Value *matVal,
-                                              Value *vecVal,
-                                              CallInst *matUseInst) {
-  HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
-
-  if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
-    if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matVal)
-      vecUseInst->setOperand(0, vecVal);
-    if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matVal)
-      vecUseInst->setOperand(1, vecVal);
-  } else {
-    if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
-      matVal) {
-      Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
-      vecCmp->setOperand(0, vecVal);
-    }
-    if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
-      matVal) {
-      Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
-      vecCmp->setOperand(0, vecVal);
+  // Original initializer should have been produced in row/column-major order
+  // depending on the qualifiers of the target variable, so preserve the order.
+  SmallVector<Constant*, 16> MatElems;
+  for (unsigned RowIdx = 0; RowIdx < MatTy.getNumRows(); ++RowIdx) {
+    Constant *RowVal = RowArrayVal->getAggregateElement(RowIdx);
+    for (unsigned ColIdx = 0; ColIdx < MatTy.getNumColumns(); ++ColIdx) {
+      MatElems.emplace_back(RowVal->getAggregateElement(ColIdx));
     }
   }
+
+  Constant *Vec = ConstantVector::get(MatElems);
+  
+  // Matrix elements are always in register representation,
+  // but the lowered global variable is of vector type in
+  // its memory representation, so we must convert here.
+
+  // This will produce a constant so we can use an IRBuilder without a valid insertion point.
+  IRBuilder<> DummyBuilder(Val->getContext());
+  return cast<Constant>(MatTy.emitLoweredRegToMem(Vec, DummyBuilder));
 }
 
-static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
-  llvm::FunctionType *MadFuncTy =
-      llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
+AllocaInst *HLMatrixLowerPass::lowerAlloca(AllocaInst *MatAlloca) {
+  PointerType *LoweredAllocaTy = cast<PointerType>(HLMatrixType::getLoweredType(MatAlloca->getType()));
 
-  Function *MAD =
-      GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
-                            (unsigned)madOp);
-  return MAD;
+  IRBuilder<> Builder(MatAlloca);
+  AllocaInst *LoweredAlloca = Builder.CreateAlloca(
+    LoweredAllocaTy->getElementType(), nullptr, MatAlloca->getName());
+
+  // Update debug info.
+  if (DbgDeclareInst *DbgDeclare = llvm::FindAllocaDbgDeclare(MatAlloca)) {
+    LLVMContext &Context = MatAlloca->getContext();
+    Value *DbgDeclareVar = MetadataAsValue::get(Context, DbgDeclare->getRawVariable());
+    Value *DbgDeclareExpr = MetadataAsValue::get(Context, DbgDeclare->getRawExpression());
+    Value *ValueMetadata = MetadataAsValue::get(Context, ValueAsMetadata::get(LoweredAlloca));
+    IRBuilder<> DebugBuilder(DbgDeclare);
+    DebugBuilder.CreateCall(DbgDeclare->getCalledFunction(), { ValueMetadata, DbgDeclareVar, DbgDeclareExpr });
+  }
+
+  if (HLModule::HasPreciseAttributeWithMetadata(MatAlloca))
+    HLModule::MarkPreciseAttributeWithMetadata(LoweredAlloca);
+
+  replaceAllVariableUses(MatAlloca, LoweredAlloca);
+
+  return LoweredAlloca;
 }
 
-void HLMatrixLowerPass::TranslateMatMatMul(Value *matVal,
-                                           Value *vecVal,
-                                           CallInst *mulInst, bool isSigned) {
-  (void)(matVal); // Unused; retrieved from matToVecMap directly
-  DXASSERT(matToVecMap.count(mulInst), "must have vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
-  // Already translated.
-  if (!isa<CallInst>(vecUseInst))
-    return;
-  Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
-  Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
-  unsigned rCol, rRow;
-  GetMatrixInfo(RVal->getType(), rCol, rRow);
-  DXASSERT_NOMSG(col == rRow);
-
-  bool isFloat = EltTy->isFloatingPointTy();
-
-  Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
-  IRBuilder<> Builder(vecUseInst);
-
-  Value *lMat = matToVecMap[cast<Instruction>(LVal)];
-  Value *rMat = matToVecMap[cast<Instruction>(RVal)];
-
-  auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
-    unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
-    unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
-    Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
-    Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
-    return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
-                   : Builder.CreateMul(lMatElt, rMatElt);
-  };
-
-  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
-  Type *opcodeTy = Builder.getInt32Ty();
-  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
-                                          *m_pHLModule->GetModule());
-  Value *madOpArg = Builder.getInt32((unsigned)madOp);
-
-  auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
-                             Value *acc) -> Value * {
-    unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
-    unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
-    Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
-    Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
-    return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
-  };
-
-  for (unsigned r = 0; r < row; r++) {
-    for (unsigned c = 0; c < rCol; c++) {
-      unsigned lc = 0;
-      Value *tmpVal = CreateOneEltMul(r, lc, c);
-
-      for (lc = 1; lc < col; lc++) {
-        tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
-      }
-      unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
-      retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
+void HLMatrixLowerPass::lowerInstruction(Instruction* Inst) {
+  if (CallInst *Call = dyn_cast<CallInst>(Inst)) {
+    Value *LoweredValue = lowerCall(Call);
+
+    // lowerCall returns the lowered value iff we should discard
+    // the original matrix instruction and replace all of its uses
+    // by the lowered value. It returns nullptr to opt-out of this.
+    if (LoweredValue != nullptr) {
+      replaceAllUsesByLoweredValue(Call, LoweredValue);
+      addToDeadInsts(Inst);
     }
   }
+  else if (ReturnInst *Return = dyn_cast<ReturnInst>(Inst)) {
+    lowerReturn(Return);
+  }
+  else
+    llvm_unreachable("Unexpected matrix instruction type.");
+}
+
+void HLMatrixLowerPass::lowerReturn(ReturnInst* Return) {
+  Value *RetVal = Return->getReturnValue();
+  Type *RetTy = RetVal->getType();
+  DXASSERT_LOCALVAR(RetTy, !RetTy->isPointerTy(), "Unexpected matrix returned by pointer.");
+
+  IRBuilder<> Builder(Return);
+  Value *LoweredRetVal = getLoweredByValOperand(RetVal, Builder, /* DiscardStub */ true);
+
+  // Since we're not lowering the signature, we can't return the lowered value directly,
+  // so insert a bitcast, which HLMatrixBitcastLower knows how to eliminate.
+  Value *BitCastedRetVal = bitCastValue(LoweredRetVal, RetVal->getType(), /* DstTyAlloca */ false, Builder);
+  Return->setOperand(0, BitCastedRetVal);
+}
+
+Value *HLMatrixLowerPass::lowerCall(CallInst *Call) {
+  HLOpcodeGroup OpcodeGroup = GetHLOpcodeGroupByName(Call->getCalledFunction());
+  return OpcodeGroup == HLOpcodeGroup::NotHL
+    ? lowerNonHLCall(Call) : lowerHLOperation(Call, OpcodeGroup);
+}
+
+Value *HLMatrixLowerPass::lowerNonHLCall(CallInst *Call) {
+  // First, handle any operand of matrix-derived type
+  // We don't lower the callee's signature in this pass,
+  // so, for any matrix-typed parameter, we create a bitcast from the
+  // lowered vector back to the matrix type, which the later HLMatrixBitcastLower
+  // pass knows how to eliminate.
+  IRBuilder<> PreCallBuilder(Call);
+  unsigned NumArgs = Call->getNumArgOperands();
+  for (unsigned ArgIdx = 0; ArgIdx < NumArgs; ++ArgIdx) {
+    Use &ArgUse = Call->getArgOperandUse(ArgIdx);
+    if (ArgUse->getType()->isPointerTy()) {
+      // Byref arg
+      Value *LoweredArg = tryGetLoweredPtrOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
+      if (LoweredArg != nullptr) {
+        // Pointer to a matrix we've lowered, insert a bitcast back to matrix pointer type.
+        Value *BitCastedArg = PreCallBuilder.CreateBitCast(LoweredArg, ArgUse->getType());
+        ArgUse.set(BitCastedArg);
+      }
+    }
+    else {
+      // Byvalue arg
+      Value *LoweredArg = getLoweredByValOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
+      if (LoweredArg == ArgUse.get()) continue;
 
-  Instruction *matmatMul = cast<Instruction>(retVal);
-  // Replace vec transpose function call with shuf.
-  vecUseInst->replaceAllUsesWith(matmatMul);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[mulInst] = matmatMul;
-}
+      Value *BitCastedArg = bitCastValue(LoweredArg, ArgUse->getType(), /* DstTyAlloca */ false, PreCallBuilder);
+      ArgUse.set(BitCastedArg);
+    }
+  }
 
-void HLMatrixLowerPass::TranslateMatVecMul(Value *matVal,
-                                           Value *vecVal,
-                                           CallInst *mulInst, bool isSigned) {
-  // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
-  Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
+  // Now check the return type
+  HLMatrixType RetMatTy = HLMatrixType::dyn_cast(Call->getType());
+  if (!RetMatTy) {
+    DXASSERT(!HLMatrixType::isMatrixPtrOrArrayPtr(Call->getType()),
+      "Unexpected user call returning a matrix by pointer.");
+    // Nothing to replace, other instructions can consume a non-matrix return type.
+    return nullptr;
+  }
 
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
-  DXASSERT_NOMSG(RVal->getType()->getVectorNumElements() == col);
+  // The callee returns a matrix, and we don't lower signatures in this pass.
+  // We perform a sketchy bitcast to the lowered register-representation type,
+  // which the later HLMatrixBitcastLower pass knows how to eliminate.
+  IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Call));
+  Value *LoweredAlloca = AllocaBuilder.CreateAlloca(RetMatTy.getLoweredVectorTypeForReg());
+  
+  IRBuilder<> PostCallBuilder(Call->getNextNode());
+  Value *BitCastedAlloca = PostCallBuilder.CreateBitCast(LoweredAlloca, Call->getType()->getPointerTo());
+  
+  // This is slightly tricky
+  // We want to replace all uses of the matrix-returning call by the bitcasted value,
+  // but the store to the bitcasted pointer itself is a use of that matrix,
+  // so we need to create the load, replace the uses, and then insert the store.
+  LoadInst *LoweredVal = PostCallBuilder.CreateLoad(LoweredAlloca);
+  replaceAllUsesByLoweredValue(Call, LoweredVal);
+
+  // Now we can insert the store. Make sure to do so before the load.
+  PostCallBuilder.SetInsertPoint(LoweredVal);
+  PostCallBuilder.CreateStore(Call, BitCastedAlloca);
+  
+  // Return nullptr since we did our own uses replacement and we don't want
+  // the matrix instruction to be marked as dead since we're still using it.
+  return nullptr;
+}
 
-  bool isFloat = EltTy->isFloatingPointTy();
+Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup) {
+  IRBuilder<> Builder(Call);
+  switch (OpcodeGroup) {
+  case HLOpcodeGroup::HLIntrinsic:
+    return lowerHLIntrinsic(Call, static_cast<IntrinsicOp>(GetHLOpcode(Call)));
 
-  Value *retVal = llvm::UndefValue::get(mulInst->getType());
-  IRBuilder<> Builder(mulInst);
+  case HLOpcodeGroup::HLBinOp:
+    return lowerHLBinaryOperation(
+      Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
+      Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
+      static_cast<HLBinaryOpcode>(GetHLOpcode(Call)), Builder);
 
-  Value *vec = RVal;
-  Value *mat = vecVal; // vec version of matInst;
+  case HLOpcodeGroup::HLUnOp:
+    return lowerHLUnaryOperation(
+      Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx),
+      static_cast<HLUnaryOpcode>(GetHLOpcode(Call)), Builder);
 
-  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
-  Type *opcodeTy = Builder.getInt32Ty();
-  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
-                                          *m_pHLModule->GetModule());
-  Value *madOpArg = Builder.getInt32((unsigned)madOp);
+  case HLOpcodeGroup::HLMatLoadStore:
+    return lowerHLLoadStore(Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
 
-  auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
-    Value *vecElt = Builder.CreateExtractElement(vec, c);
-    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
-    return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
-  };
+  case HLOpcodeGroup::HLCast:
+    return lowerHLCast(
+      Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Call->getType(),
+      static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
 
-  for (unsigned r = 0; r < row; r++) {
-    unsigned c = 0;
-    Value *vecElt = Builder.CreateExtractElement(vec, c);
-    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
+  case HLOpcodeGroup::HLSubscript:
+    return lowerHLSubscript(Call, static_cast<HLSubscriptOpcode>(GetHLOpcode(Call)));
 
-    Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
-                            : Builder.CreateMul(vecElt, matElt);
+  case HLOpcodeGroup::HLInit:
+    return lowerHLInit(Call);
 
-    for (c = 1; c < col; c++) {
-      tmpVal = CreateOneEltMad(r, c, tmpVal);
-    }
+  case HLOpcodeGroup::HLSelect:
+    return lowerHLSelect(Call);
 
-    retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
+  default:
+    llvm_unreachable("Unexpected matrix opcode");
   }
-
-  mulInst->replaceAllUsesWith(retVal);
-  AddToDeadInsts(mulInst);
 }
 
-void HLMatrixLowerPass::TranslateVecMatMul(Value *matVal,
-                                           Value *vecVal,
-                                           CallInst *mulInst, bool isSigned) {
-  Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
-  // matVal should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-  Value *RVal = vecVal;
-
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
-  DXASSERT_NOMSG(LVal->getType()->getVectorNumElements() == row);
-
-  bool isFloat = EltTy->isFloatingPointTy();
-
-  Value *retVal = llvm::UndefValue::get(mulInst->getType());
-  IRBuilder<> Builder(mulInst);
-
-  Value *vec = LVal;
-  Value *mat = RVal;
-
-  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
-  Type *opcodeTy = Builder.getInt32Ty();
-  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
-                                          *m_pHLModule->GetModule());
-  Value *madOpArg = Builder.getInt32((unsigned)madOp);
-
-  auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
-    Value *vecElt = Builder.CreateExtractElement(vec, r);
-    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
-    return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
-  };
-
-  for (unsigned c = 0; c < col; c++) {
-    unsigned r = 0;
-    Value *vecElt = Builder.CreateExtractElement(vec, r);
-    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
-
-    Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
-                            : Builder.CreateMul(vecElt, matElt);
-
-    for (r = 1; r < row; r++) {
-      tmpVal = CreateOneEltMad(r, c, tmpVal);
-    }
+static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
+  Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
+  SmallVector<Type*, 4> ArgTys;
+  ArgTys.reserve(Args.size());
+  for (Value *Arg : Args)
+    ArgTys.emplace_back(Arg->getType());
 
-    retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
-  }
+  FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
+  Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
 
-  mulInst->replaceAllUsesWith(retVal);
-  AddToDeadInsts(mulInst);
+  return Builder.CreateCall(Func, Args);
 }
 
-void HLMatrixLowerPass::TranslateMul(Value *matVal, Value *vecVal,
-                                     CallInst *mulInst, bool isSigned) {
-  Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
-  Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-
-  bool LMat = dxilutil::IsHLSLMatrixType(LVal->getType());
-  bool RMat = dxilutil::IsHLSLMatrixType(RVal->getType());
-  if (LMat && RMat) {
-    TranslateMatMatMul(matVal, vecVal, mulInst, isSigned);
-  } else if (LMat) {
-    TranslateMatVecMul(matVal, vecVal, mulInst, isSigned);
-  } else {
-    TranslateVecMatMul(matVal, vecVal, mulInst, isSigned);
-  }
-}
+Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
+  IRBuilder<> Builder(Call);
 
-void HLMatrixLowerPass::TranslateMatTranspose(Value *matVal,
-                                              Value *vecVal,
-                                              CallInst *transposeInst) {
-  // Matrix value is row major, transpose is cast it to col major.
-  TranslateMatMajorCast(matVal, vecVal, transposeInst,
-      /*bRowToCol*/ true, /*bTranspose*/ true);
-}
+  // See if this is a matrix-specific intrinsic which we should expand here
+  switch (Opcode) {
+  case IntrinsicOp::IOP_umul:
+  case IntrinsicOp::IOP_mul:
+    return lowerHLMulIntrinsic(
+      Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
+      Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
+      /* Unsigned */ Opcode == IntrinsicOp::IOP_umul, Builder);
+  case IntrinsicOp::IOP_transpose:
+    return lowerHLTransposeIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
+  case IntrinsicOp::IOP_determinant:
+    return lowerHLDeterminantIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
+  }
+
+  // Delegate to a lowered intrinsic call
+  SmallVector<Value*, 4> LoweredArgs;
+  LoweredArgs.reserve(Call->getNumArgOperands());
+  for (Value *Arg : Call->arg_operands()) {
+    if (Arg->getType()->isPointerTy()) {
+      // ByRef parameter (for example, frexp's second parameter)
+      // If the argument points to a lowered matrix variable, replace it here,
+      // otherwise preserve the matrix type and let further passes handle the lowering.
+      Value *LoweredArg = tryGetLoweredPtrOperand(Arg, Builder);
+      if (LoweredArg == nullptr) LoweredArg = Arg;
+      LoweredArgs.emplace_back(LoweredArg);
+    }
+    else {
+      LoweredArgs.emplace_back(getLoweredByValOperand(Arg, Builder));
+    }
+  }
 
-static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
-                             IRBuilder<> &Builder) {
-  Value *mul0 = Builder.CreateFMul(m00, m11);
-  Value *mul1 = Builder.CreateFMul(m01, m10);
-  return Builder.CreateFSub(mul0, mul1);
+  Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
+  return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode), 
+    LoweredRetTy, LoweredArgs, Builder);
 }
 
-static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
-                             Value *m10, Value *m11, Value *m12,
-                             Value *m20, Value *m21, Value *m22,
-                             IRBuilder<> &Builder) {
-  Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
-  Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
-  Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
-  deter00 = Builder.CreateFMul(m00, deter00);
-  deter01 = Builder.CreateFMul(m01, deter01);
-  deter02 = Builder.CreateFMul(m02, deter02);
-  Value *result = Builder.CreateFSub(deter00, deter01);
-  result = Builder.CreateFAdd(result, deter02);
-  return result;
-}
+Value *HLMatrixLowerPass::lowerHLMulIntrinsic(Value* Lhs, Value *Rhs,
+    bool Unsigned, IRBuilder<> &Builder) {
+  HLMatrixType LhsMatTy = HLMatrixType::dyn_cast(Lhs->getType());
+  HLMatrixType RhsMatTy = HLMatrixType::dyn_cast(Rhs->getType());
+  Value* LoweredLhs = getLoweredByValOperand(Lhs, Builder);
+  Value* LoweredRhs = getLoweredByValOperand(Rhs, Builder);
 
-static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
-                             Value *m10, Value *m11, Value *m12, Value *m13,
-                             Value *m20, Value *m21, Value *m22, Value *m23,
-                             Value *m30, Value *m31, Value *m32, Value *m33,
-                             IRBuilder<> &Builder) {
-  Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
-  Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
-  Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
-  Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
-  deter00 = Builder.CreateFMul(m00, deter00);
-  deter01 = Builder.CreateFMul(m01, deter01);
-  deter02 = Builder.CreateFMul(m02, deter02);
-  deter03 = Builder.CreateFMul(m03, deter03);
-  Value *result = Builder.CreateFSub(deter00, deter01);
-  result = Builder.CreateFAdd(result, deter02);
-  result = Builder.CreateFSub(result, deter03);
-  return result;
-}
+  DXASSERT(LoweredLhs->getType()->getScalarType() == LoweredRhs->getType()->getScalarType(),
+    "Unexpected element type mismatch in mul intrinsic.");
+  DXASSERT(cast<VectorType>(LoweredLhs->getType()) && cast<VectorType>(LoweredLhs->getType()),
+    "Unexpected scalar in lowered matrix mul intrinsic operands.");
 
+  Type* ElemTy = LoweredLhs->getType()->getScalarType();
 
-void HLMatrixLowerPass::TranslateMatDeterminant(Value *matVal, Value *vecVal,
-    CallInst *determinantInst) {
-  unsigned row, col;
-  GetMatrixInfo(matVal->getType(), col, row);
-  IRBuilder<> Builder(determinantInst);
-  // when row == 1, result is vecVal.
-  Value *Result = vecVal;
-  if (row == 2) {
-    Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
-    Value *m01 = Builder.CreateExtractElement(vecVal, 1);
-    Value *m10 = Builder.CreateExtractElement(vecVal, 2);
-    Value *m11 = Builder.CreateExtractElement(vecVal, 3);
-    Result = Determinant2x2(m00, m01, m10, m11, Builder);
+  // Figure out the dimensions of each side
+  unsigned LhsNumRows, LhsNumCols, RhsNumRows, RhsNumCols;
+  if (LhsMatTy && RhsMatTy) {
+    LhsNumRows = LhsMatTy.getNumRows();
+    LhsNumCols = LhsMatTy.getNumColumns();
+    RhsNumRows = RhsMatTy.getNumRows();
+    RhsNumCols = RhsMatTy.getNumColumns();
   }
-  else if (row == 3) {
-    Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
-    Value *m01 = Builder.CreateExtractElement(vecVal, 1);
-    Value *m02 = Builder.CreateExtractElement(vecVal, 2);
-    Value *m10 = Builder.CreateExtractElement(vecVal, 3);
-    Value *m11 = Builder.CreateExtractElement(vecVal, 4);
-    Value *m12 = Builder.CreateExtractElement(vecVal, 5);
-    Value *m20 = Builder.CreateExtractElement(vecVal, 6);
-    Value *m21 = Builder.CreateExtractElement(vecVal, 7);
-    Value *m22 = Builder.CreateExtractElement(vecVal, 8);
-    Result = Determinant3x3(m00, m01, m02, 
-                            m10, m11, m12, 
-                            m20, m21, m22, Builder);
+  else if (LhsMatTy) {
+    LhsNumRows = LhsMatTy.getNumRows();
+    LhsNumCols = LhsMatTy.getNumColumns();
+    RhsNumRows = LoweredRhs->getType()->getVectorNumElements();
+    RhsNumCols = 1;
   }
-  else if (row == 4) {
-    Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
-    Value *m01 = Builder.CreateExtractElement(vecVal, 1);
-    Value *m02 = Builder.CreateExtractElement(vecVal, 2);
-    Value *m03 = Builder.CreateExtractElement(vecVal, 3);
-
-    Value *m10 = Builder.CreateExtractElement(vecVal, 4);
-    Value *m11 = Builder.CreateExtractElement(vecVal, 5);
-    Value *m12 = Builder.CreateExtractElement(vecVal, 6);
-    Value *m13 = Builder.CreateExtractElement(vecVal, 7);
-
-    Value *m20 = Builder.CreateExtractElement(vecVal, 8);
-    Value *m21 = Builder.CreateExtractElement(vecVal, 9);
-    Value *m22 = Builder.CreateExtractElement(vecVal, 10);
-    Value *m23 = Builder.CreateExtractElement(vecVal, 11);
-
-    Value *m30 = Builder.CreateExtractElement(vecVal, 12);
-    Value *m31 = Builder.CreateExtractElement(vecVal, 13);
-    Value *m32 = Builder.CreateExtractElement(vecVal, 14);
-    Value *m33 = Builder.CreateExtractElement(vecVal, 15);
-
-    Result = Determinant4x4(m00, m01, m02, m03,
-                            m10, m11, m12, m13,
-                            m20, m21, m22, m23,
-                            m30, m31, m32, m33,
-                            Builder);
-  } else {
-    DXASSERT(row == 1, "invalid matrix type");
-    Result = Builder.CreateExtractElement(Result, (uint64_t)0);
+  else if (RhsMatTy) {
+    LhsNumRows = 1;
+    LhsNumCols = LoweredLhs->getType()->getVectorNumElements();
+    RhsNumRows = RhsMatTy.getNumRows();
+    RhsNumCols = RhsMatTy.getNumColumns();
   }
-  determinantInst->replaceAllUsesWith(Result);
-  AddToDeadInsts(determinantInst);
+  else {
+    llvm_unreachable("mul intrinsic was identified as a matrix operation but neither operand is a matrix.");
+  }
+
+  DXASSERT(LhsNumCols == RhsNumRows, "Matrix mul intrinsic operands dimensions mismatch.");
+  HLMatrixType ResultMatTy(ElemTy, LhsNumRows, RhsNumCols);
+  unsigned AccCount = LhsNumCols;
+
+  // Get the multiply-and-add intrinsic function, we'll need it
+  IntrinsicOp MadOpcode = Unsigned ? IntrinsicOp::IOP_umad : IntrinsicOp::IOP_mad;
+  FunctionType *MadFuncTy = FunctionType::get(ElemTy, { Builder.getInt32Ty(), ElemTy, ElemTy, ElemTy }, false);
+  Function *MadFunc = GetOrCreateHLFunction(*m_pModule, MadFuncTy, HLOpcodeGroup::HLIntrinsic, (unsigned)MadOpcode);
+  Constant *MadOpcodeVal = Builder.getInt32((unsigned)MadOpcode);
+
+  // Perform the multiplication!
+  Value *Result = UndefValue::get(VectorType::get(ElemTy, LhsNumRows * RhsNumCols));
+  for (unsigned ResultRowIdx = 0; ResultRowIdx < ResultMatTy.getNumRows(); ++ResultRowIdx) {
+    for (unsigned ResultColIdx = 0; ResultColIdx < ResultMatTy.getNumColumns(); ++ResultColIdx) {
+      unsigned ResultElemIdx = ResultMatTy.getRowMajorIndex(ResultRowIdx, ResultColIdx);
+      Value *ResultElem = nullptr;
+
+      for (unsigned AccIdx = 0; AccIdx < AccCount; ++AccIdx) {
+        unsigned LhsElemIdx = HLMatrixType::getRowMajorIndex(ResultRowIdx, AccIdx, LhsNumRows, LhsNumCols);
+        unsigned RhsElemIdx = HLMatrixType::getRowMajorIndex(AccIdx, ResultColIdx, RhsNumRows, RhsNumCols);
+        Value* LhsElem = Builder.CreateExtractElement(LoweredLhs, static_cast<uint64_t>(LhsElemIdx));
+        Value* RhsElem = Builder.CreateExtractElement(LoweredRhs, static_cast<uint64_t>(RhsElemIdx));
+        if (ResultElem == nullptr) {
+          ResultElem = ElemTy->isFloatingPointTy()
+            ? Builder.CreateFMul(LhsElem, RhsElem)
+            : Builder.CreateMul(LhsElem, RhsElem);
+        }
+        else {
+          ResultElem = Builder.CreateCall(MadFunc, { MadOpcodeVal, LhsElem, RhsElem, ResultElem });
+        }
+      }
+
+      Result = Builder.CreateInsertElement(Result, ResultElem, static_cast<uint64_t>(ResultElemIdx));
+    }
+  }
+
+  return Result;
 }
 
-void HLMatrixLowerPass::TrivialMatReplace(Value *matVal,
-                                         Value *vecVal,
-                                         CallInst *matUseInst) {
-  CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
+Value *HLMatrixLowerPass::lowerHLTransposeIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
+  HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
+  Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
+  return MatTy.emitLoweredVectorRowToCol(LoweredVal, Builder);
+}
 
-  for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
-    if (matUseInst->getArgOperand(i) == matVal) {
-      vecUseInst->setArgOperand(i, vecVal);
-    }
+static Value *determinant2x2(Value *M00, Value *M01, Value *M10, Value *M11, IRBuilder<> &Builder) {
+  Value *Mul0 = Builder.CreateFMul(M00, M11);
+  Value *Mul1 = Builder.CreateFMul(M01, M10);
+  return Builder.CreateFSub(Mul0, Mul1);
 }
 
-static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned toRows, unsigned toCols) {
-  SmallVector<int, 16> castMask(toCols * toRows);
-  unsigned idx = 0;
-  for (unsigned r = 0; r < toRows; r++)
-    for (unsigned c = 0; c < toCols; c++)
-      castMask[idx++] = c * toRows + r;
-  return cast<Instruction>(
-    Builder.CreateShuffleVector(vecVal, vecVal, castMask));
+static Value *determinant3x3(Value *M00, Value *M01, Value *M02,
+    Value *M10, Value *M11, Value *M12,
+    Value *M20, Value *M21, Value *M22,
+    IRBuilder<> &Builder) {
+  Value *Det00 = determinant2x2(M11, M12, M21, M22, Builder);
+  Value *Det01 = determinant2x2(M10, M12, M20, M22, Builder);
+  Value *Det02 = determinant2x2(M10, M11, M20, M21, Builder);
+  Det00 = Builder.CreateFMul(M00, Det00);
+  Det01 = Builder.CreateFMul(M01, Det01);
+  Det02 = Builder.CreateFMul(M02, Det02);
+  Value *Result = Builder.CreateFSub(Det00, Det01);
+  Result = Builder.CreateFAdd(Result, Det02);
+  return Result;
 }
 
-void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
-                                              Value *vecVal,
-                                              CallInst *castInst,
-                                              bool bRowToCol,
-                                              bool bTranspose) {
-  unsigned col, row;
-  if (!bTranspose) {
-    GetMatrixInfo(castInst->getType(), col, row);
-    DXASSERT(castInst->getType() == matVal->getType(), "type must match");
-  } else {
-    unsigned castCol, castRow;
-    Type *castTy = GetMatrixInfo(castInst->getType(), castCol, castRow);
-    unsigned srcCol, srcRow;
-    Type *srcTy = GetMatrixInfo(matVal->getType(), srcCol, srcRow);
-    DXASSERT_LOCALVAR((castTy == srcTy), srcTy == castTy, "type must match");
-    DXASSERT(castCol == srcRow && castRow == srcCol, "col row must match");
-    col = srcCol;
-    row = srcRow;
-  }
+static Value *determinant4x4(Value *M00, Value *M01, Value *M02, Value *M03,
+    Value *M10, Value *M11, Value *M12, Value *M13,
+    Value *M20, Value *M21, Value *M22, Value *M23,
+    Value *M30, Value *M31, Value *M32, Value *M33,
+    IRBuilder<> &Builder) {
+  Value *Det00 = determinant3x3(M11, M12, M13, M21, M22, M23, M31, M32, M33, Builder);
+  Value *Det01 = determinant3x3(M10, M12, M13, M20, M22, M23, M30, M32, M33, Builder);
+  Value *Det02 = determinant3x3(M10, M11, M13, M20, M21, M23, M30, M31, M33, Builder);
+  Value *Det03 = determinant3x3(M10, M11, M12, M20, M21, M22, M30, M31, M32, Builder);
+  Det00 = Builder.CreateFMul(M00, Det00);
+  Det01 = Builder.CreateFMul(M01, Det01);
+  Det02 = Builder.CreateFMul(M02, Det02);
+  Det03 = Builder.CreateFMul(M03, Det03);
+  Value *Result = Builder.CreateFSub(Det00, Det01);
+  Result = Builder.CreateFAdd(Result, Det02);
+  Result = Builder.CreateFSub(Result, Det03);
+  return Result;
+}
 
-  DXASSERT(matToVecMap.count(castInst), "must have vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
-  // Create before vecUseInst to prevent instructions being inserted after uses.
-  IRBuilder<> Builder(vecUseInst);
+Value *HLMatrixLowerPass::lowerHLDeterminantIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
+  HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
+  DXASSERT_NOMSG(MatTy.getNumColumns() == MatTy.getNumRows());
 
-  if (bRowToCol)
-    std::swap(row, col);
-  Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
+  Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
 
-  // Replace vec cast function call with vecCast.
-  vecUseInst->replaceAllUsesWith(vecCast);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[castInst] = vecCast;
-}
+  // Extract all matrix elements
+  SmallVector<Value*, 16> Elems;
+  for (unsigned ElemIdx = 0; ElemIdx < MatTy.getNumElements(); ++ElemIdx)
+    Elems.emplace_back(Builder.CreateExtractElement(LoweredVal, static_cast<uint64_t>(ElemIdx)));
 
-void HLMatrixLowerPass::TranslateMatMatCast(Value *matVal,
-                                            Value *vecVal,
-                                            CallInst *castInst) {
-  unsigned toCol, toRow;
-  Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
-  unsigned fromCol, fromRow;
-  Type *FromEltTy = GetMatrixInfo(matVal->getType(), fromCol, fromRow);
-  unsigned fromSize = fromCol * fromRow;
-  unsigned toSize = toCol * toRow;
-  DXASSERT(fromSize >= toSize, "cannot extend matrix");
-  DXASSERT(matToVecMap.count(castInst), "must have vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
-
-  IRBuilder<> Builder(vecUseInst);
-  Instruction *vecCast = nullptr;
-
-  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
-
-  if (fromSize == toSize) {
-    vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecVal,
-                             Builder);
-  } else {
-    // shuf first
-    std::vector<int> castMask(toCol * toRow);
-    unsigned idx = 0;
-    for (unsigned r = 0; r < toRow; r++)
-      for (unsigned c = 0; c < toCol; c++) {
-        unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
-        castMask[idx++] = matIdx;
-      }
+  // Delegate to appropriate determinant function
+  switch (MatTy.getNumColumns()) {
+  case 1:
+    return Elems[0];
 
-    Instruction *shuf = cast<Instruction>(
-        Builder.CreateShuffleVector(vecVal, vecVal, castMask));
+  case 2:
+    return determinant2x2(
+      Elems[0], Elems[1],
+      Elems[2], Elems[3],
+      Builder);
 
-    if (ToEltTy != FromEltTy)
-      vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
-                               Builder);
-    else
-      vecCast = shuf;
-  }
-  // Replace vec cast function call with vecCast.
-  vecUseInst->replaceAllUsesWith(vecCast);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[castInst] = vecCast;
-}
+  case 3:
+    return determinant3x3(
+      Elems[0], Elems[1], Elems[2],
+      Elems[3], Elems[4], Elems[5],
+      Elems[6], Elems[7], Elems[8],
+      Builder);
 
-void HLMatrixLowerPass::TranslateMatToOtherCast(Value *matVal,
-                                                Value *vecVal,
-                                                CallInst *castInst) {
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
-  unsigned fromSize = col * row;
-
-  IRBuilder<> Builder(castInst);
-  Value *sizeCast = nullptr;
-
-  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
-
-  Type *ToTy = castInst->getType();
-  if (ToTy->isVectorTy()) {
-    unsigned toSize = ToTy->getVectorNumElements();
-    if (fromSize != toSize) {
-      std::vector<int> castMask(fromSize);
-      for (unsigned c = 0; c < toSize; c++)
-        castMask[c] = c;
-
-      sizeCast = Builder.CreateShuffleVector(vecVal, vecVal, castMask);
-    } else
-      sizeCast = vecVal;
-  } else {
-    DXASSERT(ToTy->isSingleValueType(), "must scalar here");
-    sizeCast = Builder.CreateExtractElement(vecVal, (uint64_t)0);
-  }
+  case 4:
+    return determinant4x4(
+      Elems[0], Elems[1], Elems[2], Elems[3],
+      Elems[4], Elems[5], Elems[6], Elems[7],
+      Elems[8], Elems[9], Elems[10], Elems[11],
+      Elems[12], Elems[13], Elems[14], Elems[15],
+      Builder);
 
-  Value *typeCast = sizeCast;
-  if (EltTy != ToTy->getScalarType()) {
-    typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
+  default:
+    llvm_unreachable("Unexpected matrix dimensions.");
   }
-  // Replace cast function call with typeCast.
-  castInst->replaceAllUsesWith(typeCast);
-  AddToDeadInsts(castInst);
 }
 
-void HLMatrixLowerPass::TranslateMatCast(Value *matVal,
-                                         Value *vecVal,
-                                         CallInst *castInst) {
-  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
-  if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
-      opcode == HLCastOpcode::RowMatrixToColMatrix) {
-    TranslateMatMajorCast(matVal, vecVal, castInst,
-                          opcode == HLCastOpcode::RowMatrixToColMatrix,
-                          /*bTranspose*/false);
-  } else {
-    bool ToMat = dxilutil::IsHLSLMatrixType(castInst->getType());
-    bool FromMat = dxilutil::IsHLSLMatrixType(matVal->getType());
-    if (ToMat && FromMat) {
-      TranslateMatMatCast(matVal, vecVal, castInst);
-    } else if (FromMat)
-      TranslateMatToOtherCast(matVal, vecVal, castInst);
+Value *HLMatrixLowerPass::lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder) {
+  Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
+  VectorType *VecTy = cast<VectorType>(LoweredVal->getType());
+  bool IsFloat = VecTy->getElementType()->isFloatingPointTy();
+  
+  switch (Opcode) {
+  case HLUnaryOpcode::Plus: return LoweredVal; // No-op
+
+  case HLUnaryOpcode::Minus:
+    return IsFloat
+      ? Builder.CreateFSub(Constant::getNullValue(VecTy), LoweredVal)
+      : Builder.CreateSub(Constant::getNullValue(VecTy), LoweredVal);
+
+  case HLUnaryOpcode::LNot:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_UEQ, LoweredVal, Constant::getNullValue(VecTy))
+      : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredVal, Constant::getNullValue(VecTy));
+
+  case HLUnaryOpcode::Not:
+    return Builder.CreateXor(LoweredVal, Constant::getAllOnesValue(VecTy));
+
+  case HLUnaryOpcode::PostInc:
+  case HLUnaryOpcode::PreInc:
+  case HLUnaryOpcode::PostDec:
+  case HLUnaryOpcode::PreDec: {
+    Constant *ScalarOne = IsFloat
+      ? ConstantFP::get(VecTy->getElementType(), 1)
+      : ConstantInt::get(VecTy->getElementType(), 1);
+    Constant *VecOne = ConstantVector::getSplat(VecTy->getNumElements(), ScalarOne);
+    // BUGBUG: This implementation has incorrect semantics (GitHub #1780)
+    if (Opcode == HLUnaryOpcode::PostInc || Opcode == HLUnaryOpcode::PreInc) {
+      return IsFloat
+        ? Builder.CreateFAdd(LoweredVal, VecOne)
+        : Builder.CreateAdd(LoweredVal, VecOne);
+    }
     else {
-      DXASSERT(0, "Not translate as user of matInst");
+      return IsFloat
+        ? Builder.CreateFSub(LoweredVal, VecOne)
+        : Builder.CreateSub(LoweredVal, VecOne);
     }
   }
-}
-
-void HLMatrixLowerPass::MatIntrinsicReplace(Value *matVal,
-                                            Value *vecVal,
-                                            CallInst *matUseInst) {
-  IRBuilder<> Builder(matUseInst);
-  IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
-  switch (opcode) {
-  case IntrinsicOp::IOP_umul:
-    TranslateMul(matVal, vecVal, matUseInst, /*isSigned*/false);
-    break;
-  case IntrinsicOp::IOP_mul:
-    TranslateMul(matVal, vecVal, matUseInst, /*isSigned*/true);
-    break;
-  case IntrinsicOp::IOP_transpose:
-    TranslateMatTranspose(matVal, vecVal, matUseInst);
-    break;
-  case IntrinsicOp::IOP_determinant:
-    TranslateMatDeterminant(matVal, vecVal, matUseInst);
-    break;
   default:
-    CallInst *useInst = matUseInst;
-    if (matToVecMap.count(matUseInst))
-      useInst = cast<CallInst>(matToVecMap[matUseInst]);
-    for (unsigned i = 0; i < useInst->getNumArgOperands(); i++) {
-      if (matUseInst->getArgOperand(i) == matVal)
-        useInst->setArgOperand(i, vecVal);
-    }
-    break;
+    llvm_unreachable("Unsupported unary matrix operator");
   }
 }
 
-void HLMatrixLowerPass::TranslateMatSubscript(Value *matVal, Value *vecVal,
-                                              CallInst *matSubInst) {
-  unsigned opcode = GetHLOpcode(matSubInst);
-  HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
-  assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
-         "matrix don't use default subscript");
-
-  Type *matType = matVal->getType()->getPointerElementType();
-  unsigned col, row;
-  Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
-
-  bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
-                   (matOpcode == HLSubscriptOpcode::RowMatElement);
-  Value *mask =
-      matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
-
-  if (isElement) {
-    Type *resultType = matSubInst->getType()->getPointerElementType();
-    unsigned resultSize = 1;
-    if (resultType->isVectorTy())
-      resultSize = resultType->getVectorNumElements();
-
-    std::vector<int> shufMask(resultSize);
-    Constant *EltIdxs = cast<Constant>(mask);
-    for (unsigned i = 0; i < resultSize; i++) {
-      shufMask[i] =
-          cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
-    }
+Value *HLMatrixLowerPass::lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder) {
+  Value *LoweredLhs = getLoweredByValOperand(Lhs, Builder);
+  Value *LoweredRhs = getLoweredByValOperand(Rhs, Builder);
 
-    for (Value::use_iterator CallUI = matSubInst->use_begin(),
-                             CallE = matSubInst->use_end();
-         CallUI != CallE;) {
-      Use &CallUse = *CallUI++;
-      Instruction *CallUser = cast<Instruction>(CallUse.getUser());
-      IRBuilder<> Builder(CallUser);
-      Value *vecLd = Builder.CreateLoad(vecVal);
-      if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
-        if (resultSize > 1) {
-          Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
-          ld->replaceAllUsesWith(shuf);
-        } else {
-          Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
-          ld->replaceAllUsesWith(elt);
-        }
-      } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
-        Value *val = st->getValueOperand();
-        if (resultSize > 1) {
-          for (unsigned i = 0; i < shufMask.size(); i++) {
-            unsigned idx = shufMask[i];
-            Value *valElt = Builder.CreateExtractElement(val, i);
-            vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
-          }
-          Builder.CreateStore(vecLd, vecVal);
-        } else {
-          vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
-          Builder.CreateStore(vecLd, vecVal);
-        }
-      } else {
-        DXASSERT(0, "matrix element should only used by load/store.");
-      }
-      AddToDeadInsts(CallUser);
-    }
-  } else {
-    // Subscript.
-    // Return a row.
-    // Use insertElement and extractElement.
-    ArrayType *AT = ArrayType::get(EltTy, col*row);
-
-    IRBuilder<> AllocaBuilder(
-        matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
-    Value *tempArray = AllocaBuilder.CreateAlloca(AT);
-    Value *zero = AllocaBuilder.getInt32(0);
-    bool isDynamicIndexing = !isa<ConstantInt>(mask);
-    SmallVector<Value *, 4> idxList;
-    for (unsigned i = 0; i < col; i++) {
-      idxList.emplace_back(
-          matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
-    }
+  DXASSERT(LoweredLhs->getType()->isVectorTy() && LoweredRhs->getType()->isVectorTy(),
+    "Expected lowered binary operation operands to be vectors");
+  DXASSERT(LoweredLhs->getType() == LoweredRhs->getType(),
+    "Expected lowered binary operation operands to have matching types.");
 
-    for (Value::use_iterator CallUI = matSubInst->use_begin(),
-                             CallE = matSubInst->use_end();
-         CallUI != CallE;) {
-      Use &CallUse = *CallUI++;
-      Instruction *CallUser = cast<Instruction>(CallUse.getUser());
-      IRBuilder<> Builder(CallUser);
-      Value *vecLd = Builder.CreateLoad(vecVal);
-      if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
-        Value *sub = UndefValue::get(ld->getType());
-        if (!isDynamicIndexing) {
-          for (unsigned i = 0; i < col; i++) {
-            Value *matIdx = idxList[i];
-            Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
-            sub = Builder.CreateInsertElement(sub, valElt, i);
-          }
-        } else {
-          // Copy vec to array.
-          for (unsigned int i = 0; i < row*col; i++) {
-            Value *Elt =
-                Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
-            Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
-                                                   {zero, Builder.getInt32(i)});
-            Builder.CreateStore(Elt, Ptr);
-          }
-          for (unsigned i = 0; i < col; i++) {
-            Value *matIdx = idxList[i];
-            Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
-            Value *valElt = Builder.CreateLoad(Ptr);
-            sub = Builder.CreateInsertElement(sub, valElt, i);
-          }
-        }
-        ld->replaceAllUsesWith(sub);
-      } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
-        Value *val = st->getValueOperand();
-        if (!isDynamicIndexing) {
-          for (unsigned i = 0; i < col; i++) {
-            Value *matIdx = idxList[i];
-            Value *valElt = Builder.CreateExtractElement(val, i);
-            vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
-          }
-        } else {
-          // Copy vec to array.
-          for (unsigned int i = 0; i < row * col; i++) {
-            Value *Elt =
-                Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
-            Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
-                                                   {zero, Builder.getInt32(i)});
-            Builder.CreateStore(Elt, Ptr);
-          }
-          // Update array.
-          for (unsigned i = 0; i < col; i++) {
-            Value *matIdx = idxList[i];
-            Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
-            Value *valElt = Builder.CreateExtractElement(val, i);
-            Builder.CreateStore(valElt, Ptr);
-          }
-          // Copy array to vec.
-          for (unsigned int i = 0; i < row * col; i++) {
-            Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
-                                                   {zero, Builder.getInt32(i)});
-            Value *Elt = Builder.CreateLoad(Ptr);
-            vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
-          }
-        }
-        Builder.CreateStore(vecLd, vecVal);
-      } else if (GetElementPtrInst *GEP =
-                     dyn_cast<GetElementPtrInst>(CallUser)) {
-        Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
-        Value *NewGEP = Builder.CreateGEP(vecVal, {zero, GEPOffset});
-        GEP->replaceAllUsesWith(NewGEP);
-      } else {
-        DXASSERT(0, "matrix subscript should only used by load/store.");
-      }
-      AddToDeadInsts(CallUser);
-    }
-  }
-  // Check vec version.
-  DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
-  // All the user should have been removed.
-  matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
-  AddToDeadInsts(matSubInst);
-}
+  bool IsFloat = LoweredLhs->getType()->getVectorElementType()->isFloatingPointTy();
 
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
-    Value *matGlobal, ArrayRef<Value *> vecGlobals,
-    CallInst *matLdStInst) {
-  // No dynamic indexing on matrix, flatten matrix to scalars.
-  // vecGlobals already in correct major.
-  Type *matType = matGlobal->getType()->getPointerElementType();
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *vecType = HLMatrixLower::LowerMatrixType(matType);
+  switch (Opcode) {
+  case HLBinaryOpcode::Add:
+    return IsFloat
+      ? Builder.CreateFAdd(LoweredLhs, LoweredRhs)
+      : Builder.CreateAdd(LoweredLhs, LoweredRhs);
 
-  IRBuilder<> Builder(matLdStInst);
+  case HLBinaryOpcode::Sub:
+    return IsFloat
+      ? Builder.CreateFSub(LoweredLhs, LoweredRhs)
+      : Builder.CreateSub(LoweredLhs, LoweredRhs);
 
-  HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
-  switch (opcode) {
-  case HLMatLoadStoreOpcode::ColMatLoad:
-  case HLMatLoadStoreOpcode::RowMatLoad: {
-    Value *Result = UndefValue::get(vecType);
-    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
-      Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
-      Result = Builder.CreateInsertElement(Result, Elt, matIdx);
-    }
-    matLdStInst->replaceAllUsesWith(Result);
-  } break;
-  case HLMatLoadStoreOpcode::ColMatStore:
-  case HLMatLoadStoreOpcode::RowMatStore: {
-    Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
-      Value *Elt = Builder.CreateExtractElement(Val, matIdx);
-      Builder.CreateStore(Elt, vecGlobals[matIdx]);
-    }
-  } break;
-  }
-}
+  case HLBinaryOpcode::Mul:
+    return IsFloat
+      ? Builder.CreateFMul(LoweredLhs, LoweredRhs)
+      : Builder.CreateMul(LoweredLhs, LoweredRhs);
 
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
-                                                      GlobalVariable *scalarArrayGlobal,
-                                                      CallInst *matLdStInst) {
-  // vecGlobals already in correct major.
-  HLMatLoadStoreOpcode opcode =
-      static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
-  switch (opcode) {
-  case HLMatLoadStoreOpcode::ColMatLoad:
-  case HLMatLoadStoreOpcode::RowMatLoad: {
-    IRBuilder<> Builder(matLdStInst);
-    Type *matTy = matGlobal->getType()->getPointerElementType();
-    unsigned col, row;
-    Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-    Value *zeroIdx = Builder.getInt32(0);
-
-    std::vector<Value *> matElts(col * row);
-
-    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
-      Value *GEP = Builder.CreateInBoundsGEP(
-          scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
-      matElts[matIdx] = Builder.CreateLoad(GEP);
-    }
+  case HLBinaryOpcode::Div:
+    return IsFloat
+      ? Builder.CreateFDiv(LoweredLhs, LoweredRhs)
+      : Builder.CreateSDiv(LoweredLhs, LoweredRhs);
 
-    Value *newVec =
-        HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
-    matLdStInst->replaceAllUsesWith(newVec);
-    matLdStInst->eraseFromParent();
-  } break;
-  case HLMatLoadStoreOpcode::ColMatStore:
-  case HLMatLoadStoreOpcode::RowMatStore: {
-    Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-
-    IRBuilder<> Builder(matLdStInst);
-    Type *matTy = matGlobal->getType()->getPointerElementType();
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(matTy, col, row);
-    Value *zeroIdx = Builder.getInt32(0);
-
-    std::vector<Value *> matElts(col * row);
-
-    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
-      Value *GEP = Builder.CreateInBoundsGEP(
-          scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
-      Value *Elt = Builder.CreateExtractElement(Val, matIdx);
-      Builder.CreateStore(Elt, GEP);
-    }
+  case HLBinaryOpcode::Rem:
+    return IsFloat
+      ? Builder.CreateFRem(LoweredLhs, LoweredRhs)
+      : Builder.CreateSRem(LoweredLhs, LoweredRhs);
 
-    matLdStInst->eraseFromParent();
-  } break;
-  }
-}
-void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
-    CallInst *matSubInst, Value *vecPtr) {
-  Value *basePtr =
-      matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-  Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
-  IRBuilder<> subBuilder(matSubInst);
-  Value *zeroIdx = subBuilder.getInt32(0);
-
-  HLSubscriptOpcode opcode =
-      static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
-
-  Type *matTy = basePtr->getType()->getPointerElementType();
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matTy, col, row);
-
-  std::vector<Value *> idxList;
-  switch (opcode) {
-  case HLSubscriptOpcode::ColMatSubscript:
-  case HLSubscriptOpcode::RowMatSubscript: {
-    // Just use index created in EmitHLSLMatrixSubscript.
-    for (unsigned c = 0; c < col; c++) {
-      Value *matIdx =
-          matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
-      idxList.emplace_back(matIdx);
-    }
-  } break;
-  case HLSubscriptOpcode::RowMatElement:
-  case HLSubscriptOpcode::ColMatElement: {
-    Type *resultType = matSubInst->getType()->getPointerElementType();
-    unsigned resultSize = 1;
-    if (resultType->isVectorTy())
-      resultSize = resultType->getVectorNumElements();
-    // Just use index created in EmitHLSLMatrixElement.
-    Constant *EltIdxs = cast<Constant>(idx);
-    for (unsigned i = 0; i < resultSize; i++) {
-      Value *matIdx = EltIdxs->getAggregateElement(i);
-      idxList.emplace_back(matIdx);
-    }
-  } break;
-  default:
-    DXASSERT(0, "invalid operation");
-    break;
-  }
+  case HLBinaryOpcode::And:
+    return Builder.CreateAnd(LoweredLhs, LoweredRhs);
 
-  // Cannot generate vector pointer
-  // Replace all uses with scalar pointers.
-  if (!matSubInst->getType()->getPointerElementType()->isVectorTy()) {
-    DXASSERT(idxList.size() == 1, "Expected a single matrix element index if the result is not a vector");
-    Value *Ptr =
-      subBuilder.CreateInBoundsGEP(vecPtr, { zeroIdx, idxList[0] });
-    matSubInst->replaceAllUsesWith(Ptr);
-  } else {
-    // Split the use of CI with Ptrs.
-    for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
-      Instruction *subsUser = cast<Instruction>(*(U++));
-      IRBuilder<> userBuilder(subsUser);
-      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-        Value *IndexPtr =
-            HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
-        Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
-                                                   {zeroIdx, IndexPtr});
-        for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
-          Instruction *gepUser = cast<Instruction>(*(gepU++));
-          IRBuilder<> gepUserBuilder(gepUser);
-          if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
-            Value *subData = stUser->getValueOperand();
-            gepUserBuilder.CreateStore(subData, Ptr);
-            stUser->eraseFromParent();
-          } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
-            Value *subData = gepUserBuilder.CreateLoad(Ptr);
-            ldUser->replaceAllUsesWith(subData);
-            ldUser->eraseFromParent();
-          } else {
-            AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
-            Cast->setOperand(0, Ptr);
-          }
-        }
-        GEP->eraseFromParent();
-      } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
-        Value *val = stUser->getValueOperand();
-        for (unsigned i = 0; i < idxList.size(); i++) {
-          Value *Elt = userBuilder.CreateExtractElement(val, i);
-          Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
-                                                     {zeroIdx, idxList[i]});
-          userBuilder.CreateStore(Elt, Ptr);
-        }
-        stUser->eraseFromParent();
-      } else {
+  case HLBinaryOpcode::Or:
+    return Builder.CreateOr(LoweredLhs, LoweredRhs);
 
-        Value *ldVal =
-            UndefValue::get(matSubInst->getType()->getPointerElementType());
-        for (unsigned i = 0; i < idxList.size(); i++) {
-          Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
-                                                     {zeroIdx, idxList[i]});
-          Value *Elt = userBuilder.CreateLoad(Ptr);
-          ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
-        }
-        // Must be load here.
-        LoadInst *ldUser = cast<LoadInst>(subsUser);
-        ldUser->replaceAllUsesWith(ldVal);
-        ldUser->eraseFromParent();
-      }
-    }
-  }
-  matSubInst->eraseFromParent();
-}
+  case HLBinaryOpcode::Xor:
+    return Builder.CreateXor(LoweredLhs, LoweredRhs);
 
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
-    CallInst *matLdStInst, Value *vecPtr) {
-  // Just translate into vector here.
-  // DynamicIndexingVectorToArray will change it to scalar array.
-  IRBuilder<> Builder(matLdStInst);
-  unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
-  HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
-  switch (matLdStOp) {
-  case HLMatLoadStoreOpcode::ColMatLoad:
-  case HLMatLoadStoreOpcode::RowMatLoad: {
-    // Load as vector.
-    Value *newLoad = Builder.CreateLoad(vecPtr);
+  case HLBinaryOpcode::Shl:
+    return Builder.CreateShl(LoweredLhs, LoweredRhs);
 
-    matLdStInst->replaceAllUsesWith(newLoad);
-    matLdStInst->eraseFromParent();
-  } break;
-  case HLMatLoadStoreOpcode::ColMatStore:
-  case HLMatLoadStoreOpcode::RowMatStore: {
-    // Change value to vector array, then store.
-    Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-
-    Value *vecArrayGep = vecPtr;
-    Builder.CreateStore(Val, vecArrayGep);
-    matLdStInst->eraseFromParent();
-  } break;
-  default:
-    DXASSERT(0, "invalid operation");
-    break;
-  }
-}
+  case HLBinaryOpcode::Shr:
+    return Builder.CreateAShr(LoweredLhs, LoweredRhs);
 
-// Flatten values inside init list to scalar elements.
-static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
-                            Value *val,
-                            DenseMap<Value *, Value *> &matToVecMap,
-                            IRBuilder<> &Builder) {
-  Type *valTy = val->getType();
-
-  if (valTy->isPointerTy()) {
-    if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
-      if (matToVecMap.count(cast<Instruction>(val))) {
-        val = matToVecMap[cast<Instruction>(val)];
-      } else {
-        // Convert to vec array with bitcast.
-        Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
-        val = Builder.CreateBitCast(val, vecArrayPtrTy);
-      }
-    }
-    Type *valEltTy = val->getType()->getPointerElementType();
-    if (valEltTy->isVectorTy() || dxilutil::IsHLSLMatrixType(valEltTy) ||
-        valEltTy->isSingleValueType()) {
-      Value *ldVal = Builder.CreateLoad(val);
-      IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
-    } else {
-      Type *i32Ty = Type::getInt32Ty(valTy->getContext());
-      Value *zero = ConstantInt::get(i32Ty, 0);
-      if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
-        for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
-          Value *gepIdx = ConstantInt::get(i32Ty, i);
-          Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
-          IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
-        }
-      } else {
-        // Struct.
-        StructType *ST = cast<StructType>(valEltTy);
-        for (unsigned i = 0; i < ST->getNumElements(); i++) {
-          Value *gepIdx = ConstantInt::get(i32Ty, i);
-          Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
-          IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
-        }
-      }
-    }
-  } else if (dxilutil::IsHLSLMatrixType(valTy)) {
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(valTy, col, row);
-    unsigned matSize = col * row;
-    val = matToVecMap[cast<Instruction>(val)];
-    // temp matrix all row major
-    for (unsigned i = 0; i < matSize; i++) {
-      Value *Elt = Builder.CreateExtractElement(val, i);
-      elts[idx + i] = Elt;
-    }
-    idx += matSize;
-  } else {
-    if (valTy->isVectorTy()) {
-      unsigned vecSize = valTy->getVectorNumElements();
-      for (unsigned i = 0; i < vecSize; i++) {
-        Value *Elt = Builder.CreateExtractElement(val, i);
-        elts[idx + i] = Elt;
-      }
-      idx += vecSize;
-    } else {
-      DXASSERT(valTy->isSingleValueType(), "must be single value type here");
-      elts[idx++] = val;
-    }
-  }
-}
+  case HLBinaryOpcode::LT:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OLT, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_SLT, LoweredLhs, LoweredRhs);
 
-void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
-  // Array matrix init will be translated in TranslateMatArrayInitReplace.
-  if (matInitInst->getType()->isVoidTy())
-    return;
+  case HLBinaryOpcode::GT:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OGT, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_SGT, LoweredLhs, LoweredRhs);
 
-  DXASSERT(matToVecMap.count(matInitInst), "must have vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
-  IRBuilder<> Builder(vecUseInst);
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
-
-  Type *vecTy = VectorType::get(EltTy, col * row);
-  unsigned vecSize = vecTy->getVectorNumElements();
-  unsigned idx = 0;
-  std::vector<Value *> elts(vecSize);
-  // Skip opcode arg.
-  for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
-    Value *val = matInitInst->getArgOperand(i);
-
-    IterateInitList(elts, idx, val, matToVecMap, Builder);
-  }
+  case HLBinaryOpcode::LE:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OLE, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_SLE, LoweredLhs, LoweredRhs);
 
-  Value *newInit = UndefValue::get(vecTy);
-  // InitList is row major, the result is row major too.
-  for (unsigned i=0;i< col * row;i++) {
-      Constant *vecIdx = Builder.getInt32(i);
-      newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
-      Builder.Insert(cast<Instruction>(newInit));
-  }
+  case HLBinaryOpcode::GE:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OGE, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_SGE, LoweredLhs, LoweredRhs);
 
-  // Replace matInit function call with matInitInst.
-  vecUseInst->replaceAllUsesWith(newInit);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[matInitInst] = newInit;
-}
+  case HLBinaryOpcode::EQ:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OEQ, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredLhs, LoweredRhs);
 
-void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
+  case HLBinaryOpcode::NE:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, LoweredRhs);
 
-  Type *vecTy = VectorType::get(EltTy, col * row);
-  unsigned vecSize = vecTy->getVectorNumElements();
+  case HLBinaryOpcode::UDiv:
+    return Builder.CreateUDiv(LoweredLhs, LoweredRhs);
 
-  CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
-  Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
-  Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
+  case HLBinaryOpcode::URem:
+    return Builder.CreateURem(LoweredLhs, LoweredRhs);
 
-  IRBuilder<> Builder(vecUseInst);
+  case HLBinaryOpcode::UShr:
+    return Builder.CreateLShr(LoweredLhs, LoweredRhs);
 
-  Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
-  bool isVecCond = Cond->getType()->isVectorTy();
-  if (isVecCond) {
-    Instruction *MatCond = cast<Instruction>(
-        matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
-    DXASSERT_NOMSG(matToVecMap.count(MatCond));
-    Cond = matToVecMap[MatCond];
+  case HLBinaryOpcode::ULT:
+    return Builder.CreateICmp(CmpInst::ICMP_ULT, LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::UGT:
+    return Builder.CreateICmp(CmpInst::ICMP_UGT, LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::ULE:
+    return Builder.CreateICmp(CmpInst::ICMP_ULE, LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::UGE:
+    return Builder.CreateICmp(CmpInst::ICMP_UGE, LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::LAnd:
+  case HLBinaryOpcode::LOr: {
+    Value *Zero = Constant::getNullValue(LoweredLhs->getType());
+    Value *LhsCmp = IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, Zero)
+      : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, Zero);
+    Value *RhsCmp = IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredRhs, Zero)
+      : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredRhs, Zero);
+    return Opcode == HLBinaryOpcode::LOr
+      ? Builder.CreateOr(LhsCmp, RhsCmp)
+      : Builder.CreateAnd(LhsCmp, RhsCmp);
   }
-  DXASSERT_NOMSG(matToVecMap.count(LHS));
-  Value *VLHS = matToVecMap[LHS];
-  DXASSERT_NOMSG(matToVecMap.count(RHS));
-  Value *VRHS = matToVecMap[RHS];
-
-  Value *VecSelect = UndefValue::get(vecTy);
-  for (unsigned i = 0; i < vecSize; i++) {
-    llvm::Value *EltCond = Cond;
-    if (isVecCond)
-      EltCond = Builder.CreateExtractElement(Cond, i);
-    llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
-    llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
-    llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
-    VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
+  default:
+    llvm_unreachable("Unsupported binary matrix operator");
   }
-  AddToDeadInsts(vecUseInst);
-  vecUseInst->replaceAllUsesWith(VecSelect);
-  matToVecMap[matSelectInst] = VecSelect;
 }
 
-void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
-                                             Value *vecVal,
-                                             GetElementPtrInst *matGEP) {
-  SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
-
-  IRBuilder<> GEPBuilder(matGEP);
-  Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecVal, idxList);
-  // Only used by mat subscript and mat ld/st.
-  for (Value::user_iterator user = matGEP->user_begin();
-       user != matGEP->user_end();) {
-    Instruction *useInst = cast<Instruction>(*(user++));
-    IRBuilder<> Builder(useInst);
-    // Skip return here.
-    if (isa<ReturnInst>(useInst))
-      continue;
-    if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
-      // Function call.
-      hlsl::HLOpcodeGroup group =
-          hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
-      switch (group) {
-      case HLOpcodeGroup::HLMatLoadStore: {
-        unsigned opcode = GetHLOpcode(useCall);
-        HLMatLoadStoreOpcode matOpcode =
-            static_cast<HLMatLoadStoreOpcode>(opcode);
-        switch (matOpcode) {
-        case HLMatLoadStoreOpcode::ColMatLoad:
-        case HLMatLoadStoreOpcode::RowMatLoad: {
-          // Skip the vector version.
-          if (useCall->getType()->isVectorTy())
-            continue;
-          Type *matTy = useCall->getType();
-          Value *newLd = CreateVecMatrixLoad(newGEP, matTy, Builder);
-          DXASSERT(matToVecMap.count(useCall), "must have vec version");
-          Value *oldLd = matToVecMap[useCall];
-          // Delete the oldLd.
-          AddToDeadInsts(cast<Instruction>(oldLd));
-          oldLd->replaceAllUsesWith(newLd);
-          matToVecMap[useCall] = newLd;
-        } break;
-        case HLMatLoadStoreOpcode::ColMatStore:
-        case HLMatLoadStoreOpcode::RowMatStore: {
-          Value *vecPtr = newGEP;
-          
-          Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-          // Skip the vector version.
-          if (matVal->getType()->isVectorTy()) {
-            AddToDeadInsts(useCall);
-            continue;
-          }
+Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode) {
+  IRBuilder<> Builder(Call);
+  switch (Opcode) {
+  case HLMatLoadStoreOpcode::RowMatLoad:
+  case HLMatLoadStoreOpcode::ColMatLoad:
+    return lowerHLLoad(Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
+      /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
 
-          Instruction *matInst = cast<Instruction>(matVal);
+  case HLMatLoadStoreOpcode::RowMatStore:
+  case HLMatLoadStoreOpcode::ColMatStore:
+    return lowerHLStore(
+      Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
+      Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
+      /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
+      /* Return */ !Call->getType()->isVoidTy(), Builder);
 
-          DXASSERT(matToVecMap.count(matInst), "must have vec version");
-          Value *vecVal = matToVecMap[matInst];
-          CreateVecMatrixStore(vecVal, vecPtr, matVal->getType(), Builder);
-        } break;
-        }
-      } break;
-      case HLOpcodeGroup::HLSubscript: {
-        TranslateMatSubscript(matGEP, newGEP, useCall);
-      } break;
-      default:
-        DXASSERT(0, "invalid operation");
-        break;
-      }
-    } else if (dyn_cast<BitCastInst>(useInst)) {
-      // Just replace the src with vec version.
-      useInst->setOperand(0, newGEP);
-    } else {
-      // Must be GEP.
-      GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
-      TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
-    }
+  default:
+    llvm_unreachable("Unsupported matrix load/store operation");
   }
-  AddToDeadInsts(matGEP);
 }
 
-Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
-  Value *newMatVal = nullptr;
-  if (vecToMatMap.count(vecVal)) {
-    newMatVal = vecToMatMap[vecVal];
-  } else {
-    // create conversion instructions if necessary, caching result for subsequent replacements.
-    // do so right after the vecVal def so it's available to all potential uses.
-    newMatVal = BitCastValueOrPtr(vecVal,
-      cast<Instruction>(vecVal)->getNextNode(), // vecVal must be instruction
-      matTy,
-      /*bOrigAllocaTy*/true,
-      vecVal->getName());
-    vecToMatMap[vecVal] = newMatVal;
-  }
-  return newMatVal;
-}
+Value *HLMatrixLowerPass::lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
+  HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
 
-void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
-                                          Value *vecVal) {
-  for (Value::user_iterator user = matVal->user_begin();
-       user != matVal->user_end();) {
-    Instruction *useInst = cast<Instruction>(*(user++));
-    // User must be function call.
-    if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
-      hlsl::HLOpcodeGroup group =
-          hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
-      switch (group) {
-      case HLOpcodeGroup::HLIntrinsic: {
-        if (CallInst *matCI = dyn_cast<CallInst>(matVal)) {
-          MatIntrinsicReplace(matCI, vecVal, useCall);
-        } else {
-          IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(useCall));
-          if (opcode == IntrinsicOp::MOP_Append) {
-            // Replace matrix with vector representation and update intrinsic signature
-            // We don't care about matrix orientation here, since that will need to be
-            // taken into account anyways when generating the store output calls.
-            SmallVector<Value *, 4> flatArgs;
-            SmallVector<Type *, 4> flatParamTys;
-            for (Value *arg : useCall->arg_operands()) {
-              Value *flagArg = arg == matVal ? vecVal : arg;
-              flatArgs.emplace_back(arg == matVal ? vecVal : arg);
-              flatParamTys.emplace_back(flagArg->getType());
-            }
-
-            // Don't need flat return type for Append.
-            FunctionType *flatFuncTy =
-              FunctionType::get(useInst->getType(), flatParamTys, false);
-            Function *flatF = GetOrCreateHLFunction(*m_pModule, flatFuncTy, group, static_cast<unsigned int>(opcode));
-            
-            // Append returns void, so the old call should have no users
-            DXASSERT(useInst->getType()->isVoidTy(), "Unexpected MOP_Append intrinsic return type");
-            DXASSERT(useInst->use_empty(), "Unexpected users of MOP_Append intrinsic return value");
-            IRBuilder<> Builder(useCall);
-            Builder.CreateCall(flatF, flatArgs);
-            AddToDeadInsts(useCall);
-          }
-          else if (opcode == IntrinsicOp::IOP_frexp) {
-            // NOTE: because out param use copy out semantic, so the operand of
-            // out must be temp alloca.
-            DXASSERT(isa<AllocaInst>(matVal), "else invalid mat ptr for frexp");
-            auto it = matToVecMap.find(useCall);
-            DXASSERT(it != matToVecMap.end(),
-              "else fail to create vec version of useCall");
-            CallInst *vecUseInst = cast<CallInst>(it->second);
-
-            for (unsigned i = 0; i < vecUseInst->getNumArgOperands(); i++) {
-              if (useCall->getArgOperand(i) == matVal) {
-                vecUseInst->setArgOperand(i, vecVal);
-              }
-            }
-          } else {
-            DXASSERT(false, "Unexpected matrix user intrinsic.");
-          }
-        }
-      } break;
-      case HLOpcodeGroup::HLSelect: {
-        MatIntrinsicReplace(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLBinOp: {
-        TrivialMatBinOpReplace(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLUnOp: {
-        TrivialMatUnOpReplace(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLCast: {
-        TranslateMatCast(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLMatLoadStore: {
-        DXASSERT(matToVecMap.count(useCall), "must have vec version");
-        Value *vecUser = matToVecMap[useCall];
-        if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal) ||
-            GetIfMatrixGEPOfUDTArg(matVal, *m_pHLModule)) {
-          // Load Already translated in lowerToVec.
-          // Store val operand will be set by the val use.
-          // Do nothing here.
-        } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser)) {
-          DXASSERT(vecVal->getType() == stInst->getValueOperand()->getType(),
-            "Mismatched vector matrix store value types.");
-          stInst->setOperand(0, vecVal);
-        } else if (ZExtInst *zextInst = dyn_cast<ZExtInst>(vecUser)) {
-          // This happens when storing bool matrices,
-          // which must first undergo conversion from i1's to i32's.
-          DXASSERT(vecVal->getType() == zextInst->getOperand(0)->getType(),
-            "Mismatched vector matrix store value types.");
-          zextInst->setOperand(0, vecVal);
-        } else
-          TrivialMatReplace(matVal, vecVal, useCall);
-
-      } break;
-      case HLOpcodeGroup::HLSubscript: {
-        if (AllocaInst *AI = dyn_cast<AllocaInst>(vecVal))
-          TranslateMatSubscript(matVal, vecVal, useCall);
-        else if (BitCastInst *BCI = dyn_cast<BitCastInst>(vecVal))
-          TranslateMatSubscript(matVal, vecVal, useCall);
-        else
-          TrivialMatReplace(matVal, vecVal, useCall);
-
-      } break;
-      case HLOpcodeGroup::HLInit: {
-        DXASSERT(!isa<AllocaInst>(matVal), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
-        TranslateMatInit(useCall);
-      } break;
-      case HLOpcodeGroup::NotHL: {
-        castMatrixArgs(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLExtIntrinsic:
-      case HLOpcodeGroup::HLCreateHandle:
-      case HLOpcodeGroup::NumOfHLOps:
-      // No vector equivalents for these ops.
-        break;
-      }
-    } else if (dyn_cast<BitCastInst>(useInst)) {
-      // Just replace the src with vec version.
-      if (useInst != vecVal)
-        useInst->setOperand(0, vecVal);
-    } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
-      Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
-      RI->setOperand(0, newMatVal);
-    } else if (isa<StoreInst>(useInst)) {
-      DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
-    } else {
-      // Must be GEP on mat array alloca.
-      GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
-      AllocaInst *AI = cast<AllocaInst>(matVal);
-      TranslateMatArrayGEP(AI, vecVal, GEP);
-    }
+  Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
+  if (LoweredPtr == nullptr) {
+    // Can't lower this here, defer to HL signature lower
+    HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
+    return callHLFunction(
+      *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
+      MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr }, Builder);
   }
-}
 
-void HLMatrixLowerPass::castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI) {
-  // translate user function parameters as necessary
-  Type *Ty = matVal->getType();
-  if (Ty->isPointerTy()) {
-    IRBuilder<> Builder(CI);
-    Value *newMatVal = Builder.CreateBitCast(vecVal, Ty);
-    CI->replaceUsesOfWith(matVal, newMatVal);
-  } else {
-    Value *newMatVal = GetMatrixForVec(vecVal, Ty);
-    CI->replaceUsesOfWith(matVal, newMatVal);
-  }
+  return MatTy.emitLoweredLoad(LoweredPtr, Builder);
 }
 
-void HLMatrixLowerPass::finalMatTranslation(Value *matVal) {
-  // Translate matInit.
-  if (CallInst *CI = dyn_cast<CallInst>(matVal)) {
-    hlsl::HLOpcodeGroup group =
-        hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
-    switch (group) {
-    case HLOpcodeGroup::HLInit: {
-      TranslateMatInit(CI);
-    } break;
-    case HLOpcodeGroup::HLSelect: {
-      TranslateMatSelect(CI);
-    } break;
-    default:
-      // Skip group already translated.
-      break;
-    }
-  }
-}
+Value *HLMatrixLowerPass::lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder) {
+  DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
+    "Matrix store value/pointer type mismatch.");
 
-void HLMatrixLowerPass::DeleteDeadInsts() {
-  // Delete the matrix version insts.
-  for (Instruction *deadInst : m_deadInsts) {
-    // Replace with undef and remove it.
-    deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
-    deadInst->eraseFromParent();
+  Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
+  Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
+  if (LoweredPtr == nullptr) {
+    // Can't lower the pointer here, defer to HL signature lower
+    HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatStore : HLMatLoadStoreOpcode::ColMatStore;
+    return callHLFunction(
+      *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
+      Return ? LoweredVal->getType() : Builder.getVoidTy(),
+      { Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal }, Builder);
   }
-  m_deadInsts.clear();
-  m_inDeadInstsSet.clear();
+
+  HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
+  StoreInst *LoweredStore = MatTy.emitLoweredStore(LoweredVal, LoweredPtr, Builder);
+
+  // If the intrinsic returned a value, return the stored lowered value
+  return Return ? LoweredVal : LoweredStore;
 }
 
-static bool OnlyUsedByMatrixLdSt(Value *V) {
-  bool onlyLdSt = true;
-  for (User *user : V->users()) {
-    if (isa<Constant>(user) && user->use_empty())
-      continue;
+static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> Builder) {
+  DXASSERT(SrcVal->getType()->isVectorTy() == DstTy->isVectorTy(),
+    "Scalar/vector type mismatch in numerical conversion.");
+  Type *SrcTy = SrcVal->getType();
 
-    CallInst *CI = cast<CallInst>(user);
-    if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
-        HLOpcodeGroup::HLMatLoadStore)
-      continue;
+  // Conversions between equivalent types are no-ops,
+  // even between signed/unsigned variants.
+  if (SrcTy == DstTy) return SrcVal;
 
-    onlyLdSt = false;
-    break;
+  // Conversions to bools are comparisons
+  if (DstTy->getScalarSizeInBits() == 1) {
+    // fcmp une is what regular clang uses in C++ for (bool)f;
+    return cast<Instruction>(SrcTy->isIntOrIntVectorTy()
+      ? Builder.CreateICmpNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
+      : Builder.CreateFCmpUNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool"));
   }
-  return onlyLdSt;
-}
 
-static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
-  if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
-    std::vector<Constant *> Elts;
-    Type *EltResultTy = AT->getElementType();
-    for (unsigned i = 0; i < AT->getNumElements(); i++) {
-      Constant *Elt =
-          LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
-      Elts.emplace_back(Elt);
-    }
-    return ConstantArray::get(AT, Elts);
-  } else {
-    // Cast float[row][col] -> float< row * col>.
-    // Get float[row][col] from the struct.
-    Constant *rows = MA->getAggregateElement((unsigned)0);
-    ArrayType *RowAT = cast<ArrayType>(rows->getType());
-    std::vector<Constant *> Elts;
-    for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
-      Constant *row = rows->getAggregateElement(r);
-      VectorType *VT = cast<VectorType>(row->getType());
-      for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
-        Elts.emplace_back(row->getAggregateElement(c));
-      }
+  // Cast necessary
+  bool SrcIsUnsigned = Opcode == HLCastOpcode::FromUnsignedCast ||
+    Opcode == HLCastOpcode::UnsignedUnsignedCast;
+  bool DstIsUnsigned = Opcode == HLCastOpcode::ToUnsignedCast ||
+    Opcode == HLCastOpcode::UnsignedUnsignedCast;
+  auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
+    SrcTy, SrcIsUnsigned, DstTy, DstIsUnsigned));
+  return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
+}
+
+Value *HLMatrixLowerPass::lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder) {
+  // The opcode really doesn't mean much here, the types involved are what drive most of the casting.
+  DXASSERT(Opcode != HLCastOpcode::HandleToResCast, "Unexpected matrix cast opcode.");
+
+  if (dxilutil::IsIntegerOrFloatingPointType(Src->getType())) {
+    // Scalar to matrix splat
+    HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
+
+    // Apply element conversion
+    Value *Result = convertScalarOrVector(Src,
+      MatDstTy.getElementTypeForReg(), Opcode, Builder);
+
+    // Splat to a vector
+    Result = Builder.CreateInsertElement(
+      UndefValue::get(VectorType::get(Result->getType(), 1)),
+      Result, static_cast<uint64_t>(0));
+    return Builder.CreateShuffleVector(Result, Result,
+      ConstantVector::getSplat(MatDstTy.getNumElements(), Builder.getInt32(0)));
+  }
+  else if (VectorType *SrcVecTy = dyn_cast<VectorType>(Src->getType())) {
+    // Vector to matrix
+    HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
+    Value *Result = Src;
+
+    // We might need to truncate
+    if (MatDstTy.getNumElements() < SrcVecTy->getNumElements()) {
+      SmallVector<int, 4> ShuffleIndices;
+      for (unsigned Idx = 0; Idx < MatDstTy.getNumElements(); ++Idx)
+        ShuffleIndices.emplace_back(static_cast<int>(Idx));
+      Result = Builder.CreateShuffleVector(Src, Src, ShuffleIndices);
     }
-    return ConstantVector::get(Elts);
-  }
-}
 
-void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
-  // Lower to array of vector array like float[row * col].
-  // It's follow the major of decl.
-  // DynamicIndexingVectorToArray will change it to scalar array.
-  Type *Ty = GV->getType()->getPointerElementType();
-  std::vector<unsigned> arraySizeList;
-  while (Ty->isArrayTy()) {
-    arraySizeList.push_back(Ty->getArrayNumElements());
-    Ty = Ty->getArrayElementType();
-  }
-  unsigned row, col;
-  Type *EltTy = GetMatrixInfo(Ty, col, row);
-  Ty = VectorType::get(EltTy, col * row);
-
-  for (auto arraySize = arraySizeList.rbegin();
-       arraySize != arraySizeList.rend(); arraySize++)
-    Ty = ArrayType::get(Ty, *arraySize);
-
-  Type *VecArrayTy = Ty;
-  Constant *InitVal = nullptr;
-  if (GV->hasInitializer()) {
-    Constant *OldInitVal = GV->getInitializer();
-    InitVal = isa<UndefValue>(OldInitVal)
-      ? UndefValue::get(VecArrayTy)
-      : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
+    // Apply element conversion
+    return convertScalarOrVector(Result,
+      MatDstTy.getLoweredVectorTypeForReg(), Opcode, Builder);
   }
 
-  bool isConst = GV->isConstant();
-  GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
-  unsigned AddressSpace = GV->getType()->getAddressSpace();
-  GlobalValue::LinkageTypes linkage = GV->getLinkage();
+  // Source must now be a matrix
+  HLMatrixType MatSrcTy = HLMatrixType::cast(Src->getType());
+  VectorType* LoweredSrcTy = MatSrcTy.getLoweredVectorTypeForReg();
 
-  Module *M = GV->getParent();
-  GlobalVariable *VecGV =
-      new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
-                               /*InitVal*/ InitVal, GV->getName() + ".v",
-                               /*InsertBefore*/ nullptr, TLMode, AddressSpace);
-  // Add debug info.
-  if (m_HasDbgInfo) {
-    DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
-    HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
+  Value *LoweredSrc;
+  if (isa<Argument>(Src)) {
+    // Function arguments are lowered in HLSignatureLower.
+    // Initial codegen first generates those cast intrinsics to tell us how to lower them into vectors.
+    // Preserve them, but change the return type to vector.
+    DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast || Opcode == HLCastOpcode::RowMatrixToVecCast,
+      "Unexpected cast of matrix argument.");
+    LoweredSrc = callHLFunction(*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
+      LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src }, Builder);
   }
-
-  for (User *U : GV->users()) {
-    Value *VecGEP = nullptr;
-    // Must be GEP or GEPOperator.
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
-      IRBuilder<> Builder(GEP);
-      SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
-      VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
-      AddToDeadInsts(GEP);
-    } else {
-      GEPOperator *GEPOP = cast<GEPOperator>(U);
-      IRBuilder<> Builder(GV->getContext());
-      SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
-      VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
+  else {
+    LoweredSrc = getLoweredByValOperand(Src, Builder);
+  }
+  DXASSERT_NOMSG(LoweredSrc->getType() == LoweredSrcTy);
+
+  Value* Result = LoweredSrc;
+  Type* LoweredDstTy = DstTy;
+  if (dxilutil::IsIntegerOrFloatingPointType(DstTy)) {
+    // Matrix to scalar
+    Result = Builder.CreateExtractElement(LoweredSrc, static_cast<uint64_t>(0));
+  }
+  else if (DstTy->isVectorTy()) {
+    // Matrix to vector
+    VectorType *DstVecTy = cast<VectorType>(DstTy);
+    DXASSERT(DstVecTy->getNumElements() <= LoweredSrcTy->getNumElements(),
+      "Cannot cast matrix to a larger vector.");
+
+    // We might have to truncate
+    if (DstTy->getVectorNumElements() < LoweredSrcTy->getNumElements()) {
+      SmallVector<int, 3> ShuffleIndices;
+      for (unsigned Idx = 0; Idx < DstVecTy->getNumElements(); ++Idx)
+        ShuffleIndices.emplace_back(static_cast<int>(Idx));
+      Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
     }
-
-    for (auto user = U->user_begin(); user != U->user_end();) {
-      CallInst *CI = cast<CallInst>(*(user++));
-      HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-      if (group == HLOpcodeGroup::HLMatLoadStore) {
-        TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
-      } else if (group == HLOpcodeGroup::HLSubscript) {
-        TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
-      } else {
-        DXASSERT(0, "invalid operation");
-      }
+  }
+  else {
+    // Destination must now be a matrix too
+    HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
+
+    // Apply any changes at the matrix level: orientation changes and truncation
+    if (Opcode == HLCastOpcode::ColMatrixToRowMatrix)
+      Result = MatSrcTy.emitLoweredVectorColToRow(Result, Builder);
+    else if (Opcode == HLCastOpcode::RowMatrixToColMatrix)
+      Result = MatSrcTy.emitLoweredVectorRowToCol(Result, Builder);
+    else if (MatDstTy.getNumRows() != MatSrcTy.getNumRows()
+      || MatDstTy.getNumColumns() != MatSrcTy.getNumColumns()) {
+      // Apply truncation
+      DXASSERT(MatDstTy.getNumRows() <= MatSrcTy.getNumRows()
+        && MatDstTy.getNumColumns() <= MatSrcTy.getNumColumns(),
+        "Unexpected matrix cast between incompatible dimensions.");
+      SmallVector<int, 16> ShuffleIndices;
+      for (unsigned RowIdx = 0; RowIdx < MatDstTy.getNumRows(); ++RowIdx)
+        for (unsigned ColIdx = 0; ColIdx < MatDstTy.getNumColumns(); ++ColIdx)
+          ShuffleIndices.emplace_back(static_cast<int>(MatSrcTy.getRowMajorIndex(RowIdx, ColIdx)));
+      Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
     }
+
+    LoweredDstTy = MatDstTy.getLoweredVectorTypeForReg();
+    DXASSERT(Result->getType()->getVectorNumElements() == LoweredDstTy->getVectorNumElements(),
+      "Unexpected matrix src/dst lowered element count mismatch after truncation.");
   }
 
-  DeleteDeadInsts();
-  GV->removeDeadConstantUsers();
-  GV->eraseFromParent();
+  // Apply element conversion
+  return convertScalarOrVector(Result, LoweredDstTy, Opcode, Builder);
 }
 
-static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
-  unsigned row, col;
-  Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
-  if (isa<UndefValue>(M)) {
-    Constant *Elt = UndefValue::get(EltTy);
-    for (unsigned i=0;i<col*row;i++)
-      Elts.emplace_back(Elt);
-  } else {
-    M = M->getAggregateElement((unsigned)0);
-    // Initializer is already in correct major.
-    // Just read it here.
-    // The type is vector<element, col>[row].
-    for (unsigned r = 0; r < row; r++) {
-      Constant *C = M->getAggregateElement(r);
-      for (unsigned c = 0; c < col; c++) {
-        Elts.emplace_back(C->getAggregateElement(c));
-      }
-    }
+Value *HLMatrixLowerPass::lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
+  switch (Opcode) {
+  case HLSubscriptOpcode::RowMatElement:
+  case HLSubscriptOpcode::ColMatElement:
+    return lowerHLMatElementSubscript(Call,
+      /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatElement);
+
+  case HLSubscriptOpcode::RowMatSubscript:
+  case HLSubscriptOpcode::ColMatSubscript:
+    return lowerHLMatSubscript(Call,
+      /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatSubscript);
+
+  case HLSubscriptOpcode::DefaultSubscript:
+  case HLSubscriptOpcode::CBufferSubscript:
+    // Those get lowered during HLOperationLower,
+    // and the return type must stay unchanged (as a matrix)
+    // to provide the metadata to properly emit the loads.
+    return nullptr;
+
+  default:
+    llvm_unreachable("Unexpected matrix subscript opcode.");
   }
 }
 
-void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
-  if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
-    runOnGlobalMatrixArray(GV);
-    return;
-  }
+Value *HLMatrixLowerPass::lowerHLMatElementSubscript(CallInst *Call, bool RowMajor) {
+  (void)RowMajor; // It doesn't look like we actually need this?
 
-  Type *Ty = GV->getType()->getPointerElementType();
-  if (!dxilutil::IsHLSLMatrixType(Ty))
-    return;
+  Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
+  Constant *IdxVec = cast<Constant>(Call->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
+  VectorType *IdxVecTy = cast<VectorType>(IdxVec->getType());
 
-  bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
-
-  bool isConst = GV->isConstant();
-
-  Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
-  Module *M = GV->getParent();
-  const DataLayout &DL = M->getDataLayout();
-
-  std::vector<Constant *> Elts;
-  // Lower to vector or array for scalar matrix.
-  // Make it col major so don't need shuffle when load/store.
-  FlattenMatConst(GV->getInitializer(), Elts);
-
-  if (onlyLdSt) {
-    Type *EltTy = vecTy->getVectorElementType();
-    unsigned vecSize = vecTy->getVectorNumElements();
-    std::vector<Value *> vecGlobals(vecSize);
-
-    GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
-    unsigned AddressSpace = GV->getType()->getAddressSpace();
-    GlobalValue::LinkageTypes linkage = GV->getLinkage();
-    unsigned debugOffset = 0;
-    unsigned size = DL.getTypeAllocSizeInBits(EltTy);
-    unsigned align = DL.getPrefTypeAlignment(EltTy);
-    for (int i = 0, e = vecSize; i != e; ++i) {
-      Constant *InitVal = Elts[i];
-      GlobalVariable *EltGV = new llvm::GlobalVariable(
-          *M, EltTy, /*IsConstant*/ isConst, linkage,
-          /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
-          /*InsertBefore*/nullptr,
-          TLMode, AddressSpace);
-      // Add debug info.
-      if (m_HasDbgInfo) {
-        DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
-        HLModule::CreateElementGlobalVariableDebugInfo(
-            GV, Finder, EltGV, size, align, debugOffset,
-            EltGV->getName().ltrim(GV->getName()));
-        debugOffset += size;
-      }
-      vecGlobals[i] = EltGV;
-    }
-    for (User *user : GV->users()) {
-      if (isa<Constant>(user) && user->use_empty())
-        continue;
-      CallInst *CI = cast<CallInst>(user);
-      TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
-      AddToDeadInsts(CI);
-    }
-    DeleteDeadInsts();
-    GV->eraseFromParent();
+  // Get the loaded lowered vector element indices
+  SmallVector<Value*, 4> ElemIndices;
+  ElemIndices.reserve(IdxVecTy->getNumElements());
+  for (unsigned VecIdx = 0; VecIdx < IdxVecTy->getNumElements(); ++VecIdx) {
+    ElemIndices.emplace_back(IdxVec->getAggregateElement(VecIdx));
   }
-  else {
-    // lower to array of scalar here.
-    ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
-    Constant *InitVal = ConstantArray::get(AT, Elts);
-    GlobalVariable *arrayMat = new llvm::GlobalVariable(
-      *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
-      /*InitVal*/ InitVal, GV->getName());
-    // Add debug info.
-    if (m_HasDbgInfo) {
-      DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
-      HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
-                                                     arrayMat);
-    }
 
-    for (auto U = GV->user_begin(); U != GV->user_end();) {
-      Value *user = *(U++);
-      CallInst *CI = cast<CallInst>(user);
-      HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-      if (group == HLOpcodeGroup::HLMatLoadStore) {
-        TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
-      }
-      else {
-        DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
-        TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
-      }
-    }
-    GV->removeDeadConstantUsers();
-    GV->eraseFromParent();
-  }
+  lowerHLMatSubscript(Call, MatPtr, ElemIndices);
+
+  // We did our own replacement of uses, opt-out of having the caller does it for us.
+  return nullptr;
 }
 
-void HLMatrixLowerPass::runOnFunction(Function &F) {
-  // Skip hl function definition (like createhandle)
-  if (hlsl::GetHLOpcodeGroupByName(&F) != HLOpcodeGroup::NotHL)
-    return;
+Value *HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, bool RowMajor) {
+  (void)RowMajor; // It doesn't look like we actually need this?
 
-  // Create vector version of matrix instructions first.
-  // The matrix operands will be undefval for these instructions.
-  for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
-    BasicBlock *BB = BBI;
-    for (auto II = BB->begin(); II != BB->end(); ) {
-      Instruction &I = *(II++);
-      if (dxilutil::IsHLSLMatrixType(I.getType())) {
-        lowerToVec(&I);
-      } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
-        Type *Ty = AI->getAllocatedType();
-        if (dxilutil::IsHLSLMatrixType(Ty)) {
-          lowerToVec(&I);
-        } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
-          lowerToVec(&I);
-        }
-      } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
-        HLOpcodeGroup group =
-            hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
-        if (group == HLOpcodeGroup::HLMatLoadStore) {
-          HLMatLoadStoreOpcode opcode =
-              static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
-          DXASSERT_LOCALVAR(opcode,
-                            opcode == HLMatLoadStoreOpcode::ColMatStore ||
-                            opcode == HLMatLoadStoreOpcode::RowMatStore,
-                            "Must MatStore here, load will go IsMatrixType path");
-          // Lower it here to make sure it is ready before replace.
-          lowerToVec(&I);
-        }
-      } else if (GetIfMatrixGEPOfUDTAlloca(&I) ||
-                 GetIfMatrixGEPOfUDTArg(&I, *m_pHLModule)) {
-        lowerToVec(&I);
-      }
-    }
-  }
+  Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
 
-  // Update the use of matrix inst with the vector version.
-  for (auto matToVecIter = matToVecMap.begin();
-       matToVecIter != matToVecMap.end();) {
-    auto matToVec = matToVecIter++;
-    replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
+  // Gather the indices, checking if they are all constant
+  SmallVector<Value*, 4> ElemIndices;
+  for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < Call->getNumArgOperands(); ++Idx) {
+    ElemIndices.emplace_back(Call->getArgOperand(Idx));
   }
 
-  // Translate mat inst which require all operands ready.
-  for (auto matToVecIter = matToVecMap.begin();
-       matToVecIter != matToVecMap.end();) {
-    auto matToVec = matToVecIter++;
-    if (isa<Instruction>(matToVec->first))
-      finalMatTranslation(matToVec->first);
-  }
+  lowerHLMatSubscript(Call, MatPtr, ElemIndices);
 
-  // Remove matrix targets of vecToMatMap from matToVecMap before adding the rest to dead insts.
-  for (auto &it : vecToMatMap) {
-    matToVecMap.erase(it.second);
-  }
+  // We did our own replacement of uses, opt-out of having the caller does it for us.
+  return nullptr;
+}
 
-  // Delete the matrix version insts.
-  for (auto matToVecIter = matToVecMap.begin();
-       matToVecIter != matToVecMap.end();) {
-    auto matToVec = matToVecIter++;
-    // Add to m_deadInsts.
-    if (Instruction *matInst = dyn_cast<Instruction>(matToVec->first))
-      AddToDeadInsts(matInst);
-  }
+void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices) {
+  DXASSERT_NOMSG(HLMatrixType::isMatrixPtr(MatPtr->getType()));
 
-  DeleteDeadInsts();
+  IRBuilder<> CallBuilder(Call);
+  Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
+  if (LoweredPtr == nullptr) return;
+
+  // For global variables, we can GEP directly into the lowered vector pointer.
+  // This is necessary to support group shared memory atomics and the likes.
+  Value *RootPtr = LoweredPtr;
+  while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
+    RootPtr = GEP->getPointerOperand();
+  bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
   
-  matToVecMap.clear();
-  vecToMatMap.clear();
+  // Just constructing this does all the work
+  HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
 
-  return;
+  DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
+  addToDeadInsts(Call);
 }
 
-// Matrix Bitcast lower.
-// After linking Lower matrix bitcast patterns like:
-//  %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
-//  %conv.i = fptoui float %164 to i32
-//  %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
-//  %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
+// Lowers StructuredBuffer<matrix>[index] or similar with constant buffers
+Value *HLMatrixLowerPass::lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
+  // Just replace the intrinsic by its equivalent with a lowered return type
+  IRBuilder<> Builder(Call);
 
-namespace {
+  SmallVector<Value*, 4> Args;
+  Args.reserve(Call->getNumArgOperands());
+  for (Value *Arg : Call->arg_operands())
+    Args.emplace_back(Arg);
 
-Type *TryLowerMatTy(Type *Ty) {
-  Type *VecTy = nullptr;
-  if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
-    VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
-  } else if (isa<PointerType>(Ty) &&
-             dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
-    VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
-        Ty->getPointerElementType());
-    VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
-  }
-  return VecTy;
+  Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
+  return callHLFunction(*m_pModule, HLOpcodeGroup::HLSubscript, static_cast<unsigned>(Opcode),
+    LoweredRetTy, Args, Builder);
 }
 
-class MatrixBitcastLowerPass : public FunctionPass {
+Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
+  DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
 
-public:
-  static char ID; // Pass identification, replacement for typeid
-  explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
-
-  const char *getPassName() const override { return "Matrix Bitcast lower"; }
-  bool runOnFunction(Function &F) override {
-    bool bUpdated = false;
-    std::unordered_set<BitCastInst*> matCastSet;
-    for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
-      BasicBlock *BB = blkIt;
-      for (auto iIt = BB->begin(); iIt != BB->end(); ) {
-        Instruction *I = (iIt++);
-        if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
-          // Mutate mat to vec.
-          Type *ToTy = BCI->getType();
-          if (Type *ToVecTy = TryLowerMatTy(ToTy)) {
-            matCastSet.insert(BCI);
-            bUpdated = true;
-          }
-        }
-      }
-    }
+  // Figure out the result type
+  HLMatrixType MatTy = HLMatrixType::cast(Call->getType());
+  VectorType *LoweredTy = MatTy.getLoweredVectorTypeForReg();
 
-    DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
-    // Remove bitcast which has CallInst user.
-    if (DM.GetShaderModel()->IsLib()) {
-      for (auto it = matCastSet.begin(); it != matCastSet.end();) {
-        BitCastInst *BCI = *(it++);
-        if (hasCallUser(BCI)) {
-          matCastSet.erase(BCI);
-        }
-      }
-    }
-
-    // Lower matrix first.
-    for (BitCastInst *BCI : matCastSet) {
-      lowerMatrix(BCI, BCI->getOperand(0));
-    }
-    return bUpdated;
+  // Handle case where produced by EmitHLSLFlatConversion where there's one
+  // vector argument, instead of scalar arguments.
+  if (1 == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx &&
+      Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx)->
+              getType()->isVectorTy()) {
+    Value *LoweredVec = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
+    DXASSERT(LoweredTy->getNumElements() ==
+                LoweredVec->getType()->getVectorNumElements(),
+             "Invalid matrix init argument vector element count.");
+    return LoweredVec;
   }
-private:
-  void lowerMatrix(Instruction *M, Value *A);
-  bool hasCallUser(Instruction *M);
-};
 
-}
+  DXASSERT(LoweredTy->getNumElements() == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx,
+    "Invalid matrix init argument count.");
 
-bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
-  for (auto it = M->user_begin(); it != M->user_end();) {
-    User *U = *(it++);
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
-      Type *EltTy = GEP->getType()->getPointerElementType();
-      if (dxilutil::IsHLSLMatrixType(EltTy)) {
-        if (hasCallUser(GEP))
-          return true;
-      } else {
-        DXASSERT(0, "invalid GEP for matrix");
-      }
-    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
-      if (hasCallUser(BCI))
-        return true;
-    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
-      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
-      } else {
-        DXASSERT(0, "invalid load for matrix");
-      }
-    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
-      Value *V = ST->getValueOperand();
-      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
-      } else {
-        DXASSERT(0, "invalid load for matrix");
-      }
-    } else if (isa<CallInst>(U)) {
-      return true;
-    } else {
-      DXASSERT(0, "invalid use of matrix");
-    }
+  // Build the result vector from the init args.
+  // Both the args and the result vector are in row-major order, so no shuffling is necessary.
+  IRBuilder<> Builder(Call);
+  Value *LoweredVec = UndefValue::get(LoweredTy);
+  for (unsigned VecElemIdx = 0; VecElemIdx < LoweredTy->getNumElements(); ++VecElemIdx) {
+    Value *ArgVal = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx + VecElemIdx);
+    DXASSERT(dxilutil::IsIntegerOrFloatingPointType(ArgVal->getType()),
+      "Expected only scalars in matrix initialization.");
+    LoweredVec = Builder.CreateInsertElement(LoweredVec, ArgVal, static_cast<uint64_t>(VecElemIdx));
   }
-  return false;
-}
 
-namespace {
-Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
-                    IRBuilder<> &Builder) {
-  Value *GEP = nullptr;
-  if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
-    // A should be gep oneDimArray, 0, index * matSize
-    // Here add eltIdx to index * matSize foreach elt.
-    Instruction *EltGEP = GEPA->clone();
-    unsigned eltIdx = EltGEP->getNumOperands() - 1;
-    Value *NewIdx =
-        Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
-    EltGEP->setOperand(eltIdx, NewIdx);
-    Builder.Insert(EltGEP);
-    GEP = EltGEP;
-  } else {
-    GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
-  }
-  return GEP;
+  return LoweredVec;
 }
-} // namespace
-
-void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
-  for (auto it = M->user_begin(); it != M->user_end();) {
-    User *U = *(it++);
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
-      Type *EltTy = GEP->getType()->getPointerElementType();
-      if (dxilutil::IsHLSLMatrixType(EltTy)) {
-        // Change gep matrixArray, 0, index
-        // into
-        //   gep oneDimArray, 0, index * matSize
-        IRBuilder<> Builder(GEP);
-        SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
-        DXASSERT(idxList.size() == 2,
-                 "else not one dim matrix array index to matrix");
-        unsigned col = 0;
-        unsigned row = 0;
-        HLMatrixLower::GetMatrixInfo(EltTy, col, row);
-        Value *matSize = Builder.getInt32(col * row);
-        idxList.back() = Builder.CreateMul(idxList.back(), matSize);
-        Value *NewGEP = Builder.CreateGEP(A, idxList);
-        lowerMatrix(GEP, NewGEP);
-        DXASSERT(GEP->user_empty(), "else lower matrix fail");
-        GEP->eraseFromParent();
-      } else {
-        DXASSERT(0, "invalid GEP for matrix");
-      }
-    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
-      lowerMatrix(BCI, A);
-      DXASSERT(BCI->user_empty(), "else lower matrix fail");
-      BCI->eraseFromParent();
-    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
-      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
-        IRBuilder<> Builder(LI);
-        Value *zeroIdx = Builder.getInt32(0);
-        unsigned vecSize = Ty->getNumElements();
-        Value *NewVec = UndefValue::get(LI->getType());
-        for (unsigned i = 0; i < vecSize; i++) {
-          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
-          Value *Elt = Builder.CreateLoad(GEP);
-          NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
-        }
-        LI->replaceAllUsesWith(NewVec);
-        LI->eraseFromParent();
-      } else {
-        DXASSERT(0, "invalid load for matrix");
-      }
-    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
-      Value *V = ST->getValueOperand();
-      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
-        IRBuilder<> Builder(LI);
-        Value *zeroIdx = Builder.getInt32(0);
-        unsigned vecSize = Ty->getNumElements();
-        for (unsigned i = 0; i < vecSize; i++) {
-          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
-          Value *Elt = Builder.CreateExtractElement(V, i);
-          Builder.CreateStore(Elt, GEP);
-        }
-        ST->eraseFromParent();
-      } else {
-        DXASSERT(0, "invalid load for matrix");
-      }
-    } else {
-      DXASSERT(0, "invalid use of matrix");
-    }
+
+Value *HLMatrixLowerPass::lowerHLSelect(CallInst *Call) {
+  DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
+
+  Value *Cond = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
+  Value *TrueMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx);
+  Value *FalseMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx);
+
+  DXASSERT(TrueMat->getType() == FalseMat->getType(),
+    "Unexpected type mismatch between matrix ternary operator values.");
+
+#ifndef NDEBUG
+  // Assert that if the condition is a matrix, it matches the dimensions of the values
+  if (HLMatrixType MatCondTy = HLMatrixType::dyn_cast(Cond->getType())) {
+    HLMatrixType ValMatTy = HLMatrixType::cast(TrueMat->getType());
+    DXASSERT(MatCondTy.getNumRows() == ValMatTy.getNumRows()
+      && MatCondTy.getNumColumns() == ValMatTy.getNumColumns(),
+      "Unexpected mismatch between ternary operator condition and value matrix dimensions.");
   }
-}
+#endif
 
-#include "dxc/HLSL/DxilGenerationPass.h"
-char MatrixBitcastLowerPass::ID = 0;
-FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
+  IRBuilder<> Builder(Call);
+  Value *LoweredCond = getLoweredByValOperand(Cond, Builder);
+  Value *LoweredTrueVec = getLoweredByValOperand(TrueMat, Builder);
+  Value *LoweredFalseVec = getLoweredByValOperand(FalseMat, Builder);
+  Value *Result = UndefValue::get(LoweredTrueVec->getType());
 
-INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)
+  bool IsScalarCond = !LoweredCond->getType()->isVectorTy();
+
+  unsigned NumElems = Result->getType()->getVectorNumElements();
+  for (uint64_t ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
+    Value *ElemCond = IsScalarCond ? LoweredCond
+      : Builder.CreateExtractElement(LoweredCond, ElemIdx);
+    Value *ElemTrueVal = Builder.CreateExtractElement(LoweredTrueVec, ElemIdx);
+    Value *ElemFalseVal = Builder.CreateExtractElement(LoweredFalseVec, ElemIdx);
+    Value *ResultElem = Builder.CreateSelect(ElemCond, ElemTrueVal, ElemFalseVal);
+    Result = Builder.CreateInsertElement(Result, ResultElem, ElemIdx);
+  }
+
+  return Result;
+}

+ 284 - 0
lib/HLSL/HLMatrixSubscriptUseReplacer.cpp

@@ -0,0 +1,284 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixSubscriptUseReplacer.cpp                                          //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "HLMatrixSubscriptUseReplacer.h"
+#include "dxc/DXIL/DxilUtil.h"
+#include "dxc/Support/Global.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+HLMatrixSubscriptUseReplacer::HLMatrixSubscriptUseReplacer(CallInst* Call, Value *LoweredPtr,
+  SmallVectorImpl<Value*> &ElemIndices, bool AllowLoweredPtrGEPs, std::vector<Instruction*> &DeadInsts)
+  : LoweredPtr(LoweredPtr), ElemIndices(ElemIndices), DeadInsts(DeadInsts), AllowLoweredPtrGEPs(AllowLoweredPtrGEPs)
+{
+  HasScalarResult = !Call->getType()->getPointerElementType()->isVectorTy();
+
+  for (Value *ElemIdx : ElemIndices) {
+    if (!isa<Constant>(ElemIdx)) {
+      HasDynamicElemIndex = true;
+      break;
+    }
+  }
+
+  replaceUses(Call, /* GEPIdx */ nullptr);
+}
+
+void HLMatrixSubscriptUseReplacer::replaceUses(Instruction* PtrInst, Value* SubIdxVal) {
+  // We handle any number of load/stores of the subscript,
+  // whether through a GEP or not, but there should really only be one.
+  while (!PtrInst->use_empty()) {
+    llvm::Use &Use = *PtrInst->use_begin();
+    Instruction *UserInst = cast<Instruction>(Use.getUser());
+
+    bool DeleteUserInst = true;
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UserInst)) {
+      // Recurse on GEPs
+      DXASSERT(GEP->getNumIndices() >= 1 && GEP->getNumIndices() <= 2,
+        "Unexpected GEP on constant matrix subscript.");
+      DXASSERT(cast<ConstantInt>(GEP->idx_begin()->get())->isZero(),
+        "Unexpected nonzero first index of constant matrix subscript GEP.");
+
+      Value *NewSubIdxVal = SubIdxVal;
+      if (GEP->getNumIndices() == 2) {
+        DXASSERT(!HasScalarResult && SubIdxVal == nullptr,
+          "Unexpected GEP on matrix subscript scalar value.");
+        NewSubIdxVal = (GEP->idx_begin() + 1)->get();
+      }
+
+      replaceUses(GEP, NewSubIdxVal);
+    }
+    else {
+      IRBuilder<> UserBuilder(UserInst);
+
+      if (Value *ScalarElemIdx = tryGetScalarIndex(SubIdxVal, UserBuilder)) {
+        // We are accessing a scalar element
+        if (AllowLoweredPtrGEPs) {
+          // Simply make the instruction point to the element in the lowered pointer
+          DeleteUserInst = false;
+          Value *ElemPtr = UserBuilder.CreateGEP(LoweredPtr, { UserBuilder.getInt32(0), ScalarElemIdx });
+          Use.set(ElemPtr);
+        }
+        else {
+          bool IsDynamicIndex = !isa<Constant>(ScalarElemIdx);
+          cacheLoweredMatrix(IsDynamicIndex, UserBuilder);
+          if (LoadInst *Load = dyn_cast<LoadInst>(UserInst)) {
+            Value *Elem = loadElem(ScalarElemIdx, UserBuilder);
+            Load->replaceAllUsesWith(Elem);
+          }
+          else if (StoreInst *Store = dyn_cast<StoreInst>(UserInst)) {
+            storeElem(ScalarElemIdx, Store->getValueOperand(), UserBuilder);
+            flushLoweredMatrix(UserBuilder);
+          }
+          else {
+            llvm_unreachable("Unexpected matrix subscript use.");
+          }
+        }
+      }
+      else {
+        // We are accessing a vector given by ElemIndices
+        cacheLoweredMatrix(HasDynamicElemIndex, UserBuilder);
+        if (LoadInst *Load = dyn_cast<LoadInst>(UserInst)) {
+          Value *Vector = loadVector(UserBuilder);
+          Load->replaceAllUsesWith(Vector);
+        }
+        else if (StoreInst *Store = dyn_cast<StoreInst>(UserInst)) {
+          storeVector(Store->getValueOperand(), UserBuilder);
+          flushLoweredMatrix(UserBuilder);
+        }
+        else {
+          llvm_unreachable("Unexpected matrix subscript use.");
+        }
+      }
+    }
+
+    // We replaced this use, mark it dead
+    if (DeleteUserInst) {
+      DXASSERT(UserInst->use_empty(), "Matrix subscript user should be dead at this point.");
+      Use.set(UndefValue::get(Use->getType()));
+      DeadInsts.emplace_back(UserInst);
+    }
+  }
+}
+
+Value *HLMatrixSubscriptUseReplacer::tryGetScalarIndex(Value *SubIdxVal, IRBuilder<> &Builder) {
+  if (SubIdxVal == nullptr) {
+    // mat[0] case, returns a vector
+    if (!HasScalarResult) return nullptr;
+
+    // mat._11 case
+    DXASSERT_NOMSG(ElemIndices.size() == 1);
+    return ElemIndices[0];
+  }
+
+  if (ConstantInt *SubIdxConst = dyn_cast<ConstantInt>(SubIdxVal)) {
+    // mat[0][0], mat[i][0] or mat._11_12[0] cases.
+    uint64_t SubIdx = SubIdxConst->getLimitedValue();
+    DXASSERT(SubIdx < ElemIndices.size(), "Unexpected out of range constant matrix subindex.");
+    return ElemIndices[SubIdx];
+  }
+
+  // mat[0][j] or mat[i][j] case.
+  // We need to dynamically index into the level 1 element indices
+  if (LazyTempElemIndicesArrayAlloca == nullptr) {
+    // The level 2 index is dynamic, use it to index a temporary array of the level 1 indices.
+    IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
+    ArrayType *ArrayTy = ArrayType::get(AllocaBuilder.getInt32Ty(), ElemIndices.size());
+    LazyTempElemIndicesArrayAlloca = AllocaBuilder.CreateAlloca(ArrayTy);
+  }
+
+  // Store level 1 indices in the temporary array
+  Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
+  for (unsigned SubIdx = 0; SubIdx < ElemIndices.size(); ++SubIdx) {
+    GEPIndices[1] = Builder.getInt32(SubIdx);
+    Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemIndicesArrayAlloca, GEPIndices);
+    Builder.CreateStore(ElemIndices[SubIdx], TempArrayElemPtr);
+  }
+
+  // Dynamically index using the subindex
+  GEPIndices[1] = SubIdxVal;
+  Value *ElemIdxPtr = Builder.CreateGEP(LazyTempElemIndicesArrayAlloca, GEPIndices);
+  return Builder.CreateLoad(ElemIdxPtr);
+}
+
+// Unless we are allowed to GEP directly into the lowered matrix,
+// we must load the vector in memory in order to read or write any elements.
+// If we're going to dynamically index, we need to copy the vector into a temporary array.
+// Further loadElem/storeElem calls depend on how we cached the matrix here.
+void HLMatrixSubscriptUseReplacer::cacheLoweredMatrix(bool ForDynamicIndexing, IRBuilder<> &Builder) {
+  // If we can GEP right into the lowered pointer, no need for caching
+  if (AllowLoweredPtrGEPs) return;
+
+  // Load without memory to register representation conversion,
+  // since the point is to mimic pointer semantics
+  TempLoweredMatrix = Builder.CreateLoad(LoweredPtr);
+
+  if (!ForDynamicIndexing) return;
+
+  // To handle mat[i] cases, we need to copy the matrix elements to
+  // an array which we can dynamically index.
+  VectorType *MatVecTy = cast<VectorType>(TempLoweredMatrix->getType());
+
+  // Lazily create the temporary array alloca
+  if (LazyTempElemArrayAlloca == nullptr) {
+    ArrayType *TempElemArrayTy = ArrayType::get(MatVecTy->getElementType(), MatVecTy->getNumElements());
+    IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
+    LazyTempElemArrayAlloca = AllocaBuilder.CreateAlloca(TempElemArrayTy);
+  }
+
+  // Copy the matrix elements to the temporary array
+  Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
+  for (unsigned ElemIdx = 0; ElemIdx < MatVecTy->getNumElements(); ++ElemIdx) {
+    Value *VecElem = Builder.CreateExtractElement(TempLoweredMatrix, static_cast<uint64_t>(ElemIdx));
+    GEPIndices[1] = Builder.getInt32(ElemIdx);
+    Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemArrayAlloca, GEPIndices);
+    Builder.CreateStore(VecElem, TempArrayElemPtr);
+  }
+
+  // Null out the vector form so we know to use the array
+  TempLoweredMatrix = nullptr;
+}
+
+Value *HLMatrixSubscriptUseReplacer::loadElem(Value *Idx, IRBuilder<> &Builder) {
+  if (AllowLoweredPtrGEPs) {
+    Value *ElemPtr = Builder.CreateGEP(LoweredPtr, { Builder.getInt32(0), Idx });
+    return Builder.CreateLoad(ElemPtr);
+  }
+  else if (TempLoweredMatrix == nullptr) {
+    DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
+
+    Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemArrayAlloca, { Builder.getInt32(0), Idx });
+    return Builder.CreateLoad(TempArrayElemPtr);
+  }
+  else {
+    DXASSERT_NOMSG(isa<ConstantInt>(Idx));
+    return Builder.CreateExtractElement(TempLoweredMatrix, Idx);
+  }
+}
+
+void HLMatrixSubscriptUseReplacer::storeElem(Value *Idx, Value *Elem, IRBuilder<> &Builder) {
+  if (AllowLoweredPtrGEPs) {
+    Value *ElemPtr = Builder.CreateGEP(LoweredPtr, { Builder.getInt32(0), Idx });
+    Builder.CreateStore(Elem, ElemPtr);
+  }
+  else if (TempLoweredMatrix == nullptr) {
+    DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
+
+    Value *GEPIndices[2] = { Builder.getInt32(0), Idx };
+    Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemArrayAlloca, GEPIndices);
+    Builder.CreateStore(Elem, TempArrayElemPtr);
+  }
+  else {
+    DXASSERT_NOMSG(isa<ConstantInt>(Idx));
+    TempLoweredMatrix = Builder.CreateInsertElement(TempLoweredMatrix, Elem, Idx);
+  }
+}
+
+Value *HLMatrixSubscriptUseReplacer::loadVector(IRBuilder<> &Builder) {
+  if (TempLoweredMatrix != nullptr) {
+    // We can optimize this as a shuffle
+    SmallVector<Constant*, 4> ShuffleIndices;
+    ShuffleIndices.reserve(ElemIndices.size());
+    for (Value *ElemIdx : ElemIndices)
+      ShuffleIndices.emplace_back(cast<Constant>(ElemIdx));
+    Constant* ShuffleVector = ConstantVector::get(ShuffleIndices);
+    return Builder.CreateShuffleVector(TempLoweredMatrix, TempLoweredMatrix, ShuffleVector);
+  }
+
+  // Otherwise load elements one by one
+  Type* ElemTy = LoweredPtr->getType()->getPointerElementType()->getScalarType();
+  VectorType *VecTy = VectorType::get(ElemTy, static_cast<unsigned>(ElemIndices.size()));
+  Value *Result = UndefValue::get(VecTy);
+  for (unsigned SubIdx = 0; SubIdx < ElemIndices.size(); ++SubIdx) {
+    Value *Elem = loadElem(ElemIndices[SubIdx], Builder);
+    Result = Builder.CreateInsertElement(Result, Elem, static_cast<uint64_t>(SubIdx));
+  }
+
+  return Result;
+}
+
+void HLMatrixSubscriptUseReplacer::storeVector(Value *Vec, IRBuilder<> &Builder) {
+  // We can't shuffle vectors of different sizes together, so insert one by one.
+  DXASSERT(Vec->getType()->getVectorNumElements() == ElemIndices.size(),
+    "Matrix subscript stored vector element count mismatch.");
+
+  for (unsigned SubIdx = 0; SubIdx < ElemIndices.size(); ++SubIdx) {
+    Value *Elem = Builder.CreateExtractElement(Vec, static_cast<uint64_t>(SubIdx));
+    storeElem(ElemIndices[SubIdx], Elem, Builder);
+  }
+}
+
+void HLMatrixSubscriptUseReplacer::flushLoweredMatrix(IRBuilder<> &Builder) {
+  // If GEPs are allowed, no flushing is necessary, we modified the source elements directly.
+  if (AllowLoweredPtrGEPs) return;
+
+  if (TempLoweredMatrix == nullptr) {
+    // First re-create the vector from the temporary array
+    DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
+
+    VectorType *LoweredMatrixTy = cast<VectorType>(LoweredPtr->getType()->getPointerElementType());
+    TempLoweredMatrix = UndefValue::get(LoweredMatrixTy);
+    Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
+    for (unsigned ElemIdx = 0; ElemIdx < LoweredMatrixTy->getNumElements(); ++ElemIdx) {
+      GEPIndices[1] = Builder.getInt32(ElemIdx);
+      Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemArrayAlloca, GEPIndices);
+      Value *NewElem = Builder.CreateLoad(TempArrayElemPtr);
+      TempLoweredMatrix = Builder.CreateInsertElement(TempLoweredMatrix, NewElem, static_cast<uint64_t>(ElemIdx));
+    }
+  }
+
+  // Store back the lowered matrix to its pointer
+  Builder.CreateStore(TempLoweredMatrix, LoweredPtr);
+  TempLoweredMatrix = nullptr;
+}

+ 67 - 0
lib/HLSL/HLMatrixSubscriptUseReplacer.h

@@ -0,0 +1,67 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixSubscriptUseReplacer.h                                            //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include <vector>
+
+namespace llvm {
+class Value;
+class AllocaInst;
+class CallInst;
+class Instruction;
+class Function;
+} // namespace llvm
+
+namespace hlsl {
+// Implements recursive replacement of a matrix subscript's uses,
+// from a pointer to a matrix value to a pointer to its lowered vector version,
+// whether directly or through GEPs in the case of two-level indexing like mat[i][j].
+// This has to handle one or two levels of indices, each of which either
+// constant or dynamic: mat[0], mat[i], mat[0][0], mat[i][0], mat[0][j], mat[i][j],
+// plus the equivalent element accesses: mat._11, mat._11_12, mat._11_12[0], mat._11_12[i]
+class HLMatrixSubscriptUseReplacer {
+public:
+  // The constructor does everything
+  HLMatrixSubscriptUseReplacer(llvm::CallInst* Call, llvm::Value *LoweredPtr,
+    llvm::SmallVectorImpl<llvm::Value*> &ElemIndices, bool AllowLoweredPtrGEPs,
+    std::vector<llvm::Instruction*> &DeadInsts);
+
+private:
+  void replaceUses(llvm::Instruction* PtrInst, llvm::Value* SubIdxVal);
+  llvm::Value *tryGetScalarIndex(llvm::Value *SubIdxVal, llvm::IRBuilder<> &Builder);
+  void cacheLoweredMatrix(bool ForDynamicIndexing, llvm::IRBuilder<> &Builder);
+  llvm::Value *loadElem(llvm::Value *Idx, llvm::IRBuilder<> &Builder);
+  void storeElem(llvm::Value *Idx, llvm::Value *Elem, llvm::IRBuilder<> &Builder);
+  llvm::Value *loadVector(llvm::IRBuilder<> &Builder);
+  void storeVector(llvm::Value *Vec, llvm::IRBuilder<> &Builder);
+  void flushLoweredMatrix(llvm::IRBuilder<> &Builder);
+
+private:
+  llvm::Value *LoweredPtr;
+  llvm::SmallVectorImpl<llvm::Value*> &ElemIndices;
+  std::vector<llvm::Instruction*> &DeadInsts;
+  bool AllowLoweredPtrGEPs = false;
+  bool HasScalarResult = false;
+  bool HasDynamicElemIndex = false;
+
+  // The entire lowered matrix as loaded from LoweredPtr,
+  // nullptr if we copied it to a temporary array.
+  llvm::Value *TempLoweredMatrix = nullptr;
+
+  // We allocate this if the level 1 indices are not all constants,
+  // so we can dynamically index the lowered matrix vector.
+  llvm::AllocaInst *LazyTempElemArrayAlloca = nullptr;
+
+  // We'll allocate this lazily if we have a dynamic level 2 index (mat[0][i]),
+  // so we can dynamically index the level 1 indices.
+  llvm::AllocaInst *LazyTempElemIndicesArrayAlloca = nullptr;
+};
+} // namespace hlsl

+ 179 - 0
lib/HLSL/HLMatrixType.cpp

@@ -0,0 +1,179 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixType.cpp                                                          //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/HLSL/HLMatrixType.h"
+#include "dxc/Support/Global.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Value.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+HLMatrixType::HLMatrixType(Type *RegReprElemTy, unsigned NumRows, unsigned NumColumns)
+  : RegReprElemTy(RegReprElemTy), NumRows(NumRows), NumColumns(NumColumns) {
+  DXASSERT(RegReprElemTy != nullptr && (RegReprElemTy->isIntegerTy() || RegReprElemTy->isFloatingPointTy()),
+    "Invalid matrix element type.");
+  DXASSERT(NumRows >= 1 && NumRows <= 4 && NumColumns >= 1 && NumColumns <= 4,
+    "Invalid matrix dimensions.");
+}
+
+Type *HLMatrixType::getElementType(bool MemRepr) const {
+  // Bool i1s become i32s
+  return MemRepr && RegReprElemTy->isIntegerTy(1)
+    ? IntegerType::get(RegReprElemTy->getContext(), 32)
+    : RegReprElemTy;
+}
+
+unsigned HLMatrixType::getRowMajorIndex(unsigned RowIdx, unsigned ColIdx) const {
+  return getRowMajorIndex(RowIdx, ColIdx, NumRows, NumColumns);
+}
+
+unsigned HLMatrixType::getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx) const {
+  return getColumnMajorIndex(RowIdx, ColIdx, NumRows, NumColumns);
+}
+
+unsigned HLMatrixType::getRowMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns) {
+  DXASSERT_NOMSG(RowIdx < NumRows && ColIdx < NumColumns);
+  return RowIdx * NumColumns + ColIdx;
+}
+
+unsigned HLMatrixType::getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns) {
+  DXASSERT_NOMSG(RowIdx < NumRows && ColIdx < NumColumns);
+  return ColIdx * NumRows + RowIdx;
+}
+
+VectorType *HLMatrixType::getLoweredVectorType(bool MemRepr) const {
+  return VectorType::get(getElementType(MemRepr), getNumElements());
+}
+
+Value *HLMatrixType::emitLoweredMemToReg(Value *Val, IRBuilder<> &Builder) const {
+  DXASSERT(Val->getType()->getScalarType() == getElementTypeForMem(), "Lowered matrix type mismatch.");
+  if (RegReprElemTy->isIntegerTy(1)) {
+    Val = Builder.CreateICmpNE(Val, Constant::getNullValue(Val->getType()), "tobool");
+  }
+  return Val;
+}
+
+Value *HLMatrixType::emitLoweredRegToMem(Value *Val, IRBuilder<> &Builder) const {
+  DXASSERT(Val->getType()->getScalarType() == RegReprElemTy, "Lowered matrix type mismatch.");
+  if (RegReprElemTy->isIntegerTy(1)) {
+    Type *MemReprTy = Val->getType()->isVectorTy() ? getLoweredVectorTypeForMem() : getElementTypeForMem();
+    Val = Builder.CreateZExt(Val, MemReprTy, "frombool");
+  }
+  return Val;
+}
+
+Value *HLMatrixType::emitLoweredLoad(Value *Ptr, IRBuilder<> &Builder) const {
+  return emitLoweredMemToReg(Builder.CreateLoad(Ptr), Builder);
+}
+
+StoreInst *HLMatrixType::emitLoweredStore(Value *Val, Value *Ptr, IRBuilder<> &Builder) const {
+  return Builder.CreateStore(emitLoweredRegToMem(Val, Builder), Ptr);
+}
+
+Value *HLMatrixType::emitLoweredVectorRowToCol(Value *VecVal, IRBuilder<> &Builder) const {
+  DXASSERT(VecVal->getType() == getLoweredVectorTypeForReg(), "Lowered matrix type mismatch.");
+  if (NumRows == 1 || NumColumns == 1) return VecVal;
+
+  SmallVector<int, 16> ShuffleIndices;
+  for (unsigned ColIdx = 0; ColIdx < NumColumns; ++ColIdx)
+    for (unsigned RowIdx = 0; RowIdx < NumRows; ++RowIdx)
+      ShuffleIndices.emplace_back((int)getRowMajorIndex(RowIdx, ColIdx));
+  return Builder.CreateShuffleVector(VecVal, VecVal, ShuffleIndices, "row2col");
+}
+
+Value *HLMatrixType::emitLoweredVectorColToRow(Value *VecVal, IRBuilder<> &Builder) const {
+  DXASSERT(VecVal->getType() == getLoweredVectorTypeForReg(), "Lowered matrix type mismatch.");
+  if (NumRows == 1 || NumColumns == 1) return VecVal;
+
+  SmallVector<int, 16> ShuffleIndices;
+  for (unsigned RowIdx = 0; RowIdx < NumRows; ++RowIdx)
+    for (unsigned ColIdx = 0; ColIdx < NumColumns; ++ColIdx)
+      ShuffleIndices.emplace_back((int)getColumnMajorIndex(RowIdx, ColIdx));
+  return Builder.CreateShuffleVector(VecVal, VecVal, ShuffleIndices, "col2row");
+}
+
+bool HLMatrixType::isa(Type *Ty) {
+  StructType *StructTy = llvm::dyn_cast<StructType>(Ty);
+  return StructTy != nullptr && StructTy->getName().startswith(StructNamePrefix);
+}
+
+bool HLMatrixType::isMatrixPtr(Type *Ty) {
+  PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
+  return PtrTy != nullptr && isa(PtrTy->getElementType());
+}
+
+bool HLMatrixType::isMatrixArray(Type *Ty) {
+  ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty);
+  if (ArrayTy == nullptr) return false;
+  while (ArrayType *NestedArrayTy = llvm::dyn_cast<ArrayType>(ArrayTy->getElementType()))
+    ArrayTy = NestedArrayTy;
+  return isa(ArrayTy->getElementType());
+}
+
+bool HLMatrixType::isMatrixArrayPtr(Type *Ty) {
+  PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
+  if (PtrTy == nullptr) return false;
+  return isMatrixArray(PtrTy->getElementType());
+}
+
+bool HLMatrixType::isMatrixPtrOrArrayPtr(Type *Ty) {
+  PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
+  if (PtrTy == nullptr) return false;
+  Ty = PtrTy->getElementType();
+  while (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty))
+    Ty = Ty->getArrayElementType();
+  return isa(Ty);
+}
+
+bool HLMatrixType::isMatrixOrPtrOrArrayPtr(Type *Ty) {
+  if (PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty)) Ty = PtrTy->getElementType();
+  while (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty)) Ty = ArrayTy->getElementType();
+  return isa(Ty);
+}
+
+// Converts a matrix, matrix pointer, or matrix array pointer type to its lowered equivalent.
+// If the type is not matrix-derived, the original type is returned.
+// Does not lower struct types containing matrices.
+Type *HLMatrixType::getLoweredType(Type *Ty, bool MemRepr) {
+  if (PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty)) {
+    // Pointees are always in memory representation
+    Type *LoweredElemTy = getLoweredType(PtrTy->getElementType(), /* MemRepr */ true);
+    return LoweredElemTy == PtrTy->getElementType()
+      ? Ty : PointerType::get(LoweredElemTy, PtrTy->getAddressSpace());
+  }
+  else if (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty)) {
+    // Arrays are always in memory and so their elements are in memory representation
+    Type *LoweredElemTy = getLoweredType(ArrayTy->getElementType(), /* MemRepr */ true);
+    return LoweredElemTy == ArrayTy->getElementType()
+      ? Ty : ArrayType::get(LoweredElemTy, ArrayTy->getNumElements());
+  }
+  else if (HLMatrixType MatrixTy = HLMatrixType::dyn_cast(Ty)) {
+    return MatrixTy.getLoweredVectorType(MemRepr);
+  }
+  else return Ty;
+}
+
+HLMatrixType HLMatrixType::cast(Type *Ty) {
+  DXASSERT_NOMSG(isa(Ty));
+  StructType *StructTy = llvm::cast<StructType>(Ty);
+  DXASSERT_NOMSG(Ty->getNumContainedTypes() == 1);
+  ArrayType *RowArrayTy = llvm::cast<ArrayType>(StructTy->getElementType(0));
+  DXASSERT_NOMSG(RowArrayTy->getNumElements() >= 1 && RowArrayTy->getNumElements() <= 4);
+  VectorType *RowTy = llvm::cast<VectorType>(RowArrayTy->getElementType());
+  DXASSERT_NOMSG(RowTy->getNumElements() >= 1 && RowTy->getNumElements() <= 4);
+  return HLMatrixType(RowTy->getElementType(), RowArrayTy->getNumElements(), RowTy->getNumElements());
+}
+
+HLMatrixType HLMatrixType::dyn_cast(Type *Ty) {
+  return isa(Ty) ? cast(Ty) : HLMatrixType();
+}

+ 85 - 57
lib/HLSL/HLOperationLower.cpp

@@ -16,6 +16,7 @@
 #include "dxc/DXIL/DxilModule.h"
 #include "dxc/DXIL/DxilOperations.h"
 #include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/HLSL/HLOperationLower.h"
@@ -4985,10 +4986,9 @@ Value *GenerateCBLoad(Value *handle, Value *offset, Type *EltTy, OP *hlslOP,
 Value *TranslateConstBufMatLd(Type *matType, Value *handle, Value *offset,
                               bool colMajor, OP *OP, const DataLayout &DL,
                               IRBuilder<> &Builder) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
-  unsigned matSize = col * row;
+  HLMatrixType MatTy = HLMatrixType::cast(matType);
+  Type *EltTy = MatTy.getElementTypeForMem();
+  unsigned matSize = MatTy.getNumElements();
   std::vector<Value *> elts(matSize);
   Value *EltByteSize = ConstantInt::get(
       offset->getType(), GetEltTypeByteSizeForConstBuf(EltTy, DL));
@@ -5000,8 +5000,8 @@ Value *TranslateConstBufMatLd(Type *matType, Value *handle, Value *offset,
     baseOffset = Builder.CreateAdd(baseOffset, EltByteSize);
   }
 
-  Value* Vec = HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
-  Vec = HLMatrixLower::VecMatrixMemToReg(Vec, matType, Builder);
+  Value* Vec = HLMatrixLower::BuildVector(EltTy, elts, Builder);
+  Vec = MatTy.emitLoweredMemToReg(Vec, Builder);
   return Vec;
 }
 
@@ -5069,9 +5069,8 @@ void TranslateCBAddressUser(Instruction *user, Value *handle, Value *baseOffset,
     } else if (group == HLOpcodeGroup::HLSubscript) {
       HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
       Value *basePtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-      Type *matType = basePtr->getType()->getPointerElementType();
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
+      HLMatrixType MatTy = HLMatrixType::cast(basePtr->getType()->getPointerElementType());
+      Type *EltTy = MatTy.getElementTypeForReg();
 
       Value *EltByteSize = ConstantInt::get(
           baseOffset->getType(), GetEltTypeByteSizeForConstBuf(EltTy, DL));
@@ -5397,26 +5396,24 @@ Value *GenerateCBLoadLegacy(Value *handle, Value *legacyIdx,
   }
 }
 
-Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
+Value *TranslateConstBufMatLdLegacy(HLMatrixType MatTy, Value *handle,
                                     Value *legacyIdx, bool colMajor, OP *OP,
                                     bool memElemRepr, const DataLayout &DL,
                                     IRBuilder<> &Builder) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
+  Type *EltTy = MatTy.getElementTypeForMem();
 
-  unsigned matSize = col * row;
+  unsigned matSize = MatTy.getNumElements();
   std::vector<Value *> elts(matSize);
   unsigned EltByteSize = GetEltTypeByteSizeForConstBuf(EltTy, DL);
   if (colMajor) {
     unsigned colByteSize = 4 * EltByteSize;
     unsigned colRegSize = (colByteSize + 15) >> 4;
-    for (unsigned c = 0; c < col; c++) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
       Value *col = GenerateCBLoadLegacy(handle, legacyIdx, /*channelOffset*/ 0,
-                                        EltTy, row, OP, Builder);
+                                        EltTy, MatTy.getNumRows(), OP, Builder);
 
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+        unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
         elts[matIdx] = Builder.CreateExtractElement(col, r);
       }
       // Update offset for a column.
@@ -5425,11 +5422,11 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
   } else {
     unsigned rowByteSize = 4 * EltByteSize;
     unsigned rowRegSize = (rowByteSize + 15) >> 4;
-    for (unsigned r = 0; r < row; r++) {
+    for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
       Value *row = GenerateCBLoadLegacy(handle, legacyIdx, /*channelOffset*/ 0,
-                                        EltTy, col, OP, Builder);
-      for (unsigned c = 0; c < col; c++) {
-        unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
+                                        EltTy, MatTy.getNumColumns(), OP, Builder);
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+        unsigned matIdx = MatTy.getRowMajorIndex(r, c);
         elts[matIdx] = Builder.CreateExtractElement(row, c);
       }
       // Update offset for a row.
@@ -5437,9 +5434,9 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
     }
   }
 
-  Value *Vec = HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
+  Value *Vec = HLMatrixLower::BuildVector(EltTy, elts, Builder);
   if (!memElemRepr)
-    Vec = HLMatrixLower::VecMatrixMemToReg(Vec, matType, Builder);
+    Vec = MatTy.emitLoweredMemToReg(Vec, Builder);
   return Vec;
 }
 
@@ -5489,20 +5486,19 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       DXASSERT(matOp == HLMatLoadStoreOpcode::ColMatLoad ||
                    matOp == HLMatLoadStoreOpcode::RowMatLoad,
                "No store on cbuffer");
-      Type *matType = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
-                          ->getType()
-                          ->getPointerElementType();
+      HLMatrixType MatTy = HLMatrixType::cast(
+        CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
+          ->getType()->getPointerElementType());
       // This will replace a call, so we should use the register representation of elements
       Value *newLd = TranslateConstBufMatLdLegacy(
-          matType, handle, legacyIdx, colMajor, hlslOP, /*memElemRepr*/false, DL, Builder);
+        MatTy, handle, legacyIdx, colMajor, hlslOP, /*memElemRepr*/false, DL, Builder);
       CI->replaceAllUsesWith(newLd);
       CI->eraseFromParent();
     } else if (group == HLOpcodeGroup::HLSubscript) {
       HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
       Value *basePtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-      Type *matType = basePtr->getType()->getPointerElementType();
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
+      HLMatrixType MatTy = HLMatrixType::cast(basePtr->getType()->getPointerElementType());
+      Type *EltTy = MatTy.getElementTypeForReg();
 
       Value *idx = CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
 
@@ -5523,7 +5519,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       if (!dynamicIndexing) {
         // This will replace a load or GEP, so we should use the memory representation of elements
         Value *matLd = TranslateConstBufMatLdLegacy(
-            matType, handle, legacyIdx, colMajor, hlslOP, /*memElemRepr*/true, DL, Builder);
+          MatTy, handle, legacyIdx, colMajor, hlslOP, /*memElemRepr*/true, DL, Builder);
         // The matLd is keep original layout, just use the idx calc in
         // EmitHLSLMatrixElement and EmitHLSLMatrixSubscript.
         switch (subOp) {
@@ -5568,7 +5564,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
           // row.z = c[2].[idx]
           // row.w = c[3].[idx]
           Value *Elts[4];
-          ArrayType *AT = ArrayType::get(EltTy, col);
+          ArrayType *AT = ArrayType::get(EltTy, MatTy.getNumColumns());
 
           IRBuilder<> AllocaBuilder(user->getParent()
                                         ->getParent()
@@ -5578,12 +5574,12 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
           Value *tempArray = AllocaBuilder.CreateAlloca(AT);
           Value *zero = AllocaBuilder.getInt32(0);
           Value *cbufIdx = legacyIdx;
-          for (unsigned int c = 0; c < col; c++) {
+          for (unsigned int c = 0; c < MatTy.getNumColumns(); c++) {
             Value *ColVal =
                 GenerateCBLoadLegacy(handle, cbufIdx, /*channelOffset*/ 0,
-                                     EltTy, row, hlslOP, Builder);
+                                     EltTy, MatTy.getNumRows(), hlslOP, Builder);
             // Convert ColVal to array for indexing.
-            for (unsigned int r = 0; r < row; r++) {
+            for (unsigned int r = 0; r < MatTy.getNumRows(); r++) {
               Value *Elt =
                   Builder.CreateExtractElement(ColVal, Builder.getInt32(r));
               Value *Ptr = Builder.CreateInBoundsGEP(
@@ -5597,7 +5593,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
             cbufIdx = Builder.CreateAdd(cbufIdx, one);
           }
           if (resultType->isVectorTy()) {
-            for (unsigned int c = 0; c < col; c++) {
+            for (unsigned int c = 0; c < MatTy.getNumColumns(); c++) {
               ldData = Builder.CreateInsertElement(ldData, Elts[c], c);
             }
           } else {
@@ -5606,12 +5602,12 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
         } else {
           // idx is r * col + c;
           // r = idx / col;
-          Value *cCol = ConstantInt::get(idx->getType(), col);
+          Value *cCol = ConstantInt::get(idx->getType(), MatTy.getNumColumns());
           idx = Builder.CreateUDiv(idx, cCol);
           idx = Builder.CreateAdd(idx, legacyIdx);
           // Just return a row; 'col' is the number of columns in the row.
           ldData = GenerateCBLoadLegacy(handle, idx, /*channelOffset*/ 0, EltTy,
-                                        col, hlslOP, Builder);
+            MatTy.getNumColumns(), hlslOP, Builder);
         }
         if (!resultType->isVectorTy()) {
           ldData = Builder.CreateExtractElement(ldData, Builder.getInt32(0));
@@ -5989,9 +5985,8 @@ Value *TranslateStructBufMatLd(Type *matType, IRBuilder<> &Builder,
                                Value *handle, hlsl::OP *OP, Value *status,
                                Value *bufIdx, Value *baseOffset,
                                const DataLayout &DL) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
+  HLMatrixType MatTy = HLMatrixType::cast(matType);
+  Type *EltTy = MatTy.getElementTypeForMem();
   unsigned  EltSize = DL.getTypeAllocSize(EltTy);
   Constant* alignment = OP->GetI32Const(EltSize);
 
@@ -5999,7 +5994,7 @@ Value *TranslateStructBufMatLd(Type *matType, IRBuilder<> &Builder,
   if (baseOffset == nullptr)
     offset = OP->GetU32Const(0);
 
-  unsigned matSize = col * row;
+  unsigned matSize = MatTy.getNumElements();
   std::vector<Value *> elts(matSize);
 
   unsigned rest = (matSize % 4);
@@ -6023,19 +6018,18 @@ Value *TranslateStructBufMatLd(Type *matType, IRBuilder<> &Builder,
     offset = Builder.CreateAdd(offset, OP->GetU32Const(4 * EltSize));
   }
 
-  Value *Vec = HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
-  Vec = HLMatrixLower::VecMatrixMemToReg(Vec, matType, Builder);
+  Value *Vec = HLMatrixLower::BuildVector(EltTy, elts, Builder);
+  Vec = MatTy.emitLoweredMemToReg(Vec, Builder);
   return Vec;
 }
 
 void TranslateStructBufMatSt(Type *matType, IRBuilder<> &Builder, Value *handle,
                              hlsl::OP *OP, Value *bufIdx, Value *baseOffset,
                              Value *val, const DataLayout &DL) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
+  HLMatrixType MatTy = HLMatrixType::cast(matType);
+  Type *EltTy = MatTy.getElementTypeForMem();
 
-  val = HLMatrixLower::VecMatrixRegToMem(val, matType, Builder);
+  val = MatTy.emitLoweredRegToMem(val, Builder);
 
   unsigned EltSize = DL.getTypeAllocSize(EltTy);
   Constant *Alignment = OP->GetI32Const(EltSize);
@@ -6043,7 +6037,7 @@ void TranslateStructBufMatSt(Type *matType, IRBuilder<> &Builder, Value *handle,
   if (baseOffset == nullptr)
     offset = OP->GetU32Const(0);
 
-  unsigned matSize = col * row;
+  unsigned matSize = MatTy.getNumElements();
   Value *undefElt = UndefValue::get(EltTy);
 
   unsigned storeSize = matSize;
@@ -6106,6 +6100,41 @@ void TranslateStructBufSubscriptUser(Instruction *user, Value *handle,
                                      Value *bufIdx, Value *baseOffset,
                                      Value *status, hlsl::OP *OP, const DataLayout &DL);
 
+// For case like mat[i][j].
+// IdxList is [i][0], [i][1], [i][2],[i][3].
+// Idx is j.
+// return [i][j] not mat[i][j] because resource ptr and temp ptr need different
+// code gen.
+static Value *LowerGEPOnMatIndexListToIndex(
+  llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
+  IRBuilder<> Builder(GEP);
+  Value *zero = Builder.getInt32(0);
+  DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
+  Value *baseIdx = (GEP->idx_begin())->get();
+  DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
+  Value *Idx = (GEP->idx_begin() + 1)->get();
+
+  if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
+    return IdxList[immIdx->getSExtValue()];
+  }
+  else {
+    IRBuilder<> AllocaBuilder(
+      GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
+    unsigned size = IdxList.size();
+    // Store idxList to temp array.
+    ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
+    Value *tempArray = AllocaBuilder.CreateAlloca(AT);
+
+    for (unsigned i = 0; i < size; i++) {
+      Value *EltPtr = Builder.CreateGEP(tempArray, { zero, Builder.getInt32(i) });
+      Builder.CreateStore(IdxList[i], EltPtr);
+    }
+    // Load the idx.
+    Value *GEPOffset = Builder.CreateGEP(tempArray, { zero, Idx });
+    return Builder.CreateLoad(GEPOffset);
+  }
+}
+
 // subscript operator for matrix of struct element.
 void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
                                     hlsl::OP *hlslOP, Value *bufIdx,
@@ -6118,9 +6147,8 @@ void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
   IRBuilder<> subBuilder(CI);
   HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
   Value *basePtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-  Type *matType = basePtr->getType()->getPointerElementType();
-  unsigned col, row;
-  Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
+  HLMatrixType MatTy = HLMatrixType::cast(basePtr->getType()->getPointerElementType());
+  Type *EltTy = MatTy.getElementTypeForReg();
   Constant *alignment = hlslOP->GetI32Const(DL.getTypeAllocSize(EltTy));
 
   Value *EltByteSize = ConstantInt::get(
@@ -6170,8 +6198,7 @@ void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
       continue;
     }
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-      Value *GEPOffset =
-          HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
+      Value *GEPOffset = LowerGEPOnMatIndexListToIndex(GEP, idxList);
 
       for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
         Instruction *gepUserInst = cast<Instruction>(*(gepU++));
@@ -6998,8 +7025,9 @@ void TranslateHLBuiltinOperation(Function *F, HLOperationLowerHelper &helper,
               /*bOrigAllocaTy*/false,
               matVal->getName());
             if (bTranspose) {
-              unsigned row, col;
-              HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
+              HLMatrixType MatTy = HLMatrixType::cast(matVal->getType());
+              unsigned row = MatTy.getNumRows();
+              unsigned col = MatTy.getNumColumns();
               if (bColDest) std::swap(row, col);
               vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
             }

+ 0 - 1
lib/HLSL/HLOperationLowerExtension.cpp

@@ -11,7 +11,6 @@
 
 #include "dxc/DXIL/DxilModule.h"
 #include "dxc/DXIL/DxilOperations.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLOperationLower.h"
 #include "dxc/HLSL/HLOperations.h"

+ 43 - 53
lib/HLSL/HLSignatureLower.cpp

@@ -18,6 +18,7 @@
 #include "dxc/DXIL/DxilSemantic.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/HLSL/DxilPackSignatureElement.h"
@@ -632,8 +633,7 @@ void replaceDirectInputParameter(Value *param, Function *loadInput,
         GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
     param->replaceAllUsesWith(input);
   } else if (dxilutil::IsHLSLMatrixType(Ty)) {
-    Value *colIdx = hlslOP->GetU8Const(0);
-    (void)colIdx;
+    if (param->use_empty()) return;
     DXASSERT(param->hasOneUse(),
              "matrix arg should only has one use as matrix to vec");
     CallInst *CI = cast<CallInst>(param->user_back());
@@ -645,49 +645,45 @@ void replaceDirectInputParameter(Value *param, Function *loadInput,
     switch (matOp) {
     case HLCastOpcode::ColMatrixToVecCast: {
       IRBuilder<> LocalBuilder(CI);
-      Type *matTy =
-          CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType();
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-      std::vector<Value *> matElts(col * row);
-      for (unsigned c = 0; c < col; c++) {
+      HLMatrixType MatTy = HLMatrixType::cast(
+          CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType());
+      Type *EltTy = MatTy.getElementTypeForReg();
+      std::vector<Value *> matElts(MatTy.getNumElements());
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
         Value *rowIdx = hlslOP->GetI32Const(c);
         args[DXIL::OperandIndex::kLoadInputRowOpIdx] = rowIdx;
-        for (unsigned r = 0; r < row; r++) {
+        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
           Value *colIdx = hlslOP->GetU8Const(r);
           args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
           Value *input =
               GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
-          unsigned matIdx = c * row + r;
-          matElts[matIdx] = input;
+          matElts[MatTy.getColumnMajorIndex(r, c)] = input;
         }
       }
       Value *newVec =
-          HLMatrixLower::BuildVector(EltTy, col * row, matElts, LocalBuilder);
+          HLMatrixLower::BuildVector(EltTy, matElts, LocalBuilder);
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
     } break;
     case HLCastOpcode::RowMatrixToVecCast: {
       IRBuilder<> LocalBuilder(CI);
-      Type *matTy =
-          CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType();
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-      std::vector<Value *> matElts(col * row);
-      for (unsigned r = 0; r < row; r++) {
+      HLMatrixType MatTy = HLMatrixType::cast(
+          CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType());
+      Type *EltTy = MatTy.getElementTypeForReg();
+      std::vector<Value *> matElts(MatTy.getNumElements());
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
         Value *rowIdx = hlslOP->GetI32Const(r);
         args[DXIL::OperandIndex::kLoadInputRowOpIdx] = rowIdx;
-        for (unsigned c = 0; c < col; c++) {
+        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
           Value *colIdx = hlslOP->GetU8Const(c);
           args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
           Value *input =
               GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
-          unsigned matIdx = r * col + c;
-          matElts[matIdx] = input;
+          matElts[MatTy.getRowMajorIndex(r, c)] = input;
         }
       }
       Value *newVec =
-          HLMatrixLower::BuildVector(EltTy, col * row, matElts, LocalBuilder);
+          HLMatrixLower::BuildVector(EltTy, matElts, LocalBuilder);
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
     } break;
@@ -784,12 +780,10 @@ void collectInputOutputAccessInfo(
               vectorIdx = GEPIt.getOperand();
             }
           }
-          if (dxilutil::IsHLSLMatrixType(*GEPIt)) {
-            unsigned row, col;
-            HLMatrixLower::GetMatrixInfo(*GEPIt, col, row);
-            Constant *arraySize = ConstantInt::get(idxTy, col);
+          if (HLMatrixType MatTy = HLMatrixType::dyn_cast(*GEPIt)) {
+            Constant *arraySize = ConstantInt::get(idxTy, MatTy.getNumColumns());
             if (bRowMajor) {
-              arraySize = ConstantInt::get(idxTy, row);
+              arraySize = ConstantInt::get(idxTy, MatTy.getNumRows());
             }
             rowIdx = Builder.CreateMul(rowIdx, arraySize);
           }
@@ -915,46 +909,44 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
     case HLMatLoadStoreOpcode::ColMatLoad:
     case HLMatLoadStoreOpcode::RowMatLoad: {
       IRBuilder<> LocalBuilder(CI);
-      Type *matTy = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
-                        ->getType()
-                        ->getPointerElementType();
-      unsigned col, row;
-      HLMatrixLower::GetMatrixInfo(matTy, col, row);
-      std::vector<Value *> matElts(col * row);
+      HLMatrixType MatTy = HLMatrixType::cast(
+        CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
+          ->getType()->getPointerElementType());
+      std::vector<Value *> matElts(MatTy.getNumElements());
 
       if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
-        for (unsigned c = 0; c < col; c++) {
+        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
           Constant *constRowIdx = LocalBuilder.getInt32(c);
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned r = 0; r < row; r++) {
+          for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
             SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
             if (vertexID)
               args.emplace_back(vertexID);
 
             Value *input = LocalBuilder.CreateCall(ldStFunc, args);
-            unsigned matIdx = c * row + r;
+            unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
             matElts[matIdx] = input;
           }
         }
       } else {
-        for (unsigned r = 0; r < row; r++) {
+        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
           Constant *constRowIdx = LocalBuilder.getInt32(r);
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned c = 0; c < col; c++) {
+          for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
             SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
             if (vertexID)
               args.emplace_back(vertexID);
 
             Value *input = LocalBuilder.CreateCall(ldStFunc, args);
-            unsigned matIdx = r * col + c;
+            unsigned matIdx = MatTy.getRowMajorIndex(r, c);
             matElts[matIdx] = input;
           }
         }
       }
 
       Value *newVec =
-          HLMatrixLower::BuildVector(matElts[0]->getType(), col * row, matElts, LocalBuilder);
-      newVec = HLMatrixLower::VecMatrixMemToReg(newVec, matTy, LocalBuilder);
+          HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
+      newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
 
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
@@ -963,32 +955,30 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
     case HLMatLoadStoreOpcode::RowMatStore: {
       IRBuilder<> LocalBuilder(CI);
       Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-      Type *matTy = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
-                        ->getType()
-                        ->getPointerElementType();
-      unsigned col, row;
-      HLMatrixLower::GetMatrixInfo(matTy, col, row);
+      HLMatrixType MatTy = HLMatrixType::cast(
+        CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
+          ->getType()->getPointerElementType());
 
-      Val = HLMatrixLower::VecMatrixRegToMem(Val, matTy, LocalBuilder);
+      Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
 
       if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
-        for (unsigned c = 0; c < col; c++) {
+        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
           Constant *constColIdx = LocalBuilder.getInt32(c);
           Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
 
-          for (unsigned r = 0; r < row; r++) {
-            unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
+          for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+            unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
             Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
             LocalBuilder.CreateCall(ldStFunc,
               { OpArg, ID, colIdx, columnConsts[r], Elt });
           }
         }
       } else {
-        for (unsigned r = 0; r < row; r++) {
+        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
           Constant *constRowIdx = LocalBuilder.getInt32(r);
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned c = 0; c < col; c++) {
-            unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
+          for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+            unsigned matIdx = MatTy.getRowMajorIndex(r, c);
             Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
             LocalBuilder.CreateCall(ldStFunc,
               { OpArg, ID, rowIdx, columnConsts[c], Elt });

+ 7 - 1
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -236,12 +236,18 @@ static void addHLSLPasses(bool HLSLHighLevel, unsigned OptLevel, hlsl::HLSLExten
     // Do this before change vector to array.
     MPM.add(createDxilLegalizeEvalOperationsPass());
   }
+  else {
+    // This should go between matrix lower and dynamic indexing vector to array,
+    // because matrix lower may create dynamically indexed global vectors,
+    // which should become locals. If they are turned into arrays first,
+    // this pass will ignore them as it only works on scalars and vectors.
+    MPM.add(createLowerStaticGlobalIntoAlloca());
+  }
 
   // Change dynamic indexing vector to array.
   MPM.add(createDynamicIndexingVectorToArrayPass(NoOpt));
 
   if (!NoOpt) {
-    MPM.add(createLowerStaticGlobalIntoAlloca());
     // mem2reg
     MPM.add(createPromoteMemoryToRegisterPass());
 

+ 1 - 1
lib/Transforms/Scalar/DxilLoopUnroll.cpp

@@ -956,7 +956,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
     // they're in entry block so deleting loop blocks don't 
     // kill them too.
     for (AllocaInst *AI : ProblemAllocas)
-      DXASSERT(AI->getParent() == &F->getEntryBlock(), "Alloca is not in entry block.");
+      DXASSERT_LOCALVAR(AI, AI->getParent() == &F->getEntryBlock(), "Alloca is not in entry block.");
 
     LoopIteration &FirstIteration = *Iterations.front().get();
     // Make the predecessor branch to the first new header.

+ 110 - 47
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -56,6 +56,7 @@
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/DXIL/DxilTypeSystem.h"
 #include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/DXIL/DxilOperations.h"
 #include <deque>
 #include <unordered_map>
@@ -112,7 +113,7 @@ private:
 
   void RewriteForConstExpr(ConstantExpr *user, IRBuilder<> &Builder);
   void RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder);
-  void RewriteForAddrSpaceCast(ConstantExpr *user, IRBuilder<> &Builder);
+  void RewriteForAddrSpaceCast(Value *user, IRBuilder<> &Builder);
   void RewriteForLoad(LoadInst *loadInst);
   void RewriteForStore(StoreInst *storeInst);
   void RewriteMemIntrin(MemIntrinsic *MI, Value *OldV);
@@ -2442,8 +2443,8 @@ static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsi
   unsigned ptrSize = DL.getTypeAllocSize(Ty);
   // Size match, return current level.
   if (ptrSize == size) {
-    // Not go deeper for matrix.
-    if (dxilutil::IsHLSLMatrixType(Ty))
+    // Do not go deeper for matrix or object.
+    if (dxilutil::IsHLSLMatrixType(Ty) || dxilutil::IsHLSLObjectType(Ty))
       return level;
     // For struct, go deeper if size not change.
     // This will leave memcpy to deeper level when flatten.
@@ -2568,6 +2569,24 @@ static void DeleteMemcpy(MemCpyInst *MI) {
   }
 }
 
+// If user is function call, return param annotation to get matrix major.
+static DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
+  DxilTypeSystem &typeSys) {
+  for (User *U : Mat->users()) {
+    if (CallInst *CI = dyn_cast<CallInst>(U)) {
+      Function *F = CI->getCalledFunction();
+      if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
+        for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
+          if (CI->getArgOperand(i) == Mat) {
+            return &Anno->GetParameterAnnotation(i);
+          }
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
 void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
                                  DxilFieldAnnotation *fieldAnnotation,
                                  DxilTypeSystem &typeSys, const bool bEltMemCpy) {
@@ -2597,7 +2616,7 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   if (!fieldAnnotation) {
     Type *EltTy = dxilutil::GetArrayEltTy(DestTy);
     if (dxilutil::IsHLSLMatrixType(EltTy)) {
-      fieldAnnotation = HLMatrixLower::FindAnnotationFromMatUser(Dest, typeSys);
+      fieldAnnotation = FindAnnotationFromMatUser(Dest, typeSys);
     }
   }
 
@@ -3284,17 +3303,17 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
   }
 }
 
-/// RewriteForConstExpr - Rewrite the GEP which is ConstantExpr.
-void SROA_Helper::RewriteForAddrSpaceCast(ConstantExpr *CE,
+/// RewriteForAddrSpaceCast - Rewrite the AddrSpaceCast, either ConstExpr or Inst.
+void SROA_Helper::RewriteForAddrSpaceCast(Value *CE,
                                           IRBuilder<> &Builder) {
   SmallVector<Value *, 8> NewCasts;
   // create new AddrSpaceCast.
   for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
-    Value *NewGEP = Builder.CreateAddrSpaceCast(
+    Value *NewCast = Builder.CreateAddrSpaceCast(
         NewElts[i],
         PointerType::get(NewElts[i]->getType()->getPointerElementType(),
                          CE->getType()->getPointerAddressSpace()));
-    NewCasts.emplace_back(NewGEP);
+    NewCasts.emplace_back(NewCast);
   }
   SROA_Helper helper(CE, NewCasts, DeadInsts, typeSys, DL);
   helper.RewriteForScalarRepl(CE, Builder);
@@ -3360,7 +3379,9 @@ void SROA_Helper::RewriteForScalarRepl(Value *V, IRBuilder<> &Builder) {
       RewriteCall(CI);
     else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User))
       RewriteBitCast(BCI);
-    else {
+    else if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(User)) {
+      RewriteForAddrSpaceCast(CI, Builder);
+    } else {
       assert(0 && "not support.");
     }
   }
@@ -4186,9 +4207,28 @@ bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
 
 /// MarkEmptyStructUsers - Add instruction related to Empty struct to DeadInsts.
 void SROA_Helper::MarkEmptyStructUsers(Value *V, SmallVector<Value *, 32> &DeadInsts) {
-  for (User *U : V->users()) {
-    MarkEmptyStructUsers(U, DeadInsts);
+  UndefValue *undef = UndefValue::get(V->getType());
+  for (auto itU = V->user_begin(), E = V->user_end(); itU != E;) {
+    Value *U = *(itU++);
+    // Kill memcpy, set operands to undef for call and ret, and recurse
+    if (MemCpyInst *MC = dyn_cast<MemCpyInst>(U)) {
+      DeadInsts.emplace_back(MC);
+    } else if (CallInst *CI = dyn_cast<CallInst>(U)) {
+      for (auto &operand : CI->operands()) {
+        if (operand == V)
+          operand.set(undef);
+      }
+    } else if (ReturnInst *Ret = dyn_cast<ReturnInst>(U)) {
+      Ret->setOperand(0, undef);
+    } else if (isa<Constant>(U) || isa<GetElementPtrInst>(U) ||
+               isa<BitCastInst>(U) || isa<LoadInst>(U) || isa<StoreInst>(U)) {
+      // Recurse users
+      MarkEmptyStructUsers(U, DeadInsts);
+    } else {
+      DXASSERT(false, "otherwise, recursing unexpected empty struct user");
+    }
   }
+
   if (Instruction *I = dyn_cast<Instruction>(V)) {
     // Only need to add no use inst here.
     // DeleteDeadInst will delete everything.
@@ -4619,6 +4659,12 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
 
     // Flat Global vector if no dynamic vector indexing.
     bool bFlatVector = !hasDynamicVectorIndexing(EltGV);
+
+    // Disable scalarization of groupshared vector arrays
+    if (GV->getType()->getAddressSpace() == DXIL::kTGSMAddrSpace &&
+        Ty->isArrayTy())
+      bFlatVector = false;
+
     std::vector<Value *> Elts;
     bool SROAed = SROA_Helper::DoScalarReplacement(
         EltGV, Elts, Builder, bFlatVector,
@@ -4872,23 +4918,19 @@ static void CopyEltsPtrToVectorPtr(ArrayRef<Value *> elts, Value *VecPtr,
 static void CopyMatToArrayPtr(Value *Mat, Value *ArrayPtr,
                               unsigned arrayBaseIdx, HLModule &HLM,
                               IRBuilder<> &Builder, bool bRowMajor) {
-  Type *Ty = Mat->getType();
   // Mat val is row major.
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(Mat->getType(), col, row);
-  Type *VecTy = HLMatrixLower::LowerMatrixType(Ty);
+  HLMatrixType MatTy = HLMatrixType::cast(Mat->getType());
+  Type *VecTy = MatTy.getLoweredVectorTypeForReg();
   Value *Vec =
       HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
                               (unsigned)HLCastOpcode::RowMatrixToVecCast, VecTy,
                               {Mat}, *HLM.GetModule());
   Value *zero = Builder.getInt32(0);
 
-  for (unsigned r = 0; r < row; r++) {
-    for (unsigned c = 0; c < col; c++) {
-      unsigned rowMatIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
-      Value *Elt = Builder.CreateExtractElement(Vec, rowMatIdx);
-      unsigned matIdx =
-          bRowMajor ? rowMatIdx :  HLMatrixLower::GetColMajorIdx(r, c, row);
+  for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
+      Value *Elt = Builder.CreateExtractElement(Vec, matIdx);
       Value *Ptr = Builder.CreateInBoundsGEP(
           ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
       Builder.CreateStore(Elt, Ptr);
@@ -4919,15 +4961,15 @@ static void CopyMatPtrToArrayPtr(Value *MatPtr, Value *ArrayPtr,
 static Value *LoadArrayPtrToMat(Value *ArrayPtr, unsigned arrayBaseIdx,
                                 Type *Ty, HLModule &HLM, IRBuilder<> &Builder,
                                 bool bRowMajor) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(Ty, col, row);
+  HLMatrixType MatTy = HLMatrixType::cast(Ty);
   // HLInit operands are in row major.
   SmallVector<Value *, 16> Elts;
   Value *zero = Builder.getInt32(0);
-  for (unsigned r = 0; r < row; r++) {
-    for (unsigned c = 0; c < col; c++) {
-      unsigned matIdx = bRowMajor ? HLMatrixLower::GetRowMajorIdx(r, c, col)
-                                  : HLMatrixLower::GetColMajorIdx(r, c, row);
+  for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      unsigned matIdx = bRowMajor
+        ? MatTy.getRowMajorIndex(r, c)
+        : MatTy.getColumnMajorIndex(r, c);
       Value *Ptr = Builder.CreateInBoundsGEP(
           ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
       Value *Elt = Builder.CreateLoad(Ptr);
@@ -4981,12 +5023,10 @@ CastCopyArrayMultiDimTo1Dim(Value *FromArray, Value *ToArray, Type *CurFromTy,
       Value *Elt = Builder.CreateExtractElement(V, i);
       Builder.CreateStore(Elt, ToPtr);
     }
-  } else if (dxilutil::IsHLSLMatrixType(CurFromTy)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(CurFromTy)) {
     // Copy matrix to array.
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(CurFromTy, col, row);
     // Calculate the offset.
-    unsigned offset = calcIdx * col * row;
+    unsigned offset = calcIdx * MatTy.getNumElements();
     Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
     CopyMatPtrToArrayPtr(FromPtr, ToArray, offset, HLM, Builder, bRowMajor);
   } else if (!CurFromTy->isArrayTy()) {
@@ -5028,12 +5068,10 @@ CastCopyArray1DimToMultiDim(Value *FromArray, Value *ToArray, Type *CurToTy,
       V = Builder.CreateInsertElement(V, Elt, i);
     }
     Builder.CreateStore(V, ToPtr);
-  } else if (dxilutil::IsHLSLMatrixType(CurToTy)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::cast(CurToTy)) {
     // Copy array to matrix.
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(CurToTy, col, row);
     // Calculate the offset.
-    unsigned offset = calcIdx * col * row;
+    unsigned offset = calcIdx * MatTy.getNumElements();
     Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
     CopyArrayPtrToMatPtr(FromArray, offset, ToPtr, HLM, Builder, bRowMajor);
   } else if (!CurToTy->isArrayTy()) {
@@ -6636,6 +6674,8 @@ private:
   void ReplaceVectorWithArray(Value *Vec, Value *Array);
   void ReplaceVectorArrayWithArray(Value *VecArray, Value *Array);
   void ReplaceStaticIndexingOnVector(Value *V);
+  void ReplaceAddrSpaceCast(ConstantExpr *CE,
+                            Value *A, IRBuilder<> &Builder);
 };
 
 void DynamicIndexingVectorToArray::applyOptions(PassOptions O) {
@@ -6738,13 +6778,23 @@ void DynamicIndexingVectorToArray::ReplaceVecGEP(Value *GEP, ArrayRef<Value *> i
   }
 }
 
+void DynamicIndexingVectorToArray::ReplaceAddrSpaceCast(ConstantExpr *CE,
+                                              Value *A, IRBuilder<> &Builder) {
+  // create new AddrSpaceCast.
+  Value *NewAddrSpaceCast = Builder.CreateAddrSpaceCast(
+    A,
+    PointerType::get(A->getType()->getPointerElementType(),
+                      CE->getType()->getPointerAddressSpace()));
+  ReplaceVectorWithArray(CE, NewAddrSpaceCast);
+}
+
 void DynamicIndexingVectorToArray::ReplaceVectorWithArray(Value *Vec, Value *A) {
   unsigned size = Vec->getType()->getPointerElementType()->getVectorNumElements();
   for (auto U = Vec->user_begin(); U != Vec->user_end();) {
     User *User = (*U++);
 
     // GlobalVariable user.
-    if (isa<ConstantExpr>(User)) {
+    if (ConstantExpr * CE = dyn_cast<ConstantExpr>(User)) {
       if (User->user_empty())
         continue;
       if (GEPOperator *GEP = dyn_cast<GEPOperator>(User)) {
@@ -6752,7 +6802,12 @@ void DynamicIndexingVectorToArray::ReplaceVectorWithArray(Value *Vec, Value *A)
         SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
         ReplaceVecGEP(GEP, idxList, A, Builder);
         continue;
+      } else if (CE->getOpcode() == Instruction::AddrSpaceCast) {
+        IRBuilder<> Builder(Vec->getContext());
+        ReplaceAddrSpaceCast(CE, A, Builder);
+        continue;
       }
+      DXASSERT(0, "not implemented yet");
     }
     // Instrution user.
     Instruction *UserInst = cast<Instruction>(User);
@@ -6961,7 +7016,7 @@ void ReplaceMultiDimGEP(User *GEP, Value *OneDim, IRBuilder<> &Builder) {
 void MultiDimArrayToOneDimArray::lowerUseWithNewValue(Value *MultiDim, Value *OneDim) {
   LLVMContext &Context = MultiDim->getContext();
   // All users should be element type.
-  // Replace users of AI.
+  // Replace users of AI or GV.
   for (auto it = MultiDim->user_begin(); it != MultiDim->user_end();) {
     User *U = *(it++);
     if (U->user_empty())
@@ -6970,21 +7025,29 @@ void MultiDimArrayToOneDimArray::lowerUseWithNewValue(Value *MultiDim, Value *On
       BCI->setOperand(0, OneDim);
       continue;
     }
-    // Must be GEP.
-    GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U);
 
-    if (!GEP) {
-      DXASSERT_NOMSG(isa<GEPOperator>(U));
-      // NewGEP must be GEPOperator too.
-      // No instruction will be build.
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
       IRBuilder<> Builder(Context);
-      ReplaceMultiDimGEP(U, OneDim, Builder);
-    } else {
+      if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
+        // NewGEP must be GEPOperator too.
+        // No instruction will be build.
+        ReplaceMultiDimGEP(U, OneDim, Builder);
+      } else if (CE->getOpcode() == Instruction::AddrSpaceCast) {
+        Value *NewAddrSpaceCast = Builder.CreateAddrSpaceCast(
+          OneDim,
+          PointerType::get(OneDim->getType()->getPointerElementType(),
+                           CE->getType()->getPointerAddressSpace()));
+        lowerUseWithNewValue(CE, NewAddrSpaceCast);
+      } else {
+        DXASSERT(0, "not implemented");
+      }
+    } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
       IRBuilder<> Builder(GEP);
       ReplaceMultiDimGEP(U, OneDim, Builder);
-    }
-    if (GEP)
       GEP->eraseFromParent();
+    } else {
+      DXASSERT(0, "not implemented");
+    }
   }
 }
 

+ 1 - 2
tools/clang/include/clang/AST/HlslTypes.h

@@ -382,9 +382,8 @@ bool IsHLSLLineStreamType(clang::QualType type);
 bool IsHLSLTriangleStreamType(clang::QualType type);
 bool IsHLSLStreamOutputType(clang::QualType type);
 bool IsHLSLResourceType(clang::QualType type);
-bool IsHLSLNumeric(clang::QualType type);
 bool IsHLSLNumericUserDefinedType(clang::QualType type);
-bool IsHLSLAggregateType(clang::ASTContext& context, clang::QualType type);
+bool IsHLSLAggregateType(clang::QualType type);
 clang::QualType GetHLSLResourceResultType(clang::QualType type);
 bool IsIncompleteHLSLResourceArrayType(clang::ASTContext& context, clang::QualType type);
 clang::QualType GetHLSLInputPatchElementType(clang::QualType type);

+ 6 - 4
tools/clang/lib/AST/HlslTypes.cpp

@@ -91,7 +91,7 @@ bool IsHLSLVecType(clang::QualType type) {
   return false;
 }
 
-bool IsHLSLNumeric(clang::QualType type) {
+static bool IsHLSLNumeric(clang::QualType type) {
   const clang::Type *Ty = type.getCanonicalType().getTypePtr();
   if (isa<RecordType>(Ty)) {
     if (IsHLSLVecMatType(type))
@@ -125,9 +125,11 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) {
   return false;
 }
 
-bool IsHLSLAggregateType(clang::ASTContext& context, clang::QualType type) {
-  // Aggregate types are arrays and user-defined structs
-  if (context.getAsArrayType(type) != nullptr) return true;
+// Aggregate types are arrays and user-defined structs
+bool IsHLSLAggregateType(clang::QualType type) {
+  type = type.getCanonicalType();
+  if (isa<clang::ArrayType>(type)) return true;
+
   const RecordType *Record = dyn_cast<RecordType>(type);
   return Record != nullptr
     && !IsHLSLVecMatType(type) && !IsHLSLResourceType(type)

+ 2 - 1
tools/clang/lib/CodeGen/CGClass.cpp

@@ -171,7 +171,8 @@ llvm::Value *CodeGenFunction::GetAddressOfBaseClass(
 
   // Get the base pointer type.
   llvm::Type *BasePtrTy =
-    ConvertType((PathEnd[-1])->getType())->getPointerTo();
+    ConvertType((PathEnd[-1])->getType())->getPointerTo(
+      Value->getType()->getPointerAddressSpace()); // HLSL Change: match address space
 
   QualType DerivedTy = getContext().getRecordType(Derived);
   CharUnits DerivedAlign = getContext().getTypeAlignInChars(DerivedTy);

+ 1 - 1
tools/clang/lib/CodeGen/CGDecl.cpp

@@ -1795,7 +1795,7 @@ void CodeGenFunction::EmitParmDecl(const VarDecl &D, llvm::Value *Arg,
   }
 
   LValue lv = MakeAddrLValue(DeclPtr, Ty, Align);
-  if (IsScalar) {
+  if (!getLangOpts().HLSL && IsScalar) {  // HLSL Change: not ObjC
     Qualifiers qs = Ty.getQualifiers();
     if (Qualifiers::ObjCLifetime lt = qs.getObjCLifetime()) {
       // We honor __attribute__((ns_consumed)) for types with lifetime.

+ 15 - 2
tools/clang/lib/CodeGen/CGExpr.cpp

@@ -1397,7 +1397,7 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
       }
     }
 
-    if (hlsl::IsHLSLAggregateType(getContext(), LV.getType())) {
+    if (hlsl::IsHLSLAggregateType(LV.getType())) {
       // We cannot load the value because we don't expect to ever have
       // user-defined struct or array-typed llvm registers, only pointers to them.
       // To preserve the snapshot semantics of LValue loads, we copy the
@@ -3285,10 +3285,23 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
     LValue LV = EmitLValue(E->getSubExpr());
     QualType ToType = getContext().getLValueReferenceType(E->getType());
 
+    llvm::Value *FromValue = LV.getAddress();
+    llvm::Type *FromTy = FromValue->getType();
     llvm::Type *RetTy = ConvertType(ToType);
     // type not changed, LValueToRValue, CStyleCast may go this path
-    if (LV.getAddress()->getType() == RetTy)
+    if (FromTy == RetTy) {
       return LV;
+    // If only address space changed, add address space cast
+    }
+    if (FromTy->getPointerAddressSpace() != RetTy->getPointerAddressSpace()) {
+      llvm::Type *ConvertedFromTy = llvm::PointerType::get(
+        FromTy->getPointerElementType(), RetTy->getPointerAddressSpace());
+      assert(ConvertedFromTy == RetTy &&
+             "otherwise, more than just address space changing in one step");
+      llvm::Value *cast =
+          Builder.CreateAddrSpaceCast(FromValue, ConvertedFromTy);
+      return MakeAddrLValue(cast, ToType);
+    }
     llvm::Value *cast = CGM.getHLSLRuntime().EmitHLSLMatrixOperationCall(*this, E, RetTy, { LV.getAddress() });
     return MakeAddrLValue(cast, ToType);
   }

+ 2 - 2
tools/clang/lib/CodeGen/CGExprScalar.cpp

@@ -1825,7 +1825,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     // If the aggregate type is the cast source, it should be a pointer.
     // Aggregate to aggregate casts are handled in CGExprAgg.cpp
     auto areCompoundAndNumeric = [this](QualType lhs, QualType rhs) {
-      return hlsl::IsHLSLAggregateType(CGF.getContext(), lhs)
+      return hlsl::IsHLSLAggregateType(lhs)
         && (rhs->isBuiltinType() || hlsl::IsHLSLVecMatType(rhs));
     };
     assert(Src->getType()->isPointerTy()
@@ -1843,7 +1843,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
       return CGF.CGM.getHLSLRuntime().EmitHLSLMatrixLoad(CGF, DstPtr, DestTy);
     
     // Structs/arrays are pointers to temporaries
-    if (hlsl::IsHLSLAggregateType(CGF.getContext(), DestTy))
+    if (hlsl::IsHLSLAggregateType(DestTy))
       return DstPtr;
     
     // Scalars/vectors are loaded regularly

+ 104 - 60
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -15,7 +15,7 @@
 #include "CodeGenModule.h"
 #include "CGRecordLayout.h"
 #include "dxc/HlslIntrinsicOp.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/HLSL/HLOperations.h"
@@ -981,8 +981,7 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
   unsigned size = dataLayout.getTypeAllocSize(Type);
 
   if (IsHLSLMatType(Ty)) {
-    unsigned col, row;
-    llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Type, col, row);
+    llvm::Type *EltTy = HLMatrixType::cast(Type).getElementTypeForReg();
     bool b64Bit = dataLayout.getTypeAllocSize(EltTy) == 8;
     size = GetMatrixSizeInCB(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor,
                              b64Bit);
@@ -2644,7 +2643,7 @@ bool CGMSHLSLRuntime::SetUAVSRV(SourceLocation loc,
       EltTy = hlsl::GetHLSLVecElementType(Ty);
     } else if (hlsl::IsHLSLMatType(Ty)) {
       EltTy = hlsl::GetHLSLMatElementType(Ty);
-    } else if (resultTy->isAggregateType()) {
+    } else if (hlsl::IsHLSLAggregateType(resultTy)) {
       // Struct or array in a none-struct resource.
       std::vector<QualType> ScalarTys;
       CollectScalarTypes(ScalarTys, resultTy);
@@ -5331,13 +5330,12 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
       }
     }
   } else {
-    if (dxilutil::IsHLSLMatrixType(valTy)) {
-      unsigned col, row;
-      llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(valTy, col, row);
+    if (HLMatrixType MatTy = HLMatrixType::dyn_cast(valTy)) {
+      llvm::Type *EltTy = MatTy.getElementTypeForReg();
       // All matrix Value should be row major.
       // Init list is row major in scalar.
       // So the order is match here, just cast to vector.
-      unsigned matSize = col * row;
+      unsigned matSize = MatTy.getNumElements();
       bool isRowMajor = hlsl::IsHLSLMatRowMajor(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
 
       HLCastOpcode opcode = isRowMajor ? HLCastOpcode::RowMatrixToVecCast
@@ -5477,18 +5475,16 @@ static void StoreInitListToDestPtr(Value *DestPtr,
     Result = CGF.EmitToMemory(Result, Type);
     Builder.CreateStore(Result, DestPtr);
     idx += Ty->getVectorNumElements();
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
     bool isRowMajor = hlsl::IsHLSLMatRowMajor(Type, bDefaultRowMajor);
-    unsigned row, col;
-    HLMatrixLower::GetMatrixInfo(Ty, col, row);
-    std::vector<Value *> matInitList(col * row);
-    for (unsigned i = 0; i < col; i++) {
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = i * row + r;
+    std::vector<Value *> matInitList(MatTy.getNumElements());
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+        unsigned matIdx = c * MatTy.getNumRows() + r;
         matInitList[matIdx] = elts[idx + matIdx];
       }
     }
-    idx += row * col;
+    idx += MatTy.getNumElements();
     Value *matVal =
         EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLInit,
                                        /*opcode*/ 0, Ty, matInitList, M);
@@ -6518,10 +6514,9 @@ void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
                                  GepList, EltTyList);
 
     idxList.pop_back();
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
     // Use matLd/St for matrix.
-    unsigned col, row;
-    llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
+    llvm::Type *EltTy = MatTy.getElementTypeForReg();
     llvm::PointerType *EltPtrTy =
         llvm::PointerType::get(EltTy, Ptr->getType()->getPointerAddressSpace());
     QualType EltQualTy = hlsl::GetHLSLMatElementType(Type);
@@ -6529,8 +6524,8 @@ void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
     Value *matPtr = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
 
     // Flatten matrix to elements.
-    for (unsigned r = 0; r < row; r++) {
-      for (unsigned c = 0; c < col; c++) {
+    for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
         ConstantInt *cRow = CGF.Builder.getInt32(r);
         ConstantInt *cCol = CGF.Builder.getInt32(c);
         Constant *CV = llvm::ConstantVector::get({cRow, cCol});
@@ -6694,7 +6689,7 @@ void CGMSHLSLRuntime::EmitHLSLAggregateCopy(
     // Memcpy struct.
     CGF.Builder.CreateMemCpy(dstGEP, srcGEP, size, 1);
   } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
-    if (!HLMatrixLower::IsMatrixArrayPointer(llvm::PointerType::get(Ty,0))) {
+    if (!HLMatrixType::isMatrixArray(Ty)) {
       Value *srcGEP = CGF.Builder.CreateInBoundsGEP(SrcPtr, idxList);
       Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
       unsigned size = this->TheModule.getDataLayout().getTypeAllocSize(AT);
@@ -6775,7 +6770,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF,
   bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
   if (SrcPtrTy == DestPtrTy) {
     bool bMatArrayRotate = false;
-    if (HLMatrixLower::IsMatrixArrayPointer(SrcPtr->getType())) {
+    if (HLMatrixType::isMatrixArrayPtr(SrcPtr->getType())) {
       QualType SrcEltTy = GetArrayEltType(SrcTy);
       QualType DestEltTy = GetArrayEltType(DestTy);
       if (GetMatrixMajor(SrcEltTy, bDefaultRowMajor) !=
@@ -6871,11 +6866,10 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
                                       SrcType, PT->getElementType());
 
     idxList.pop_back();
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
     // Use matLd/St for matrix.
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
-    unsigned row, col;
-    llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
+    llvm::Type *EltTy = MatTy.getElementTypeForReg();
 
     llvm::VectorType *VT1 = llvm::VectorType::get(EltTy, 1);
     SrcVal = ConvertScalarOrVector(CGF, SrcVal, SrcType, hlsl::GetHLSLMatElementType(Type));
@@ -6883,7 +6877,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
     // Splat the value
     Value *V1 = CGF.Builder.CreateInsertElement(UndefValue::get(VT1), SrcVal,
                                                 (uint64_t)0);
-    std::vector<int> shufIdx(col * row, 0);
+    std::vector<int> shufIdx(MatTy.getNumElements(), 0);
     Value *VecMat = CGF.Builder.CreateShuffleVector(V1, V1, shufIdx);
     Value *MatInit = EmitHLSLMatrixOperationCallImp(
         CGF.Builder, HLOpcodeGroup::HLInit, 0, Ty, {VecMat}, TheModule);
@@ -7010,9 +7004,15 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
     const ParmVarDecl *Param = FD->getParamDecl(i);
     const Expr *Arg = E->getArg(i+ArgsToSkip);
     QualType ParamTy = Param->getType().getNonReferenceType();
+    bool isObject = dxilutil::IsHLSLObjectType(CGF.ConvertTypeForMem(ParamTy));
+    bool isAggregateType = !isObject &&
+      (ParamTy->isArrayType() || ParamTy->isRecordType()) &&
+      !hlsl::IsHLSLVecMatType(ParamTy);
+
+    bool EmitRValueAgg = false;
     bool RValOnRef = false;
     if (!Param->isModifierOut()) {
-      if (!ParamTy->isAggregateType() || hlsl::IsHLSLMatType(ParamTy)) {
+      if (!isAggregateType && !isObject) {
         if (Arg->isRValue() && Param->getType()->isReferenceType()) {
           // RValue on a reference type.
           if (const CStyleCastExpr *cCast = dyn_cast<CStyleCastExpr>(Arg)) {
@@ -7035,28 +7035,62 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
         } else {
           continue;
         }
+      } else if (isAggregateType) {
+        // aggregate in-only - emit RValue, unless LValueToRValue cast
+        EmitRValueAgg = true;
+        if (const ImplicitCastExpr *cast =
+                dyn_cast<ImplicitCastExpr>(Arg)) {
+          if (cast->getCastKind() == CastKind::CK_LValueToRValue) {
+            EmitRValueAgg = false;
+          }
+        }
+      } else {
+        // Must be object
+        DXASSERT(isObject, "otherwise, flow condition changed, breaking assumption");
+        // in-only objects should be skipped to preserve previous behavior.
+        continue;
       }
     }
 
-    // get original arg
-    LValue argLV = CGF.EmitLValue(Arg);
+    // Skip unbounded array, since we cannot preserve copy-in copy-out
+    // semantics for these.
+    if (ParamTy->isIncompleteArrayType()) {
+      continue;
+    }
 
     if (!Param->isModifierOut() && !RValOnRef) {
-      bool isDefaultAddrSpace = true;
-      if (argLV.isSimple()) {
-        isDefaultAddrSpace =
-            argLV.getAddress()->getType()->getPointerAddressSpace() ==
-            DXIL::kDefaultAddrSpace;
-      }
-      bool isHLSLIntrinsic = false;
+      // No need to copy arg to in-only param for hlsl intrinsic.
       if (const FunctionDecl *Callee = E->getDirectCallee()) {
-        isHLSLIntrinsic = Callee->hasAttr<HLSLIntrinsicAttr>();
+        if (Callee->hasAttr<HLSLIntrinsicAttr>())
+          continue;
       }
-      // Copy in arg which is not default address space and not on hlsl intrinsic.
-      if (isDefaultAddrSpace || isHLSLIntrinsic)
-        continue;
     }
 
+
+    // get original arg
+    // FIXME: This will not emit in correct argument order with the other
+    //        arguments. This should be integrated into
+    //        CodeGenFunction::EmitCallArg if possible.
+    RValue argRV; // emit this if aggregate arg on in-only param
+    LValue argLV; // otherwise, we may emit this
+    llvm::Value *argAddr = nullptr;
+    QualType argType = Arg->getType();
+    CharUnits argAlignment;
+    if (EmitRValueAgg) {
+      argRV = CGF.EmitAnyExprToTemp(Arg);
+      argAddr = argRV.getAggregateAddr(); // must be alloca
+      argAlignment = CharUnits::fromQuantity(cast<AllocaInst>(argAddr)->getAlignment());
+      argLV = LValue::MakeAddr(argAddr, ParamTy, argAlignment, CGF.getContext());
+    } else {
+      argLV = CGF.EmitLValue(Arg);
+      if (argLV.isSimple())
+        argAddr = argLV.getAddress();
+      argType = argLV.getType();  // TBD: Can this be different than Arg->getType()?
+      argAlignment = argLV.getAlignment();
+    }
+    // After emit Arg, we must update the argList[i],
+    // otherwise we get double emit of the expression.
+
     // create temp Var
     VarDecl *tmpArg =
         VarDecl::Create(CGF.getContext(), const_cast<FunctionDecl *>(FD),
@@ -7065,17 +7099,26 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
                         CGF.getContext().getTrivialTypeSourceInfo(ParamTy),
                         StorageClass::SC_Auto);
 
+    bool isEmptyAggregate = false;
+    if (isAggregateType) {
+      DXASSERT(argAddr, "should be RV or simple LV");
+      llvm::Type *ElTy = argAddr->getType()->getPointerElementType();
+      while (ElTy->isArrayTy())
+        ElTy = ElTy->getArrayElementType();
+      if (llvm::StructType *ST = dyn_cast<StructType>(ElTy)) {
+        DxilStructAnnotation *SA = m_pHLModule->GetTypeSystem().GetStructAnnotation(ST);
+        isEmptyAggregate = SA && SA->IsEmptyStruct();
+      }
+    }
+
     // Aggregate type will be indirect param convert to pointer type.
     // So don't update to ReferenceType, use RValue for it.
-    bool isAggregateType = (ParamTy->isArrayType() || ParamTy->isRecordType()) &&
-      !hlsl::IsHLSLVecMatType(ParamTy);
-
     const DeclRefExpr *tmpRef = DeclRefExpr::Create(
         CGF.getContext(), NestedNameSpecifierLoc(), SourceLocation(), tmpArg,
         /*enclosing*/ false, tmpArg->getLocation(), ParamTy,
-        isAggregateType ? VK_RValue : VK_LValue);
+        (isAggregateType || isObject) ? VK_RValue : VK_LValue);
 
-    // update the arg
+    // must update the arg, since we did emit Arg, else we get double emit.
     argList[i] = tmpRef;
 
     // create alloc for the tmp arg
@@ -7090,7 +7133,12 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
     // add it to local decl map
     TmpArgMap(tmpArg, tmpArgAddr);
 
-    LValue tmpLV = LValue::MakeAddr(tmpArgAddr, ParamTy, argLV.getAlignment(),
+    // If param is empty, copy in/out will just create problems.
+    // No copy will result in undef, which is fine.
+    if (isEmptyAggregate)
+      continue;
+
+    LValue tmpLV = LValue::MakeAddr(tmpArgAddr, ParamTy, argAlignment,
                                     CGF.getContext());
 
     // save for cast after call
@@ -7099,22 +7147,18 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
       castArgList.emplace_back(argLV);
     }
 
-    bool isObject = dxilutil::IsHLSLObjectType(
-        tmpArgAddr->getType()->getPointerElementType());
-
     // cast before the call
     if (Param->isModifierIn() &&
         // Don't copy object
         !isObject) {
       QualType ArgTy = Arg->getType();
       Value *outVal = nullptr;
-      bool isAggregateTy = ParamTy->isAggregateType() && !IsHLSLVecMatType(ParamTy);
-      if (!isAggregateTy) {
+      if (!isAggregateType) {
         if (!IsHLSLMatType(ParamTy)) {
           RValue outRVal = CGF.EmitLoadOfLValue(argLV, SourceLocation());
           outVal = outRVal.getScalarVal();
         } else {
-          Value *argAddr = argLV.getAddress();
+          DXASSERT(argAddr, "should be RV or simple LV");
           outVal = EmitHLSLMatrixLoad(CGF, argAddr, ArgTy);
         }
 
@@ -7124,15 +7168,16 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
           EmitHLSLMatrixStore(CGF, castVal, tmpArgAddr, ParamTy);
         }
         else {
-          Value *castVal = ConvertScalarOrVector(CGF, outVal, argLV.getType(), tmpLV.getType());
-          castVal = CGF.EmitToMemory(castVal, tmpLV.getType());
+          Value *castVal = ConvertScalarOrVector(CGF, outVal, argType, ParamTy);
+          castVal = CGF.EmitToMemory(castVal, ParamTy);
           CGF.Builder.CreateStore(castVal, tmpArgAddr);
         }
       } else {
+        DXASSERT(argAddr, "should be RV or simple LV");
         SmallVector<Value *, 4> idxList;
-        EmitHLSLAggregateCopy(CGF, argLV.getAddress(), tmpLV.getAddress(),
+        EmitHLSLAggregateCopy(CGF, argAddr, tmpArgAddr,
                               idxList, ArgTy, ParamTy,
-                              argLV.getAddress()->getType());
+                              argAddr->getType());
       }
     }
   }
@@ -7151,13 +7196,12 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
     
     Value *outVal = nullptr;
 
-    bool isAggrageteTy = ArgTy->isAggregateType();
-    isAggrageteTy &= !IsHLSLVecMatType(ArgTy);
+    bool isAggregateTy = hlsl::IsHLSLAggregateType(ArgTy);
 
     bool isObject = dxilutil::IsHLSLObjectType(
        tmpArgAddr->getType()->getPointerElementType());
     if (!isObject) {
-      if (!isAggrageteTy) {
+      if (!isAggregateTy) {
         if (!IsHLSLMatType(ParamTy))
           outVal = CGF.Builder.CreateLoad(tmpArgAddr);
         else

+ 1 - 6
tools/clang/lib/CodeGen/CodeGenModule.cpp

@@ -1811,11 +1811,6 @@ CodeGenModule::GetOrCreateLLVMGlobal(StringRef MangledName,
 
     // Make sure the result is of the correct type.
     if (Entry->getType()->getAddressSpace() != Ty->getAddressSpace()) {
-      // HLSL Change Begins
-      // TODO: do we put address space in type?
-      if (LangOpts.HLSL) return Entry;
-      else
-      // HLSL Change Ends
       return llvm::ConstantExpr::getAddrSpaceCast(Entry, Ty);
     }
 
@@ -1869,7 +1864,7 @@ CodeGenModule::GetOrCreateLLVMGlobal(StringRef MangledName,
       GV->setSection(".cp.rodata");
   }
 
-  if (AddrSpace != Ty->getAddressSpace() && !LangOpts.HLSL) // HLSL Change -TODO: do we put address space in type?
+  if (AddrSpace != Ty->getAddressSpace())
     return llvm::ConstantExpr::getAddrSpaceCast(GV, Ty);
 
 

+ 1 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -6813,7 +6813,7 @@ SPIRVEmitter::processIntrinsicInterlockedMethod(const CallExpr *expr,
   const uint32_t zero = theBuilder.getConstantUint32(0);
   const uint32_t scope = theBuilder.getConstantUint32(1); // Device
   const auto *dest = expr->getArg(0);
-  const auto baseType = dest->getType();
+  const auto baseType = dest->getType()->getCanonicalTypeUnqualified();
 
   if (!baseType->isIntegerType()) {
     emitError("can only perform atomic operations on scalar integer values",

+ 4 - 1
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -4639,7 +4639,7 @@ public:
 
     // Change return type to rvalue reference type for aggregate types
     QualType retTy = parameterTypes[0];
-    if (retTy->isAggregateType() && !IsHLSLVecMatType(retTy))
+    if (hlsl::IsHLSLAggregateType(retTy))
       parameterTypes[0] = m_context->getRValueReferenceType(retTy);
 
     // Create a new specialization.
@@ -10806,6 +10806,9 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
   case AttributeList::AT_HLSLGroupShared:
     declAttr = ::new (S.Context) HLSLGroupSharedAttr(A.getRange(), S.Context,
       A.getAttributeSpellingListIndex());
+    if (VarDecl *VD = dyn_cast<VarDecl>(D)) {
+      VD->setType(S.Context.getAddrSpaceQualType(VD->getType(), DXIL::kTGSMAddrSpace));
+    }
     break;
   case AttributeList::AT_HLSLUniform:
     declAttr = ::new (S.Context) HLSLUniformAttr(A.getRange(), S.Context,

+ 2 - 1
tools/clang/test/CodeGenHLSL/RValSubscript.hlsl

@@ -1,5 +1,7 @@
 // RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
 
+// CHECK: alloca [16 x i32]
+
 // For b4[2]
 // CHECK: cbufferLoadLegacy
 // CHECK: i32 5)
@@ -47,7 +49,6 @@
 // CHECK: fcmp fast oeq
 // CHECK: fcmp fast oeq
 // CHECK: fcmp fast oeq
-// CHECK: alloca [16 x i32]
 
 
 float4x4 xt;

+ 15 - 0
tools/clang/test/CodeGenHLSL/declarations/functions/inout_derived_struct_no_crash.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
+
+// Regression test for GitHub #1929, where we used the C++ definition
+// of an aggregate type and failed to match derived structs.
+
+// CHECK: ret void
+
+struct Base {};
+struct Derived : Base {};
+void f(inout Derived d) {}
+void main()
+{
+    Derived d;
+    f(d);
+}

+ 28 - 0
tools/clang/test/CodeGenHLSL/quick-test/addrspacecast.hlsl

@@ -0,0 +1,28 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// Make sure generate addrspacecast.
+// CHECK: addrspacecast ([6 x float] addrspace(3)*
+
+struct ST
+{
+  float3 a; // center
+  float3 b; // half extents
+
+  void func(float3 x, float3 y)
+  {
+    a = x + y;
+    b = x * y;
+  }
+};
+
+groupshared ST myST[2];
+StructuredBuffer<ST> buf0;
+float3 a;
+float3 b;
+RWBuffer<float3> buf1;
+[numthreads(8,8,1)]
+void main() {
+  myST[0] = buf0[0];
+  myST[0].func(a, b);
+  buf1[0] = myST[0].b;
+}

+ 29 - 0
tools/clang/test/CodeGenHLSL/quick-test/empty_struct3.hlsl

@@ -0,0 +1,29 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// Make sure nested empty struct works.  Also test related paths such as
+// derived, multi-dim array in constant buffer, and argument passing.
+// CHECK: main
+
+struct KillerStruct {};
+
+struct InnerStruct {
+  KillerStruct s;
+};
+
+struct OuterStruct {
+  InnerStruct s;
+};
+
+class Derived : OuterStruct {
+  InnerStruct s2;
+};
+
+cbuffer Params_cbuffer : register(b0) {
+  Derived constants[2][3];
+};
+
+float4 foo(Derived s) { return (float4)0; }
+
+float4 main() : SV_POSITION {
+  return foo(constants[1][2]);
+}

+ 50 - 0
tools/clang/test/CodeGenHLSL/quick-test/groupshared-base-cast.hlsl

@@ -0,0 +1,50 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// This tests cast of derived to base when derived is groupshared.
+// Different use cases can hit different code paths, hence the variety of
+// uses here:
+//    - calling base method
+//    - vector element assignment on base member
+//    - casting to base and passing to function
+// The barrier and write to RWBuf prevents optimizations from eliminating
+// groupshared use, considering this dead-code, or detecting a race condition.
+
+// CHECK: @[[gs0:.+]] = addrspace(3) global i32 undef
+// CHECK: @[[gs1:.+]] = addrspace(3) global i32 undef
+// CHECK: @[[gs2:.+]] = addrspace(3) global i32 undef
+// CHECK: store i32 1, i32 addrspace(3)* @[[gs0]], align 4
+// CHECK: store i32 2, i32 addrspace(3)* @[[gs1]], align 4
+// CHECK: store i32 3, i32 addrspace(3)* @[[gs2]], align 4
+
+// CHECK: %[[l0:[^ ]+]] = load i32, i32 addrspace(3)* @[[gs0]], align 4
+// CHECK: %[[l1:[^ ]+]] = load i32, i32 addrspace(3)* @[[gs1]], align 4
+// CHECK: %[[l2:[^ ]+]] = load i32, i32 addrspace(3)* @[[gs2]], align 4
+// CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %{{.+}}, i32 %{{.+}}, i32 undef, i32 %[[l0]], i32 %[[l1]], i32 %[[l2]], i32 undef, i8 7)
+
+
+class Base {
+  uint3 u;
+  void set_u_y(uint value) { u.y = value; }
+};
+class Derived : Base {
+  float bar;
+};
+
+groupshared Derived gs_derived;
+RWByteAddressBuffer RWBuf;
+
+void UpdateBase_z(inout Base b, uint value) {
+  b.u.z = value;
+}
+
+[numthreads(2, 1, 1)]
+void main(uint3 groupThreadID: SV_GroupThreadID) {
+  if (groupThreadID.x == 0) {
+    gs_derived.u.x = 1;
+    gs_derived.set_u_y(2);
+    UpdateBase_z((Base)gs_derived, 3);
+  }
+  GroupMemoryBarrierWithGroupSync();
+  uint addr = groupThreadID.x * 4;
+  RWBuf.Store3(addr, gs_derived.u);
+}

+ 47 - 0
tools/clang/test/CodeGenHLSL/quick-test/groupshared-member-matrix-subscript-col.hlsl

@@ -0,0 +1,47 @@
+// RUN: %dxc -E main -T cs_6_0 -Zpc %s | FileCheck %s
+
+// CHECK: %[[cb0:[^ ]+]] = call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %{{.*}}, i32 0)
+// CHECK: %[[cb0x:[^ ]+]] = extractvalue %dx.types.CBufRet.f32 %[[cb0]], 0
+// CHECK: store float %[[cb0x]], float addrspace(3)* getelementptr inbounds ([4 x float], [4 x float] addrspace(3)* @[[obj:[^,]+]], i32 0, i32 0), align 4
+
+// CHECK: %[[cb1:[^ ]+]] = call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %{{.*}}, i32 1)
+// CHECK: %[[cb1x:[^ ]+]] = extractvalue %dx.types.CBufRet.f32 %[[cb1]], 0
+// CHECK: store float %[[cb1x]], float addrspace(3)* getelementptr inbounds ([4 x float], [4 x float] addrspace(3)* @[[obj]], i32 0, i32 1), align 4
+
+// CHECK: %[[_25:[^ ]+]] = getelementptr [4 x float], [4 x float] addrspace(3)* @[[obj]], i32 0, i32 %{{.+}}
+// CHECK: %[[_26:[^ ]+]] = load float, float addrspace(3)* %[[_25]], align 4
+// CHECK: %[[_27:[^ ]+]] = getelementptr [4 x float], [4 x float] addrspace(3)* @[[obj]], i32 0, i32 %{{.+}}
+// CHECK: %[[_28:[^ ]+]] = load float, float addrspace(3)* %[[_27]], align 4
+
+// CHECK: %[[_33:[^ ]+]] = bitcast float %[[_26]] to i32
+// CHECK: %[[_34:[^ ]+]] = bitcast float %[[_28]] to i32
+
+// CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %{{[^,]+}}, i32 %{{.+}}, i32 undef, i32 %[[_33]], i32 %[[_34]], i32 undef, i32 undef, i8 3)
+
+float2 rows[2];
+
+void set_row(inout float2 row, uint i) {
+  row = rows[i];
+}
+
+class Obj {
+  float2x2 mat;
+  void set() {
+    set_row(mat[0], 0);
+    set_row(mat[1], 1);
+  }
+};
+
+RWByteAddressBuffer RWBuf;
+groupshared Obj obj;
+
+[numthreads(2, 1, 1)]
+void main(uint3 groupThreadID: SV_GroupThreadID) {
+  if (groupThreadID.x == 0) {
+    obj.set();
+  }
+  GroupMemoryBarrierWithGroupSync();
+  float2 row = obj.mat[groupThreadID.x];
+  uint addr = groupThreadID.x * 8;
+  RWBuf.Store2(addr, uint2(asuint(row.x), asuint(row.y)));
+}

+ 46 - 0
tools/clang/test/CodeGenHLSL/quick-test/groupshared-member-matrix-subscript.hlsl

@@ -0,0 +1,46 @@
+// RUN: %dxc -E main -T cs_6_0 -Zpr %s | FileCheck %s
+
+// CHECK: %[[cb0:[^ ]+]] = call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %{{.*}}, i32 0)
+// CHECK: %[[cb0x:[^ ]+]] = extractvalue %dx.types.CBufRet.f32 %[[cb0]], 0
+// CHECK: %[[cb0y:[^ ]+]] = extractvalue %dx.types.CBufRet.f32 %[[cb0]], 1
+
+// CHECK: store float %[[cb0x]], float addrspace(3)* getelementptr inbounds ([4 x float], [4 x float] addrspace(3)* @[[obj:[^,]+]], i32 0, i32 0), align 4
+// CHECK: store float %[[cb0y]], float addrspace(3)* getelementptr inbounds ([4 x float], [4 x float] addrspace(3)* @[[obj]], i32 0, i32 1), align 4
+
+// CHECK: %[[_25:[^ ]+]] = getelementptr [4 x float], [4 x float] addrspace(3)* @[[obj]], i32 0, i32 %{{.+}}
+// CHECK: %[[_26:[^ ]+]] = load float, float addrspace(3)* %[[_25]], align 4
+// CHECK: %[[_27:[^ ]+]] = getelementptr [4 x float], [4 x float] addrspace(3)* @[[obj]], i32 0, i32 %{{.+}}
+// CHECK: %[[_28:[^ ]+]] = load float, float addrspace(3)* %[[_27]], align 4
+
+// CHECK: %[[_33:[^ ]+]] = bitcast float %[[_26]] to i32
+// CHECK: %[[_34:[^ ]+]] = bitcast float %[[_28]] to i32
+
+// CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %{{.*}}, i32 %{{.+}}, i32 undef, i32 %[[_33]], i32 %[[_34]], i32 undef, i32 undef, i8 3)
+
+float2 rows[2];
+
+void set_row(inout float2 row, uint i) {
+  row = rows[i];
+}
+
+class Obj {
+  float2x2 mat;
+  void set() {
+    set_row(mat[0], 0);
+    set_row(mat[1], 1);
+  }
+};
+
+RWByteAddressBuffer RWBuf;
+groupshared Obj obj;
+
+[numthreads(2, 1, 1)]
+void main(uint3 groupThreadID: SV_GroupThreadID) {
+  if (groupThreadID.x == 0) {
+    obj.set();
+  }
+  GroupMemoryBarrierWithGroupSync();
+  float2 row = obj.mat[groupThreadID.x];
+  uint addr = groupThreadID.x * 8;
+  RWBuf.Store2(addr, uint2(asuint(row.x), asuint(row.y)));
+}

+ 45 - 0
tools/clang/test/CodeGenHLSL/quick-test/groupshared-member-matrix-subscript2.hlsl

@@ -0,0 +1,45 @@
+// RUN: %dxc -E main -T cs_6_0 -Zpr %s | FileCheck %s
+
+// Make sure non-const gep/addrspace cast in codegen is translated properly
+
+// CHECK: @[[obj:[^,]+]] = addrspace(3) global [8 x float] undef
+
+// CHECK: %[[_6:[^ ]+]] = getelementptr [8 x float], [8 x float] addrspace(3)* @[[obj]], i32 0, i32 %{{.+}}
+// CHECK: store float %{{.+}}, float addrspace(3)* %[[_6]], align 16
+
+// Skip next three stores to get to loads
+// CHECK: store
+// CHECK: store
+// CHECK: store
+
+// CHECK: %[[_23:[^ ]+]] = getelementptr [8 x float], [8 x float] addrspace(3)* @[[obj]], i32 0, i32 %{{.+}}
+// CHECK: %{{.+}} = load float, float addrspace(3)* %[[_23]], align 8
+
+
+float4 rows[2];
+
+void set_row(inout float2 row, uint i) {
+  row = rows[i];
+}
+
+class Obj {
+  float2x2 mat;
+  void set() {
+    set_row(mat[0], 0);
+    set_row(mat[1], 1);
+  }
+};
+
+RWByteAddressBuffer RWBuf;
+
+// Dynamic index array to generate non-const gep/addrspace cast
+groupshared Obj obj[2];
+
+[numthreads(2, 1, 1)]
+void main(uint3 groupThreadID: SV_GroupThreadID) {
+  obj[groupThreadID.x].set();
+  GroupMemoryBarrierWithGroupSync();
+  float2 row = obj[1 - groupThreadID.x].mat[groupThreadID.x];
+  uint addr = groupThreadID.x * 8;
+  RWBuf.Store2(addr, uint2(asuint(row.x), asuint(row.y)));
+}

+ 12 - 0
tools/clang/test/CodeGenHLSL/quick-test/mat_init_splat.hlsl

@@ -0,0 +1,12 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 1.000000e+00)
+
+struct MyStruct {
+  float3x3 mat;
+};
+
+float main() : OUT {
+  MyStruct st = (MyStruct)1;
+  return st.mat[0].x;
+}

+ 1 - 1
tools/clang/test/CodeGenHLSL/quick-test/static_global_copy3.hlsl

@@ -17,7 +17,7 @@ A a;
 
 static A a2;
 
-void set(A aa) {
+void set(out A aa) {
    aa = a;
 }
 

+ 28 - 0
tools/clang/test/CodeGenHLSL/quick-test/struct_param_in_mod.hlsl

@@ -0,0 +1,28 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Verify that modificaion of in-only struct parameter does not modify
+// the value passed in by the caller.
+
+// CHECK-DAG: [[f:%[^ ]*]] = call float @dx.op.loadInput.f32(i32 4, i32 1, i32 0, i8 0,
+// CHECK-DAG: [[p:%[^ ]*]] = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0,
+// CHECK-DAG: [[o1:%[^ ]*]] = fmul fast float [[p]], [[f]]
+// CHECK-DAG: [[ret:%[^ ]*]] = fadd fast float [[o1]], [[p]]
+// CHECK-DAG: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float [[ret]])
+
+struct PayloadStruct {
+  float Color;
+};
+
+PayloadStruct MulPayload(in PayloadStruct Payload, in float x)
+{
+  Payload.Color *= x;
+  return Payload;
+}
+
+void main(PayloadStruct p : Payload,
+          float f : INPUT,
+          out PayloadStruct o : SV_Target) {
+
+  o = MulPayload(p, f);
+  o.Color += p.Color;
+}

+ 38 - 0
tools/clang/test/CodeGenHLSL/quick-test/struct_param_in_mod2.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Verify that passing struct result of call as arg to another call does not
+// generate extra call.
+
+// CHECK-DAG: [[f:%[^ ]*]] = call float @dx.op.loadInput.f32(i32 4, i32 1, i32 0, i8 0,
+// CHECK-DAG: [[p:%[^ ]*]] = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0,
+// CHECK-DAG: [[factor:%[^ ]*]] = fmul fast float [[f]], 2.000000e+00
+// CHECK-DAG: [[factor2:%[^ ]*]] = fadd fast float [[factor]], 1.000000e+00
+// CHECK-DAG: [[ret:%[^ ]*]] = fmul fast float [[factor2]], [[p]]
+// CHECK-DAG: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float [[ret]])
+
+struct PayloadStruct {
+  float Color;
+};
+
+static float factor = 1.0;
+
+PayloadStruct MulPayload(in PayloadStruct Payload)
+{
+  Payload.Color *= factor;
+  factor += 1.0;
+  return Payload;
+}
+
+PayloadStruct AddPayload(in PayloadStruct Payload0, in PayloadStruct Payload1)
+{
+  Payload0.Color += Payload1.Color;
+  return Payload0;
+}
+
+void main(PayloadStruct p : Payload,
+	  	  float f : INPUT,
+          out PayloadStruct OutputPayload : SV_Target) {
+  factor = f;
+  OutputPayload = AddPayload(MulPayload(p),
+                             MulPayload(p));
+}

+ 8 - 0
tools/clang/test/CodeGenHLSL/quick-test/unused_matrix_input_regression.hlsl

@@ -0,0 +1,8 @@
+// RUN: %dxc /T vs_6_2 /E main %s | FileCheck %s
+
+// Regression test for GitHub #1947, where matrix input parameters were expected to have
+// exactly one use in HLSignatureLower, instead of zero or one, leading to a crash.
+
+// CHECK: ret void
+
+void main(int2x2 mat : IN) {}

+ 2 - 1
tools/clang/test/CodeGenHLSL/shader-compat-suite/lib_arg_flatten/lib_arg_flatten.hlsl

@@ -1,7 +1,8 @@
 // RUN: %dxc -T lib_6_3 -auto-binding-space 11 -default-linkage external %s | FileCheck %s
 
 // Make sure function call on external function has correct type.
-// CHECK: call float @"\01?test_extern@@YAMUT@@Y01U1@U1@AIAV?$matrix@M$01$01@@@Z"(%struct.T* {{.*}}, [2 x %struct.T]* {{.*}}, %struct.T* nonnull {{.*}}, %class.matrix.float.2.2* dereferenceable(16) {{.*}})
+// CHECK: call float @"\01?test_extern@@YAMUT@@Y01U1@U1@AIAV?$matrix@M$01$01@@@Z"(%struct.T* {{.*}}, [2 x %struct.T]* {{.*}}, %struct.T* {{.*}}, %class.matrix.float.2.2* dereferenceable(16) {{.*}})
+
 struct T {
   float a;
   float b;

+ 1 - 1
tools/clang/test/CodeGenHLSL/shader-compat-suite/lib_arg_flatten/lib_arg_flatten3.hlsl

@@ -2,7 +2,7 @@
 
 // Make sure function call on external function has correct type.
 
-// CHECK: call float @"\01?test_extern@@YAMUT@@@Z"(%struct.T* nonnull %tmp) #2
+// CHECK: call float @"\01?test_extern@@YAMUT@@@Z"(%struct.T* nonnull {{.*}}) #2
 
 struct T {
   float a;

+ 2 - 2
tools/clang/test/CodeGenHLSL/shader-compat-suite/lib_arg_flatten/lib_empty_struct_arg.hlsl

@@ -1,7 +1,7 @@
 // RUN: %dxc -T lib_6_3 -auto-binding-space 11 -default-linkage external %s | FileCheck %s
 
-// Make sure empty struct arg works.
-// CHECK: call float @"\01?test@@YAMUT@@@Z"(%struct.T* %t)
+// Make sure empty struct arg is replaced with undef.
+// CHECK: call float @"\01?test@@YAMUT@@@Z"(%struct.T* undef)
 
 struct T {
 };

+ 3 - 6
tools/clang/test/CodeGenHLSL/share_mem_dbg.hlsl

@@ -12,12 +12,9 @@
 // CHECK: !dx.source.mainFileName
 // CHECK: !dx.source.args
 
-// CHECK: DIGlobalVariable(name: "dataC.1.0"
-// CHECK: DIDerivedType(tag: DW_TAG_member, name: ".1.0"
-// CHECK: DIGlobalVariable(name: "dataC.1.1"
-// CHECK: DIDerivedType(tag: DW_TAG_member, name: ".1.1"
-// CHECK: DIGlobalVariable(name: "dataC.0
-// CHECK: DIDerivedType(tag: DW_TAG_member, name: ".0"
+// CHECK: DIGlobalVariable(name: "dataC"
+// CHECK: DIDerivedType(tag: DW_TAG_member, name: "d"
+// CHECK: DIDerivedType(tag: DW_TAG_member, name: "b"
 
 // Make sure source info contents exist.
 // CHECK: share_mem_dbg.hlsl", !"// RUN: %dxc

+ 1 - 1
tools/clang/test/CodeGenHLSL/staticGlobals.hlsl

@@ -6,9 +6,9 @@
 // CHECK: [3 x float] [float 6.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 7.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 8.000000e+00, float 0.000000e+00, float 0.000000e+00]
+// CHECK: [4 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
 // CHECK: [16 x float] [float 1.500000e+01, float 1.500000e+01, float 1.500000e+01, float 1.500000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.700000e+01, float 1.700000e+01, float 1.700000e+01, float 1.700000e+01, float 1.800000e+01, float 1.800000e+01, float 1.800000e+01, float 1.800000e+01]
 // CHECK: [16 x float] [float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00]
-// CHECK: [4 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
 // CHECK: [16 x float] [float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01]
 
 static float4 f0 = {5,6,7,8};

+ 4 - 1
tools/clang/test/CodeGenHLSL/static_matrix.hlsl

@@ -1,6 +1,9 @@
 // RUN: %dxc -E not_main -T ps_6_0 %s | FileCheck %s
 
-// Make sure internal global is removed.
+// Tests that the static-global-to-alloca pass
+// will turn a static matrix lowered into a static vector
+// into a local variable, even in the absence of the GVN pass.
+
 // CHECK-NOT:  = internal
 
 static float2x2 a;

+ 4 - 5
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -17,7 +17,7 @@
 #include "dxc/DXIL/DxilShaderModel.h"
 #include "dxc/DXIL/DxilModule.h"
 #include "dxc/DXIL/DxilResource.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/DXIL/DxilConstants.h"
 #include "dxc/DXIL/DxilOperations.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -808,10 +808,9 @@ void PrintFieldLayout(llvm::Type *Ty, DxilFieldAnnotation &annotation,
       }
       if (EltTy->isVectorTy()) {
         EltTy = EltTy->getVectorElementType();
-      } else if (EltTy->isStructTy()) {
-        unsigned col, row;
-        EltTy = HLMatrixLower::GetMatrixInfo(EltTy, col, row);
-      }
+      } else if (EltTy->isStructTy())
+        EltTy = HLMatrixType::cast(EltTy).getElementTypeForReg();
+
       if (arrayLevel == 1)
         arraySize = 0;
     }