Просмотр исходного кода

Update the HLMatrixLower pass to make it deterministic and correct

Updates the HLMatrixLower pass logic to eliminate the nondeterminism and randomly bogus codegen that came from iterating over a pointer-keyed map and modification of that map during the iteration. The new approach does a single pass on every instruction consuming or producing matrices, replacing it by its vector equivalent. Any consumed matrix is replaced by a temporary mat-to-vec translation stub, and any formerly produced matrix is emitted as a temporary vec-to-mat translation stub. Stubs get cleared as both ends of a consumer-producer dependency get lowered.
Tristan Labelle 6 лет назад
Родитель
Сommit
32fe0936b7

+ 1 - 0
include/dxc/DXIL/DxilUtil.h

@@ -102,6 +102,7 @@ namespace dxilutil {
   std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
     llvm::LLVMContext &Ctx, std::string &DiagStr);
   void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context);
+  bool IsIntegerOrFloatingPointType(llvm::Type *Ty);
   // Returns true if type contains HLSL Object type (resource)
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);

+ 1 - 31
include/dxc/HLSL/HLMatrixLowerHelper.h

@@ -26,41 +26,11 @@ class DxilFieldAnnotation;
 class DxilTypeSystem;
 
 namespace HLMatrixLower {
-// TODO: use type annotation.
-DxilFieldAnnotation *FindAnnotationFromMatUser(llvm::Value *Mat,
-                                               DxilTypeSystem &typeSys);
-// Translate matrix type to vector type.
-llvm::Type *LowerMatrixType(llvm::Type *Ty, bool forMem = false);
-// TODO: use type annotation.
-llvm::Type *GetMatrixInfo(llvm::Type *Ty, unsigned &col, unsigned &row);
-// TODO: use type annotation.
-bool IsMatrixArrayPointer(llvm::Type *Ty);
-// Translate matrix array pointer type to vector array pointer type.
-llvm::Type *LowerMatrixArrayPointer(llvm::Type *Ty, bool forMem = false);
 
-llvm::Value *BuildVector(llvm::Type *EltTy, unsigned size,
+llvm::Value *BuildVector(llvm::Type *EltTy,
                          llvm::ArrayRef<llvm::Value *> elts,
                          llvm::IRBuilder<> &Builder);
 
-llvm::Value *VecMatrixMemToReg(llvm::Value *VecVal, llvm::Type *MatType,
-                               llvm::IRBuilder<> &Builder);
-llvm::Value *VecMatrixRegToMem(llvm::Value* VecVal, llvm::Type *MatType,
-                               llvm::IRBuilder<> &Builder);
-llvm::Instruction *CreateVecMatrixLoad(llvm::Value *VecPtr,
-                                       llvm::Type *MatType, llvm::IRBuilder<> &Builder);
-llvm::Instruction *CreateVecMatrixStore(llvm::Value* VecVal, llvm::Value *VecPtr,
-                                        llvm::Type *MatType, llvm::IRBuilder<> &Builder);
-
-// For case like mat[i][j].
-// IdxList is [i][0], [i][1], [i][2],[i][3].
-// Idx is j.
-// return [i][j] not mat[i][j] because resource ptr and temp ptr need different
-// code gen.
-llvm::Value *
-LowerGEPOnMatIndexListToIndex(llvm::GetElementPtrInst *GEP,
-                              llvm::ArrayRef<llvm::Value *> IdxList);
-unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row);
-unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col);
 } // namespace HLMatrixLower
 
 } // namespace hlsl

+ 97 - 0
include/dxc/HLSL/HLMatrixType.h

@@ -0,0 +1,97 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixType.h                                                            //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+
+#include "llvm/IR/IRBuilder.h"
+
+namespace llvm {
+  template<typename T>
+  class ArrayRef;
+  class Type;
+  class Value;
+  class Constant;
+  class StructType;
+  class VectorType;
+  class StoreInst;
+}
+
+namespace hlsl {
+
+class DxilFieldAnnotation;
+class DxilTypeSystem;
+
+// A high-level matrix type in LLVM IR.
+//
+// Matrices are represented by an llvm struct type of the following form:
+// { [RowCount x <ColCount x RegReprTy>] }
+// Note that the element type is always in its register representation (ie bools are i1s).
+// This allows preserving the original type and is okay since matrix types are only
+// manipulated in an opaque way, through intrinsics.
+//
+// During matrix lowering, matrices are converted to vectors of the following form:
+// <RowCount*ColCount x Ty>
+// At this point, register vs memory representation starts to matter and we have to
+// imitate the codegen for scalar and vector bools: i1s when in llvm registers,
+// and i32s when in memory (allocas, pointers, or in structs/lists, which are always in memory).
+//
+// This class is designed to resemble a llvm::Type-derived class.
+class HLMatrixType
+{
+public:
+  static constexpr const char* StructNamePrefix = "class.matrix.";
+
+  HLMatrixType() : RegReprElemTy(nullptr), NumRows(0), NumColumns(0) {}
+  HLMatrixType(llvm::Type *RegReprElemTy, unsigned NumRows, unsigned NumColumns);
+
+  // We allow default construction to an invalid state to support the dynCast pattern.
+  // This tests whether we have a legit object.
+  operator bool() const { return RegReprElemTy != nullptr; }
+
+  llvm::Type *getElementType(bool MemRepr) const;
+  llvm::Type *getElementTypeForReg() const { return getElementType(false); }
+  llvm::Type *getElementTypeForMem() const { return getElementType(true); }
+  unsigned getNumRows() const { return NumRows; }
+  unsigned getNumColumns() const { return NumColumns; }
+  unsigned getNumElements() const { return NumRows * NumColumns; }
+  unsigned getRowMajorIndex(unsigned RowIdx, unsigned ColIdx) const;
+  unsigned getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx) const;
+  static unsigned getRowMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns);
+  static unsigned getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns);
+
+  llvm::VectorType *getLoweredVectorType(bool MemRepr) const;
+  llvm::VectorType *getLoweredVectorTypeForReg() const { return getLoweredVectorType(false); }
+  llvm::VectorType *getLoweredVectorTypeForMem() const { return getLoweredVectorType(true); }
+
+  llvm::Value *emitLoweredMemToReg(llvm::Value *Val, llvm::IRBuilder<> &Builder) const;
+  llvm::Value *emitLoweredRegToMem(llvm::Value *Val, llvm::IRBuilder<> &Builder) const;
+  llvm::Value *emitLoweredLoad(llvm::Value *Ptr, llvm::IRBuilder<> &Builder) const;
+  llvm::StoreInst *emitLoweredStore(llvm::Value *Val, llvm::Value *Ptr, llvm::IRBuilder<> &Builder) const;
+
+  llvm::Value *emitLoweredVectorRowToCol(llvm::Value *VecVal, llvm::IRBuilder<> &Builder) const;
+  llvm::Value *emitLoweredVectorColToRow(llvm::Value *VecVal, llvm::IRBuilder<> &Builder) const;
+
+  static bool isa(llvm::Type *Ty);
+  static bool isMatrixPtr(llvm::Type *Ty);
+  static bool isMatrixArray(llvm::Type *Ty);
+  static bool isMatrixArrayPtr(llvm::Type *Ty);
+  static bool isMatrixPtrOrArrayPtr(llvm::Type *Ty);
+  static bool isMatrixOrPtrOrArrayPtr(llvm::Type *Ty);
+
+  static llvm::Type *getLoweredType(llvm::Type *Ty, bool MemRepr = false);
+
+  static HLMatrixType cast(llvm::Type *Ty);
+  static HLMatrixType dyn_cast(llvm::Type *Ty);
+
+private:
+  llvm::Type *RegReprElemTy;
+  unsigned NumRows, NumColumns;
+};
+
+} // namespace hlsl

+ 4 - 0
lib/DXIL/DxilUtil.cpp

@@ -477,6 +477,10 @@ bool IsHLSLMatrixType(Type *Ty) {
   return false;
 }
 
+bool IsIntegerOrFloatingPointType(llvm::Type *Ty) {
+  return Ty->isIntegerTy() || Ty->isFloatingPointTy();
+}
+
 bool ContainsHLSLObjectType(llvm::Type *Ty) {
   // Unwrap pointer/array
   while (llvm::isa<llvm::PointerType>(Ty))

+ 3 - 0
lib/HLSL/CMakeLists.txt

@@ -23,7 +23,10 @@ add_llvm_library(LLVMHLSL
   DxilExportMap.cpp
   DxilValidation.cpp
   DxcOptimizer.cpp
+  HLMatrixBitcastLowerPass.cpp
   HLMatrixLowerPass.cpp
+  HLMatrixSubscriptUseReplacer.cpp
+  HLMatrixType.cpp
   HLModule.cpp
   HLOperations.cpp
   HLOperationLower.cpp

+ 5 - 3
lib/HLSL/DxilCondenseResources.cpp

@@ -17,7 +17,7 @@
 #include "dxc/DXIL/DxilTypeSystem.h"
 #include "dxc/DXIL/DxilInstructions.h"
 #include "dxc/HLSL/DxilSpanAllocator.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/HLSL/HLModule.h"
 
@@ -1536,8 +1536,10 @@ Type *UpdateFieldTypeForLegacyLayout(Type *Ty, bool IsCBuf,
       return ArrayType::get(UpdatedTy, Ty->getArrayNumElements());
   } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     DXASSERT(annotation.HasMatrixAnnotation(), "must a matrix");
-    unsigned rows, cols;
-    Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, cols, rows);
+    HLMatrixType MatTy = HLMatrixType::cast(Ty);
+    unsigned rows = MatTy.getNumRows();
+    unsigned cols = MatTy.getNumColumns();
+    Type *EltTy = MatTy.getElementTypeForReg();
 
     // Get cols and rows from annotation.
     const DxilMatrixAnnotation &matrix = annotation.GetMatrixAnnotation();

+ 0 - 1
lib/HLSL/DxilGenerationPass.cpp

@@ -15,7 +15,6 @@
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLOperations.h"
 #include "dxc/DXIL/DxilInstructions.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/Support/Global.h"
 #include "dxc/DXIL/DxilTypeSystem.h"

+ 251 - 0
lib/HLSL/HLMatrixBitcastLowerPass.cpp

@@ -0,0 +1,251 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixBitcastLowerPass.cpp                                              //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/HLSL/HLMatrixLowerPass.h"
+#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
+#include "dxc/DXIL/DxilUtil.h"
+#include "dxc/Support/Global.h"
+#include "dxc/DXIL/DxilOperations.h"
+#include "dxc/DXIL/DxilModule.h"
+#include "dxc/HLSL/DxilGenerationPass.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+#include <unordered_set>
+#include <vector>
+
+using namespace llvm;
+using namespace hlsl;
+using namespace hlsl::HLMatrixLower;
+
+// Matrix Bitcast lower.
+// After linking Lower matrix bitcast patterns like:
+//  %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
+//  %conv.i = fptoui float %164 to i32
+//  %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
+//  %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
+
+namespace {
+
+// Translate matrix type to array type.
+Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
+  if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
+    Type *EltTy = MatTy.getElementTypeForReg();
+    return ArrayType::get(EltTy, MatTy.getNumElements());
+  }
+  else {
+    return Ty;
+  }
+}
+
+Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
+  unsigned addrSpace = Ty->getPointerAddressSpace();
+  Ty = Ty->getPointerElementType();
+
+  unsigned arraySize = 1;
+  while (Ty->isArrayTy()) {
+    arraySize *= Ty->getArrayNumElements();
+    Ty = Ty->getArrayElementType();
+  }
+
+  HLMatrixType MatTy = HLMatrixType::cast(Ty);
+  arraySize *= MatTy.getNumElements();
+
+  Ty = ArrayType::get(MatTy.getElementTypeForReg(), arraySize);
+  return PointerType::get(Ty, addrSpace);
+}
+
+Type *TryLowerMatTy(Type *Ty) {
+  Type *VecTy = nullptr;
+  if (HLMatrixType::isMatrixArrayPtr(Ty)) {
+    VecTy = LowerMatrixArrayPointerToOneDimArray(Ty);
+  } else if (isa<PointerType>(Ty) &&
+             dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
+    VecTy = LowerMatrixTypeToOneDimArray(
+        Ty->getPointerElementType());
+    VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
+  }
+  return VecTy;
+}
+
+class MatrixBitcastLowerPass : public FunctionPass {
+
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
+
+  const char *getPassName() const override { return "Matrix Bitcast lower"; }
+  bool runOnFunction(Function &F) override {
+    bool bUpdated = false;
+    std::unordered_set<BitCastInst*> matCastSet;
+    for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
+      BasicBlock *BB = blkIt;
+      for (auto iIt = BB->begin(); iIt != BB->end(); ) {
+        Instruction *I = (iIt++);
+        if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
+          // Mutate mat to vec.
+          Type *ToTy = BCI->getType();
+          if (TryLowerMatTy(ToTy)) {
+            matCastSet.insert(BCI);
+            bUpdated = true;
+          }
+        }
+      }
+    }
+
+    DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
+    // Remove bitcast which has CallInst user.
+    if (DM.GetShaderModel()->IsLib()) {
+      for (auto it = matCastSet.begin(); it != matCastSet.end();) {
+        BitCastInst *BCI = *(it++);
+        if (hasCallUser(BCI)) {
+          matCastSet.erase(BCI);
+        }
+      }
+    }
+
+    // Lower matrix first.
+    for (BitCastInst *BCI : matCastSet) {
+      lowerMatrix(BCI, BCI->getOperand(0));
+    }
+    return bUpdated;
+  }
+private:
+  void lowerMatrix(Instruction *M, Value *A);
+  bool hasCallUser(Instruction *M);
+};
+
+}
+
+bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
+  for (auto it = M->user_begin(); it != M->user_end();) {
+    User *U = *(it++);
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      Type *EltTy = GEP->getType()->getPointerElementType();
+      if (dxilutil::IsHLSLMatrixType(EltTy)) {
+        if (hasCallUser(GEP))
+          return true;
+      } else {
+        DXASSERT(0, "invalid GEP for matrix");
+      }
+    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
+      if (hasCallUser(BCI))
+        return true;
+    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
+      if (isa<VectorType>(LI->getType())) {
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
+      Value *V = ST->getValueOperand();
+      if (isa<VectorType>(V->getType())) {
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (isa<CallInst>(U)) {
+      return true;
+    } else {
+      DXASSERT(0, "invalid use of matrix");
+    }
+  }
+  return false;
+}
+
+namespace {
+Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
+                    IRBuilder<> &Builder) {
+  Value *GEP = nullptr;
+  if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
+    // A should be gep oneDimArray, 0, index * matSize
+    // Here add eltIdx to index * matSize foreach elt.
+    Instruction *EltGEP = GEPA->clone();
+    unsigned eltIdx = EltGEP->getNumOperands() - 1;
+    Value *NewIdx =
+        Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
+    EltGEP->setOperand(eltIdx, NewIdx);
+    Builder.Insert(EltGEP);
+    GEP = EltGEP;
+  } else {
+    GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
+  }
+  return GEP;
+}
+} // namespace
+
+void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
+  for (auto it = M->user_begin(); it != M->user_end();) {
+    User *U = *(it++);
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      Type *EltTy = GEP->getType()->getPointerElementType();
+      if (dxilutil::IsHLSLMatrixType(EltTy)) {
+        // Change gep matrixArray, 0, index
+        // into
+        //   gep oneDimArray, 0, index * matSize
+        IRBuilder<> Builder(GEP);
+        SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
+        DXASSERT(idxList.size() == 2,
+                 "else not one dim matrix array index to matrix");
+
+        HLMatrixType MatTy = HLMatrixType::cast(EltTy);
+        Value *matSize = Builder.getInt32(MatTy.getNumElements());
+        idxList.back() = Builder.CreateMul(idxList.back(), matSize);
+        Value *NewGEP = Builder.CreateGEP(A, idxList);
+        lowerMatrix(GEP, NewGEP);
+        DXASSERT(GEP->user_empty(), "else lower matrix fail");
+        GEP->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid GEP for matrix");
+      }
+    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
+      lowerMatrix(BCI, A);
+      DXASSERT(BCI->user_empty(), "else lower matrix fail");
+      BCI->eraseFromParent();
+    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
+      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
+        IRBuilder<> Builder(LI);
+        Value *zeroIdx = Builder.getInt32(0);
+        unsigned vecSize = Ty->getNumElements();
+        Value *NewVec = UndefValue::get(LI->getType());
+        for (unsigned i = 0; i < vecSize; i++) {
+          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
+          Value *Elt = Builder.CreateLoad(GEP);
+          NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
+        }
+        LI->replaceAllUsesWith(NewVec);
+        LI->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
+      Value *V = ST->getValueOperand();
+      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
+        IRBuilder<> Builder(LI);
+        Value *zeroIdx = Builder.getInt32(0);
+        unsigned vecSize = Ty->getNumElements();
+        for (unsigned i = 0; i < vecSize; i++) {
+          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
+          Value *Elt = Builder.CreateExtractElement(V, i);
+          Builder.CreateStore(Elt, GEP);
+        }
+        ST->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else {
+      DXASSERT(0, "invalid use of matrix");
+    }
+  }
+}
+
+char MatrixBitcastLowerPass::ID = 0;
+FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
+
+INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)

+ 1224 - 2561
lib/HLSL/HLMatrixLowerPass.cpp

@@ -9,8 +9,9 @@
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
 #include "dxc/HLSL/HLMatrixLowerPass.h"
+#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HLSL/HLOperations.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/DXIL/DxilUtil.h"
@@ -19,6 +20,7 @@
 #include "dxc/DXIL/DxilOperations.h"
 #include "dxc/DXIL/DxilTypeSystem.h"
 #include "dxc/DXIL/DxilModule.h"
+#include "HLMatrixSubscriptUseReplacer.h"
 
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
@@ -36,356 +38,146 @@ using namespace hlsl::HLMatrixLower;
 namespace hlsl {
 namespace HLMatrixLower {
 
-// If user is function call, return param annotation to get matrix major.
-DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
-                                               DxilTypeSystem &typeSys) {
-  for (User *U : Mat->users()) {
-    if (CallInst *CI = dyn_cast<CallInst>(U)) {
-      Function *F = CI->getCalledFunction();
-      if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
-        for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
-          if (CI->getArgOperand(i) == Mat) {
-            return &Anno->GetParameterAnnotation(i);
-          }
-        }
-      }
-    }
-  }
-  return nullptr;
-}
-
-// Translate matrix type to vector type.
-Type *LowerMatrixType(Type *Ty, bool forMem) {
-  // Only translate matrix type and function type which use matrix type.
-  // Not translate struct has matrix or matrix pointer.
-  // Struct should be flattened before.
-  // Pointer could cover by matldst which use vector as value type.
-  if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
-    Type *RetTy = LowerMatrixType(FT->getReturnType());
-    SmallVector<Type *, 4> params;
-    for (Type *param : FT->params()) {
-      params.emplace_back(LowerMatrixType(param));
-    }
-    return FunctionType::get(RetTy, params, false);
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
-    unsigned row, col;
-    Type *EltTy = GetMatrixInfo(Ty, col, row);
-    if (forMem && EltTy->isIntegerTy(1))
-      EltTy = Type::getInt32Ty(Ty->getContext());
-    return VectorType::get(EltTy, row * col);
-  } else {
-    return Ty;
-  }
-}
-
-// Translate matrix type to array type.
-Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
-  if (dxilutil::IsHLSLMatrixType(Ty)) {
-    unsigned row, col;
-    Type *EltTy = GetMatrixInfo(Ty, col, row);
-    return ArrayType::get(EltTy, row * col);
-  } else {
-    return Ty;
-  }
-}
-
-
-Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
-  DXASSERT(dxilutil::IsHLSLMatrixType(Ty), "not matrix type");
-  StructType *ST = cast<StructType>(Ty);
-  Type *EltTy = ST->getElementType(0);
-  Type *RowTy = EltTy->getArrayElementType();
-  row = EltTy->getArrayNumElements();
-  col = RowTy->getVectorNumElements();
-  return RowTy->getVectorElementType();
-}
-
-bool IsMatrixArrayPointer(llvm::Type *Ty) {
-  if (!Ty->isPointerTy())
-    return false;
-  Ty = Ty->getPointerElementType();
-  if (!Ty->isArrayTy())
-    return false;
-  while (Ty->isArrayTy())
-    Ty = Ty->getArrayElementType();
-  return dxilutil::IsHLSLMatrixType(Ty);
-}
-Type *LowerMatrixArrayPointer(Type *Ty, bool forMem) {
-  unsigned addrSpace = Ty->getPointerAddressSpace();
-  Ty = Ty->getPointerElementType();
-  std::vector<unsigned> arraySizeList;
-  while (Ty->isArrayTy()) {
-    arraySizeList.push_back(Ty->getArrayNumElements());
-    Ty = Ty->getArrayElementType();
-  }
-  Ty = LowerMatrixType(Ty, forMem);
-
-  for (auto arraySize = arraySizeList.rbegin();
-       arraySize != arraySizeList.rend(); arraySize++)
-    Ty = ArrayType::get(Ty, *arraySize);
-  return PointerType::get(Ty, addrSpace);
-}
-
-Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
-  unsigned addrSpace = Ty->getPointerAddressSpace();
-  Ty = Ty->getPointerElementType();
-
-  unsigned arraySize = 1;
-  while (Ty->isArrayTy()) {
-    arraySize *= Ty->getArrayNumElements();
-    Ty = Ty->getArrayElementType();
-  }
-  unsigned row, col;
-  Type *EltTy = GetMatrixInfo(Ty, col, row);
-  arraySize *= row*col;
-
-  Ty = ArrayType::get(EltTy, arraySize);
-  return PointerType::get(Ty, addrSpace);
-}
-Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
-  IRBuilder<> &Builder) {
-  Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
-  for (unsigned i = 0; i < size; i++)
+Value *BuildVector(Type *EltTy, ArrayRef<llvm::Value *> elts, IRBuilder<> &Builder) {
+  Value *Vec = UndefValue::get(VectorType::get(EltTy, static_cast<unsigned>(elts.size())));
+  for (unsigned i = 0; i < elts.size(); i++)
     Vec = Builder.CreateInsertElement(Vec, elts[i], i);
   return Vec;
 }
 
-llvm::Value *VecMatrixMemToReg(llvm::Value *VecVal, llvm::Type *MatType,
-  llvm::IRBuilder<> &Builder)
-{
-  llvm::Type *VecMatRegTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/false);
-  if (VecVal->getType() == VecMatRegTy) {
-    return VecVal;
-  }
-
-  DXASSERT(VecMatRegTy->getVectorElementType()->isIntegerTy(1),
-    "Vector matrix mem to reg type mismatch should only happen for bools.");
-  llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
-  return Builder.CreateICmpNE(VecVal, Constant::getNullValue(VecMatMemTy));
-}
-
-llvm::Value *VecMatrixRegToMem(llvm::Value* VecVal, llvm::Type *MatType,
-  llvm::IRBuilder<> &Builder)
-{
-  llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
-  if (VecVal->getType() == VecMatMemTy) {
-    return VecVal;
-  }
-
-  DXASSERT(VecVal->getType()->getVectorElementType()->isIntegerTy(1),
-    "Vector matrix reg to mem type mismatch should only happen for bools.");
-  return Builder.CreateZExt(VecVal, VecMatMemTy);
-}
-
-llvm::Instruction *CreateVecMatrixLoad(
-  llvm::Value *VecPtr, llvm::Type *MatType, llvm::IRBuilder<> &Builder)
-{
-  llvm::Instruction *VecVal = Builder.CreateLoad(VecPtr);
-  return cast<llvm::Instruction>(VecMatrixMemToReg(VecVal, MatType, Builder));
-}
-
-llvm::Instruction *CreateVecMatrixStore(llvm::Value* VecVal, llvm::Value *VecPtr,
-  llvm::Type *MatType, llvm::IRBuilder<> &Builder)
-{
-  llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
-  if (VecVal->getType() == VecMatMemTy) {
-    return Builder.CreateStore(VecVal, VecPtr);
-  }
-
-  // We need to convert to the memory representation, and we want to return
-  // the conversion instruction rather than the store since that's what
-  // accepts the register-typed i1 values.
-
-  // Do not use VecMatrixRegToMem as it may constant fold the conversion
-  // instruction, which is what we want to return.
-  DXASSERT(VecVal->getType()->getVectorElementType()->isIntegerTy(1),
-    "Vector matrix reg to mem type mismatch should only happen for bools.");
-
-  llvm::Instruction *ConvInst = Builder.Insert(new ZExtInst(VecVal, VecMatMemTy));
-  Builder.CreateStore(ConvInst, VecPtr);
-  return ConvInst;
-}
-
-Value *LowerGEPOnMatIndexListToIndex(
-    llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
-  IRBuilder<> Builder(GEP);
-  Value *zero = Builder.getInt32(0);
-  DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
-  Value *baseIdx = (GEP->idx_begin())->get();
-  DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
-  Value *Idx = (GEP->idx_begin() + 1)->get();
-
-  if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
-    return IdxList[immIdx->getSExtValue()];
-  } else {
-    IRBuilder<> AllocaBuilder(
-        GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
-    unsigned size = IdxList.size();
-    // Store idxList to temp array.
-    ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
-    Value *tempArray = AllocaBuilder.CreateAlloca(AT);
-
-    for (unsigned i = 0; i < size; i++) {
-      Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
-      Builder.CreateStore(IdxList[i], EltPtr);
-    }
-    // Load the idx.
-    Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
-    return Builder.CreateLoad(GEPOffset);
-  }
-}
-
-
-unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
-  return c * row + r;
-}
-unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
-  return r * col + c;
-}
-
 } // namespace HLMatrixLower
 } // namespace hlsl
 
 namespace {
 
-class HLMatrixLowerPass : public ModulePass {
+// Creates and manages a set of temporary overloaded functions keyed on the function type,
+// and which should be destroyed when the pool gets out of scope.
+class TempOverloadPool {
+public:
+  TempOverloadPool(llvm::Module &Module, const char* BaseName)
+    : Module(Module), BaseName(BaseName) {}
+  ~TempOverloadPool() { clear(); }
 
+  Function *get(FunctionType *Ty);
+  bool contains(FunctionType *Ty) const { return Funcs.count(Ty) != 0; }
+  bool contains(Function *Func) const;
+  void clear();
+
+private:
+  llvm::Module &Module;
+  const char* BaseName;
+  llvm::DenseMap<FunctionType*, Function*> Funcs;
+};
+
+Function *TempOverloadPool::get(FunctionType *Ty) {
+  auto It = Funcs.find(Ty);
+  if (It != Funcs.end()) return It->second;
+
+  std::string MangledName;
+  raw_string_ostream MangledNameStream(MangledName);
+  MangledNameStream << BaseName;
+  MangledNameStream << '.';
+  Ty->print(MangledNameStream);
+  MangledNameStream.flush();
+
+  Function* Func = cast<Function>(Module.getOrInsertFunction(MangledName, Ty));
+  Funcs.insert(std::make_pair(Ty, Func));
+  return Func;
+}
+
+bool TempOverloadPool::contains(Function *Func) const {
+  auto It = Funcs.find(Func->getFunctionType());
+  return It != Funcs.end() && It->second == Func;
+}
+
+void TempOverloadPool::clear() {
+  for (auto Entry : Funcs) {
+    DXASSERT(Entry.second->use_empty(), "Temporary function still used during pool destruction.");
+    Entry.second->removeFromParent();
+  }
+  Funcs.clear();
+}
+
+// High-level matrix lowering pass.
+//
+// This pass converts matrices to their lowered vector representations,
+// including global variables, local variables and operations,
+// but not function signatures (arguments and return types) - left to HLSignatureLower and HLMatrixBitcastLower,
+// nor matrices obtained from resources or constant - left to HLOperationLower.
+//
+// Algorithm overview:
+// 1. Find all matrix and matrix array global variables and lower them to vectors.
+//    Walk any GEPs and insert vec-to-mat translation stubs so that consuming
+//    instructions keep dealing with matrix types for the moment.
+// 2. For each function
+// 2a. Lower all matrix and matrix array allocas, just like global variables.
+// 2b. Lower all other instructions producing or consuming matrices
+//
+// Conversion stubs are used to allow converting instructions in isolation,
+// and in an order-independent manner:
+//
+// Initial: MatInst1(MatInst2(MatInst3))
+// After lowering MatInst2: MatInst1(VecToMat(VecInst2(MatToVec(MatInst3))))
+// After lowering MatInst1: VecInst1(VecInst2(MatToVec(MatInst3)))
+// After lowering MatInst3: VecInst1(VecInst2(VecInst3))
+class HLMatrixLowerPass : public ModulePass {
 public:
   static char ID; // Pass identification, replacement for typeid
   explicit HLMatrixLowerPass() : ModulePass(ID) {}
 
   const char *getPassName() const override { return "HL matrix lower"; }
+  bool runOnModule(Module &M) override;
 
-  bool runOnModule(Module &M) override {
-    m_pModule = &M;
-    m_pHLModule = &m_pModule->GetOrCreateHLModule();
-    // Load up debug information, to cross-reference values and the instructions
-    // used to load them.
-    m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
-
-    for (Function &F : M.functions()) {
-
-      if (F.isDeclaration())
-        continue;
-      runOnFunction(F);
-    }
-    std::vector<GlobalVariable*> staticGVs;
-    for (GlobalVariable &GV : M.globals()) {
-      if (dxilutil::IsStaticGlobal(&GV) ||
-          dxilutil::IsSharedMemoryGlobal(&GV)) {
-        staticGVs.emplace_back(&GV);
-      }
-    }
-
-    for (GlobalVariable *GV : staticGVs)
-      runOnGlobal(GV);
-
-    return true;
-  }
+private:
+  void runOnFunction(Function &Func);
+  void addToDeadInsts(Instruction *Inst) { m_deadInsts.emplace_back(Inst); }
+  void deleteDeadInsts();
+
+  void getMatrixAllocasAndOtherInsts(Function &Func,
+    std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts);
+  Value *getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub = false);
+  Value *tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub = false);
+  Value *bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder);
+  void replaceAllUsesByLoweredValue(Instruction *MatInst, Value *VecVal);
+  void replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr);
+  void replaceAllVariableUses(SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr);
+
+  void lowerGlobal(GlobalVariable *Global);
+  Constant *lowerConstInitVal(Constant *Val);
+  AllocaInst *lowerAlloca(AllocaInst *MatAlloca);
+  void lowerInstruction(Instruction* Inst);
+  void lowerReturn(ReturnInst* Return);
+  Value *lowerCall(CallInst *Call);
+  Value *lowerNonHLCall(CallInst *Call);
+  Value *lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup);
+  Value *lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode);
+  Value *lowerHLMulIntrinsic(Value* Lhs, Value *Rhs, bool Unsigned, IRBuilder<> &Builder);
+  Value *lowerHLTransposeIntrinsic(Value *MatVal, IRBuilder<> &Builder);
+  Value *lowerHLDeterminantIntrinsic(Value *MatVal, IRBuilder<> &Builder);
+  Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder);
+  Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder);
+  Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
+  Value *lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
+  Value *lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
+  Value *lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
+  Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
+  Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
+  Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
+  void lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices);
+  Value *lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
+  Value *lowerHLInit(CallInst *Call);
+  Value *lowerHLSelect(CallInst *Call);
 
 private:
   Module *m_pModule;
   HLModule *m_pHLModule;
   bool m_HasDbgInfo;
+
+  // Pools for the translation stubs
+  TempOverloadPool *m_matToVecStubs = nullptr;
+  TempOverloadPool *m_vecToMatStubs = nullptr;
+
   std::vector<Instruction *> m_deadInsts;
-  // For instruction like matrix array init.
-  // May use more than 1 matrix alloca inst.
-  // This set is here to avoid put it into deadInsts more than once.
-  std::unordered_set<Instruction *> m_inDeadInstsSet;
-  // For most matrix insturction users, it will only have one matrix use.
-  // Use vector so save deadInsts because vector is cheap.
-  void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
-  // In case instruction has more than one matrix use.
-  // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
-  void AddToDeadInstsWithDups(Instruction *I) {
-    if (m_inDeadInstsSet.count(I) == 0) {
-      // Only add to deadInsts when it's not inside m_inDeadInstsSet.
-      m_inDeadInstsSet.insert(I);
-      AddToDeadInsts(I);
-    }
-  }
-  void runOnFunction(Function &F);
-  void runOnGlobal(GlobalVariable *GV);
-  void runOnGlobalMatrixArray(GlobalVariable *GV);
-  Instruction *MatCastToVec(CallInst *CI);
-  Instruction *MatLdStToVec(CallInst *CI);
-  Instruction *MatSubscriptToVec(CallInst *CI);
-  Instruction *MatFrExpToVec(CallInst *CI);
-  Instruction *MatIntrinsicToVec(CallInst *CI);
-  Instruction *TrivialMatUnOpToVec(CallInst *CI);
-  // Replace matVal with vecVal on matUseInst.
-  void TrivialMatUnOpReplace(Value *matVal, Value *vecVal,
-                            CallInst *matUseInst);
-  Instruction *TrivialMatBinOpToVec(CallInst *CI);
-  // Replace matVal with vecVal on matUseInst.
-  void TrivialMatBinOpReplace(Value *matVal, Value *vecVal,
-                             CallInst *matUseInst);
-  // Replace matVal with vecVal on mulInst.
-  void TranslateMatMatMul(Value *matVal, Value *vecVal,
-                          CallInst *mulInst, bool isSigned);
-  void TranslateMatVecMul(Value *matVal, Value *vecVal,
-                          CallInst *mulInst, bool isSigned);
-  void TranslateVecMatMul(Value *matVal, Value *vecVal,
-                          CallInst *mulInst, bool isSigned);
-  void TranslateMul(Value *matVal, Value *vecVal, CallInst *mulInst,
-                    bool isSigned);
-  // Replace matVal with vecVal on transposeInst.
-  void TranslateMatTranspose(Value *matVal, Value *vecVal,
-                             CallInst *transposeInst);
-  void TranslateMatDeterminant(Value *matVal, Value *vecVal,
-                             CallInst *determinantInst);
-  void MatIntrinsicReplace(Value *matVal, Value *vecVal,
-                           CallInst *matUseInst);
-  // Replace matVal with vecVal on castInst.
-  void TranslateMatMatCast(Value *matVal, Value *vecVal,
-                           CallInst *castInst);
-  void TranslateMatToOtherCast(Value *matVal, Value *vecVal,
-                               CallInst *castInst);
-  void TranslateMatCast(Value *matVal, Value *vecVal,
-                        CallInst *castInst);
-  void TranslateMatMajorCast(Value *matVal, Value *vecVal,
-                        CallInst *castInst, bool rowToCol, bool transpose);
-  // Replace matVal with vecVal in matSubscript
-  void TranslateMatSubscript(Value *matVal, Value *vecVal,
-                             CallInst *matSubInst);
-  // Replace matInitInst using matToVecMap
-  void TranslateMatInit(CallInst *matInitInst);
-  // Replace matSelectInst using matToVecMap
-  void TranslateMatSelect(CallInst *matSelectInst);
-  // Replace matVal with vecVal on matInitInst.
-  void TranslateMatArrayGEP(Value *matVal, Value *vecVal,
-                            GetElementPtrInst *matGEP);
-  void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
-                             CallInst *matLdStInst);
-  void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
-                             CallInst *matLdStInst);
-  void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
-  void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
-
-  // Get new matrix value corresponding to vecVal
-  Value *GetMatrixForVec(Value *vecVal, Type *matTy);
-
-  // Translate library function input/output to preserve function signatures
-  void TranslateArgForLibFunc(CallInst *CI);
-  void TranslateArgsForLibFunc(Function &F);
-
-  // Replace matVal with vecVal on matUseInst.
-  void TrivialMatReplace(Value *matVal, Value *vecVal,
-                        CallInst *matUseInst);
-  // Lower a matrix type instruction to a vector type instruction.
-  void lowerToVec(Instruction *matInst);
-  // Lower users of a matrix type instruction.
-  void replaceMatWithVec(Value *matVal, Value *vecVal);
-  // Translate user library function call arguments
-  void castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI);
-  // Translate mat inst which need all operands ready.
-  void finalMatTranslation(Value *matVal);
-  // Delete dead insts in m_deadInsts.
-  void DeleteDeadInsts();
-  // Map from matrix value to its vector version.
-  DenseMap<Value *, Value *> matToVecMap;
-  // Map from new vector version to matrix version needed by user call or return.
-  DenseMap<Value *, Value *> vecToMatMap;
 };
 }
 
@@ -395,2445 +187,1316 @@ ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
 
 INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
 
-static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
-                                   IRBuilder<> Builder) {
-  Type *srcTy = src->getType();
-
-  // Conversions between equivalent types are no-ops,
-  // even between signed/unsigned variants.
-  if (srcTy == toTy) return cast<Instruction>(src);
+bool HLMatrixLowerPass::runOnModule(Module &M) {
+  TempOverloadPool matToVecStubs(M, "hlmatrixlower.mat2vec");
+  TempOverloadPool vecToMatStubs(M, "hlmatrixlower.vec2mat");
+
+  m_pModule = &M;
+  m_pHLModule = &m_pModule->GetOrCreateHLModule();
+  // Load up debug information, to cross-reference values and the instructions
+  // used to load them.
+  m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
+  m_matToVecStubs = &matToVecStubs;
+  m_vecToMatStubs = &vecToMatStubs;
+
+  // First, lower static global variables.
+  // We need to accumulate them locally because we'll be creating new ones as we lower them.
+  std::vector<GlobalVariable*> Globals;
+  for (GlobalVariable &Global : M.globals()) {
+    if ((dxilutil::IsStaticGlobal(&Global) || dxilutil::IsSharedMemoryGlobal(&Global))
+      && HLMatrixType::isMatrixPtrOrArrayPtr(Global.getType())) {
+      Globals.emplace_back(&Global);
+    }
+  }
 
-  bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
-                      castOp == HLCastOpcode::UnsignedUnsignedCast;
-  bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
-                    castOp == HLCastOpcode::UnsignedUnsignedCast;
+  for (GlobalVariable *Global : Globals)
+    lowerGlobal(Global);
 
-  // Conversions to bools are comparisons
-  if (toTy->getScalarSizeInBits() == 1) {
-    // fcmp une is what regular clang uses in C++ for (bool)f;
-    return cast<Instruction>(srcTy->isIntOrIntVectorTy()
-      ? Builder.CreateICmpNE(src, llvm::Constant::getNullValue(srcTy), "tobool")
-      : Builder.CreateFCmpUNE(src, llvm::Constant::getNullValue(srcTy), "tobool"));
+  for (Function &F : M.functions()) {
+    if (F.isDeclaration()) continue;
+    runOnFunction(F);
   }
 
-  // Cast necessary
-  auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
-    srcTy, fromUnsigned, toTy, toUnsigned));
-  return cast<Instruction>(Builder.CreateCast(CastOp, src, toTy));
+  m_pModule = nullptr;
+  m_pHLModule = nullptr;
+  m_matToVecStubs = nullptr;
+  m_vecToMatStubs = nullptr;
+
+  // If you hit an assert during TempOverloadPool destruction,
+  // it means that either a matrix producer was lowered,
+  // causing a translation stub to be created,
+  // but the consumer of that matrix was never (properly) lowered.
+  // Or the opposite: a matrix consumer was lowered and not its producer.
+
+  return true;
 }
 
-Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
-  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
-
-  bool ToMat = dxilutil::IsHLSLMatrixType(CI->getType());
-  bool FromMat = dxilutil::IsHLSLMatrixType(op->getType());
-  if (ToMat && !FromMat) {
-    // Translate OtherToMat here.
-    // Rest will translated when replace.
-    unsigned col, row;
-    Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
-    unsigned toSize = col * row;
-    Instruction *sizeCast = nullptr;
-    Type *FromTy = op->getType();
-    Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
-    if (FromTy->isVectorTy()) {
-      std::vector<Constant *> MaskVec(toSize);
-      for (size_t i = 0; i != toSize; ++i)
-        MaskVec[i] = ConstantInt::get(I32Ty, i);
-
-      Value *castMask = ConstantVector::get(MaskVec);
-
-      sizeCast = new ShuffleVectorInst(op, op, castMask);
-      Builder.Insert(sizeCast);
-
-    } else {
-      op = Builder.CreateInsertElement(
-          UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
-      Constant *zero = ConstantInt::get(I32Ty, 0);
-      std::vector<Constant *> MaskVec(toSize, zero);
-      Value *castMask = ConstantVector::get(MaskVec);
-
-      sizeCast = new ShuffleVectorInst(op, op, castMask);
-      Builder.Insert(sizeCast);
-    }
-    Instruction *typeCast = sizeCast;
-    if (EltTy != FromTy->getScalarType()) {
-      typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
-                                sizeCast, Builder);
-    }
-    return typeCast;
-  } else if (FromMat && ToMat) {
-    if (isa<Argument>(op)) {
-      // Cast From mat to mat for arugment.
-      IRBuilder<> Builder(CI);
-
-      // Here only lower the return type to vector.
-      Type *RetTy = LowerMatrixType(CI->getType());
-      SmallVector<Type *, 4> params;
-      for (Value *operand : CI->arg_operands()) {
-        params.emplace_back(operand->getType());
+void HLMatrixLowerPass::runOnFunction(Function &Func) {
+  // Skip hl function definition (like createhandle)
+  if (hlsl::GetHLOpcodeGroupByName(&Func) != HLOpcodeGroup::NotHL)
+    return;
+
+  // Save the matrix instructions first since the translation process
+  // will temporarily create other instructions consuming/producing matrix types.
+  std::vector<AllocaInst*> MatAllocas;
+  std::vector<Instruction*> MatInsts;
+  getMatrixAllocasAndOtherInsts(Func, MatAllocas, MatInsts);
+
+  // First lower all allocas and take care of their GEP chains
+  for (AllocaInst* MatAlloca : MatAllocas) {
+    AllocaInst* LoweredAlloca = lowerAlloca(MatAlloca);
+    replaceAllVariableUses(MatAlloca, LoweredAlloca);
+    addToDeadInsts(MatAlloca);
+  }
+
+  // Now lower all other matrix instructions
+  for (Instruction *MatInst : MatInsts)
+    lowerInstruction(MatInst);
+
+  deleteDeadInsts();
+}
+
+void HLMatrixLowerPass::deleteDeadInsts() {
+  while (!m_deadInsts.empty()) {
+    Instruction *Inst = m_deadInsts.back();
+    m_deadInsts.pop_back();
+
+    DXASSERT_NOMSG(Inst->use_empty());
+    for (Value *Operand : Inst->operand_values()) {
+      Instruction *OperandInst = dyn_cast<Instruction>(Operand);
+      if (OperandInst && ++OperandInst->user_begin() == OperandInst->user_end()) {
+        // We were its only user, erase recursively.
+        // This will get rid of translation stubs:
+        // Original: MatConsumer(MatProducer)
+        // Producer lowered: MatConsumer(VecToMat(VecProducer)), MatProducer dead
+        // Consumer lowered: VecConsumer(VecProducer)), MatConsumer(VecToMat) dead
+        // Only by recursing on MatConsumer's operand do we delete the VecToMat stub.
+        DXASSERT_NOMSG(*OperandInst->user_begin() == Inst);
+        m_deadInsts.emplace_back(OperandInst);
       }
+    }
 
-      Type *FT = FunctionType::get(RetTy, params, false);
+    Inst->eraseFromParent();
+  }
+}
 
-      HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-      unsigned opcode = GetHLOpcode(CI);
+// Find all instructions consuming or producing matrices,
+// directly or through pointers/arrays.
+void HLMatrixLowerPass::getMatrixAllocasAndOtherInsts(Function &Func,
+    std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts){
+  for (BasicBlock &BasicBlock : Func) {
+    for (Instruction &Inst : BasicBlock) {
+      // Don't lower GEPs directly, we'll handle them as we lower the root pointer,
+      // typically a global variable or alloca.
+      if (isa<GetElementPtrInst>(&Inst)) continue;
 
-      Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
-                                             group, opcode);
+      if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&Inst)) {
+        if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Alloca->getType())) {
+          MatAllocas.emplace_back(Alloca);
+        }
+        continue;
+      }
+      
+      if (CallInst *Call = dyn_cast<CallInst>(&Inst)) {
+        // Lowering of global variables will have introduced
+        // vec-to-mat translation stubs, which we deal with indirectly,
+        // as we lower the instructions consuming them.
+        if (m_vecToMatStubs->contains(Call->getCalledFunction()))
+          continue;
+
+        // Mat-to-vec stubs should only be introduced during instruction lowering.
+        // Globals lowering won't introduce any because their only operand is
+        // their initializer, which we can fully lower without stubbing since it is constant.
+        DXASSERT(!m_matToVecStubs->contains(Call->getCalledFunction()),
+          "Unexpected mat-to-vec stubbing before function instruction lowering.");
+
+        // Match matrix producers
+        if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Inst.getType())) {
+          MatInsts.emplace_back(Call);
+          continue;
+        }
 
-      SmallVector<Value *, 4> argList;
-      for (Value *arg : CI->arg_operands()) {
-        argList.emplace_back(arg);
+        // Match matrix consumers
+        for (Value *Operand : Inst.operand_values()) {
+          if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Operand->getType())) {
+            MatInsts.emplace_back(Call);
+            break;
+          }
+        }
+
+        continue;
       }
 
-      return Builder.CreateCall(vecF, argList);
+      if (ReturnInst *Return = dyn_cast<ReturnInst>(&Inst)) {
+        Value *ReturnValue = Return->getReturnValue();
+        if (ReturnValue != nullptr && HLMatrixType::isMatrixOrPtrOrArrayPtr(ReturnValue->getType()))
+          MatInsts.emplace_back(Return);
+        continue;
+      }
+
+      // Nothing else should produce or consume matrices
     }
   }
-
-  return MatIntrinsicToVec(CI);
 }
 
-// Return GEP if value is Matrix resulting GEP from UDT alloca
-// UDT alloca must be there for library function args
-static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
-  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
-    if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
-      Value *ptr = GEP->getPointerOperand();
-      if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
-        Type *ATy = AI->getAllocatedType();
-        if (ATy->isStructTy() && !dxilutil::IsHLSLMatrixType(ATy)) {
-          return GEP;
-        }
+// Gets the matrix-lowered representation of a value, potentially adding a translation stub.
+// DiscardStub causes any vec-to-mat translation stubs to be deleted,
+// it should be true only if the original instruction will be modified and kept alive.
+// If a new instruction is created and the original marked as dead,
+// then the remove dead instructions pass will take care of removing the stub.
+Value* HLMatrixLowerPass::getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub) {
+  Type *Ty = Val->getType();
+
+  // We're only lowering byval matrices.
+  // Since structs and arrays are always accessed by pointer,
+  // we do not need to worry about a matrix being hidden inside a more complex type.
+  DXASSERT(!Ty->isPointerTy(), "Value cannot be a pointer.");
+  HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty);
+  if (!MatTy) return Val;
+
+  Type *LoweredTy = MatTy.getLoweredVectorTypeForReg();
+  
+  // Check if the value is already a vec-to-mat translation stub
+  if (CallInst *Call = dyn_cast<CallInst>(Val)) {
+    if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
+      if (DiscardStub && Call->getNumUses() == 1) {
+        Call->use_begin()->set(UndefValue::get(Call->getType()));
+        addToDeadInsts(Call);
       }
+
+      Value *LoweredVal = Call->getArgOperand(0);
+      DXASSERT(LoweredVal->getType() == LoweredTy, "Unexpected already-lowered value type.");
+      return LoweredVal;
     }
   }
-  return nullptr;
+
+  // Return a mat-to-vec translation stub
+  FunctionType *TranslationStubTy = FunctionType::get(LoweredTy, { Ty }, /* isVarArg */ false);
+  Function *TranslationStub = m_matToVecStubs->get(TranslationStubTy);
+  return Builder.CreateCall(TranslationStub, { Val });
 }
 
-// Return GEP if value is Matrix resulting GEP from UDT argument of
-// none-graphics functions.
-static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
-  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
-    if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
-      Value *ptr = GEP->getPointerOperand();
-      if (Argument *Arg = dyn_cast<Argument>(ptr)) {
-        if (!HM.IsGraphicsShader(Arg->getParent()))
-          return GEP;
+// Attempts to retrieve the lowered vector pointer equivalent to a matrix pointer.
+// Returns nullptr if the pointed-to matrix lives in memory that cannot be lowered at this time,
+// for example a buffer or shader inputs/outputs, which are lowered during signature lowering.
+Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub) {
+  if (!HLMatrixType::isMatrixPtrOrArrayPtr(Ptr->getType()))
+    return nullptr;
+
+  // Matrix pointers can only be derived from Allocas, GlobalVariables or resource accesses.
+  // The first two cases are what this pass must be able to lower, and we should already
+  // have replaced their uses by vector to matrix pointer translation stubs.
+  if (CallInst *Call = dyn_cast<CallInst>(Ptr)) {
+    if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
+      if (DiscardStub && Call->getNumUses() == 1) {
+        Call->use_begin()->set(UndefValue::get(Call->getType()));
+        addToDeadInsts(Call);
       }
+      return Call->getArgOperand(0);
     }
   }
-  return nullptr;
-}
 
-Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  unsigned opcode = GetHLOpcode(CI);
-  HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
-  Instruction *result = nullptr;
-  switch (matOpcode) {
-  case HLMatLoadStoreOpcode::ColMatLoad:
-  case HLMatLoadStoreOpcode::RowMatLoad: {
-    Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
-    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
-        GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
-      Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
-      result = CreateVecMatrixLoad(vecPtr, matPtr->getType()->getPointerElementType(), Builder);
-    } else
-      result = MatIntrinsicToVec(CI);
-  } break;
-  case HLMatLoadStoreOpcode::ColMatStore:
-  case HLMatLoadStoreOpcode::RowMatStore: {
-    Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
-    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
-        GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
-      Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
-      Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-      Value *vecVal = UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
-      result = CreateVecMatrixStore(vecVal, vecPtr, matVal->getType(), Builder);
-    } else
-      result = MatIntrinsicToVec(CI);
-  } break;
+  // There's one more case to handle.
+  // When compiling shader libraries, signatures won't have been lowered yet.
+  // So we can have a matrix in a struct as an argument,
+  // or an alloca'd struct holding the return value of a call and containing a matrix.
+  Value *RootPtr = Ptr;
+  while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
+    RootPtr = GEP->getPointerOperand();
+
+  Argument *Arg = dyn_cast<Argument>(RootPtr);
+  bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsGraphicsShader(Arg->getParent());
+  if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
+    // Bitcast the matrix pointer to its lowered equivalent.
+    // The HLMatrixBitcast pass will take care of this later.
+    return Builder.CreateBitCast(Ptr, HLMatrixType::getLoweredType(Ptr->getType()));
   }
-  return result;
+
+  // The pointer must be derived from a resource, we don't handle it in this pass.
+  return nullptr;
 }
 
-Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-  if (isa<AllocaInst>(matPtr)) {
-    // Here just create a new matSub call which use vec ptr.
-    // Later in TranslateMatSubscript will do the real translation.
-    std::vector<Value *> args(CI->getNumArgOperands());
-    for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
-      args[i] = CI->getArgOperand(i);
-    }
-    // Change mat ptr into vec ptr.
-    args[HLOperandIndex::kMatSubscriptMatOpIdx] =
-        matToVecMap[cast<Instruction>(matPtr)];
-    std::vector<Type *> paramTyList(CI->getNumArgOperands());
-    for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
-      paramTyList[i] = args[i]->getType();
-    }
+// Bitcasts a value from matrix to vector or vice-versa.
+// This is used to convert to/from arguments/return values since we don't
+// lower signatures in this pass. The later HLMatrixBitcastLower pass fixes this.
+Value *HLMatrixLowerPass::bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder) {
+  Type *SrcTy = SrcVal->getType();
+  DXASSERT_NOMSG(!SrcTy->isPointerTy());
 
-    FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
-    unsigned opcode = GetHLOpcode(CI);
-    Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
-    return Builder.CreateCall(opFunc, args);
-  } else
-    return MatIntrinsicToVec(CI);
+  // We store and load from a temporary alloca, bitcasting either on the store pointer
+  // or on the load pointer.
+  IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
+  Value *Alloca = AllocaBuilder.CreateAlloca(DstTyAlloca ? DstTy : SrcTy);
+  Value *BitCastedAlloca = Builder.CreateBitCast(Alloca, (DstTyAlloca ? SrcTy : DstTy)->getPointerTo());
+  Builder.CreateStore(SrcVal, DstTyAlloca ? BitCastedAlloca : Alloca);
+  return Builder.CreateLoad(DstTyAlloca ? Alloca : BitCastedAlloca);
 }
 
-Instruction *HLMatrixLowerPass::MatFrExpToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  FunctionType *FT = CI->getCalledFunction()->getFunctionType();
-  Type *RetTy = LowerMatrixType(FT->getReturnType());
-  SmallVector<Type *, 4> params;
-  for (Type *param : FT->params()) {
-    if (!param->isPointerTy()) {
-      params.emplace_back(LowerMatrixType(param));
-    } else {
-      // Lower pointer type for frexp.
-      Type *EltTy = LowerMatrixType(param->getPointerElementType());
-      params.emplace_back(
-          PointerType::get(EltTy, param->getPointerAddressSpace()));
-    }
-  }
+// Replaces all uses of a matrix value by its lowered vector form,
+// inserting translation stubs for users which still expect a matrix value.
+void HLMatrixLowerPass::replaceAllUsesByLoweredValue(Instruction* MatInst, Value* VecVal) {
+  if (VecVal == nullptr || VecVal == MatInst) return;
 
-  Type *VecFT = FunctionType::get(RetTy, params, false);
+  DXASSERT(HLMatrixType::getLoweredType(MatInst->getType()) == VecVal->getType(),
+    "Unexpected lowered value type.");
 
-  HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-  Function *vecF =
-      GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(VecFT), group,
-                            static_cast<unsigned>(IntrinsicOp::IOP_frexp));
+  Instruction *VecToMatStub = nullptr;
 
-  SmallVector<Value *, 4> argList;
-  auto paramTyIt = params.begin();
-  for (Value *arg : CI->arg_operands()) {
-    Type *Ty = arg->getType();
-    Type *ParamTy = *(paramTyIt++);
+  while (!MatInst->use_empty()) {
+    Use &ValUse = *MatInst->use_begin();
 
-    if (Ty != ParamTy)
-      argList.emplace_back(UndefValue::get(ParamTy));
-    else
-      argList.emplace_back(arg);
-  }
+    // Handle non-matrix cases, just point to the new value.
+    if (MatInst->getType() == VecVal->getType()) {
+      ValUse.set(VecVal);
+      continue;
+    }
 
-  return Builder.CreateCall(vecF, argList);
-}
+    // If the user is already a matrix-to-vector translation stub,
+    // we can now replace it by the proper vector value.
+    if (CallInst *Call = dyn_cast<CallInst>(ValUse.getUser())) {
+      if (m_matToVecStubs->contains(Call->getCalledFunction())) {
+        Call->replaceAllUsesWith(VecVal);
+        ValUse.set(UndefValue::get(MatInst->getType()));
+        addToDeadInsts(Call);
+        continue;
+      }
+    }
 
-Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  unsigned opcode = GetHLOpcode(CI);
+    // Otherwise, the user should point to a vector-to-matrix translation
+    // stub of the new vector value.
+    if (VecToMatStub == nullptr) {
+      FunctionType *TranslationStubTy = FunctionType::get(
+        MatInst->getType(), { VecVal->getType() }, /* isVarArg */ false);
+      Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
 
-  if (opcode == static_cast<unsigned>(IntrinsicOp::IOP_frexp))
-    return MatFrExpToVec(CI);
+      Instruction *PrevInst = dyn_cast<Instruction>(VecVal);
+      if (PrevInst == nullptr) PrevInst = MatInst;
 
-  Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
+      IRBuilder<> Builder(dxilutil::SkipAllocas(PrevInst->getNextNode()));
+      VecToMatStub = Builder.CreateCall(TranslationStub, { VecVal });
+    }
 
-  HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+    ValUse.set(VecToMatStub);
+  }
+}
+
+// Replaces all uses of a matrix or matrix array alloca or global variable by its lowered equivalent.
+// This doesn't lower the users, but will insert a translation stub from the lowered value pointer
+// back to the matrix value pointer, and recreate any GEPs around the new pointer.
+// Before: User(GEP(MatrixArrayAlloca))
+// After: User(VecToMatPtrStub(GEP'(VectorArrayAlloca)))
+void HLMatrixLowerPass::replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr) {
+  DXASSERT_NOMSG(HLMatrixType::isMatrixPtrOrArrayPtr(MatPtr->getType()));
+  DXASSERT_NOMSG(LoweredPtr->getType() == HLMatrixType::getLoweredType(MatPtr->getType()));
+
+  SmallVector<Value*, 4> GEPIdxStack;
+  GEPIdxStack.emplace_back(ConstantInt::get(Type::getInt32Ty(MatPtr->getContext()), 0));
+  replaceAllVariableUses(GEPIdxStack, MatPtr, LoweredPtr);
+}
+
+void HLMatrixLowerPass::replaceAllVariableUses(
+    SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr) {
+  while (!StackTopPtr->use_empty()) {
+    llvm::Use &Use = *StackTopPtr->use_begin();
+    if (GEPOperator *GEP = dyn_cast<GEPOperator>(Use.getUser())) {
+      DXASSERT(GEP->getNumIndices() >= 1, "Unexpected degenerate GEP.");
+      DXASSERT(cast<ConstantInt>(*GEP->idx_begin())->isZero(), "Unexpected non-zero first GEP index.");
+
+      // Recurse in GEP to find actual users
+      for (auto It = GEP->idx_begin() + 1; It != GEP->idx_end(); ++It)
+        GEPIdxStack.emplace_back(*It);
+      replaceAllVariableUses(GEPIdxStack, GEP, LoweredPtr);
+      GEPIdxStack.erase(GEPIdxStack.end() - (GEP->getNumIndices() - 1), GEPIdxStack.end());
+      
+      // Discard the GEP
+      DXASSERT_NOMSG(GEP->use_empty());
+      Use.set(UndefValue::get(Use->getType()));
+      if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP))
+        addToDeadInsts(GEPInst);
+      continue;
+    }
 
-  Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
+    // Recreate the same GEP sequence, if any, on the lowered pointer
+    IRBuilder<> Builder(cast<Instruction>(Use.getUser()));
+    Value *LoweredStackTopPtr = GEPIdxStack.size() == 1
+      ? LoweredPtr : Builder.CreateGEP(LoweredPtr, GEPIdxStack);
 
-  SmallVector<Value *, 4> argList;
-  for (Value *arg : CI->arg_operands()) {
-    Type *Ty = arg->getType();
-    if (dxilutil::IsHLSLMatrixType(Ty)) {
-      argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
-    } else
-      argList.emplace_back(arg);
+    // Generate a stub translating the vector pointer back to a matrix pointer,
+    // such that consuming instructions are unaffected.
+    FunctionType *TranslationStubTy = FunctionType::get(
+      StackTopPtr->getType(), { LoweredStackTopPtr->getType() }, /* isVarArg */ false);
+    Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
+    Use.set(Builder.CreateCall(TranslationStub, { LoweredStackTopPtr }));
   }
-
-  return Builder.CreateCall(vecF, argList);
 }
 
-Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
-  Type *ResultTy = LowerMatrixType(CI->getType());
-  UndefValue *tmp = UndefValue::get(ResultTy);
-  IRBuilder<> Builder(CI);
-  HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
-  bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
-
-  Constant *one = isFloat
-    ? ConstantFP::get(ResultTy->getVectorElementType(), 1)
-    : ConstantInt::get(ResultTy->getVectorElementType(), 1);
-  Constant *oneVec = ConstantVector::getSplat(ResultTy->getVectorNumElements(), one);
-
-  Instruction *Result = nullptr;
-  switch (opcode) {
-  case HLUnaryOpcode::Plus: {
-    // This is actually a no-op, but the structure of the code here requires
-    // that we create an instruction.
-    Constant *zero = Constant::getNullValue(ResultTy);
-    if (isFloat)
-      Result = BinaryOperator::CreateFAdd(tmp, zero);
-    else
-      Result = BinaryOperator::CreateAdd(tmp, zero);
-  } break;
-  case HLUnaryOpcode::Minus: {
-    Constant *zero = Constant::getNullValue(ResultTy);
-    if (isFloat)
-      Result = BinaryOperator::CreateFSub(zero, tmp);
-    else
-      Result = BinaryOperator::CreateSub(zero, tmp);
-  } break;
-  case HLUnaryOpcode::LNot: {
-    Constant *zero = Constant::getNullValue(ResultTy);
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UEQ, tmp, zero);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, zero);
-  } break;
-  case HLUnaryOpcode::Not: {
-    Constant *allOneBits = Constant::getAllOnesValue(ResultTy);
-    Result = BinaryOperator::CreateXor(tmp, allOneBits);
-  } break;
-  case HLUnaryOpcode::PostInc:
-  case HLUnaryOpcode::PreInc:
-    if (isFloat)
-      Result = BinaryOperator::CreateFAdd(tmp, oneVec);
-    else
-      Result = BinaryOperator::CreateAdd(tmp, oneVec);
-    break;
-  case HLUnaryOpcode::PostDec:
-  case HLUnaryOpcode::PreDec:
-    if (isFloat)
-      Result = BinaryOperator::CreateFSub(tmp, oneVec);
-    else
-      Result = BinaryOperator::CreateSub(tmp, oneVec);
-    break;
-  default:
-    DXASSERT(0, "not implement");
-    return nullptr;
-  }
-  Builder.Insert(Result);
-  return Result;
-}
+void HLMatrixLowerPass::lowerGlobal(GlobalVariable *Global) {
+  if (Global->user_empty()) return;
 
-Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
-  Type *ResultTy = LowerMatrixType(CI->getType());
-  IRBuilder<> Builder(CI);
-  HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
-  Type *OpTy = LowerMatrixType(
-      CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
-  UndefValue *tmp = UndefValue::get(OpTy);
-  bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
+  PointerType *LoweredPtrTy = cast<PointerType>(HLMatrixType::getLoweredType(Global->getType()));
+  DXASSERT_NOMSG(LoweredPtrTy != Global->getType());
 
-  Instruction *Result = nullptr;
+  Constant *LoweredInitVal = Global->hasInitializer()
+    ? lowerConstInitVal(Global->getInitializer()) : nullptr;
+  GlobalVariable *LoweredGlobal = new GlobalVariable(*m_pModule, LoweredPtrTy->getElementType(),
+    Global->isConstant(), Global->getLinkage(), LoweredInitVal,
+    Global->getName() + ".v", /*InsertBefore*/ nullptr, Global->getThreadLocalMode(),
+    Global->getType()->getAddressSpace());
 
-  switch (opcode) {
-  case HLBinaryOpcode::Add:
-    if (isFloat)
-      Result = BinaryOperator::CreateFAdd(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateAdd(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Sub:
-    if (isFloat)
-      Result = BinaryOperator::CreateFSub(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateSub(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Mul:
-    if (isFloat)
-      Result = BinaryOperator::CreateFMul(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateMul(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Div:
-    if (isFloat)
-      Result = BinaryOperator::CreateFDiv(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateSDiv(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Rem:
-    if (isFloat)
-      Result = BinaryOperator::CreateFRem(tmp, tmp);
-    else
-      Result = BinaryOperator::CreateSRem(tmp, tmp);
-    break;
-  case HLBinaryOpcode::And:
-    Result = BinaryOperator::CreateAnd(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Or:
-    Result = BinaryOperator::CreateOr(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Xor:
-    Result = BinaryOperator::CreateXor(tmp, tmp);
-    break;
-  case HLBinaryOpcode::Shl: {
-    Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
-                      "must be matrix type here");
-    Result = BinaryOperator::CreateShl(tmp, tmp);
-  } break;
-  case HLBinaryOpcode::Shr: {
-    Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
-                      "must be matrix type here");
-    Result = BinaryOperator::CreateAShr(tmp, tmp);
-  } break;
-  case HLBinaryOpcode::LT:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
-    break;
-  case HLBinaryOpcode::GT:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
-    break;
-  case HLBinaryOpcode::LE:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::GE:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::EQ:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
-    break;
-  case HLBinaryOpcode::NE:
-    if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
-    else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::UDiv:
-    Result = BinaryOperator::CreateUDiv(tmp, tmp);
-    break;
-  case HLBinaryOpcode::URem:
-    Result = BinaryOperator::CreateURem(tmp, tmp);
-    break;
-  case HLBinaryOpcode::UShr: {
-    Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
-                      "must be matrix type here");
-    Result = BinaryOperator::CreateLShr(tmp, tmp);
-  } break;
-  case HLBinaryOpcode::ULT:
-    Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
-    break;
-  case HLBinaryOpcode::UGT:
-    Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
-    break;
-  case HLBinaryOpcode::ULE:
-    Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::UGE:
-    Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
-    break;
-  case HLBinaryOpcode::LAnd:
-  case HLBinaryOpcode::LOr: {
-    Value *vecZero = Constant::getNullValue(ResultTy);
-    Instruction *cmpL;
-    if (isFloat)
-      cmpL = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, vecZero);
-    else
-      cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, vecZero);
-    Builder.Insert(cmpL);
-
-    Instruction *cmpR;
-    if (isFloat)
-      cmpR =
-          CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, vecZero);
-    else
-      cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, vecZero);
-    Builder.Insert(cmpR);
-
-    // How to map l, r back? Need check opcode
-    if (opcode == HLBinaryOpcode::LOr)
-      Result = BinaryOperator::CreateOr(cmpL, cmpR);
-    else
-      Result = BinaryOperator::CreateAnd(cmpL, cmpR);
-    break;
-  }
-  default:
-    DXASSERT(0, "not implement");
-    return nullptr;
+  // Add debug info.
+  if (m_HasDbgInfo) {
+    DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
+    HLModule::UpdateGlobalVariableDebugInfo(Global, Finder, LoweredGlobal);
   }
-  Builder.Insert(Result);
-  return Result;
-}
 
-// Create BitCast if ptr, otherwise, create alloca of new type, write to bitcast of alloca, and return load from alloca
-// If bOrigAllocaTy is true: create alloca of old type instead, write to alloca, and return load from bitcast of alloca
-static Instruction *BitCastValueOrPtr(Value* V, Instruction *Insert, Type *Ty, bool bOrigAllocaTy = false, const Twine &Name = "") {
-  IRBuilder<> Builder(Insert);
-  if (Ty->isPointerTy()) {
-    // If pointer, we can bitcast directly
-    return cast<Instruction>(Builder.CreateBitCast(V, Ty, Name));
-  } else {
-    // If value, we have to alloca, store to bitcast ptr, and load
-    IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Insert));
-    Type *allocaTy = bOrigAllocaTy ? V->getType() : Ty;
-    Type *otherTy = bOrigAllocaTy ? Ty : V->getType();
-    Instruction *allocaInst = AllocaBuilder.CreateAlloca(allocaTy);
-    Instruction *bitCast = cast<Instruction>(Builder.CreateBitCast(allocaInst, otherTy->getPointerTo()));
-    Builder.CreateStore(V, bOrigAllocaTy ? allocaInst : bitCast);
-    return Builder.CreateLoad(bOrigAllocaTy ? bitCast : allocaInst, Name);
-  }
+  replaceAllVariableUses(Global, LoweredGlobal);
+  Global->removeDeadConstantUsers();
+  Global->eraseFromParent();
 }
 
-void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
-  Value *vecVal = nullptr;
-
-  if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
-    hlsl::HLOpcodeGroup group =
-        hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
-    switch (group) {
-    case HLOpcodeGroup::HLIntrinsic: {
-      vecVal = MatIntrinsicToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLSelect: {
-      vecVal = MatIntrinsicToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLBinOp: {
-      vecVal = TrivialMatBinOpToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLUnOp: {
-      vecVal = TrivialMatUnOpToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLCast: {
-      vecVal = MatCastToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLInit: {
-      vecVal = MatIntrinsicToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLMatLoadStore: {
-      vecVal = MatLdStToVec(CI);
-    } break;
-    case HLOpcodeGroup::HLSubscript: {
-      vecVal = MatSubscriptToVec(CI);
-    } break;
-    case HLOpcodeGroup::NotHL: {
-      // Translate user function return
-      vecVal = BitCastValueOrPtr( matInst,
-                                  matInst->getNextNode(),
-                                  HLMatrixLower::LowerMatrixType(matInst->getType()),
-                                  /*bOrigAllocaTy*/ false,
-                                  matInst->getName());
-      // matrix equivalent of this new vector will be the original, retained user call
-      vecToMatMap[vecVal] = matInst;
-    } break;
-    default:
-      DXASSERT(0, "invalid inst");
-    }
-  } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
-    Type *Ty = AI->getAllocatedType();
-    Type *matTy = Ty;
-    
-    IRBuilder<> AllocaBuilder(AI);
-    if (Ty->isArrayTy()) {
-      Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType(), /*forMem*/ true);
-      vecTy = vecTy->getPointerElementType();
-      vecVal = AllocaBuilder.CreateAlloca(vecTy, nullptr, AI->getName());
-    } else {
-      Type *vecTy = HLMatrixLower::LowerMatrixType(matTy, /*forMem*/ true);
-      vecVal = AllocaBuilder.CreateAlloca(vecTy, nullptr, AI->getName());
-    }
-    // Update debug info.
-    DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
-    if (DDI) {
-      LLVMContext &Context = AI->getContext();
-      Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
-      Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
-      Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecVal));
-      IRBuilder<> debugBuilder(DDI);
-      debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
+Constant *HLMatrixLowerPass::lowerConstInitVal(Constant *Val) {
+  Type *Ty = Val->getType();
+
+  // If it's an array of matrices, recurse for each element or nested array
+  if (ArrayType *ArrayTy = dyn_cast<ArrayType>(Ty)) {
+    SmallVector<Constant*, 4> LoweredElems;
+    unsigned NumElems = ArrayTy->getNumElements();
+    LoweredElems.reserve(NumElems);
+    for (unsigned ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
+      Constant *ArrayElem = Val->getAggregateElement(ElemIdx);
+      LoweredElems.emplace_back(lowerConstInitVal(ArrayElem));
     }
 
-    if (HLModule::HasPreciseAttributeWithMetadata(AI))
-      HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
-
-  } else if (GetIfMatrixGEPOfUDTAlloca(matInst) ||
-             GetIfMatrixGEPOfUDTArg(matInst, *m_pHLModule)) {
-    // If GEP from alloca of non-matrix UDT, bitcast
-    IRBuilder<> Builder(matInst->getNextNode());
-    vecVal = Builder.CreateBitCast(matInst,
-      HLMatrixLower::LowerMatrixType(
-        matInst->getType()->getPointerElementType() )->getPointerTo());
-    // matrix equivalent of this new vector will be the original, retained GEP
-    vecToMatMap[vecVal] = matInst;
-  } else {
-    DXASSERT(0, "invalid inst");
-  }
-  if (vecVal) {
-    matToVecMap[matInst] = vecVal;
+    Type *LoweredElemTy = HLMatrixType::getLoweredType(ArrayTy->getElementType());
+    ArrayType *LoweredArrayTy = ArrayType::get(LoweredElemTy, NumElems);
+    return ConstantArray::get(LoweredArrayTy, LoweredElems);
   }
-}
 
-// Replace matInst with vecVal on matUseInst.
-void HLMatrixLowerPass::TrivialMatUnOpReplace(Value *matVal,
-                                             Value *vecVal,
-                                             CallInst *matUseInst) {
-  (void)(matVal); // Unused
-  HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
-  switch (opcode) {
-  case HLUnaryOpcode::Plus: // add(x, 0)
-    // Ideally we'd get completely rid of the instruction for +mat,
-    // but matToVecMap needs to point to some instruction.
-  case HLUnaryOpcode::Not: // xor(x, -1)
-  case HLUnaryOpcode::LNot: // cmpeq(x, 0)
-  case HLUnaryOpcode::PostInc:
-  case HLUnaryOpcode::PreInc:
-  case HLUnaryOpcode::PostDec:
-  case HLUnaryOpcode::PreDec:
-    vecUseInst->setOperand(0, vecVal);
-    break;
-  case HLUnaryOpcode::Minus: // sub(0, x)
-    vecUseInst->setOperand(1, vecVal);
-    break;
-  case HLUnaryOpcode::Invalid:
-  case HLUnaryOpcode::NumOfUO:
-    DXASSERT(false, "Unexpected HL unary opcode.");
-    break;
-  }
-}
+  // Otherwise it's a matrix, lower it to a vector
+  HLMatrixType MatTy = HLMatrixType::cast(Ty);
+  DXASSERT_NOMSG(isa<StructType>(Ty));
+  Constant *RowArrayVal = Val->getAggregateElement((unsigned)0);
 
-// Replace matInst with vecVal on matUseInst.
-void HLMatrixLowerPass::TrivialMatBinOpReplace(Value *matVal,
-                                              Value *vecVal,
-                                              CallInst *matUseInst) {
-  HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
-
-  if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
-    if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matVal)
-      vecUseInst->setOperand(0, vecVal);
-    if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matVal)
-      vecUseInst->setOperand(1, vecVal);
-  } else {
-    if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
-      matVal) {
-      Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
-      vecCmp->setOperand(0, vecVal);
-    }
-    if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
-      matVal) {
-      Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
-      vecCmp->setOperand(0, vecVal);
+  // Original initializer should have been produced in row/column-major order
+  // depending on the qualifiers of the target variable, so preserve the order.
+  SmallVector<Constant*, 16> MatElems;
+  for (unsigned RowIdx = 0; RowIdx < MatTy.getNumRows(); ++RowIdx) {
+    Constant *RowVal = RowArrayVal->getAggregateElement(RowIdx);
+    for (unsigned ColIdx = 0; ColIdx < MatTy.getNumColumns(); ++ColIdx) {
+      MatElems.emplace_back(RowVal->getAggregateElement(ColIdx));
     }
   }
+
+  Constant *Vec = ConstantVector::get(MatElems);
+  
+  // Matrix elements are always in register representation,
+  // but the lowered global variable is of vector type in
+  // its memory representation, so we must convert here.
+
+  // This will produce a constant so we can use an IRBuilder without a valid insertion point.
+  IRBuilder<> DummyBuilder(Val->getContext());
+  return cast<Constant>(MatTy.emitLoweredRegToMem(Vec, DummyBuilder));
 }
 
-static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
-  llvm::FunctionType *MadFuncTy =
-      llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
+AllocaInst *HLMatrixLowerPass::lowerAlloca(AllocaInst *MatAlloca) {
+  PointerType *LoweredAllocaTy = cast<PointerType>(HLMatrixType::getLoweredType(MatAlloca->getType()));
+
+  IRBuilder<> Builder(MatAlloca);
+  AllocaInst *LoweredAlloca = Builder.CreateAlloca(
+    LoweredAllocaTy->getElementType(), nullptr, MatAlloca->getName());
+
+  // Update debug info.
+  if (DbgDeclareInst *DbgDeclare = llvm::FindAllocaDbgDeclare(MatAlloca)) {
+    LLVMContext &Context = MatAlloca->getContext();
+    Value *DbgDeclareVar = MetadataAsValue::get(Context, DbgDeclare->getRawVariable());
+    Value *DbgDeclareExpr = MetadataAsValue::get(Context, DbgDeclare->getRawExpression());
+    Value *ValueMetadata = MetadataAsValue::get(Context, ValueAsMetadata::get(LoweredAlloca));
+    IRBuilder<> DebugBuilder(DbgDeclare);
+    DebugBuilder.CreateCall(DbgDeclare->getCalledFunction(), { ValueMetadata, DbgDeclareVar, DbgDeclareExpr });
+  }
 
-  Function *MAD =
-      GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
-                            (unsigned)madOp);
-  return MAD;
+  if (HLModule::HasPreciseAttributeWithMetadata(MatAlloca))
+    HLModule::MarkPreciseAttributeWithMetadata(LoweredAlloca);
+
+  replaceAllVariableUses(MatAlloca, LoweredAlloca);
+
+  return LoweredAlloca;
 }
 
-void HLMatrixLowerPass::TranslateMatMatMul(Value *matVal,
-                                           Value *vecVal,
-                                           CallInst *mulInst, bool isSigned) {
-  (void)(matVal); // Unused; retrieved from matToVecMap directly
-  DXASSERT(matToVecMap.count(mulInst), "must have vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
-  // Already translated.
-  if (!isa<CallInst>(vecUseInst))
-    return;
-  Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
-  Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
-  unsigned rCol, rRow;
-  GetMatrixInfo(RVal->getType(), rCol, rRow);
-  DXASSERT_NOMSG(col == rRow);
-
-  bool isFloat = EltTy->isFloatingPointTy();
-
-  Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
-  IRBuilder<> Builder(vecUseInst);
-
-  Value *lMat = matToVecMap[cast<Instruction>(LVal)];
-  Value *rMat = matToVecMap[cast<Instruction>(RVal)];
-
-  auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
-    unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
-    unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
-    Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
-    Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
-    return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
-                   : Builder.CreateMul(lMatElt, rMatElt);
-  };
-
-  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
-  Type *opcodeTy = Builder.getInt32Ty();
-  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
-                                          *m_pHLModule->GetModule());
-  Value *madOpArg = Builder.getInt32((unsigned)madOp);
-
-  auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
-                             Value *acc) -> Value * {
-    unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
-    unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
-    Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
-    Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
-    return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
-  };
-
-  for (unsigned r = 0; r < row; r++) {
-    for (unsigned c = 0; c < rCol; c++) {
-      unsigned lc = 0;
-      Value *tmpVal = CreateOneEltMul(r, lc, c);
-
-      for (lc = 1; lc < col; lc++) {
-        tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
-      }
-      unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
-      retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
+void HLMatrixLowerPass::lowerInstruction(Instruction* Inst) {
+  if (CallInst *Call = dyn_cast<CallInst>(Inst)) {
+    Value *LoweredValue = lowerCall(Call);
+
+    // lowerCall returns the lowered value iff we should discard
+    // the original matrix instruction and replace all of its uses
+    // by the lowered value. It returns nullptr to opt-out of this.
+    if (LoweredValue != nullptr) {
+      replaceAllUsesByLoweredValue(Call, LoweredValue);
+      addToDeadInsts(Inst);
     }
   }
+  else if (ReturnInst *Return = dyn_cast<ReturnInst>(Inst)) {
+    lowerReturn(Return);
+  }
+  else
+    llvm_unreachable("Unexpected matrix instruction type.");
+}
+
+void HLMatrixLowerPass::lowerReturn(ReturnInst* Return) {
+  Value *RetVal = Return->getReturnValue();
+  Type *RetTy = RetVal->getType();
+  DXASSERT(!RetTy->isPointerTy(), "Unexpected matrix returned by pointer.");
+
+  IRBuilder<> Builder(Return);
+  Value *LoweredRetVal = getLoweredByValOperand(RetVal, Builder, /* DiscardStub */ true);
+
+  // Since we're not lowering the signature, we can't return the lowered value directly,
+  // so insert a bitcast, which HLMatrixBitcastLower knows how to eliminate.
+  Value *BitCastedRetVal = bitCastValue(LoweredRetVal, RetVal->getType(), /* DstTyAlloca */ false, Builder);
+  Return->setOperand(0, BitCastedRetVal);
+}
+
+Value *HLMatrixLowerPass::lowerCall(CallInst *Call) {
+  HLOpcodeGroup OpcodeGroup = GetHLOpcodeGroupByName(Call->getCalledFunction());
+  return OpcodeGroup == HLOpcodeGroup::NotHL
+    ? lowerNonHLCall(Call) : lowerHLOperation(Call, OpcodeGroup);
+}
+
+Value *HLMatrixLowerPass::lowerNonHLCall(CallInst *Call) {
+  // First, handle any operand of matrix-derived type
+  // We don't lower the callee's signature in this pass,
+  // so, for any matrix-typed parameter, we create a bitcast from the
+  // lowered vector back to the matrix type, which the later HLMatrixBitcastLower
+  // pass knows how to eliminate.
+  IRBuilder<> PreCallBuilder(Call);
+  unsigned NumArgs = Call->getNumArgOperands();
+  for (unsigned ArgIdx = 0; ArgIdx < NumArgs; ++ArgIdx) {
+    Use &ArgUse = Call->getArgOperandUse(ArgIdx);
+    if (ArgUse->getType()->isPointerTy()) {
+      // Byref arg
+      Value *LoweredArg = tryGetLoweredPtrOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
+      if (LoweredArg != nullptr) {
+        // Pointer to a matrix we've lowered, insert a bitcast back to matrix pointer type.
+        Value *BitCastedArg = PreCallBuilder.CreateBitCast(LoweredArg, ArgUse->getType());
+        ArgUse.set(BitCastedArg);
+      }
+    }
+    else {
+      // Byvalue arg
+      Value *LoweredArg = getLoweredByValOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
+      if (LoweredArg == ArgUse.get()) continue;
 
-  Instruction *matmatMul = cast<Instruction>(retVal);
-  // Replace vec transpose function call with shuf.
-  vecUseInst->replaceAllUsesWith(matmatMul);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[mulInst] = matmatMul;
-}
+      Value *BitCastedArg = bitCastValue(LoweredArg, ArgUse->getType(), /* DstTyAlloca */ false, PreCallBuilder);
+      ArgUse.set(BitCastedArg);
+    }
+  }
 
-void HLMatrixLowerPass::TranslateMatVecMul(Value *matVal,
-                                           Value *vecVal,
-                                           CallInst *mulInst, bool isSigned) {
-  // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
-  Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
+  // Now check the return type
+  HLMatrixType RetMatTy = HLMatrixType::dyn_cast(Call->getType());
+  if (!RetMatTy) {
+    DXASSERT(!HLMatrixType::isMatrixPtrOrArrayPtr(Call->getType()),
+      "Unexpected user call returning a matrix by pointer.");
+    // Nothing to replace, other instructions can consume a non-matrix return type.
+    return nullptr;
+  }
 
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
-  DXASSERT_NOMSG(RVal->getType()->getVectorNumElements() == col);
+  // The callee returns a matrix, and we don't lower signatures in this pass.
+  // We perform a sketchy bitcast to the lowered register-representation type,
+  // which the later HLMatrixBitcastLower pass knows how to eliminate.
+  IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Call));
+  Value *LoweredAlloca = AllocaBuilder.CreateAlloca(RetMatTy.getLoweredVectorTypeForReg());
+  
+  IRBuilder<> PostCallBuilder(Call->getNextNode());
+  Value *BitCastedAlloca = PostCallBuilder.CreateBitCast(LoweredAlloca, Call->getType()->getPointerTo());
+  
+  // This is slightly tricky
+  // We want to replace all uses of the matrix-returning call by the bitcasted value,
+  // but the store to the bitcasted pointer itself is a use of that matrix,
+  // so we need to create the load, replace the uses, and then insert the store.
+  LoadInst *LoweredVal = PostCallBuilder.CreateLoad(LoweredAlloca);
+  replaceAllUsesByLoweredValue(Call, LoweredVal);
+
+  // Now we can insert the store. Make sure to do so before the load.
+  PostCallBuilder.SetInsertPoint(LoweredVal);
+  PostCallBuilder.CreateStore(Call, BitCastedAlloca);
+  
+  // Return nullptr since we did our own uses replacement and we don't want
+  // the matrix instruction to be marked as dead since we're still using it.
+  return nullptr;
+}
 
-  bool isFloat = EltTy->isFloatingPointTy();
+Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup) {
+  IRBuilder<> Builder(Call);
+  switch (OpcodeGroup) {
+  case HLOpcodeGroup::HLIntrinsic:
+    return lowerHLIntrinsic(Call, static_cast<IntrinsicOp>(GetHLOpcode(Call)));
 
-  Value *retVal = llvm::UndefValue::get(mulInst->getType());
-  IRBuilder<> Builder(mulInst);
+  case HLOpcodeGroup::HLBinOp:
+    return lowerHLBinaryOperation(
+      Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
+      Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
+      static_cast<HLBinaryOpcode>(GetHLOpcode(Call)), Builder);
 
-  Value *vec = RVal;
-  Value *mat = vecVal; // vec version of matInst;
+  case HLOpcodeGroup::HLUnOp:
+    return lowerHLUnaryOperation(
+      Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx),
+      static_cast<HLUnaryOpcode>(GetHLOpcode(Call)), Builder);
 
-  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
-  Type *opcodeTy = Builder.getInt32Ty();
-  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
-                                          *m_pHLModule->GetModule());
-  Value *madOpArg = Builder.getInt32((unsigned)madOp);
+  case HLOpcodeGroup::HLMatLoadStore:
+    return lowerHLLoadStore(Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
 
-  auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
-    Value *vecElt = Builder.CreateExtractElement(vec, c);
-    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
-    return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
-  };
+  case HLOpcodeGroup::HLCast:
+    return lowerHLCast(
+      Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Call->getType(),
+      static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
 
-  for (unsigned r = 0; r < row; r++) {
-    unsigned c = 0;
-    Value *vecElt = Builder.CreateExtractElement(vec, c);
-    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
+  case HLOpcodeGroup::HLSubscript:
+    return lowerHLSubscript(Call, static_cast<HLSubscriptOpcode>(GetHLOpcode(Call)));
 
-    Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
-                            : Builder.CreateMul(vecElt, matElt);
+  case HLOpcodeGroup::HLInit:
+    return lowerHLInit(Call);
 
-    for (c = 1; c < col; c++) {
-      tmpVal = CreateOneEltMad(r, c, tmpVal);
-    }
+  case HLOpcodeGroup::HLSelect:
+    return lowerHLSelect(Call);
 
-    retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
+  default:
+    llvm_unreachable("Unexpected matrix opcode");
   }
-
-  mulInst->replaceAllUsesWith(retVal);
-  AddToDeadInsts(mulInst);
 }
 
-void HLMatrixLowerPass::TranslateVecMatMul(Value *matVal,
-                                           Value *vecVal,
-                                           CallInst *mulInst, bool isSigned) {
-  Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
-  // matVal should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-  Value *RVal = vecVal;
-
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
-  DXASSERT_NOMSG(LVal->getType()->getVectorNumElements() == row);
-
-  bool isFloat = EltTy->isFloatingPointTy();
-
-  Value *retVal = llvm::UndefValue::get(mulInst->getType());
-  IRBuilder<> Builder(mulInst);
-
-  Value *vec = LVal;
-  Value *mat = RVal;
-
-  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
-  Type *opcodeTy = Builder.getInt32Ty();
-  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
-                                          *m_pHLModule->GetModule());
-  Value *madOpArg = Builder.getInt32((unsigned)madOp);
-
-  auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
-    Value *vecElt = Builder.CreateExtractElement(vec, r);
-    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
-    return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
-  };
-
-  for (unsigned c = 0; c < col; c++) {
-    unsigned r = 0;
-    Value *vecElt = Builder.CreateExtractElement(vec, r);
-    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
-
-    Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
-                            : Builder.CreateMul(vecElt, matElt);
-
-    for (r = 1; r < row; r++) {
-      tmpVal = CreateOneEltMad(r, c, tmpVal);
-    }
+static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
+  Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
+  SmallVector<Type*, 4> ArgTys;
+  ArgTys.reserve(Args.size());
+  for (Value *Arg : Args)
+    ArgTys.emplace_back(Arg->getType());
 
-    retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
-  }
+  FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
+  Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
 
-  mulInst->replaceAllUsesWith(retVal);
-  AddToDeadInsts(mulInst);
+  return Builder.CreateCall(Func, Args);
 }
 
-void HLMatrixLowerPass::TranslateMul(Value *matVal, Value *vecVal,
-                                     CallInst *mulInst, bool isSigned) {
-  Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
-  Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-
-  bool LMat = dxilutil::IsHLSLMatrixType(LVal->getType());
-  bool RMat = dxilutil::IsHLSLMatrixType(RVal->getType());
-  if (LMat && RMat) {
-    TranslateMatMatMul(matVal, vecVal, mulInst, isSigned);
-  } else if (LMat) {
-    TranslateMatVecMul(matVal, vecVal, mulInst, isSigned);
-  } else {
-    TranslateVecMatMul(matVal, vecVal, mulInst, isSigned);
-  }
-}
+Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
+  IRBuilder<> Builder(Call);
 
-void HLMatrixLowerPass::TranslateMatTranspose(Value *matVal,
-                                              Value *vecVal,
-                                              CallInst *transposeInst) {
-  // Matrix value is row major, transpose is cast it to col major.
-  TranslateMatMajorCast(matVal, vecVal, transposeInst,
-      /*bRowToCol*/ true, /*bTranspose*/ true);
-}
+  // See if this is a matrix-specific intrinsic which we should expand here
+  switch (Opcode) {
+  case IntrinsicOp::IOP_umul:
+  case IntrinsicOp::IOP_mul:
+    return lowerHLMulIntrinsic(
+      Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
+      Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
+      /* Unsigned */ Opcode == IntrinsicOp::IOP_umul, Builder);
+  case IntrinsicOp::IOP_transpose:
+    return lowerHLTransposeIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
+  case IntrinsicOp::IOP_determinant:
+    return lowerHLDeterminantIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
+  }
+
+  // Delegate to a lowered intrinsic call
+  SmallVector<Value*, 4> LoweredArgs;
+  LoweredArgs.reserve(Call->getNumArgOperands());
+  for (Value *Arg : Call->arg_operands()) {
+    if (Arg->getType()->isPointerTy()) {
+      // ByRef parameter (for example, frexp's second parameter)
+      // If the argument points to a lowered matrix variable, replace it here,
+      // otherwise preserve the matrix type and let further passes handle the lowering.
+      Value *LoweredArg = tryGetLoweredPtrOperand(Arg, Builder);
+      if (LoweredArg == nullptr) LoweredArg = Arg;
+      LoweredArgs.emplace_back(LoweredArg);
+    }
+    else {
+      LoweredArgs.emplace_back(getLoweredByValOperand(Arg, Builder));
+    }
+  }
 
-static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
-                             IRBuilder<> &Builder) {
-  Value *mul0 = Builder.CreateFMul(m00, m11);
-  Value *mul1 = Builder.CreateFMul(m01, m10);
-  return Builder.CreateFSub(mul0, mul1);
+  Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
+  return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode), 
+    LoweredRetTy, LoweredArgs, Builder);
 }
 
-static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
-                             Value *m10, Value *m11, Value *m12,
-                             Value *m20, Value *m21, Value *m22,
-                             IRBuilder<> &Builder) {
-  Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
-  Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
-  Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
-  deter00 = Builder.CreateFMul(m00, deter00);
-  deter01 = Builder.CreateFMul(m01, deter01);
-  deter02 = Builder.CreateFMul(m02, deter02);
-  Value *result = Builder.CreateFSub(deter00, deter01);
-  result = Builder.CreateFAdd(result, deter02);
-  return result;
-}
+Value *HLMatrixLowerPass::lowerHLMulIntrinsic(Value* Lhs, Value *Rhs,
+    bool Unsigned, IRBuilder<> &Builder) {
+  HLMatrixType LhsMatTy = HLMatrixType::dyn_cast(Lhs->getType());
+  HLMatrixType RhsMatTy = HLMatrixType::dyn_cast(Rhs->getType());
+  Value* LoweredLhs = getLoweredByValOperand(Lhs, Builder);
+  Value* LoweredRhs = getLoweredByValOperand(Rhs, Builder);
 
-static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
-                             Value *m10, Value *m11, Value *m12, Value *m13,
-                             Value *m20, Value *m21, Value *m22, Value *m23,
-                             Value *m30, Value *m31, Value *m32, Value *m33,
-                             IRBuilder<> &Builder) {
-  Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
-  Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
-  Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
-  Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
-  deter00 = Builder.CreateFMul(m00, deter00);
-  deter01 = Builder.CreateFMul(m01, deter01);
-  deter02 = Builder.CreateFMul(m02, deter02);
-  deter03 = Builder.CreateFMul(m03, deter03);
-  Value *result = Builder.CreateFSub(deter00, deter01);
-  result = Builder.CreateFAdd(result, deter02);
-  result = Builder.CreateFSub(result, deter03);
-  return result;
-}
+  DXASSERT(LoweredLhs->getType()->getScalarType() == LoweredRhs->getType()->getScalarType(),
+    "Unexpected element type mismatch in mul intrinsic.");
+  DXASSERT(cast<VectorType>(LoweredLhs->getType()) && cast<VectorType>(LoweredLhs->getType()),
+    "Unexpected scalar in lowered matrix mul intrinsic operands.");
 
+  Type* ElemTy = LoweredLhs->getType()->getScalarType();
 
-void HLMatrixLowerPass::TranslateMatDeterminant(Value *matVal, Value *vecVal,
-    CallInst *determinantInst) {
-  unsigned row, col;
-  GetMatrixInfo(matVal->getType(), col, row);
-  IRBuilder<> Builder(determinantInst);
-  // when row == 1, result is vecVal.
-  Value *Result = vecVal;
-  if (row == 2) {
-    Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
-    Value *m01 = Builder.CreateExtractElement(vecVal, 1);
-    Value *m10 = Builder.CreateExtractElement(vecVal, 2);
-    Value *m11 = Builder.CreateExtractElement(vecVal, 3);
-    Result = Determinant2x2(m00, m01, m10, m11, Builder);
+  // Figure out the dimensions of each side
+  unsigned LhsNumRows, LhsNumCols, RhsNumRows, RhsNumCols;
+  if (LhsMatTy && RhsMatTy) {
+    LhsNumRows = LhsMatTy.getNumRows();
+    LhsNumCols = LhsMatTy.getNumColumns();
+    RhsNumRows = RhsMatTy.getNumRows();
+    RhsNumCols = RhsMatTy.getNumColumns();
   }
-  else if (row == 3) {
-    Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
-    Value *m01 = Builder.CreateExtractElement(vecVal, 1);
-    Value *m02 = Builder.CreateExtractElement(vecVal, 2);
-    Value *m10 = Builder.CreateExtractElement(vecVal, 3);
-    Value *m11 = Builder.CreateExtractElement(vecVal, 4);
-    Value *m12 = Builder.CreateExtractElement(vecVal, 5);
-    Value *m20 = Builder.CreateExtractElement(vecVal, 6);
-    Value *m21 = Builder.CreateExtractElement(vecVal, 7);
-    Value *m22 = Builder.CreateExtractElement(vecVal, 8);
-    Result = Determinant3x3(m00, m01, m02, 
-                            m10, m11, m12, 
-                            m20, m21, m22, Builder);
+  else if (LhsMatTy) {
+    LhsNumRows = LhsMatTy.getNumRows();
+    LhsNumCols = LhsMatTy.getNumColumns();
+    RhsNumRows = LoweredRhs->getType()->getVectorNumElements();
+    RhsNumCols = 1;
   }
-  else if (row == 4) {
-    Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
-    Value *m01 = Builder.CreateExtractElement(vecVal, 1);
-    Value *m02 = Builder.CreateExtractElement(vecVal, 2);
-    Value *m03 = Builder.CreateExtractElement(vecVal, 3);
-
-    Value *m10 = Builder.CreateExtractElement(vecVal, 4);
-    Value *m11 = Builder.CreateExtractElement(vecVal, 5);
-    Value *m12 = Builder.CreateExtractElement(vecVal, 6);
-    Value *m13 = Builder.CreateExtractElement(vecVal, 7);
-
-    Value *m20 = Builder.CreateExtractElement(vecVal, 8);
-    Value *m21 = Builder.CreateExtractElement(vecVal, 9);
-    Value *m22 = Builder.CreateExtractElement(vecVal, 10);
-    Value *m23 = Builder.CreateExtractElement(vecVal, 11);
-
-    Value *m30 = Builder.CreateExtractElement(vecVal, 12);
-    Value *m31 = Builder.CreateExtractElement(vecVal, 13);
-    Value *m32 = Builder.CreateExtractElement(vecVal, 14);
-    Value *m33 = Builder.CreateExtractElement(vecVal, 15);
-
-    Result = Determinant4x4(m00, m01, m02, m03,
-                            m10, m11, m12, m13,
-                            m20, m21, m22, m23,
-                            m30, m31, m32, m33,
-                            Builder);
-  } else {
-    DXASSERT(row == 1, "invalid matrix type");
-    Result = Builder.CreateExtractElement(Result, (uint64_t)0);
+  else if (RhsMatTy) {
+    LhsNumRows = 1;
+    LhsNumCols = LoweredLhs->getType()->getVectorNumElements();
+    RhsNumRows = RhsMatTy.getNumRows();
+    RhsNumCols = RhsMatTy.getNumColumns();
   }
-  determinantInst->replaceAllUsesWith(Result);
-  AddToDeadInsts(determinantInst);
+  else {
+    llvm_unreachable("mul intrinsic was identified as a matrix operation but neither operand is a matrix.");
+  }
+
+  DXASSERT(LhsNumCols == RhsNumRows, "Matrix mul intrinsic operands dimensions mismatch.");
+  HLMatrixType ResultMatTy(ElemTy, LhsNumRows, RhsNumCols);
+  unsigned AccCount = LhsNumCols;
+
+  // Get the multiply-and-add intrinsic function, we'll need it
+  IntrinsicOp MadOpcode = Unsigned ? IntrinsicOp::IOP_umad : IntrinsicOp::IOP_mad;
+  FunctionType *MadFuncTy = FunctionType::get(ElemTy, { Builder.getInt32Ty(), ElemTy, ElemTy, ElemTy }, false);
+  Function *MadFunc = GetOrCreateHLFunction(*m_pModule, MadFuncTy, HLOpcodeGroup::HLIntrinsic, (unsigned)MadOpcode);
+  Constant *MadOpcodeVal = Builder.getInt32((unsigned)MadOpcode);
+
+  // Perform the multiplication!
+  Value *Result = UndefValue::get(VectorType::get(ElemTy, LhsNumRows * RhsNumCols));
+  for (unsigned ResultRowIdx = 0; ResultRowIdx < ResultMatTy.getNumRows(); ++ResultRowIdx) {
+    for (unsigned ResultColIdx = 0; ResultColIdx < ResultMatTy.getNumColumns(); ++ResultColIdx) {
+      unsigned ResultElemIdx = ResultMatTy.getRowMajorIndex(ResultRowIdx, ResultColIdx);
+      Value *ResultElem = nullptr;
+
+      for (unsigned AccIdx = 0; AccIdx < AccCount; ++AccIdx) {
+        unsigned LhsElemIdx = HLMatrixType::getRowMajorIndex(ResultRowIdx, AccIdx, LhsNumRows, LhsNumCols);
+        unsigned RhsElemIdx = HLMatrixType::getRowMajorIndex(AccIdx, ResultColIdx, RhsNumRows, RhsNumCols);
+        Value* LhsElem = Builder.CreateExtractElement(LoweredLhs, static_cast<uint64_t>(LhsElemIdx));
+        Value* RhsElem = Builder.CreateExtractElement(LoweredRhs, static_cast<uint64_t>(RhsElemIdx));
+        if (ResultElem == nullptr) {
+          ResultElem = ElemTy->isFloatingPointTy()
+            ? Builder.CreateFMul(LhsElem, RhsElem)
+            : Builder.CreateMul(LhsElem, RhsElem);
+        }
+        else {
+          ResultElem = Builder.CreateCall(MadFunc, { MadOpcodeVal, LhsElem, RhsElem, ResultElem });
+        }
+      }
+
+      Result = Builder.CreateInsertElement(Result, ResultElem, static_cast<uint64_t>(ResultElemIdx));
+    }
+  }
+
+  return Result;
 }
 
-void HLMatrixLowerPass::TrivialMatReplace(Value *matVal,
-                                         Value *vecVal,
-                                         CallInst *matUseInst) {
-  CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
+Value *HLMatrixLowerPass::lowerHLTransposeIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
+  HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
+  Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
+  return MatTy.emitLoweredVectorRowToCol(LoweredVal, Builder);
+}
 
-  for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
-    if (matUseInst->getArgOperand(i) == matVal) {
-      vecUseInst->setArgOperand(i, vecVal);
-    }
+static Value *determinant2x2(Value *M00, Value *M01, Value *M10, Value *M11, IRBuilder<> &Builder) {
+  Value *Mul0 = Builder.CreateFMul(M00, M11);
+  Value *Mul1 = Builder.CreateFMul(M01, M10);
+  return Builder.CreateFSub(Mul0, Mul1);
 }
 
-static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned toRows, unsigned toCols) {
-  SmallVector<int, 16> castMask(toCols * toRows);
-  unsigned idx = 0;
-  for (unsigned r = 0; r < toRows; r++)
-    for (unsigned c = 0; c < toCols; c++)
-      castMask[idx++] = c * toRows + r;
-  return cast<Instruction>(
-    Builder.CreateShuffleVector(vecVal, vecVal, castMask));
+static Value *determinant3x3(Value *M00, Value *M01, Value *M02,
+    Value *M10, Value *M11, Value *M12,
+    Value *M20, Value *M21, Value *M22,
+    IRBuilder<> &Builder) {
+  Value *Det00 = determinant2x2(M11, M12, M21, M22, Builder);
+  Value *Det01 = determinant2x2(M10, M12, M20, M22, Builder);
+  Value *Det02 = determinant2x2(M10, M11, M20, M21, Builder);
+  Det00 = Builder.CreateFMul(M00, Det00);
+  Det01 = Builder.CreateFMul(M01, Det01);
+  Det02 = Builder.CreateFMul(M02, Det02);
+  Value *Result = Builder.CreateFSub(Det00, Det01);
+  Result = Builder.CreateFAdd(Result, Det02);
+  return Result;
 }
 
-void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
-                                              Value *vecVal,
-                                              CallInst *castInst,
-                                              bool bRowToCol,
-                                              bool bTranspose) {
-  unsigned col, row;
-  if (!bTranspose) {
-    GetMatrixInfo(castInst->getType(), col, row);
-    DXASSERT(castInst->getType() == matVal->getType(), "type must match");
-  } else {
-    unsigned castCol, castRow;
-    Type *castTy = GetMatrixInfo(castInst->getType(), castCol, castRow);
-    unsigned srcCol, srcRow;
-    Type *srcTy = GetMatrixInfo(matVal->getType(), srcCol, srcRow);
-    DXASSERT_LOCALVAR((castTy == srcTy), srcTy == castTy, "type must match");
-    DXASSERT(castCol == srcRow && castRow == srcCol, "col row must match");
-    col = srcCol;
-    row = srcRow;
-  }
+static Value *determinant4x4(Value *M00, Value *M01, Value *M02, Value *M03,
+    Value *M10, Value *M11, Value *M12, Value *M13,
+    Value *M20, Value *M21, Value *M22, Value *M23,
+    Value *M30, Value *M31, Value *M32, Value *M33,
+    IRBuilder<> &Builder) {
+  Value *Det00 = determinant3x3(M11, M12, M13, M21, M22, M23, M31, M32, M33, Builder);
+  Value *Det01 = determinant3x3(M10, M12, M13, M20, M22, M23, M30, M32, M33, Builder);
+  Value *Det02 = determinant3x3(M10, M11, M13, M20, M21, M23, M30, M31, M33, Builder);
+  Value *Det03 = determinant3x3(M10, M11, M12, M20, M21, M22, M30, M31, M32, Builder);
+  Det00 = Builder.CreateFMul(M00, Det00);
+  Det01 = Builder.CreateFMul(M01, Det01);
+  Det02 = Builder.CreateFMul(M02, Det02);
+  Det03 = Builder.CreateFMul(M03, Det03);
+  Value *Result = Builder.CreateFSub(Det00, Det01);
+  Result = Builder.CreateFAdd(Result, Det02);
+  Result = Builder.CreateFSub(Result, Det03);
+  return Result;
+}
 
-  DXASSERT(matToVecMap.count(castInst), "must have vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
-  // Create before vecUseInst to prevent instructions being inserted after uses.
-  IRBuilder<> Builder(vecUseInst);
+Value *HLMatrixLowerPass::lowerHLDeterminantIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
+  HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
+  DXASSERT_NOMSG(MatTy.getNumColumns() == MatTy.getNumRows());
 
-  if (bRowToCol)
-    std::swap(row, col);
-  Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
+  Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
 
-  // Replace vec cast function call with vecCast.
-  vecUseInst->replaceAllUsesWith(vecCast);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[castInst] = vecCast;
-}
+  // Extract all matrix elements
+  SmallVector<Value*, 16> Elems;
+  for (unsigned ElemIdx = 0; ElemIdx < MatTy.getNumElements(); ++ElemIdx)
+    Elems.emplace_back(Builder.CreateExtractElement(LoweredVal, static_cast<uint64_t>(ElemIdx)));
 
-void HLMatrixLowerPass::TranslateMatMatCast(Value *matVal,
-                                            Value *vecVal,
-                                            CallInst *castInst) {
-  unsigned toCol, toRow;
-  Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
-  unsigned fromCol, fromRow;
-  Type *FromEltTy = GetMatrixInfo(matVal->getType(), fromCol, fromRow);
-  unsigned fromSize = fromCol * fromRow;
-  unsigned toSize = toCol * toRow;
-  DXASSERT(fromSize >= toSize, "cannot extend matrix");
-  DXASSERT(matToVecMap.count(castInst), "must have vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
-
-  IRBuilder<> Builder(vecUseInst);
-  Instruction *vecCast = nullptr;
-
-  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
-
-  if (fromSize == toSize) {
-    vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecVal,
-                             Builder);
-  } else {
-    // shuf first
-    std::vector<int> castMask(toCol * toRow);
-    unsigned idx = 0;
-    for (unsigned r = 0; r < toRow; r++)
-      for (unsigned c = 0; c < toCol; c++) {
-        unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
-        castMask[idx++] = matIdx;
-      }
+  // Delegate to appropriate determinant function
+  switch (MatTy.getNumColumns()) {
+  case 1:
+    return Elems[0];
 
-    Instruction *shuf = cast<Instruction>(
-        Builder.CreateShuffleVector(vecVal, vecVal, castMask));
+  case 2:
+    return determinant2x2(
+      Elems[0], Elems[1],
+      Elems[2], Elems[3],
+      Builder);
 
-    if (ToEltTy != FromEltTy)
-      vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
-                               Builder);
-    else
-      vecCast = shuf;
-  }
-  // Replace vec cast function call with vecCast.
-  vecUseInst->replaceAllUsesWith(vecCast);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[castInst] = vecCast;
-}
+  case 3:
+    return determinant3x3(
+      Elems[0], Elems[1], Elems[2],
+      Elems[3], Elems[4], Elems[5],
+      Elems[6], Elems[7], Elems[8],
+      Builder);
 
-void HLMatrixLowerPass::TranslateMatToOtherCast(Value *matVal,
-                                                Value *vecVal,
-                                                CallInst *castInst) {
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
-  unsigned fromSize = col * row;
-
-  IRBuilder<> Builder(castInst);
-  Value *sizeCast = nullptr;
-
-  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
-
-  Type *ToTy = castInst->getType();
-  if (ToTy->isVectorTy()) {
-    unsigned toSize = ToTy->getVectorNumElements();
-    if (fromSize != toSize) {
-      std::vector<int> castMask(fromSize);
-      for (unsigned c = 0; c < toSize; c++)
-        castMask[c] = c;
-
-      sizeCast = Builder.CreateShuffleVector(vecVal, vecVal, castMask);
-    } else
-      sizeCast = vecVal;
-  } else {
-    DXASSERT(ToTy->isSingleValueType(), "must scalar here");
-    sizeCast = Builder.CreateExtractElement(vecVal, (uint64_t)0);
-  }
+  case 4:
+    return determinant4x4(
+      Elems[0], Elems[1], Elems[2], Elems[3],
+      Elems[4], Elems[5], Elems[6], Elems[7],
+      Elems[8], Elems[9], Elems[10], Elems[11],
+      Elems[12], Elems[13], Elems[14], Elems[15],
+      Builder);
 
-  Value *typeCast = sizeCast;
-  if (EltTy != ToTy->getScalarType()) {
-    typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
+  default:
+    llvm_unreachable("Unexpected matrix dimensions.");
   }
-  // Replace cast function call with typeCast.
-  castInst->replaceAllUsesWith(typeCast);
-  AddToDeadInsts(castInst);
 }
 
-void HLMatrixLowerPass::TranslateMatCast(Value *matVal,
-                                         Value *vecVal,
-                                         CallInst *castInst) {
-  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
-  if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
-      opcode == HLCastOpcode::RowMatrixToColMatrix) {
-    TranslateMatMajorCast(matVal, vecVal, castInst,
-                          opcode == HLCastOpcode::RowMatrixToColMatrix,
-                          /*bTranspose*/false);
-  } else {
-    bool ToMat = dxilutil::IsHLSLMatrixType(castInst->getType());
-    bool FromMat = dxilutil::IsHLSLMatrixType(matVal->getType());
-    if (ToMat && FromMat) {
-      TranslateMatMatCast(matVal, vecVal, castInst);
-    } else if (FromMat)
-      TranslateMatToOtherCast(matVal, vecVal, castInst);
+Value *HLMatrixLowerPass::lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder) {
+  Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
+  VectorType *VecTy = cast<VectorType>(LoweredVal->getType());
+  bool IsFloat = VecTy->getElementType()->isFloatingPointTy();
+  
+  switch (Opcode) {
+  case HLUnaryOpcode::Plus: return LoweredVal; // No-op
+
+  case HLUnaryOpcode::Minus:
+    return IsFloat
+      ? Builder.CreateFSub(Constant::getNullValue(VecTy), LoweredVal)
+      : Builder.CreateSub(Constant::getNullValue(VecTy), LoweredVal);
+
+  case HLUnaryOpcode::LNot:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_UEQ, LoweredVal, Constant::getNullValue(VecTy))
+      : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredVal, Constant::getNullValue(VecTy));
+
+  case HLUnaryOpcode::Not:
+    return Builder.CreateXor(LoweredVal, Constant::getAllOnesValue(VecTy));
+
+  case HLUnaryOpcode::PostInc:
+  case HLUnaryOpcode::PreInc:
+  case HLUnaryOpcode::PostDec:
+  case HLUnaryOpcode::PreDec: {
+    Constant *ScalarOne = IsFloat
+      ? ConstantFP::get(VecTy->getElementType(), 1)
+      : ConstantInt::get(VecTy->getElementType(), 1);
+    Constant *VecOne = ConstantVector::getSplat(VecTy->getNumElements(), ScalarOne);
+    // BUGBUG: This implementation has incorrect semantics (GitHub #1780)
+    if (Opcode == HLUnaryOpcode::PostInc || Opcode == HLUnaryOpcode::PreInc) {
+      return IsFloat
+        ? Builder.CreateFAdd(LoweredVal, VecOne)
+        : Builder.CreateAdd(LoweredVal, VecOne);
+    }
     else {
-      DXASSERT(0, "Not translate as user of matInst");
+      return IsFloat
+        ? Builder.CreateFSub(LoweredVal, VecOne)
+        : Builder.CreateSub(LoweredVal, VecOne);
     }
   }
-}
-
-void HLMatrixLowerPass::MatIntrinsicReplace(Value *matVal,
-                                            Value *vecVal,
-                                            CallInst *matUseInst) {
-  IRBuilder<> Builder(matUseInst);
-  IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
-  switch (opcode) {
-  case IntrinsicOp::IOP_umul:
-    TranslateMul(matVal, vecVal, matUseInst, /*isSigned*/false);
-    break;
-  case IntrinsicOp::IOP_mul:
-    TranslateMul(matVal, vecVal, matUseInst, /*isSigned*/true);
-    break;
-  case IntrinsicOp::IOP_transpose:
-    TranslateMatTranspose(matVal, vecVal, matUseInst);
-    break;
-  case IntrinsicOp::IOP_determinant:
-    TranslateMatDeterminant(matVal, vecVal, matUseInst);
-    break;
   default:
-    CallInst *useInst = matUseInst;
-    if (matToVecMap.count(matUseInst))
-      useInst = cast<CallInst>(matToVecMap[matUseInst]);
-    for (unsigned i = 0; i < useInst->getNumArgOperands(); i++) {
-      if (matUseInst->getArgOperand(i) == matVal)
-        useInst->setArgOperand(i, vecVal);
-    }
-    break;
+    llvm_unreachable("Unsupported unary matrix operator");
   }
 }
 
-void HLMatrixLowerPass::TranslateMatSubscript(Value *matVal, Value *vecVal,
-                                              CallInst *matSubInst) {
-  unsigned opcode = GetHLOpcode(matSubInst);
-  HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
-  assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
-         "matrix don't use default subscript");
-
-  Type *matType = matVal->getType()->getPointerElementType();
-  unsigned col, row;
-  Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
-
-  bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
-                   (matOpcode == HLSubscriptOpcode::RowMatElement);
-  Value *mask =
-      matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
-
-  if (isElement) {
-    Type *resultType = matSubInst->getType()->getPointerElementType();
-    unsigned resultSize = 1;
-    if (resultType->isVectorTy())
-      resultSize = resultType->getVectorNumElements();
-
-    std::vector<int> shufMask(resultSize);
-    Constant *EltIdxs = cast<Constant>(mask);
-    for (unsigned i = 0; i < resultSize; i++) {
-      shufMask[i] =
-          cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
-    }
+Value *HLMatrixLowerPass::lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder) {
+  Value *LoweredLhs = getLoweredByValOperand(Lhs, Builder);
+  Value *LoweredRhs = getLoweredByValOperand(Rhs, Builder);
 
-    for (Value::use_iterator CallUI = matSubInst->use_begin(),
-                             CallE = matSubInst->use_end();
-         CallUI != CallE;) {
-      Use &CallUse = *CallUI++;
-      Instruction *CallUser = cast<Instruction>(CallUse.getUser());
-      IRBuilder<> Builder(CallUser);
-      Value *vecLd = Builder.CreateLoad(vecVal);
-      if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
-        if (resultSize > 1) {
-          Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
-          ld->replaceAllUsesWith(shuf);
-        } else {
-          Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
-          ld->replaceAllUsesWith(elt);
-        }
-      } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
-        Value *val = st->getValueOperand();
-        if (resultSize > 1) {
-          for (unsigned i = 0; i < shufMask.size(); i++) {
-            unsigned idx = shufMask[i];
-            Value *valElt = Builder.CreateExtractElement(val, i);
-            vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
-          }
-          Builder.CreateStore(vecLd, vecVal);
-        } else {
-          vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
-          Builder.CreateStore(vecLd, vecVal);
-        }
-      } else {
-        DXASSERT(0, "matrix element should only used by load/store.");
-      }
-      AddToDeadInsts(CallUser);
-    }
-  } else {
-    // Subscript.
-    // Return a row.
-    // Use insertElement and extractElement.
-    ArrayType *AT = ArrayType::get(EltTy, col*row);
-
-    IRBuilder<> AllocaBuilder(
-        matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
-    Value *tempArray = AllocaBuilder.CreateAlloca(AT);
-    Value *zero = AllocaBuilder.getInt32(0);
-    bool isDynamicIndexing = !isa<ConstantInt>(mask);
-    SmallVector<Value *, 4> idxList;
-    for (unsigned i = 0; i < col; i++) {
-      idxList.emplace_back(
-          matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
-    }
+  DXASSERT(LoweredLhs->getType()->isVectorTy() && LoweredRhs->getType()->isVectorTy(),
+    "Expected lowered binary operation operands to be vectors");
+  DXASSERT(LoweredLhs->getType() == LoweredRhs->getType(),
+    "Expected lowered binary operation operands to have matching types.");
 
-    for (Value::use_iterator CallUI = matSubInst->use_begin(),
-                             CallE = matSubInst->use_end();
-         CallUI != CallE;) {
-      Use &CallUse = *CallUI++;
-      Instruction *CallUser = cast<Instruction>(CallUse.getUser());
-      IRBuilder<> Builder(CallUser);
-      Value *vecLd = Builder.CreateLoad(vecVal);
-      if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
-        Value *sub = UndefValue::get(ld->getType());
-        if (!isDynamicIndexing) {
-          for (unsigned i = 0; i < col; i++) {
-            Value *matIdx = idxList[i];
-            Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
-            sub = Builder.CreateInsertElement(sub, valElt, i);
-          }
-        } else {
-          // Copy vec to array.
-          for (unsigned int i = 0; i < row*col; i++) {
-            Value *Elt =
-                Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
-            Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
-                                                   {zero, Builder.getInt32(i)});
-            Builder.CreateStore(Elt, Ptr);
-          }
-          for (unsigned i = 0; i < col; i++) {
-            Value *matIdx = idxList[i];
-            Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
-            Value *valElt = Builder.CreateLoad(Ptr);
-            sub = Builder.CreateInsertElement(sub, valElt, i);
-          }
-        }
-        ld->replaceAllUsesWith(sub);
-      } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
-        Value *val = st->getValueOperand();
-        if (!isDynamicIndexing) {
-          for (unsigned i = 0; i < col; i++) {
-            Value *matIdx = idxList[i];
-            Value *valElt = Builder.CreateExtractElement(val, i);
-            vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
-          }
-        } else {
-          // Copy vec to array.
-          for (unsigned int i = 0; i < row * col; i++) {
-            Value *Elt =
-                Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
-            Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
-                                                   {zero, Builder.getInt32(i)});
-            Builder.CreateStore(Elt, Ptr);
-          }
-          // Update array.
-          for (unsigned i = 0; i < col; i++) {
-            Value *matIdx = idxList[i];
-            Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
-            Value *valElt = Builder.CreateExtractElement(val, i);
-            Builder.CreateStore(valElt, Ptr);
-          }
-          // Copy array to vec.
-          for (unsigned int i = 0; i < row * col; i++) {
-            Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
-                                                   {zero, Builder.getInt32(i)});
-            Value *Elt = Builder.CreateLoad(Ptr);
-            vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
-          }
-        }
-        Builder.CreateStore(vecLd, vecVal);
-      } else if (GetElementPtrInst *GEP =
-                     dyn_cast<GetElementPtrInst>(CallUser)) {
-        Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
-        Value *NewGEP = Builder.CreateGEP(vecVal, {zero, GEPOffset});
-        GEP->replaceAllUsesWith(NewGEP);
-      } else {
-        DXASSERT(0, "matrix subscript should only used by load/store.");
-      }
-      AddToDeadInsts(CallUser);
-    }
-  }
-  // Check vec version.
-  DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
-  // All the user should have been removed.
-  matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
-  AddToDeadInsts(matSubInst);
-}
+  bool IsFloat = LoweredLhs->getType()->getVectorElementType()->isFloatingPointTy();
 
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
-    Value *matGlobal, ArrayRef<Value *> vecGlobals,
-    CallInst *matLdStInst) {
-  // No dynamic indexing on matrix, flatten matrix to scalars.
-  // vecGlobals already in correct major.
-  Type *matType = matGlobal->getType()->getPointerElementType();
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *vecType = HLMatrixLower::LowerMatrixType(matType);
+  switch (Opcode) {
+  case HLBinaryOpcode::Add:
+    return IsFloat
+      ? Builder.CreateFAdd(LoweredLhs, LoweredRhs)
+      : Builder.CreateAdd(LoweredLhs, LoweredRhs);
 
-  IRBuilder<> Builder(matLdStInst);
+  case HLBinaryOpcode::Sub:
+    return IsFloat
+      ? Builder.CreateFSub(LoweredLhs, LoweredRhs)
+      : Builder.CreateSub(LoweredLhs, LoweredRhs);
 
-  HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
-  switch (opcode) {
-  case HLMatLoadStoreOpcode::ColMatLoad:
-  case HLMatLoadStoreOpcode::RowMatLoad: {
-    Value *Result = UndefValue::get(vecType);
-    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
-      Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
-      Result = Builder.CreateInsertElement(Result, Elt, matIdx);
-    }
-    matLdStInst->replaceAllUsesWith(Result);
-  } break;
-  case HLMatLoadStoreOpcode::ColMatStore:
-  case HLMatLoadStoreOpcode::RowMatStore: {
-    Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
-      Value *Elt = Builder.CreateExtractElement(Val, matIdx);
-      Builder.CreateStore(Elt, vecGlobals[matIdx]);
-    }
-  } break;
-  }
-}
+  case HLBinaryOpcode::Mul:
+    return IsFloat
+      ? Builder.CreateFMul(LoweredLhs, LoweredRhs)
+      : Builder.CreateMul(LoweredLhs, LoweredRhs);
 
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
-                                                      GlobalVariable *scalarArrayGlobal,
-                                                      CallInst *matLdStInst) {
-  // vecGlobals already in correct major.
-  HLMatLoadStoreOpcode opcode =
-      static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
-  switch (opcode) {
-  case HLMatLoadStoreOpcode::ColMatLoad:
-  case HLMatLoadStoreOpcode::RowMatLoad: {
-    IRBuilder<> Builder(matLdStInst);
-    Type *matTy = matGlobal->getType()->getPointerElementType();
-    unsigned col, row;
-    Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-    Value *zeroIdx = Builder.getInt32(0);
-
-    std::vector<Value *> matElts(col * row);
-
-    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
-      Value *GEP = Builder.CreateInBoundsGEP(
-          scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
-      matElts[matIdx] = Builder.CreateLoad(GEP);
-    }
+  case HLBinaryOpcode::Div:
+    return IsFloat
+      ? Builder.CreateFDiv(LoweredLhs, LoweredRhs)
+      : Builder.CreateSDiv(LoweredLhs, LoweredRhs);
 
-    Value *newVec =
-        HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
-    matLdStInst->replaceAllUsesWith(newVec);
-    matLdStInst->eraseFromParent();
-  } break;
-  case HLMatLoadStoreOpcode::ColMatStore:
-  case HLMatLoadStoreOpcode::RowMatStore: {
-    Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-
-    IRBuilder<> Builder(matLdStInst);
-    Type *matTy = matGlobal->getType()->getPointerElementType();
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(matTy, col, row);
-    Value *zeroIdx = Builder.getInt32(0);
-
-    std::vector<Value *> matElts(col * row);
-
-    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
-      Value *GEP = Builder.CreateInBoundsGEP(
-          scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
-      Value *Elt = Builder.CreateExtractElement(Val, matIdx);
-      Builder.CreateStore(Elt, GEP);
-    }
+  case HLBinaryOpcode::Rem:
+    return IsFloat
+      ? Builder.CreateFRem(LoweredLhs, LoweredRhs)
+      : Builder.CreateSRem(LoweredLhs, LoweredRhs);
 
-    matLdStInst->eraseFromParent();
-  } break;
-  }
-}
-void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
-    CallInst *matSubInst, Value *vecPtr) {
-  Value *basePtr =
-      matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-  Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
-  IRBuilder<> subBuilder(matSubInst);
-  Value *zeroIdx = subBuilder.getInt32(0);
-
-  HLSubscriptOpcode opcode =
-      static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
-
-  Type *matTy = basePtr->getType()->getPointerElementType();
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matTy, col, row);
-
-  std::vector<Value *> idxList;
-  switch (opcode) {
-  case HLSubscriptOpcode::ColMatSubscript:
-  case HLSubscriptOpcode::RowMatSubscript: {
-    // Just use index created in EmitHLSLMatrixSubscript.
-    for (unsigned c = 0; c < col; c++) {
-      Value *matIdx =
-          matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
-      idxList.emplace_back(matIdx);
-    }
-  } break;
-  case HLSubscriptOpcode::RowMatElement:
-  case HLSubscriptOpcode::ColMatElement: {
-    Type *resultType = matSubInst->getType()->getPointerElementType();
-    unsigned resultSize = 1;
-    if (resultType->isVectorTy())
-      resultSize = resultType->getVectorNumElements();
-    // Just use index created in EmitHLSLMatrixElement.
-    Constant *EltIdxs = cast<Constant>(idx);
-    for (unsigned i = 0; i < resultSize; i++) {
-      Value *matIdx = EltIdxs->getAggregateElement(i);
-      idxList.emplace_back(matIdx);
-    }
-  } break;
-  default:
-    DXASSERT(0, "invalid operation");
-    break;
-  }
+  case HLBinaryOpcode::And:
+    return Builder.CreateAnd(LoweredLhs, LoweredRhs);
 
-  // Cannot generate vector pointer
-  // Replace all uses with scalar pointers.
-  if (!matSubInst->getType()->getPointerElementType()->isVectorTy()) {
-    DXASSERT(idxList.size() == 1, "Expected a single matrix element index if the result is not a vector");
-    Value *Ptr =
-      subBuilder.CreateInBoundsGEP(vecPtr, { zeroIdx, idxList[0] });
-    matSubInst->replaceAllUsesWith(Ptr);
-  } else {
-    // Split the use of CI with Ptrs.
-    for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
-      Instruction *subsUser = cast<Instruction>(*(U++));
-      IRBuilder<> userBuilder(subsUser);
-      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-        Value *IndexPtr =
-            HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
-        Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
-                                                   {zeroIdx, IndexPtr});
-        for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
-          Instruction *gepUser = cast<Instruction>(*(gepU++));
-          IRBuilder<> gepUserBuilder(gepUser);
-          if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
-            Value *subData = stUser->getValueOperand();
-            gepUserBuilder.CreateStore(subData, Ptr);
-            stUser->eraseFromParent();
-          } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
-            Value *subData = gepUserBuilder.CreateLoad(Ptr);
-            ldUser->replaceAllUsesWith(subData);
-            ldUser->eraseFromParent();
-          } else {
-            AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
-            Cast->setOperand(0, Ptr);
-          }
-        }
-        GEP->eraseFromParent();
-      } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
-        Value *val = stUser->getValueOperand();
-        for (unsigned i = 0; i < idxList.size(); i++) {
-          Value *Elt = userBuilder.CreateExtractElement(val, i);
-          Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
-                                                     {zeroIdx, idxList[i]});
-          userBuilder.CreateStore(Elt, Ptr);
-        }
-        stUser->eraseFromParent();
-      } else {
-
-        Value *ldVal =
-            UndefValue::get(matSubInst->getType()->getPointerElementType());
-        for (unsigned i = 0; i < idxList.size(); i++) {
-          Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
-                                                     {zeroIdx, idxList[i]});
-          Value *Elt = userBuilder.CreateLoad(Ptr);
-          ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
-        }
-        // Must be load here.
-        LoadInst *ldUser = cast<LoadInst>(subsUser);
-        ldUser->replaceAllUsesWith(ldVal);
-        ldUser->eraseFromParent();
-      }
-    }
-  }
-  matSubInst->eraseFromParent();
-}
+  case HLBinaryOpcode::Or:
+    return Builder.CreateOr(LoweredLhs, LoweredRhs);
 
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
-    CallInst *matLdStInst, Value *vecPtr) {
-  // Just translate into vector here.
-  // DynamicIndexingVectorToArray will change it to scalar array.
-  IRBuilder<> Builder(matLdStInst);
-  unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
-  HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
-  switch (matLdStOp) {
-  case HLMatLoadStoreOpcode::ColMatLoad:
-  case HLMatLoadStoreOpcode::RowMatLoad: {
-    // Load as vector.
-    Value *newLoad = Builder.CreateLoad(vecPtr);
+  case HLBinaryOpcode::Xor:
+    return Builder.CreateXor(LoweredLhs, LoweredRhs);
 
-    matLdStInst->replaceAllUsesWith(newLoad);
-    matLdStInst->eraseFromParent();
-  } break;
-  case HLMatLoadStoreOpcode::ColMatStore:
-  case HLMatLoadStoreOpcode::RowMatStore: {
-    // Change value to vector array, then store.
-    Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-
-    Value *vecArrayGep = vecPtr;
-    Builder.CreateStore(Val, vecArrayGep);
-    matLdStInst->eraseFromParent();
-  } break;
-  default:
-    DXASSERT(0, "invalid operation");
-    break;
-  }
-}
+  case HLBinaryOpcode::Shl:
+    return Builder.CreateShl(LoweredLhs, LoweredRhs);
 
-// Flatten values inside init list to scalar elements.
-static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
-                            Value *val,
-                            DenseMap<Value *, Value *> &matToVecMap,
-                            IRBuilder<> &Builder) {
-  Type *valTy = val->getType();
-
-  if (valTy->isPointerTy()) {
-    if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
-      if (matToVecMap.count(cast<Instruction>(val))) {
-        val = matToVecMap[cast<Instruction>(val)];
-      } else {
-        // Convert to vec array with bitcast.
-        Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
-        val = Builder.CreateBitCast(val, vecArrayPtrTy);
-      }
-    }
-    Type *valEltTy = val->getType()->getPointerElementType();
-    if (valEltTy->isVectorTy() || dxilutil::IsHLSLMatrixType(valEltTy) ||
-        valEltTy->isSingleValueType()) {
-      Value *ldVal = Builder.CreateLoad(val);
-      IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
-    } else {
-      Type *i32Ty = Type::getInt32Ty(valTy->getContext());
-      Value *zero = ConstantInt::get(i32Ty, 0);
-      if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
-        for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
-          Value *gepIdx = ConstantInt::get(i32Ty, i);
-          Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
-          IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
-        }
-      } else {
-        // Struct.
-        StructType *ST = cast<StructType>(valEltTy);
-        for (unsigned i = 0; i < ST->getNumElements(); i++) {
-          Value *gepIdx = ConstantInt::get(i32Ty, i);
-          Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
-          IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
-        }
-      }
-    }
-  } else if (dxilutil::IsHLSLMatrixType(valTy)) {
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(valTy, col, row);
-    unsigned matSize = col * row;
-    val = matToVecMap[cast<Instruction>(val)];
-    // temp matrix all row major
-    for (unsigned i = 0; i < matSize; i++) {
-      Value *Elt = Builder.CreateExtractElement(val, i);
-      elts[idx + i] = Elt;
-    }
-    idx += matSize;
-  } else {
-    if (valTy->isVectorTy()) {
-      unsigned vecSize = valTy->getVectorNumElements();
-      for (unsigned i = 0; i < vecSize; i++) {
-        Value *Elt = Builder.CreateExtractElement(val, i);
-        elts[idx + i] = Elt;
-      }
-      idx += vecSize;
-    } else {
-      DXASSERT(valTy->isSingleValueType(), "must be single value type here");
-      elts[idx++] = val;
-    }
-  }
-}
+  case HLBinaryOpcode::Shr:
+    return Builder.CreateAShr(LoweredLhs, LoweredRhs);
 
-void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
-  // Array matrix init will be translated in TranslateMatArrayInitReplace.
-  if (matInitInst->getType()->isVoidTy())
-    return;
+  case HLBinaryOpcode::LT:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OLT, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_SLT, LoweredLhs, LoweredRhs);
 
-  DXASSERT(matToVecMap.count(matInitInst), "must have vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
-  IRBuilder<> Builder(vecUseInst);
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
-
-  Type *vecTy = VectorType::get(EltTy, col * row);
-  unsigned vecSize = vecTy->getVectorNumElements();
-  unsigned idx = 0;
-  std::vector<Value *> elts(vecSize);
-  // Skip opcode arg.
-  for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
-    Value *val = matInitInst->getArgOperand(i);
-
-    IterateInitList(elts, idx, val, matToVecMap, Builder);
-  }
+  case HLBinaryOpcode::GT:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OGT, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_SGT, LoweredLhs, LoweredRhs);
 
-  Value *newInit = UndefValue::get(vecTy);
-  // InitList is row major, the result is row major too.
-  for (unsigned i=0;i< col * row;i++) {
-      Constant *vecIdx = Builder.getInt32(i);
-      newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
-      Builder.Insert(cast<Instruction>(newInit));
-  }
+  case HLBinaryOpcode::LE:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OLE, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_SLE, LoweredLhs, LoweredRhs);
 
-  // Replace matInit function call with matInitInst.
-  vecUseInst->replaceAllUsesWith(newInit);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[matInitInst] = newInit;
-}
+  case HLBinaryOpcode::GE:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OGE, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_SGE, LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::EQ:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_OEQ, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredLhs, LoweredRhs);
 
-void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
-  unsigned col, row;
-  Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
+  case HLBinaryOpcode::NE:
+    return IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, LoweredRhs)
+      : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::UDiv:
+    return Builder.CreateUDiv(LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::URem:
+    return Builder.CreateURem(LoweredLhs, LoweredRhs);
 
-  Type *vecTy = VectorType::get(EltTy, col * row);
-  unsigned vecSize = vecTy->getVectorNumElements();
+  case HLBinaryOpcode::UShr:
+    return Builder.CreateLShr(LoweredLhs, LoweredRhs);
 
-  CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
-  Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
-  Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
+  case HLBinaryOpcode::ULT:
+    return Builder.CreateICmp(CmpInst::ICMP_ULT, LoweredLhs, LoweredRhs);
 
-  IRBuilder<> Builder(vecUseInst);
+  case HLBinaryOpcode::UGT:
+    return Builder.CreateICmp(CmpInst::ICMP_UGT, LoweredLhs, LoweredRhs);
 
-  Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
-  bool isVecCond = Cond->getType()->isVectorTy();
-  if (isVecCond) {
-    Instruction *MatCond = cast<Instruction>(
-        matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
-    DXASSERT_NOMSG(matToVecMap.count(MatCond));
-    Cond = matToVecMap[MatCond];
+  case HLBinaryOpcode::ULE:
+    return Builder.CreateICmp(CmpInst::ICMP_ULE, LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::UGE:
+    return Builder.CreateICmp(CmpInst::ICMP_UGE, LoweredLhs, LoweredRhs);
+
+  case HLBinaryOpcode::LAnd:
+  case HLBinaryOpcode::LOr: {
+    Value *Zero = Constant::getNullValue(LoweredLhs->getType());
+    Value *LhsCmp = IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, Zero)
+      : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, Zero);
+    Value *RhsCmp = IsFloat
+      ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredRhs, Zero)
+      : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredRhs, Zero);
+    return Opcode == HLBinaryOpcode::LOr
+      ? Builder.CreateOr(LhsCmp, RhsCmp)
+      : Builder.CreateAnd(LhsCmp, RhsCmp);
   }
-  DXASSERT_NOMSG(matToVecMap.count(LHS));
-  Value *VLHS = matToVecMap[LHS];
-  DXASSERT_NOMSG(matToVecMap.count(RHS));
-  Value *VRHS = matToVecMap[RHS];
-
-  Value *VecSelect = UndefValue::get(vecTy);
-  for (unsigned i = 0; i < vecSize; i++) {
-    llvm::Value *EltCond = Cond;
-    if (isVecCond)
-      EltCond = Builder.CreateExtractElement(Cond, i);
-    llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
-    llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
-    llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
-    VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
+  default:
+    llvm_unreachable("Unsupported binary matrix operator");
   }
-  AddToDeadInsts(vecUseInst);
-  vecUseInst->replaceAllUsesWith(VecSelect);
-  matToVecMap[matSelectInst] = VecSelect;
 }
 
-void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
-                                             Value *vecVal,
-                                             GetElementPtrInst *matGEP) {
-  SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
-
-  IRBuilder<> GEPBuilder(matGEP);
-  Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecVal, idxList);
-  // Only used by mat subscript and mat ld/st.
-  for (Value::user_iterator user = matGEP->user_begin();
-       user != matGEP->user_end();) {
-    Instruction *useInst = cast<Instruction>(*(user++));
-    IRBuilder<> Builder(useInst);
-    // Skip return here.
-    if (isa<ReturnInst>(useInst))
-      continue;
-    if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
-      // Function call.
-      hlsl::HLOpcodeGroup group =
-          hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
-      switch (group) {
-      case HLOpcodeGroup::HLMatLoadStore: {
-        unsigned opcode = GetHLOpcode(useCall);
-        HLMatLoadStoreOpcode matOpcode =
-            static_cast<HLMatLoadStoreOpcode>(opcode);
-        switch (matOpcode) {
-        case HLMatLoadStoreOpcode::ColMatLoad:
-        case HLMatLoadStoreOpcode::RowMatLoad: {
-          // Skip the vector version.
-          if (useCall->getType()->isVectorTy())
-            continue;
-          Type *matTy = useCall->getType();
-          Value *newLd = CreateVecMatrixLoad(newGEP, matTy, Builder);
-          DXASSERT(matToVecMap.count(useCall), "must have vec version");
-          Value *oldLd = matToVecMap[useCall];
-          // Delete the oldLd.
-          AddToDeadInsts(cast<Instruction>(oldLd));
-          oldLd->replaceAllUsesWith(newLd);
-          matToVecMap[useCall] = newLd;
-        } break;
-        case HLMatLoadStoreOpcode::ColMatStore:
-        case HLMatLoadStoreOpcode::RowMatStore: {
-          Value *vecPtr = newGEP;
-          
-          Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-          // Skip the vector version.
-          if (matVal->getType()->isVectorTy()) {
-            AddToDeadInsts(useCall);
-            continue;
-          }
+Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode) {
+  IRBuilder<> Builder(Call);
+  switch (Opcode) {
+  case HLMatLoadStoreOpcode::RowMatLoad:
+  case HLMatLoadStoreOpcode::ColMatLoad:
+    return lowerHLLoad(Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
+      /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
 
-          Instruction *matInst = cast<Instruction>(matVal);
+  case HLMatLoadStoreOpcode::RowMatStore:
+  case HLMatLoadStoreOpcode::ColMatStore:
+    return lowerHLStore(
+      Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
+      Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
+      /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
+      /* Return */ !Call->getType()->isVoidTy(), Builder);
 
-          DXASSERT(matToVecMap.count(matInst), "must have vec version");
-          Value *vecVal = matToVecMap[matInst];
-          CreateVecMatrixStore(vecVal, vecPtr, matVal->getType(), Builder);
-        } break;
-        }
-      } break;
-      case HLOpcodeGroup::HLSubscript: {
-        TranslateMatSubscript(matGEP, newGEP, useCall);
-      } break;
-      default:
-        DXASSERT(0, "invalid operation");
-        break;
-      }
-    } else if (dyn_cast<BitCastInst>(useInst)) {
-      // Just replace the src with vec version.
-      useInst->setOperand(0, newGEP);
-    } else {
-      // Must be GEP.
-      GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
-      TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
-    }
+  default:
+    llvm_unreachable("Unsupported matrix load/store operation");
   }
-  AddToDeadInsts(matGEP);
 }
 
-Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
-  Value *newMatVal = nullptr;
-  if (vecToMatMap.count(vecVal)) {
-    newMatVal = vecToMatMap[vecVal];
-  } else {
-    // create conversion instructions if necessary, caching result for subsequent replacements.
-    // do so right after the vecVal def so it's available to all potential uses.
-    newMatVal = BitCastValueOrPtr(vecVal,
-      cast<Instruction>(vecVal)->getNextNode(), // vecVal must be instruction
-      matTy,
-      /*bOrigAllocaTy*/true,
-      vecVal->getName());
-    vecToMatMap[vecVal] = newMatVal;
-  }
-  return newMatVal;
-}
+Value *HLMatrixLowerPass::lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
+  HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
 
-void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
-                                          Value *vecVal) {
-  for (Value::user_iterator user = matVal->user_begin();
-       user != matVal->user_end();) {
-    Instruction *useInst = cast<Instruction>(*(user++));
-    // User must be function call.
-    if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
-      hlsl::HLOpcodeGroup group =
-          hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
-      switch (group) {
-      case HLOpcodeGroup::HLIntrinsic: {
-        if (CallInst *matCI = dyn_cast<CallInst>(matVal)) {
-          MatIntrinsicReplace(matCI, vecVal, useCall);
-        } else {
-          IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(useCall));
-          if (opcode == IntrinsicOp::MOP_Append) {
-            // Replace matrix with vector representation and update intrinsic signature
-            // We don't care about matrix orientation here, since that will need to be
-            // taken into account anyways when generating the store output calls.
-            SmallVector<Value *, 4> flatArgs;
-            SmallVector<Type *, 4> flatParamTys;
-            for (Value *arg : useCall->arg_operands()) {
-              Value *flagArg = arg == matVal ? vecVal : arg;
-              flatArgs.emplace_back(arg == matVal ? vecVal : arg);
-              flatParamTys.emplace_back(flagArg->getType());
-            }
-
-            // Don't need flat return type for Append.
-            FunctionType *flatFuncTy =
-              FunctionType::get(useInst->getType(), flatParamTys, false);
-            Function *flatF = GetOrCreateHLFunction(*m_pModule, flatFuncTy, group, static_cast<unsigned int>(opcode));
-            
-            // Append returns void, so the old call should have no users
-            DXASSERT(useInst->getType()->isVoidTy(), "Unexpected MOP_Append intrinsic return type");
-            DXASSERT(useInst->use_empty(), "Unexpected users of MOP_Append intrinsic return value");
-            IRBuilder<> Builder(useCall);
-            Builder.CreateCall(flatF, flatArgs);
-            AddToDeadInsts(useCall);
-          }
-          else if (opcode == IntrinsicOp::IOP_frexp) {
-            // NOTE: because out param use copy out semantic, so the operand of
-            // out must be temp alloca.
-            DXASSERT(isa<AllocaInst>(matVal), "else invalid mat ptr for frexp");
-            auto it = matToVecMap.find(useCall);
-            DXASSERT(it != matToVecMap.end(),
-              "else fail to create vec version of useCall");
-            CallInst *vecUseInst = cast<CallInst>(it->second);
-
-            for (unsigned i = 0; i < vecUseInst->getNumArgOperands(); i++) {
-              if (useCall->getArgOperand(i) == matVal) {
-                vecUseInst->setArgOperand(i, vecVal);
-              }
-            }
-          } else {
-            DXASSERT(false, "Unexpected matrix user intrinsic.");
-          }
-        }
-      } break;
-      case HLOpcodeGroup::HLSelect: {
-        MatIntrinsicReplace(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLBinOp: {
-        TrivialMatBinOpReplace(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLUnOp: {
-        TrivialMatUnOpReplace(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLCast: {
-        TranslateMatCast(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLMatLoadStore: {
-        DXASSERT(matToVecMap.count(useCall), "must have vec version");
-        Value *vecUser = matToVecMap[useCall];
-        if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal) ||
-            GetIfMatrixGEPOfUDTArg(matVal, *m_pHLModule)) {
-          // Load Already translated in lowerToVec.
-          // Store val operand will be set by the val use.
-          // Do nothing here.
-        } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser)) {
-          DXASSERT(vecVal->getType() == stInst->getValueOperand()->getType(),
-            "Mismatched vector matrix store value types.");
-          stInst->setOperand(0, vecVal);
-        } else if (ZExtInst *zextInst = dyn_cast<ZExtInst>(vecUser)) {
-          // This happens when storing bool matrices,
-          // which must first undergo conversion from i1's to i32's.
-          DXASSERT(vecVal->getType() == zextInst->getOperand(0)->getType(),
-            "Mismatched vector matrix store value types.");
-          zextInst->setOperand(0, vecVal);
-        } else
-          TrivialMatReplace(matVal, vecVal, useCall);
-
-      } break;
-      case HLOpcodeGroup::HLSubscript: {
-        if (AllocaInst *AI = dyn_cast<AllocaInst>(vecVal))
-          TranslateMatSubscript(matVal, vecVal, useCall);
-        else if (BitCastInst *BCI = dyn_cast<BitCastInst>(vecVal))
-          TranslateMatSubscript(matVal, vecVal, useCall);
-        else
-          TrivialMatReplace(matVal, vecVal, useCall);
-
-      } break;
-      case HLOpcodeGroup::HLInit: {
-        DXASSERT(!isa<AllocaInst>(matVal), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
-        TranslateMatInit(useCall);
-      } break;
-      case HLOpcodeGroup::NotHL: {
-        castMatrixArgs(matVal, vecVal, useCall);
-      } break;
-      case HLOpcodeGroup::HLExtIntrinsic:
-      case HLOpcodeGroup::HLCreateHandle:
-      case HLOpcodeGroup::NumOfHLOps:
-      // No vector equivalents for these ops.
-        break;
-      }
-    } else if (dyn_cast<BitCastInst>(useInst)) {
-      // Just replace the src with vec version.
-      if (useInst != vecVal)
-        useInst->setOperand(0, vecVal);
-    } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
-      Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
-      RI->setOperand(0, newMatVal);
-    } else if (isa<StoreInst>(useInst)) {
-      DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
-    } else {
-      // Must be GEP on mat array alloca.
-      GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
-      AllocaInst *AI = cast<AllocaInst>(matVal);
-      TranslateMatArrayGEP(AI, vecVal, GEP);
-    }
+  Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
+  if (LoweredPtr == nullptr) {
+    // Can't lower this here, defer to HL signature lower
+    HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
+    return callHLFunction(
+      *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
+      MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr }, Builder);
   }
-}
 
-void HLMatrixLowerPass::castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI) {
-  // translate user function parameters as necessary
-  Type *Ty = matVal->getType();
-  if (Ty->isPointerTy()) {
-    IRBuilder<> Builder(CI);
-    Value *newMatVal = Builder.CreateBitCast(vecVal, Ty);
-    CI->replaceUsesOfWith(matVal, newMatVal);
-  } else {
-    Value *newMatVal = GetMatrixForVec(vecVal, Ty);
-    CI->replaceUsesOfWith(matVal, newMatVal);
-  }
+  return MatTy.emitLoweredLoad(LoweredPtr, Builder);
 }
 
-void HLMatrixLowerPass::finalMatTranslation(Value *matVal) {
-  // Translate matInit.
-  if (CallInst *CI = dyn_cast<CallInst>(matVal)) {
-    hlsl::HLOpcodeGroup group =
-        hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
-    switch (group) {
-    case HLOpcodeGroup::HLInit: {
-      TranslateMatInit(CI);
-    } break;
-    case HLOpcodeGroup::HLSelect: {
-      TranslateMatSelect(CI);
-    } break;
-    default:
-      // Skip group already translated.
-      break;
-    }
-  }
-}
+Value *HLMatrixLowerPass::lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder) {
+  DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
+    "Matrix store value/pointer type mismatch.");
 
-void HLMatrixLowerPass::DeleteDeadInsts() {
-  // Delete the matrix version insts.
-  for (Instruction *deadInst : m_deadInsts) {
-    // Replace with undef and remove it.
-    deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
-    deadInst->eraseFromParent();
+  Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
+  Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
+  if (LoweredPtr == nullptr) {
+    // Can't lower the pointer here, defer to HL signature lower
+    HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatStore : HLMatLoadStoreOpcode::ColMatStore;
+    return callHLFunction(
+      *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
+      Return ? LoweredVal->getType() : Builder.getVoidTy(),
+      { Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal }, Builder);
   }
-  m_deadInsts.clear();
-  m_inDeadInstsSet.clear();
+
+  HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
+  StoreInst *LoweredStore = MatTy.emitLoweredStore(LoweredVal, LoweredPtr, Builder);
+
+  // If the intrinsic returned a value, return the stored lowered value
+  return Return ? LoweredVal : LoweredStore;
 }
 
-static bool OnlyUsedByMatrixLdSt(Value *V) {
-  bool onlyLdSt = true;
-  for (User *user : V->users()) {
-    if (isa<Constant>(user) && user->use_empty())
-      continue;
+static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> Builder) {
+  DXASSERT(SrcVal->getType()->isVectorTy() == DstTy->isVectorTy(),
+    "Scalar/vector type mismatch in numerical conversion.");
+  Type *SrcTy = SrcVal->getType();
 
-    CallInst *CI = cast<CallInst>(user);
-    if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
-        HLOpcodeGroup::HLMatLoadStore)
-      continue;
+  // Conversions between equivalent types are no-ops,
+  // even between signed/unsigned variants.
+  if (SrcTy == DstTy) return SrcVal;
 
-    onlyLdSt = false;
-    break;
+  // Conversions to bools are comparisons
+  if (DstTy->getScalarSizeInBits() == 1) {
+    // fcmp une is what regular clang uses in C++ for (bool)f;
+    return cast<Instruction>(SrcTy->isIntOrIntVectorTy()
+      ? Builder.CreateICmpNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
+      : Builder.CreateFCmpUNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool"));
   }
-  return onlyLdSt;
-}
 
-static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
-  if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
-    std::vector<Constant *> Elts;
-    Type *EltResultTy = AT->getElementType();
-    for (unsigned i = 0; i < AT->getNumElements(); i++) {
-      Constant *Elt =
-          LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
-      Elts.emplace_back(Elt);
-    }
-    return ConstantArray::get(AT, Elts);
-  } else {
-    // Cast float[row][col] -> float< row * col>.
-    // Get float[row][col] from the struct.
-    Constant *rows = MA->getAggregateElement((unsigned)0);
-    ArrayType *RowAT = cast<ArrayType>(rows->getType());
-    std::vector<Constant *> Elts;
-    for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
-      Constant *row = rows->getAggregateElement(r);
-      VectorType *VT = cast<VectorType>(row->getType());
-      for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
-        Elts.emplace_back(row->getAggregateElement(c));
-      }
+  // Cast necessary
+  bool SrcIsUnsigned = Opcode == HLCastOpcode::FromUnsignedCast ||
+    Opcode == HLCastOpcode::UnsignedUnsignedCast;
+  bool DstIsUnsigned = Opcode == HLCastOpcode::ToUnsignedCast ||
+    Opcode == HLCastOpcode::UnsignedUnsignedCast;
+  auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
+    SrcTy, SrcIsUnsigned, DstTy, DstIsUnsigned));
+  return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
+}
+
+Value *HLMatrixLowerPass::lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder) {
+  // The opcode really doesn't mean much here, the types involved are what drive most of the casting.
+  DXASSERT(Opcode != HLCastOpcode::HandleToResCast, "Unexpected matrix cast opcode.");
+
+  if (dxilutil::IsIntegerOrFloatingPointType(Src->getType())) {
+    // Scalar to matrix splat
+    HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
+
+    // Apply element conversion
+    Value *Result = convertScalarOrVector(Src,
+      MatDstTy.getElementTypeForReg(), Opcode, Builder);
+
+    // Splat to a vector
+    Result = Builder.CreateInsertElement(
+      UndefValue::get(VectorType::get(Result->getType(), 1)),
+      Result, static_cast<uint64_t>(0));
+    return Builder.CreateShuffleVector(Result, Result,
+      ConstantVector::getSplat(MatDstTy.getNumElements(), Builder.getInt32(0)));
+  }
+  else if (VectorType *SrcVecTy = dyn_cast<VectorType>(Src->getType())) {
+    // Vector to matrix
+    HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
+    Value *Result = Src;
+
+    // We might need to truncate
+    if (MatDstTy.getNumElements() < SrcVecTy->getNumElements()) {
+      SmallVector<int, 4> ShuffleIndices;
+      for (unsigned Idx = 0; Idx < MatDstTy.getNumElements(); ++Idx)
+        ShuffleIndices.emplace_back(static_cast<int>(Idx));
+      Result = Builder.CreateShuffleVector(Src, Src, ShuffleIndices);
     }
-    return ConstantVector::get(Elts);
-  }
-}
 
-void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
-  // Lower to array of vector array like float[row * col].
-  // It's follow the major of decl.
-  // DynamicIndexingVectorToArray will change it to scalar array.
-  Type *Ty = GV->getType()->getPointerElementType();
-  std::vector<unsigned> arraySizeList;
-  while (Ty->isArrayTy()) {
-    arraySizeList.push_back(Ty->getArrayNumElements());
-    Ty = Ty->getArrayElementType();
-  }
-  unsigned row, col;
-  Type *EltTy = GetMatrixInfo(Ty, col, row);
-  Ty = VectorType::get(EltTy, col * row);
-
-  for (auto arraySize = arraySizeList.rbegin();
-       arraySize != arraySizeList.rend(); arraySize++)
-    Ty = ArrayType::get(Ty, *arraySize);
-
-  Type *VecArrayTy = Ty;
-  Constant *InitVal = nullptr;
-  if (GV->hasInitializer()) {
-    Constant *OldInitVal = GV->getInitializer();
-    InitVal = isa<UndefValue>(OldInitVal)
-      ? UndefValue::get(VecArrayTy)
-      : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
+    // Apply element conversion
+    return convertScalarOrVector(Result,
+      MatDstTy.getLoweredVectorTypeForReg(), Opcode, Builder);
   }
 
-  bool isConst = GV->isConstant();
-  GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
-  unsigned AddressSpace = GV->getType()->getAddressSpace();
-  GlobalValue::LinkageTypes linkage = GV->getLinkage();
+  // Source must now be a matrix
+  HLMatrixType MatSrcTy = HLMatrixType::cast(Src->getType());
+  VectorType* LoweredSrcTy = MatSrcTy.getLoweredVectorTypeForReg();
 
-  Module *M = GV->getParent();
-  GlobalVariable *VecGV =
-      new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
-                               /*InitVal*/ InitVal, GV->getName() + ".v",
-                               /*InsertBefore*/ nullptr, TLMode, AddressSpace);
-  // Add debug info.
-  if (m_HasDbgInfo) {
-    DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
-    HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
+  Value *LoweredSrc;
+  if (isa<Argument>(Src)) {
+    // Function arguments are lowered in HLSignatureLower.
+    // Initial codegen first generates those cast intrinsics to tell us how to lower them into vectors.
+    // Preserve them, but change the return type to vector.
+    DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast || Opcode == HLCastOpcode::RowMatrixToVecCast,
+      "Unexpected cast of matrix argument.");
+    LoweredSrc = callHLFunction(*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
+      LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src }, Builder);
   }
-
-  for (User *U : GV->users()) {
-    Value *VecGEP = nullptr;
-    // Must be GEP or GEPOperator.
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
-      IRBuilder<> Builder(GEP);
-      SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
-      VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
-      AddToDeadInsts(GEP);
-    } else {
-      GEPOperator *GEPOP = cast<GEPOperator>(U);
-      IRBuilder<> Builder(GV->getContext());
-      SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
-      VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
+  else {
+    LoweredSrc = getLoweredByValOperand(Src, Builder);
+  }
+  DXASSERT_NOMSG(LoweredSrc->getType() == LoweredSrcTy);
+
+  Value* Result = LoweredSrc;
+  Type* LoweredDstTy = DstTy;
+  if (dxilutil::IsIntegerOrFloatingPointType(DstTy)) {
+    // Matrix to scalar
+    Result = Builder.CreateExtractElement(LoweredSrc, static_cast<uint64_t>(0));
+  }
+  else if (DstTy->isVectorTy()) {
+    // Matrix to vector
+    VectorType *DstVecTy = cast<VectorType>(DstTy);
+    DXASSERT(DstVecTy->getNumElements() <= LoweredSrcTy->getNumElements(),
+      "Cannot cast matrix to a larger vector.");
+
+    // We might have to truncate
+    if (DstTy->getVectorNumElements() < LoweredSrcTy->getNumElements()) {
+      SmallVector<int, 3> ShuffleIndices;
+      for (unsigned Idx = 0; Idx < DstVecTy->getNumElements(); ++Idx)
+        ShuffleIndices.emplace_back(static_cast<int>(Idx));
+      Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
     }
-
-    for (auto user = U->user_begin(); user != U->user_end();) {
-      CallInst *CI = cast<CallInst>(*(user++));
-      HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-      if (group == HLOpcodeGroup::HLMatLoadStore) {
-        TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
-      } else if (group == HLOpcodeGroup::HLSubscript) {
-        TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
-      } else {
-        DXASSERT(0, "invalid operation");
-      }
+  }
+  else {
+    // Destination must now be a matrix too
+    HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
+
+    // Apply any changes at the matrix level: orientation changes and truncation
+    if (Opcode == HLCastOpcode::ColMatrixToRowMatrix)
+      Result = MatSrcTy.emitLoweredVectorColToRow(Result, Builder);
+    else if (Opcode == HLCastOpcode::RowMatrixToColMatrix)
+      Result = MatSrcTy.emitLoweredVectorRowToCol(Result, Builder);
+    else if (MatDstTy.getNumRows() != MatSrcTy.getNumRows()
+      || MatDstTy.getNumColumns() != MatSrcTy.getNumColumns()) {
+      // Apply truncation
+      DXASSERT(MatDstTy.getNumRows() <= MatSrcTy.getNumRows()
+        && MatDstTy.getNumColumns() <= MatSrcTy.getNumColumns(),
+        "Unexpected matrix cast between incompatible dimensions.");
+      SmallVector<int, 16> ShuffleIndices;
+      for (unsigned RowIdx = 0; RowIdx < MatDstTy.getNumRows(); ++RowIdx)
+        for (unsigned ColIdx = 0; ColIdx < MatDstTy.getNumColumns(); ++ColIdx)
+          ShuffleIndices.emplace_back(static_cast<int>(MatSrcTy.getRowMajorIndex(RowIdx, ColIdx)));
+      Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
     }
+
+    LoweredDstTy = MatDstTy.getLoweredVectorTypeForReg();
+    DXASSERT(Result->getType()->getVectorNumElements() == LoweredDstTy->getVectorNumElements(),
+      "Unexpected matrix src/dst lowered element count mismatch after truncation.");
   }
 
-  DeleteDeadInsts();
-  GV->removeDeadConstantUsers();
-  GV->eraseFromParent();
+  // Apply element conversion
+  return convertScalarOrVector(Result, LoweredDstTy, Opcode, Builder);
 }
 
-static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
-  unsigned row, col;
-  Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
-  if (isa<UndefValue>(M)) {
-    Constant *Elt = UndefValue::get(EltTy);
-    for (unsigned i=0;i<col*row;i++)
-      Elts.emplace_back(Elt);
-  } else {
-    M = M->getAggregateElement((unsigned)0);
-    // Initializer is already in correct major.
-    // Just read it here.
-    // The type is vector<element, col>[row].
-    for (unsigned r = 0; r < row; r++) {
-      Constant *C = M->getAggregateElement(r);
-      for (unsigned c = 0; c < col; c++) {
-        Elts.emplace_back(C->getAggregateElement(c));
-      }
-    }
+Value *HLMatrixLowerPass::lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
+  switch (Opcode) {
+  case HLSubscriptOpcode::RowMatElement:
+  case HLSubscriptOpcode::ColMatElement:
+    return lowerHLMatElementSubscript(Call,
+      /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatElement);
+
+  case HLSubscriptOpcode::RowMatSubscript:
+  case HLSubscriptOpcode::ColMatSubscript:
+    return lowerHLMatSubscript(Call,
+      /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatSubscript);
+
+  case HLSubscriptOpcode::DefaultSubscript:
+  case HLSubscriptOpcode::CBufferSubscript:
+    // Those get lowered during HLOperationLower,
+    // and the return type must stay unchanged (as a matrix)
+    // to provide the metadata to properly emit the loads.
+    return nullptr;
+
+  default:
+    llvm_unreachable("Unexpected matrix subscript opcode.");
   }
 }
 
-void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
-  if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
-    runOnGlobalMatrixArray(GV);
-    return;
-  }
+Value *HLMatrixLowerPass::lowerHLMatElementSubscript(CallInst *Call, bool RowMajor) {
+  (void)RowMajor; // It doesn't look like we actually need this?
 
-  Type *Ty = GV->getType()->getPointerElementType();
-  if (!dxilutil::IsHLSLMatrixType(Ty))
-    return;
+  Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
+  Constant *IdxVec = cast<Constant>(Call->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
+  VectorType *IdxVecTy = cast<VectorType>(IdxVec->getType());
 
-  bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
-
-  bool isConst = GV->isConstant();
-
-  Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
-  Module *M = GV->getParent();
-  const DataLayout &DL = M->getDataLayout();
-
-  std::vector<Constant *> Elts;
-  // Lower to vector or array for scalar matrix.
-  // Make it col major so don't need shuffle when load/store.
-  FlattenMatConst(GV->getInitializer(), Elts);
-
-  if (onlyLdSt) {
-    Type *EltTy = vecTy->getVectorElementType();
-    unsigned vecSize = vecTy->getVectorNumElements();
-    std::vector<Value *> vecGlobals(vecSize);
-
-    GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
-    unsigned AddressSpace = GV->getType()->getAddressSpace();
-    GlobalValue::LinkageTypes linkage = GV->getLinkage();
-    unsigned debugOffset = 0;
-    unsigned size = DL.getTypeAllocSizeInBits(EltTy);
-    unsigned align = DL.getPrefTypeAlignment(EltTy);
-    for (int i = 0, e = vecSize; i != e; ++i) {
-      Constant *InitVal = Elts[i];
-      GlobalVariable *EltGV = new llvm::GlobalVariable(
-          *M, EltTy, /*IsConstant*/ isConst, linkage,
-          /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
-          /*InsertBefore*/nullptr,
-          TLMode, AddressSpace);
-      // Add debug info.
-      if (m_HasDbgInfo) {
-        DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
-        HLModule::CreateElementGlobalVariableDebugInfo(
-            GV, Finder, EltGV, size, align, debugOffset,
-            EltGV->getName().ltrim(GV->getName()));
-        debugOffset += size;
-      }
-      vecGlobals[i] = EltGV;
-    }
-    for (User *user : GV->users()) {
-      if (isa<Constant>(user) && user->use_empty())
-        continue;
-      CallInst *CI = cast<CallInst>(user);
-      TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
-      AddToDeadInsts(CI);
-    }
-    DeleteDeadInsts();
-    GV->eraseFromParent();
+  // Get the loaded lowered vector element indices
+  SmallVector<Value*, 4> ElemIndices;
+  ElemIndices.reserve(IdxVecTy->getNumElements());
+  for (unsigned VecIdx = 0; VecIdx < IdxVecTy->getNumElements(); ++VecIdx) {
+    ElemIndices.emplace_back(IdxVec->getAggregateElement(VecIdx));
   }
-  else {
-    // lower to array of scalar here.
-    ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
-    Constant *InitVal = ConstantArray::get(AT, Elts);
-    GlobalVariable *arrayMat = new llvm::GlobalVariable(
-      *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
-      /*InitVal*/ InitVal, GV->getName());
-    // Add debug info.
-    if (m_HasDbgInfo) {
-      DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
-      HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
-                                                     arrayMat);
-    }
 
-    for (auto U = GV->user_begin(); U != GV->user_end();) {
-      Value *user = *(U++);
-      CallInst *CI = cast<CallInst>(user);
-      HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-      if (group == HLOpcodeGroup::HLMatLoadStore) {
-        TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
-      }
-      else {
-        DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
-        TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
-      }
-    }
-    GV->removeDeadConstantUsers();
-    GV->eraseFromParent();
-  }
+  lowerHLMatSubscript(Call, MatPtr, ElemIndices);
+
+  // We did our own replacement of uses, opt-out of having the caller does it for us.
+  return nullptr;
 }
 
-void HLMatrixLowerPass::runOnFunction(Function &F) {
-  // Skip hl function definition (like createhandle)
-  if (hlsl::GetHLOpcodeGroupByName(&F) != HLOpcodeGroup::NotHL)
-    return;
+Value *HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, bool RowMajor) {
+  (void)RowMajor; // It doesn't look like we actually need this?
 
-  // Create vector version of matrix instructions first.
-  // The matrix operands will be undefval for these instructions.
-  for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
-    BasicBlock *BB = BBI;
-    for (auto II = BB->begin(); II != BB->end(); ) {
-      Instruction &I = *(II++);
-      if (dxilutil::IsHLSLMatrixType(I.getType())) {
-        lowerToVec(&I);
-      } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
-        Type *Ty = AI->getAllocatedType();
-        if (dxilutil::IsHLSLMatrixType(Ty)) {
-          lowerToVec(&I);
-        } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
-          lowerToVec(&I);
-        }
-      } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
-        HLOpcodeGroup group =
-            hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
-        if (group == HLOpcodeGroup::HLMatLoadStore) {
-          HLMatLoadStoreOpcode opcode =
-              static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
-          DXASSERT_LOCALVAR(opcode,
-                            opcode == HLMatLoadStoreOpcode::ColMatStore ||
-                            opcode == HLMatLoadStoreOpcode::RowMatStore,
-                            "Must MatStore here, load will go IsMatrixType path");
-          // Lower it here to make sure it is ready before replace.
-          lowerToVec(&I);
-        }
-      } else if (GetIfMatrixGEPOfUDTAlloca(&I) ||
-                 GetIfMatrixGEPOfUDTArg(&I, *m_pHLModule)) {
-        lowerToVec(&I);
-      }
-    }
-  }
+  Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
 
-  // Update the use of matrix inst with the vector version.
-  for (auto matToVecIter = matToVecMap.begin();
-       matToVecIter != matToVecMap.end();) {
-    auto matToVec = matToVecIter++;
-    replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
+  // Gather the indices, checking if they are all constant
+  SmallVector<Value*, 4> ElemIndices;
+  for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < Call->getNumArgOperands(); ++Idx) {
+    ElemIndices.emplace_back(Call->getArgOperand(Idx));
   }
 
-  // Translate mat inst which require all operands ready.
-  for (auto matToVecIter = matToVecMap.begin();
-       matToVecIter != matToVecMap.end();) {
-    auto matToVec = matToVecIter++;
-    if (isa<Instruction>(matToVec->first))
-      finalMatTranslation(matToVec->first);
-  }
+  lowerHLMatSubscript(Call, MatPtr, ElemIndices);
 
-  // Remove matrix targets of vecToMatMap from matToVecMap before adding the rest to dead insts.
-  for (auto &it : vecToMatMap) {
-    matToVecMap.erase(it.second);
-  }
+  // We did our own replacement of uses, opt-out of having the caller does it for us.
+  return nullptr;
+}
 
-  // Delete the matrix version insts.
-  for (auto matToVecIter = matToVecMap.begin();
-       matToVecIter != matToVecMap.end();) {
-    auto matToVec = matToVecIter++;
-    // Add to m_deadInsts.
-    if (Instruction *matInst = dyn_cast<Instruction>(matToVec->first))
-      AddToDeadInsts(matInst);
-  }
+void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices) {
+  DXASSERT_NOMSG(HLMatrixType::isMatrixPtr(MatPtr->getType()));
+
+  IRBuilder<> CallBuilder(Call);
+  Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
+  if (LoweredPtr == nullptr) return;
 
-  DeleteDeadInsts();
+  // For global variables, we can GEP directly into the lowered vector pointer.
+  // This is necessary to support group shared memory atomics and the likes.
+  Value *RootPtr = LoweredPtr;
+  while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
+    RootPtr = GEP->getPointerOperand();
+  bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
   
-  matToVecMap.clear();
-  vecToMatMap.clear();
+  // Just constructing this does all the work
+  HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
 
-  return;
+  DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
+  addToDeadInsts(Call);
 }
 
-// Matrix Bitcast lower.
-// After linking Lower matrix bitcast patterns like:
-//  %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
-//  %conv.i = fptoui float %164 to i32
-//  %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
-//  %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
+// Lowers StructuredBuffer<matrix>[index] or similar with constant buffers
+Value *HLMatrixLowerPass::lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
+  // Just replace the intrinsic by its equivalent with a lowered return type
+  IRBuilder<> Builder(Call);
 
-namespace {
+  SmallVector<Value*, 4> Args;
+  Args.reserve(Call->getNumArgOperands());
+  for (Value *Arg : Call->arg_operands())
+    Args.emplace_back(Arg);
 
-Type *TryLowerMatTy(Type *Ty) {
-  Type *VecTy = nullptr;
-  if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
-    VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
-  } else if (isa<PointerType>(Ty) &&
-             dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
-    VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
-        Ty->getPointerElementType());
-    VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
-  }
-  return VecTy;
+  Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
+  return callHLFunction(*m_pModule, HLOpcodeGroup::HLSubscript, static_cast<unsigned>(Opcode),
+    LoweredRetTy, Args, Builder);
 }
 
-class MatrixBitcastLowerPass : public FunctionPass {
+Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
+  DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
 
-public:
-  static char ID; // Pass identification, replacement for typeid
-  explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
-
-  const char *getPassName() const override { return "Matrix Bitcast lower"; }
-  bool runOnFunction(Function &F) override {
-    bool bUpdated = false;
-    std::unordered_set<BitCastInst*> matCastSet;
-    for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
-      BasicBlock *BB = blkIt;
-      for (auto iIt = BB->begin(); iIt != BB->end(); ) {
-        Instruction *I = (iIt++);
-        if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
-          // Mutate mat to vec.
-          Type *ToTy = BCI->getType();
-          if (Type *ToVecTy = TryLowerMatTy(ToTy)) {
-            matCastSet.insert(BCI);
-            bUpdated = true;
-          }
-        }
-      }
-    }
+  // Figure out the result type
+  HLMatrixType MatTy = HLMatrixType::cast(Call->getType());
+  VectorType *LoweredTy = MatTy.getLoweredVectorTypeForReg();
+  DXASSERT(LoweredTy->getNumElements() == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx,
+    "Invalid matrix init argument count.");
 
-    DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
-    // Remove bitcast which has CallInst user.
-    if (DM.GetShaderModel()->IsLib()) {
-      for (auto it = matCastSet.begin(); it != matCastSet.end();) {
-        BitCastInst *BCI = *(it++);
-        if (hasCallUser(BCI)) {
-          matCastSet.erase(BCI);
-        }
-      }
-    }
-
-    // Lower matrix first.
-    for (BitCastInst *BCI : matCastSet) {
-      lowerMatrix(BCI, BCI->getOperand(0));
-    }
-    return bUpdated;
+  // Build the result vector from the init args.
+  // Both the args and the result vector are in row-major order, so no shuffling is necessary.
+  IRBuilder<> Builder(Call);
+  Value *LoweredVec = UndefValue::get(LoweredTy);
+  for (unsigned VecElemIdx = 0; VecElemIdx < LoweredTy->getNumElements(); ++VecElemIdx) {
+    Value *ArgVal = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx + VecElemIdx);
+    DXASSERT(dxilutil::IsIntegerOrFloatingPointType(ArgVal->getType()),
+      "Expected only scalars in matrix initialization.");
+    LoweredVec = Builder.CreateInsertElement(LoweredVec, ArgVal, static_cast<uint64_t>(VecElemIdx));
   }
-private:
-  void lowerMatrix(Instruction *M, Value *A);
-  bool hasCallUser(Instruction *M);
-};
 
+  return LoweredVec;
 }
 
-bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
-  for (auto it = M->user_begin(); it != M->user_end();) {
-    User *U = *(it++);
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
-      Type *EltTy = GEP->getType()->getPointerElementType();
-      if (dxilutil::IsHLSLMatrixType(EltTy)) {
-        if (hasCallUser(GEP))
-          return true;
-      } else {
-        DXASSERT(0, "invalid GEP for matrix");
-      }
-    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
-      if (hasCallUser(BCI))
-        return true;
-    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
-      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
-      } else {
-        DXASSERT(0, "invalid load for matrix");
-      }
-    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
-      Value *V = ST->getValueOperand();
-      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
-      } else {
-        DXASSERT(0, "invalid load for matrix");
-      }
-    } else if (isa<CallInst>(U)) {
-      return true;
-    } else {
-      DXASSERT(0, "invalid use of matrix");
-    }
-  }
-  return false;
-}
+Value *HLMatrixLowerPass::lowerHLSelect(CallInst *Call) {
+  DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
 
-namespace {
-Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
-                    IRBuilder<> &Builder) {
-  Value *GEP = nullptr;
-  if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
-    // A should be gep oneDimArray, 0, index * matSize
-    // Here add eltIdx to index * matSize foreach elt.
-    Instruction *EltGEP = GEPA->clone();
-    unsigned eltIdx = EltGEP->getNumOperands() - 1;
-    Value *NewIdx =
-        Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
-    EltGEP->setOperand(eltIdx, NewIdx);
-    Builder.Insert(EltGEP);
-    GEP = EltGEP;
-  } else {
-    GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
-  }
-  return GEP;
-}
-} // namespace
-
-void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
-  for (auto it = M->user_begin(); it != M->user_end();) {
-    User *U = *(it++);
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
-      Type *EltTy = GEP->getType()->getPointerElementType();
-      if (dxilutil::IsHLSLMatrixType(EltTy)) {
-        // Change gep matrixArray, 0, index
-        // into
-        //   gep oneDimArray, 0, index * matSize
-        IRBuilder<> Builder(GEP);
-        SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
-        DXASSERT(idxList.size() == 2,
-                 "else not one dim matrix array index to matrix");
-        unsigned col = 0;
-        unsigned row = 0;
-        HLMatrixLower::GetMatrixInfo(EltTy, col, row);
-        Value *matSize = Builder.getInt32(col * row);
-        idxList.back() = Builder.CreateMul(idxList.back(), matSize);
-        Value *NewGEP = Builder.CreateGEP(A, idxList);
-        lowerMatrix(GEP, NewGEP);
-        DXASSERT(GEP->user_empty(), "else lower matrix fail");
-        GEP->eraseFromParent();
-      } else {
-        DXASSERT(0, "invalid GEP for matrix");
-      }
-    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
-      lowerMatrix(BCI, A);
-      DXASSERT(BCI->user_empty(), "else lower matrix fail");
-      BCI->eraseFromParent();
-    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
-      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
-        IRBuilder<> Builder(LI);
-        Value *zeroIdx = Builder.getInt32(0);
-        unsigned vecSize = Ty->getNumElements();
-        Value *NewVec = UndefValue::get(LI->getType());
-        for (unsigned i = 0; i < vecSize; i++) {
-          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
-          Value *Elt = Builder.CreateLoad(GEP);
-          NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
-        }
-        LI->replaceAllUsesWith(NewVec);
-        LI->eraseFromParent();
-      } else {
-        DXASSERT(0, "invalid load for matrix");
-      }
-    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
-      Value *V = ST->getValueOperand();
-      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
-        IRBuilder<> Builder(LI);
-        Value *zeroIdx = Builder.getInt32(0);
-        unsigned vecSize = Ty->getNumElements();
-        for (unsigned i = 0; i < vecSize; i++) {
-          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
-          Value *Elt = Builder.CreateExtractElement(V, i);
-          Builder.CreateStore(Elt, GEP);
-        }
-        ST->eraseFromParent();
-      } else {
-        DXASSERT(0, "invalid load for matrix");
-      }
-    } else {
-      DXASSERT(0, "invalid use of matrix");
-    }
+  Value *Cond = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
+  Value *TrueMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx);
+  Value *FalseMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx);
+
+  DXASSERT(TrueMat->getType() == FalseMat->getType(),
+    "Unexpected type mismatch between matrix ternary operator values.");
+
+#ifndef NDEBUG
+  // Assert that if the condition is a matrix, it matches the dimensions of the values
+  if (HLMatrixType MatCondTy = HLMatrixType::dyn_cast(Cond->getType())) {
+    HLMatrixType ValMatTy = HLMatrixType::cast(TrueMat->getType());
+    DXASSERT(MatCondTy.getNumRows() == ValMatTy.getNumRows()
+      && MatCondTy.getNumColumns() == ValMatTy.getNumColumns(),
+      "Unexpected mismatch between ternary operator condition and value matrix dimensions.");
   }
-}
+#endif
 
-#include "dxc/HLSL/DxilGenerationPass.h"
-char MatrixBitcastLowerPass::ID = 0;
-FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
+  IRBuilder<> Builder(Call);
+  Value *LoweredCond = getLoweredByValOperand(Cond, Builder);
+  Value *LoweredTrueVec = getLoweredByValOperand(TrueMat, Builder);
+  Value *LoweredFalseVec = getLoweredByValOperand(FalseMat, Builder);
+  Value *Result = UndefValue::get(LoweredTrueVec->getType());
 
-INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)
+  bool IsScalarCond = !LoweredCond->getType()->isVectorTy();
+
+  unsigned NumElems = Result->getType()->getVectorNumElements();
+  for (uint64_t ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
+    Value *ElemCond = IsScalarCond ? LoweredCond
+      : Builder.CreateExtractElement(LoweredCond, ElemIdx);
+    Value *ElemTrueVal = Builder.CreateExtractElement(LoweredTrueVec, ElemIdx);
+    Value *ElemFalseVal = Builder.CreateExtractElement(LoweredFalseVec, ElemIdx);
+    Value *ResultElem = Builder.CreateSelect(ElemCond, ElemTrueVal, ElemFalseVal);
+    Result = Builder.CreateInsertElement(Result, ResultElem, ElemIdx);
+  }
+
+  return Result;
+}

+ 284 - 0
lib/HLSL/HLMatrixSubscriptUseReplacer.cpp

@@ -0,0 +1,284 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixSubscriptUseReplacer.cpp                                          //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "HLMatrixSubscriptUseReplacer.h"
+#include "dxc/DXIL/DxilUtil.h"
+#include "dxc/Support/Global.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+HLMatrixSubscriptUseReplacer::HLMatrixSubscriptUseReplacer(CallInst* Call, Value *LoweredPtr,
+  SmallVectorImpl<Value*> &ElemIndices, bool AllowLoweredPtrGEPs, std::vector<Instruction*> &DeadInsts)
+  : LoweredPtr(LoweredPtr), ElemIndices(ElemIndices), DeadInsts(DeadInsts), AllowLoweredPtrGEPs(AllowLoweredPtrGEPs)
+{
+  HasScalarResult = !Call->getType()->getPointerElementType()->isVectorTy();
+
+  for (Value *ElemIdx : ElemIndices) {
+    if (!isa<Constant>(ElemIdx)) {
+      HasDynamicElemIndex = true;
+      break;
+    }
+  }
+
+  replaceUses(Call, /* GEPIdx */ nullptr);
+}
+
+void HLMatrixSubscriptUseReplacer::replaceUses(Instruction* PtrInst, Value* SubIdxVal) {
+  // We handle any number of load/stores of the subscript,
+  // whether through a GEP or not, but there should really only be one.
+  while (!PtrInst->use_empty()) {
+    llvm::Use &Use = *PtrInst->use_begin();
+    Instruction *UserInst = cast<Instruction>(Use.getUser());
+
+    bool DeleteUserInst = true;
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UserInst)) {
+      // Recurse on GEPs
+      DXASSERT(GEP->getNumIndices() >= 1 && GEP->getNumIndices() <= 2,
+        "Unexpected GEP on constant matrix subscript.");
+      DXASSERT(cast<ConstantInt>(GEP->idx_begin()->get())->isZero(),
+        "Unexpected nonzero first index of constant matrix subscript GEP.");
+
+      Value *NewSubIdxVal = SubIdxVal;
+      if (GEP->getNumIndices() == 2) {
+        DXASSERT(!HasScalarResult && SubIdxVal == nullptr,
+          "Unexpected GEP on matrix subscript scalar value.");
+        NewSubIdxVal = (GEP->idx_begin() + 1)->get();
+      }
+
+      replaceUses(GEP, NewSubIdxVal);
+    }
+    else {
+      IRBuilder<> UserBuilder(UserInst);
+
+      if (Value *ScalarElemIdx = tryGetScalarIndex(SubIdxVal, UserBuilder)) {
+        // We are accessing a scalar element
+        if (AllowLoweredPtrGEPs) {
+          // Simply make the instruction point to the element in the lowered pointer
+          DeleteUserInst = false;
+          Value *ElemPtr = UserBuilder.CreateGEP(LoweredPtr, { UserBuilder.getInt32(0), ScalarElemIdx });
+          Use.set(ElemPtr);
+        }
+        else {
+          bool IsDynamicIndex = !isa<Constant>(ScalarElemIdx);
+          cacheLoweredMatrix(IsDynamicIndex, UserBuilder);
+          if (LoadInst *Load = dyn_cast<LoadInst>(UserInst)) {
+            Value *Elem = loadElem(ScalarElemIdx, UserBuilder);
+            Load->replaceAllUsesWith(Elem);
+          }
+          else if (StoreInst *Store = dyn_cast<StoreInst>(UserInst)) {
+            storeElem(ScalarElemIdx, Store->getValueOperand(), UserBuilder);
+            flushLoweredMatrix(UserBuilder);
+          }
+          else {
+            llvm_unreachable("Unexpected matrix subscript use.");
+          }
+        }
+      }
+      else {
+        // We are accessing a vector given by ElemIndices
+        cacheLoweredMatrix(HasDynamicElemIndex, UserBuilder);
+        if (LoadInst *Load = dyn_cast<LoadInst>(UserInst)) {
+          Value *Vector = loadVector(UserBuilder);
+          Load->replaceAllUsesWith(Vector);
+        }
+        else if (StoreInst *Store = dyn_cast<StoreInst>(UserInst)) {
+          storeVector(Store->getValueOperand(), UserBuilder);
+          flushLoweredMatrix(UserBuilder);
+        }
+        else {
+          llvm_unreachable("Unexpected matrix subscript use.");
+        }
+      }
+    }
+
+    // We replaced this use, mark it dead
+    if (DeleteUserInst) {
+      DXASSERT(UserInst->use_empty(), "Matrix subscript user should be dead at this point.");
+      Use.set(UndefValue::get(Use->getType()));
+      DeadInsts.emplace_back(UserInst);
+    }
+  }
+}
+
+Value *HLMatrixSubscriptUseReplacer::tryGetScalarIndex(Value *SubIdxVal, IRBuilder<> &Builder) {
+  if (SubIdxVal == nullptr) {
+    // mat[0] case, returns a vector
+    if (!HasScalarResult) return nullptr;
+
+    // mat._11 case
+    DXASSERT_NOMSG(ElemIndices.size() == 1);
+    return ElemIndices[0];
+  }
+
+  if (ConstantInt *SubIdxConst = dyn_cast<ConstantInt>(SubIdxVal)) {
+    // mat[0][0], mat[i][0] or mat._11_12[0] cases.
+    uint64_t SubIdx = SubIdxConst->getLimitedValue();
+    DXASSERT(SubIdx < ElemIndices.size(), "Unexpected out of range constant matrix subindex.");
+    return ElemIndices[SubIdx];
+  }
+
+  // mat[0][j] or mat[i][j] case.
+  // We need to dynamically index into the level 1 element indices
+  if (LazyTempElemIndicesArrayAlloca == nullptr) {
+    // The level 2 index is dynamic, use it to index a temporary array of the level 1 indices.
+    IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
+    ArrayType *ArrayTy = ArrayType::get(AllocaBuilder.getInt32Ty(), ElemIndices.size());
+    LazyTempElemIndicesArrayAlloca = AllocaBuilder.CreateAlloca(ArrayTy);
+  }
+
+  // Store level 1 indices in the temporary array
+  Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
+  for (unsigned SubIdx = 0; SubIdx < ElemIndices.size(); ++SubIdx) {
+    GEPIndices[1] = Builder.getInt32(SubIdx);
+    Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemIndicesArrayAlloca, GEPIndices);
+    Builder.CreateStore(ElemIndices[SubIdx], TempArrayElemPtr);
+  }
+
+  // Dynamically index using the subindex
+  GEPIndices[1] = SubIdxVal;
+  Value *ElemIdxPtr = Builder.CreateGEP(LazyTempElemIndicesArrayAlloca, GEPIndices);
+  return Builder.CreateLoad(ElemIdxPtr);
+}
+
+// Unless we are allowed to GEP directly into the lowered matrix,
+// we must load the vector in memory in order to read or write any elements.
+// If we're going to dynamically index, we need to copy the vector into a temporary array.
+// Further loadElem/storeElem calls depend on how we cached the matrix here.
+void HLMatrixSubscriptUseReplacer::cacheLoweredMatrix(bool ForDynamicIndexing, IRBuilder<> &Builder) {
+  // If we can GEP right into the lowered pointer, no need for caching
+  if (AllowLoweredPtrGEPs) return;
+
+  // Load without memory to register representation conversion,
+  // since the point is to mimic pointer semantics
+  TempLoweredMatrix = Builder.CreateLoad(LoweredPtr);
+
+  if (!ForDynamicIndexing) return;
+
+  // To handle mat[i] cases, we need to copy the matrix elements to
+  // an array which we can dynamically index.
+  VectorType *MatVecTy = cast<VectorType>(TempLoweredMatrix->getType());
+
+  // Lazily create the temporary array alloca
+  if (LazyTempElemArrayAlloca == nullptr) {
+    ArrayType *TempElemArrayTy = ArrayType::get(MatVecTy->getElementType(), MatVecTy->getNumElements());
+    IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
+    LazyTempElemArrayAlloca = AllocaBuilder.CreateAlloca(TempElemArrayTy);
+  }
+
+  // Copy the matrix elements to the temporary array
+  Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
+  for (unsigned ElemIdx = 0; ElemIdx < MatVecTy->getNumElements(); ++ElemIdx) {
+    Value *VecElem = Builder.CreateExtractElement(TempLoweredMatrix, static_cast<uint64_t>(ElemIdx));
+    GEPIndices[1] = Builder.getInt32(ElemIdx);
+    Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemArrayAlloca, GEPIndices);
+    Builder.CreateStore(VecElem, TempArrayElemPtr);
+  }
+
+  // Null out the vector form so we know to use the array
+  TempLoweredMatrix = nullptr;
+}
+
+Value *HLMatrixSubscriptUseReplacer::loadElem(Value *Idx, IRBuilder<> &Builder) {
+  if (AllowLoweredPtrGEPs) {
+    Value *ElemPtr = Builder.CreateGEP(LoweredPtr, { Builder.getInt32(0), Idx });
+    return Builder.CreateLoad(ElemPtr);
+  }
+  else if (TempLoweredMatrix == nullptr) {
+    DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
+
+    Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemArrayAlloca, { Builder.getInt32(0), Idx });
+    return Builder.CreateLoad(TempArrayElemPtr);
+  }
+  else {
+    DXASSERT_NOMSG(isa<ConstantInt>(Idx));
+    return Builder.CreateExtractElement(TempLoweredMatrix, Idx);
+  }
+}
+
+void HLMatrixSubscriptUseReplacer::storeElem(Value *Idx, Value *Elem, IRBuilder<> &Builder) {
+  if (AllowLoweredPtrGEPs) {
+    Value *ElemPtr = Builder.CreateGEP(LoweredPtr, { Builder.getInt32(0), Idx });
+    Builder.CreateStore(Elem, ElemPtr);
+  }
+  else if (TempLoweredMatrix == nullptr) {
+    DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
+
+    Value *GEPIndices[2] = { Builder.getInt32(0), Idx };
+    Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemArrayAlloca, GEPIndices);
+    Builder.CreateStore(Elem, TempArrayElemPtr);
+  }
+  else {
+    DXASSERT_NOMSG(isa<ConstantInt>(Idx));
+    TempLoweredMatrix = Builder.CreateInsertElement(TempLoweredMatrix, Elem, Idx);
+  }
+}
+
+Value *HLMatrixSubscriptUseReplacer::loadVector(IRBuilder<> &Builder) {
+  if (TempLoweredMatrix != nullptr) {
+    // We can optimize this as a shuffle
+    SmallVector<Constant*, 4> ShuffleIndices;
+    ShuffleIndices.reserve(ElemIndices.size());
+    for (Value *ElemIdx : ElemIndices)
+      ShuffleIndices.emplace_back(cast<Constant>(ElemIdx));
+    Constant* ShuffleVector = ConstantVector::get(ShuffleIndices);
+    return Builder.CreateShuffleVector(TempLoweredMatrix, TempLoweredMatrix, ShuffleVector);
+  }
+
+  // Otherwise load elements one by one
+  Type* ElemTy = LoweredPtr->getType()->getPointerElementType()->getScalarType();
+  VectorType *VecTy = VectorType::get(ElemTy, static_cast<unsigned>(ElemIndices.size()));
+  Value *Result = UndefValue::get(VecTy);
+  for (unsigned SubIdx = 0; SubIdx < ElemIndices.size(); ++SubIdx) {
+    Value *Elem = loadElem(ElemIndices[SubIdx], Builder);
+    Result = Builder.CreateInsertElement(Result, Elem, static_cast<uint64_t>(SubIdx));
+  }
+
+  return Result;
+}
+
+void HLMatrixSubscriptUseReplacer::storeVector(Value *Vec, IRBuilder<> &Builder) {
+  // We can't shuffle vectors of different sizes together, so insert one by one.
+  DXASSERT(Vec->getType()->getVectorNumElements() == ElemIndices.size(),
+    "Matrix subscript stored vector element count mismatch.");
+
+  for (unsigned SubIdx = 0; SubIdx < ElemIndices.size(); ++SubIdx) {
+    Value *Elem = Builder.CreateExtractElement(Vec, static_cast<uint64_t>(SubIdx));
+    storeElem(ElemIndices[SubIdx], Elem, Builder);
+  }
+}
+
+void HLMatrixSubscriptUseReplacer::flushLoweredMatrix(IRBuilder<> &Builder) {
+  // If GEPs are allowed, no flushing is necessary, we modified the source elements directly.
+  if (AllowLoweredPtrGEPs) return;
+
+  if (TempLoweredMatrix == nullptr) {
+    // First re-create the vector from the temporary array
+    DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
+
+    VectorType *LoweredMatrixTy = cast<VectorType>(LoweredPtr->getType()->getPointerElementType());
+    TempLoweredMatrix = UndefValue::get(LoweredMatrixTy);
+    Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
+    for (unsigned ElemIdx = 0; ElemIdx < LoweredMatrixTy->getNumElements(); ++ElemIdx) {
+      GEPIndices[1] = Builder.getInt32(ElemIdx);
+      Value *TempArrayElemPtr = Builder.CreateGEP(LazyTempElemArrayAlloca, GEPIndices);
+      Value *NewElem = Builder.CreateLoad(TempArrayElemPtr);
+      TempLoweredMatrix = Builder.CreateInsertElement(TempLoweredMatrix, NewElem, static_cast<uint64_t>(ElemIdx));
+    }
+  }
+
+  // Store back the lowered matrix to its pointer
+  Builder.CreateStore(TempLoweredMatrix, LoweredPtr);
+  TempLoweredMatrix = nullptr;
+}

+ 67 - 0
lib/HLSL/HLMatrixSubscriptUseReplacer.h

@@ -0,0 +1,67 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixSubscriptUseReplacer.h                                            //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include <vector>
+
+namespace llvm {
+class Value;
+class AllocaInst;
+class CallInst;
+class Instruction;
+class Function;
+} // namespace llvm
+
+namespace hlsl {
+// Implements recursive replacement of a matrix subscript's uses,
+// from a pointer to a matrix value to a pointer to its lowered vector version,
+// whether directly or through GEPs in the case of two-level indexing like mat[i][j].
+// This has to handle one or two levels of indices, each of which either
+// constant or dynamic: mat[0], mat[i], mat[0][0], mat[i][0], mat[0][j], mat[i][j],
+// plus the equivalent element accesses: mat._11, mat._11_12, mat._11_12[0], mat._11_12[i]
+class HLMatrixSubscriptUseReplacer {
+public:
+  // The constructor does everything
+  HLMatrixSubscriptUseReplacer(llvm::CallInst* Call, llvm::Value *LoweredPtr,
+    llvm::SmallVectorImpl<llvm::Value*> &ElemIndices, bool AllowLoweredPtrGEPs,
+    std::vector<llvm::Instruction*> &DeadInsts);
+
+private:
+  void replaceUses(llvm::Instruction* PtrInst, llvm::Value* SubIdxVal);
+  llvm::Value *tryGetScalarIndex(llvm::Value *SubIdxVal, llvm::IRBuilder<> &Builder);
+  void cacheLoweredMatrix(bool ForDynamicIndexing, llvm::IRBuilder<> &Builder);
+  llvm::Value *loadElem(llvm::Value *Idx, llvm::IRBuilder<> &Builder);
+  void storeElem(llvm::Value *Idx, llvm::Value *Elem, llvm::IRBuilder<> &Builder);
+  llvm::Value *loadVector(llvm::IRBuilder<> &Builder);
+  void storeVector(llvm::Value *Vec, llvm::IRBuilder<> &Builder);
+  void flushLoweredMatrix(llvm::IRBuilder<> &Builder);
+
+private:
+  llvm::Value *LoweredPtr;
+  llvm::SmallVectorImpl<llvm::Value*> &ElemIndices;
+  std::vector<llvm::Instruction*> &DeadInsts;
+  bool AllowLoweredPtrGEPs = false;
+  bool HasScalarResult = false;
+  bool HasDynamicElemIndex = false;
+
+  // The entire lowered matrix as loaded from LoweredPtr,
+  // nullptr if we copied it to a temporary array.
+  llvm::Value *TempLoweredMatrix = nullptr;
+
+  // We allocate this if the level 1 indices are not all constants,
+  // so we can dynamically index the lowered matrix vector.
+  llvm::AllocaInst *LazyTempElemArrayAlloca = nullptr;
+
+  // We'll allocate this lazily if we have a dynamic level 2 index (mat[0][i]),
+  // so we can dynamically index the level 1 indices.
+  llvm::AllocaInst *LazyTempElemIndicesArrayAlloca = nullptr;
+};
+} // namespace hlsl

+ 179 - 0
lib/HLSL/HLMatrixType.cpp

@@ -0,0 +1,179 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLMatrixType.cpp                                                          //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/HLSL/HLMatrixType.h"
+#include "dxc/Support/Global.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Value.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+HLMatrixType::HLMatrixType(Type *RegReprElemTy, unsigned NumRows, unsigned NumColumns)
+  : RegReprElemTy(RegReprElemTy), NumRows(NumRows), NumColumns(NumColumns) {
+  DXASSERT(RegReprElemTy != nullptr && (RegReprElemTy->isIntegerTy() || RegReprElemTy->isFloatingPointTy()),
+    "Invalid matrix element type.");
+  DXASSERT(NumRows >= 1 && NumRows <= 4 && NumColumns >= 1 && NumColumns <= 4,
+    "Invalid matrix dimensions.");
+}
+
+Type *HLMatrixType::getElementType(bool MemRepr) const {
+  // Bool i1s become i32s
+  return MemRepr && RegReprElemTy->isIntegerTy(1)
+    ? IntegerType::get(RegReprElemTy->getContext(), 32)
+    : RegReprElemTy;
+}
+
+unsigned HLMatrixType::getRowMajorIndex(unsigned RowIdx, unsigned ColIdx) const {
+  return getRowMajorIndex(RowIdx, ColIdx, NumRows, NumColumns);
+}
+
+unsigned HLMatrixType::getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx) const {
+  return getColumnMajorIndex(RowIdx, ColIdx, NumRows, NumColumns);
+}
+
+unsigned HLMatrixType::getRowMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns) {
+  DXASSERT_NOMSG(RowIdx < NumRows && ColIdx < NumColumns);
+  return RowIdx * NumColumns + ColIdx;
+}
+
+unsigned HLMatrixType::getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns) {
+  DXASSERT_NOMSG(RowIdx < NumRows && ColIdx < NumColumns);
+  return ColIdx * NumRows + RowIdx;
+}
+
+VectorType *HLMatrixType::getLoweredVectorType(bool MemRepr) const {
+  return VectorType::get(getElementType(MemRepr), getNumElements());
+}
+
+Value *HLMatrixType::emitLoweredMemToReg(Value *Val, IRBuilder<> &Builder) const {
+  DXASSERT(Val->getType()->getScalarType() == getElementTypeForMem(), "Lowered matrix type mismatch.");
+  if (RegReprElemTy->isIntegerTy(1)) {
+    Val = Builder.CreateICmpNE(Val, Constant::getNullValue(Val->getType()), "tobool");
+  }
+  return Val;
+}
+
+Value *HLMatrixType::emitLoweredRegToMem(Value *Val, IRBuilder<> &Builder) const {
+  DXASSERT(Val->getType()->getScalarType() == RegReprElemTy, "Lowered matrix type mismatch.");
+  if (RegReprElemTy->isIntegerTy(1)) {
+    Type *MemReprTy = Val->getType()->isVectorTy() ? getLoweredVectorTypeForMem() : getElementTypeForMem();
+    Val = Builder.CreateZExt(Val, MemReprTy, "frombool");
+  }
+  return Val;
+}
+
+Value *HLMatrixType::emitLoweredLoad(Value *Ptr, IRBuilder<> &Builder) const {
+  return emitLoweredMemToReg(Builder.CreateLoad(Ptr), Builder);
+}
+
+StoreInst *HLMatrixType::emitLoweredStore(Value *Val, Value *Ptr, IRBuilder<> &Builder) const {
+  return Builder.CreateStore(emitLoweredRegToMem(Val, Builder), Ptr);
+}
+
+Value *HLMatrixType::emitLoweredVectorRowToCol(Value *VecVal, IRBuilder<> &Builder) const {
+  DXASSERT(VecVal->getType() == getLoweredVectorTypeForReg(), "Lowered matrix type mismatch.");
+  if (NumRows == 1 || NumColumns == 1) return VecVal;
+
+  SmallVector<int, 16> ShuffleIndices;
+  for (unsigned ColIdx = 0; ColIdx < NumColumns; ++ColIdx)
+    for (unsigned RowIdx = 0; RowIdx < NumRows; ++RowIdx)
+      ShuffleIndices.emplace_back((int)getRowMajorIndex(RowIdx, ColIdx));
+  return Builder.CreateShuffleVector(VecVal, VecVal, ShuffleIndices, "row2col");
+}
+
+Value *HLMatrixType::emitLoweredVectorColToRow(Value *VecVal, IRBuilder<> &Builder) const {
+  DXASSERT(VecVal->getType() == getLoweredVectorTypeForReg(), "Lowered matrix type mismatch.");
+  if (NumRows == 1 || NumColumns == 1) return VecVal;
+
+  SmallVector<int, 16> ShuffleIndices;
+  for (unsigned RowIdx = 0; RowIdx < NumRows; ++RowIdx)
+    for (unsigned ColIdx = 0; ColIdx < NumColumns; ++ColIdx)
+      ShuffleIndices.emplace_back((int)getColumnMajorIndex(RowIdx, ColIdx));
+  return Builder.CreateShuffleVector(VecVal, VecVal, ShuffleIndices, "col2row");
+}
+
+bool HLMatrixType::isa(Type *Ty) {
+  StructType *StructTy = llvm::dyn_cast<StructType>(Ty);
+  return StructTy != nullptr && StructTy->getName().startswith(StructNamePrefix);
+}
+
+bool HLMatrixType::isMatrixPtr(Type *Ty) {
+  PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
+  return PtrTy != nullptr && isa(PtrTy->getElementType());
+}
+
+bool HLMatrixType::isMatrixArray(Type *Ty) {
+  ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty);
+  if (ArrayTy == nullptr) return false;
+  while (ArrayType *NestedArrayTy = llvm::dyn_cast<ArrayType>(ArrayTy->getElementType()))
+    ArrayTy = NestedArrayTy;
+  return isa(ArrayTy->getElementType());
+}
+
+bool HLMatrixType::isMatrixArrayPtr(Type *Ty) {
+  PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
+  if (PtrTy == nullptr) return false;
+  return isMatrixArray(PtrTy->getElementType());
+}
+
+bool HLMatrixType::isMatrixPtrOrArrayPtr(Type *Ty) {
+  PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
+  if (PtrTy == nullptr) return false;
+  Ty = PtrTy->getElementType();
+  while (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty))
+    Ty = Ty->getArrayElementType();
+  return isa(Ty);
+}
+
+bool HLMatrixType::isMatrixOrPtrOrArrayPtr(Type *Ty) {
+  if (PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty)) Ty = PtrTy->getElementType();
+  while (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty)) Ty = ArrayTy->getElementType();
+  return isa(Ty);
+}
+
+// Converts a matrix, matrix pointer, or matrix array pointer type to its lowered equivalent.
+// If the type is not matrix-derived, the original type is returned.
+// Does not lower struct types containing matrices.
+Type *HLMatrixType::getLoweredType(Type *Ty, bool MemRepr) {
+  if (PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty)) {
+    // Pointees are always in memory representation
+    Type *LoweredElemTy = getLoweredType(PtrTy->getElementType(), /* MemRepr */ true);
+    return LoweredElemTy == PtrTy->getElementType()
+      ? Ty : PointerType::get(LoweredElemTy, PtrTy->getAddressSpace());
+  }
+  else if (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty)) {
+    // Arrays are always in memory and so their elements are in memory representation
+    Type *LoweredElemTy = getLoweredType(ArrayTy->getElementType(), /* MemRepr */ true);
+    return LoweredElemTy == ArrayTy->getElementType()
+      ? Ty : ArrayType::get(LoweredElemTy, ArrayTy->getNumElements());
+  }
+  else if (HLMatrixType MatrixTy = HLMatrixType::dyn_cast(Ty)) {
+    return MatrixTy.getLoweredVectorType(MemRepr);
+  }
+  else return Ty;
+}
+
+HLMatrixType HLMatrixType::cast(Type *Ty) {
+  DXASSERT_NOMSG(isa(Ty));
+  StructType *StructTy = llvm::cast<StructType>(Ty);
+  DXASSERT_NOMSG(Ty->getNumContainedTypes() == 1);
+  ArrayType *RowArrayTy = llvm::cast<ArrayType>(StructTy->getElementType(0));
+  DXASSERT_NOMSG(RowArrayTy->getNumElements() >= 1 && RowArrayTy->getNumElements() <= 4);
+  VectorType *RowTy = llvm::cast<VectorType>(RowArrayTy->getElementType());
+  DXASSERT_NOMSG(RowTy->getNumElements() >= 1 && RowTy->getNumElements() <= 4);
+  return HLMatrixType(RowTy->getElementType(), RowArrayTy->getNumElements(), RowTy->getNumElements());
+}
+
+HLMatrixType HLMatrixType::dyn_cast(Type *Ty) {
+  return isa(Ty) ? cast(Ty) : HLMatrixType();
+}

+ 85 - 57
lib/HLSL/HLOperationLower.cpp

@@ -16,6 +16,7 @@
 #include "dxc/DXIL/DxilModule.h"
 #include "dxc/DXIL/DxilOperations.h"
 #include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/HLSL/HLOperationLower.h"
@@ -4985,10 +4986,9 @@ Value *GenerateCBLoad(Value *handle, Value *offset, Type *EltTy, OP *hlslOP,
 Value *TranslateConstBufMatLd(Type *matType, Value *handle, Value *offset,
                               bool colMajor, OP *OP, const DataLayout &DL,
                               IRBuilder<> &Builder) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
-  unsigned matSize = col * row;
+  HLMatrixType MatTy = HLMatrixType::cast(matType);
+  Type *EltTy = MatTy.getElementTypeForMem();
+  unsigned matSize = MatTy.getNumElements();
   std::vector<Value *> elts(matSize);
   Value *EltByteSize = ConstantInt::get(
       offset->getType(), GetEltTypeByteSizeForConstBuf(EltTy, DL));
@@ -5000,8 +5000,8 @@ Value *TranslateConstBufMatLd(Type *matType, Value *handle, Value *offset,
     baseOffset = Builder.CreateAdd(baseOffset, EltByteSize);
   }
 
-  Value* Vec = HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
-  Vec = HLMatrixLower::VecMatrixMemToReg(Vec, matType, Builder);
+  Value* Vec = HLMatrixLower::BuildVector(EltTy, elts, Builder);
+  Vec = MatTy.emitLoweredMemToReg(Vec, Builder);
   return Vec;
 }
 
@@ -5069,9 +5069,8 @@ void TranslateCBAddressUser(Instruction *user, Value *handle, Value *baseOffset,
     } else if (group == HLOpcodeGroup::HLSubscript) {
       HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
       Value *basePtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-      Type *matType = basePtr->getType()->getPointerElementType();
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
+      HLMatrixType MatTy = HLMatrixType::cast(basePtr->getType()->getPointerElementType());
+      Type *EltTy = MatTy.getElementTypeForReg();
 
       Value *EltByteSize = ConstantInt::get(
           baseOffset->getType(), GetEltTypeByteSizeForConstBuf(EltTy, DL));
@@ -5397,26 +5396,24 @@ Value *GenerateCBLoadLegacy(Value *handle, Value *legacyIdx,
   }
 }
 
-Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
+Value *TranslateConstBufMatLdLegacy(HLMatrixType MatTy, Value *handle,
                                     Value *legacyIdx, bool colMajor, OP *OP,
                                     bool memElemRepr, const DataLayout &DL,
                                     IRBuilder<> &Builder) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
+  Type *EltTy = MatTy.getElementTypeForMem();
 
-  unsigned matSize = col * row;
+  unsigned matSize = MatTy.getNumElements();
   std::vector<Value *> elts(matSize);
   unsigned EltByteSize = GetEltTypeByteSizeForConstBuf(EltTy, DL);
   if (colMajor) {
     unsigned colByteSize = 4 * EltByteSize;
     unsigned colRegSize = (colByteSize + 15) >> 4;
-    for (unsigned c = 0; c < col; c++) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
       Value *col = GenerateCBLoadLegacy(handle, legacyIdx, /*channelOffset*/ 0,
-                                        EltTy, row, OP, Builder);
+                                        EltTy, MatTy.getNumRows(), OP, Builder);
 
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+        unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
         elts[matIdx] = Builder.CreateExtractElement(col, r);
       }
       // Update offset for a column.
@@ -5425,11 +5422,11 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
   } else {
     unsigned rowByteSize = 4 * EltByteSize;
     unsigned rowRegSize = (rowByteSize + 15) >> 4;
-    for (unsigned r = 0; r < row; r++) {
+    for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
       Value *row = GenerateCBLoadLegacy(handle, legacyIdx, /*channelOffset*/ 0,
-                                        EltTy, col, OP, Builder);
-      for (unsigned c = 0; c < col; c++) {
-        unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
+                                        EltTy, MatTy.getNumColumns(), OP, Builder);
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+        unsigned matIdx = MatTy.getRowMajorIndex(r, c);
         elts[matIdx] = Builder.CreateExtractElement(row, c);
       }
       // Update offset for a row.
@@ -5437,9 +5434,9 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
     }
   }
 
-  Value *Vec = HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
+  Value *Vec = HLMatrixLower::BuildVector(EltTy, elts, Builder);
   if (!memElemRepr)
-    Vec = HLMatrixLower::VecMatrixMemToReg(Vec, matType, Builder);
+    Vec = MatTy.emitLoweredMemToReg(Vec, Builder);
   return Vec;
 }
 
@@ -5489,20 +5486,19 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       DXASSERT(matOp == HLMatLoadStoreOpcode::ColMatLoad ||
                    matOp == HLMatLoadStoreOpcode::RowMatLoad,
                "No store on cbuffer");
-      Type *matType = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
-                          ->getType()
-                          ->getPointerElementType();
+      HLMatrixType MatTy = HLMatrixType::cast(
+        CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
+          ->getType()->getPointerElementType());
       // This will replace a call, so we should use the register representation of elements
       Value *newLd = TranslateConstBufMatLdLegacy(
-          matType, handle, legacyIdx, colMajor, hlslOP, /*memElemRepr*/false, DL, Builder);
+        MatTy, handle, legacyIdx, colMajor, hlslOP, /*memElemRepr*/false, DL, Builder);
       CI->replaceAllUsesWith(newLd);
       CI->eraseFromParent();
     } else if (group == HLOpcodeGroup::HLSubscript) {
       HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
       Value *basePtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-      Type *matType = basePtr->getType()->getPointerElementType();
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
+      HLMatrixType MatTy = HLMatrixType::cast(basePtr->getType()->getPointerElementType());
+      Type *EltTy = MatTy.getElementTypeForReg();
 
       Value *idx = CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
 
@@ -5523,7 +5519,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       if (!dynamicIndexing) {
         // This will replace a load or GEP, so we should use the memory representation of elements
         Value *matLd = TranslateConstBufMatLdLegacy(
-            matType, handle, legacyIdx, colMajor, hlslOP, /*memElemRepr*/true, DL, Builder);
+          MatTy, handle, legacyIdx, colMajor, hlslOP, /*memElemRepr*/true, DL, Builder);
         // The matLd is keep original layout, just use the idx calc in
         // EmitHLSLMatrixElement and EmitHLSLMatrixSubscript.
         switch (subOp) {
@@ -5568,7 +5564,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
           // row.z = c[2].[idx]
           // row.w = c[3].[idx]
           Value *Elts[4];
-          ArrayType *AT = ArrayType::get(EltTy, col);
+          ArrayType *AT = ArrayType::get(EltTy, MatTy.getNumColumns());
 
           IRBuilder<> AllocaBuilder(user->getParent()
                                         ->getParent()
@@ -5578,12 +5574,12 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
           Value *tempArray = AllocaBuilder.CreateAlloca(AT);
           Value *zero = AllocaBuilder.getInt32(0);
           Value *cbufIdx = legacyIdx;
-          for (unsigned int c = 0; c < col; c++) {
+          for (unsigned int c = 0; c < MatTy.getNumColumns(); c++) {
             Value *ColVal =
                 GenerateCBLoadLegacy(handle, cbufIdx, /*channelOffset*/ 0,
-                                     EltTy, row, hlslOP, Builder);
+                                     EltTy, MatTy.getNumRows(), hlslOP, Builder);
             // Convert ColVal to array for indexing.
-            for (unsigned int r = 0; r < row; r++) {
+            for (unsigned int r = 0; r < MatTy.getNumRows(); r++) {
               Value *Elt =
                   Builder.CreateExtractElement(ColVal, Builder.getInt32(r));
               Value *Ptr = Builder.CreateInBoundsGEP(
@@ -5597,7 +5593,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
             cbufIdx = Builder.CreateAdd(cbufIdx, one);
           }
           if (resultType->isVectorTy()) {
-            for (unsigned int c = 0; c < col; c++) {
+            for (unsigned int c = 0; c < MatTy.getNumColumns(); c++) {
               ldData = Builder.CreateInsertElement(ldData, Elts[c], c);
             }
           } else {
@@ -5606,12 +5602,12 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
         } else {
           // idx is r * col + c;
           // r = idx / col;
-          Value *cCol = ConstantInt::get(idx->getType(), col);
+          Value *cCol = ConstantInt::get(idx->getType(), MatTy.getNumColumns());
           idx = Builder.CreateUDiv(idx, cCol);
           idx = Builder.CreateAdd(idx, legacyIdx);
           // Just return a row; 'col' is the number of columns in the row.
           ldData = GenerateCBLoadLegacy(handle, idx, /*channelOffset*/ 0, EltTy,
-                                        col, hlslOP, Builder);
+            MatTy.getNumColumns(), hlslOP, Builder);
         }
         if (!resultType->isVectorTy()) {
           ldData = Builder.CreateExtractElement(ldData, Builder.getInt32(0));
@@ -5989,9 +5985,8 @@ Value *TranslateStructBufMatLd(Type *matType, IRBuilder<> &Builder,
                                Value *handle, hlsl::OP *OP, Value *status,
                                Value *bufIdx, Value *baseOffset,
                                const DataLayout &DL) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
+  HLMatrixType MatTy = HLMatrixType::cast(matType);
+  Type *EltTy = MatTy.getElementTypeForMem();
   unsigned  EltSize = DL.getTypeAllocSize(EltTy);
   Constant* alignment = OP->GetI32Const(EltSize);
 
@@ -5999,7 +5994,7 @@ Value *TranslateStructBufMatLd(Type *matType, IRBuilder<> &Builder,
   if (baseOffset == nullptr)
     offset = OP->GetU32Const(0);
 
-  unsigned matSize = col * row;
+  unsigned matSize = MatTy.getNumElements();
   std::vector<Value *> elts(matSize);
 
   unsigned rest = (matSize % 4);
@@ -6023,19 +6018,18 @@ Value *TranslateStructBufMatLd(Type *matType, IRBuilder<> &Builder,
     offset = Builder.CreateAdd(offset, OP->GetU32Const(4 * EltSize));
   }
 
-  Value *Vec = HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
-  Vec = HLMatrixLower::VecMatrixMemToReg(Vec, matType, Builder);
+  Value *Vec = HLMatrixLower::BuildVector(EltTy, elts, Builder);
+  Vec = MatTy.emitLoweredMemToReg(Vec, Builder);
   return Vec;
 }
 
 void TranslateStructBufMatSt(Type *matType, IRBuilder<> &Builder, Value *handle,
                              hlsl::OP *OP, Value *bufIdx, Value *baseOffset,
                              Value *val, const DataLayout &DL) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
+  HLMatrixType MatTy = HLMatrixType::cast(matType);
+  Type *EltTy = MatTy.getElementTypeForMem();
 
-  val = HLMatrixLower::VecMatrixRegToMem(val, matType, Builder);
+  val = MatTy.emitLoweredRegToMem(val, Builder);
 
   unsigned EltSize = DL.getTypeAllocSize(EltTy);
   Constant *Alignment = OP->GetI32Const(EltSize);
@@ -6043,7 +6037,7 @@ void TranslateStructBufMatSt(Type *matType, IRBuilder<> &Builder, Value *handle,
   if (baseOffset == nullptr)
     offset = OP->GetU32Const(0);
 
-  unsigned matSize = col * row;
+  unsigned matSize = MatTy.getNumElements();
   Value *undefElt = UndefValue::get(EltTy);
 
   unsigned storeSize = matSize;
@@ -6106,6 +6100,41 @@ void TranslateStructBufSubscriptUser(Instruction *user, Value *handle,
                                      Value *bufIdx, Value *baseOffset,
                                      Value *status, hlsl::OP *OP, const DataLayout &DL);
 
+// For case like mat[i][j].
+// IdxList is [i][0], [i][1], [i][2],[i][3].
+// Idx is j.
+// return [i][j] not mat[i][j] because resource ptr and temp ptr need different
+// code gen.
+static Value *LowerGEPOnMatIndexListToIndex(
+  llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
+  IRBuilder<> Builder(GEP);
+  Value *zero = Builder.getInt32(0);
+  DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
+  Value *baseIdx = (GEP->idx_begin())->get();
+  DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
+  Value *Idx = (GEP->idx_begin() + 1)->get();
+
+  if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
+    return IdxList[immIdx->getSExtValue()];
+  }
+  else {
+    IRBuilder<> AllocaBuilder(
+      GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
+    unsigned size = IdxList.size();
+    // Store idxList to temp array.
+    ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
+    Value *tempArray = AllocaBuilder.CreateAlloca(AT);
+
+    for (unsigned i = 0; i < size; i++) {
+      Value *EltPtr = Builder.CreateGEP(tempArray, { zero, Builder.getInt32(i) });
+      Builder.CreateStore(IdxList[i], EltPtr);
+    }
+    // Load the idx.
+    Value *GEPOffset = Builder.CreateGEP(tempArray, { zero, Idx });
+    return Builder.CreateLoad(GEPOffset);
+  }
+}
+
 // subscript operator for matrix of struct element.
 void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
                                     hlsl::OP *hlslOP, Value *bufIdx,
@@ -6118,9 +6147,8 @@ void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
   IRBuilder<> subBuilder(CI);
   HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
   Value *basePtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-  Type *matType = basePtr->getType()->getPointerElementType();
-  unsigned col, row;
-  Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
+  HLMatrixType MatTy = HLMatrixType::cast(basePtr->getType()->getPointerElementType());
+  Type *EltTy = MatTy.getElementTypeForReg();
   Constant *alignment = hlslOP->GetI32Const(DL.getTypeAllocSize(EltTy));
 
   Value *EltByteSize = ConstantInt::get(
@@ -6170,8 +6198,7 @@ void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
       continue;
     }
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-      Value *GEPOffset =
-          HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
+      Value *GEPOffset = LowerGEPOnMatIndexListToIndex(GEP, idxList);
 
       for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
         Instruction *gepUserInst = cast<Instruction>(*(gepU++));
@@ -6998,8 +7025,9 @@ void TranslateHLBuiltinOperation(Function *F, HLOperationLowerHelper &helper,
               /*bOrigAllocaTy*/false,
               matVal->getName());
             if (bTranspose) {
-              unsigned row, col;
-              HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
+              HLMatrixType MatTy = HLMatrixType::cast(matVal->getType());
+              unsigned row = MatTy.getNumRows();
+              unsigned col = MatTy.getNumColumns();
               if (bColDest) std::swap(row, col);
               vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
             }

+ 0 - 1
lib/HLSL/HLOperationLowerExtension.cpp

@@ -11,7 +11,6 @@
 
 #include "dxc/DXIL/DxilModule.h"
 #include "dxc/DXIL/DxilOperations.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLOperationLower.h"
 #include "dxc/HLSL/HLOperations.h"

+ 42 - 51
lib/HLSL/HLSignatureLower.cpp

@@ -18,6 +18,7 @@
 #include "dxc/DXIL/DxilSemantic.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/HLSL/DxilPackSignatureElement.h"
@@ -645,49 +646,45 @@ void replaceDirectInputParameter(Value *param, Function *loadInput,
     switch (matOp) {
     case HLCastOpcode::ColMatrixToVecCast: {
       IRBuilder<> LocalBuilder(CI);
-      Type *matTy =
-          CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType();
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-      std::vector<Value *> matElts(col * row);
-      for (unsigned c = 0; c < col; c++) {
+      HLMatrixType MatTy = HLMatrixType::cast(
+          CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType());
+      Type *EltTy = MatTy.getElementTypeForReg();
+      std::vector<Value *> matElts(MatTy.getNumElements());
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
         Value *rowIdx = hlslOP->GetI32Const(c);
         args[DXIL::OperandIndex::kLoadInputRowOpIdx] = rowIdx;
-        for (unsigned r = 0; r < row; r++) {
+        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
           Value *colIdx = hlslOP->GetU8Const(r);
           args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
           Value *input =
               GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
-          unsigned matIdx = c * row + r;
-          matElts[matIdx] = input;
+          matElts[MatTy.getColumnMajorIndex(r, c)] = input;
         }
       }
       Value *newVec =
-          HLMatrixLower::BuildVector(EltTy, col * row, matElts, LocalBuilder);
+          HLMatrixLower::BuildVector(EltTy, matElts, LocalBuilder);
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
     } break;
     case HLCastOpcode::RowMatrixToVecCast: {
       IRBuilder<> LocalBuilder(CI);
-      Type *matTy =
-          CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType();
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-      std::vector<Value *> matElts(col * row);
-      for (unsigned r = 0; r < row; r++) {
+      HLMatrixType MatTy = HLMatrixType::cast(
+          CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType());
+      Type *EltTy = MatTy.getElementTypeForReg();
+      std::vector<Value *> matElts(MatTy.getNumElements());
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
         Value *rowIdx = hlslOP->GetI32Const(r);
         args[DXIL::OperandIndex::kLoadInputRowOpIdx] = rowIdx;
-        for (unsigned c = 0; c < col; c++) {
+        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
           Value *colIdx = hlslOP->GetU8Const(c);
           args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
           Value *input =
               GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
-          unsigned matIdx = r * col + c;
-          matElts[matIdx] = input;
+          matElts[MatTy.getRowMajorIndex(r, c)] = input;
         }
       }
       Value *newVec =
-          HLMatrixLower::BuildVector(EltTy, col * row, matElts, LocalBuilder);
+          HLMatrixLower::BuildVector(EltTy, matElts, LocalBuilder);
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
     } break;
@@ -784,12 +781,10 @@ void collectInputOutputAccessInfo(
               vectorIdx = GEPIt.getOperand();
             }
           }
-          if (dxilutil::IsHLSLMatrixType(*GEPIt)) {
-            unsigned row, col;
-            HLMatrixLower::GetMatrixInfo(*GEPIt, col, row);
-            Constant *arraySize = ConstantInt::get(idxTy, col);
+          if (HLMatrixType MatTy = HLMatrixType::dyn_cast(*GEPIt)) {
+            Constant *arraySize = ConstantInt::get(idxTy, MatTy.getNumColumns());
             if (bRowMajor) {
-              arraySize = ConstantInt::get(idxTy, row);
+              arraySize = ConstantInt::get(idxTy, MatTy.getNumRows());
             }
             rowIdx = Builder.CreateMul(rowIdx, arraySize);
           }
@@ -915,46 +910,44 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
     case HLMatLoadStoreOpcode::ColMatLoad:
     case HLMatLoadStoreOpcode::RowMatLoad: {
       IRBuilder<> LocalBuilder(CI);
-      Type *matTy = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
-                        ->getType()
-                        ->getPointerElementType();
-      unsigned col, row;
-      HLMatrixLower::GetMatrixInfo(matTy, col, row);
-      std::vector<Value *> matElts(col * row);
+      HLMatrixType MatTy = HLMatrixType::cast(
+        CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
+          ->getType()->getPointerElementType());
+      std::vector<Value *> matElts(MatTy.getNumElements());
 
       if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
-        for (unsigned c = 0; c < col; c++) {
+        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
           Constant *constRowIdx = LocalBuilder.getInt32(c);
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned r = 0; r < row; r++) {
+          for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
             SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
             if (vertexID)
               args.emplace_back(vertexID);
 
             Value *input = LocalBuilder.CreateCall(ldStFunc, args);
-            unsigned matIdx = c * row + r;
+            unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
             matElts[matIdx] = input;
           }
         }
       } else {
-        for (unsigned r = 0; r < row; r++) {
+        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
           Constant *constRowIdx = LocalBuilder.getInt32(r);
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned c = 0; c < col; c++) {
+          for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
             SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
             if (vertexID)
               args.emplace_back(vertexID);
 
             Value *input = LocalBuilder.CreateCall(ldStFunc, args);
-            unsigned matIdx = r * col + c;
+            unsigned matIdx = MatTy.getRowMajorIndex(r, c);
             matElts[matIdx] = input;
           }
         }
       }
 
       Value *newVec =
-          HLMatrixLower::BuildVector(matElts[0]->getType(), col * row, matElts, LocalBuilder);
-      newVec = HLMatrixLower::VecMatrixMemToReg(newVec, matTy, LocalBuilder);
+          HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
+      newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
 
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
@@ -963,32 +956,30 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
     case HLMatLoadStoreOpcode::RowMatStore: {
       IRBuilder<> LocalBuilder(CI);
       Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-      Type *matTy = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
-                        ->getType()
-                        ->getPointerElementType();
-      unsigned col, row;
-      HLMatrixLower::GetMatrixInfo(matTy, col, row);
+      HLMatrixType MatTy = HLMatrixType::cast(
+        CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
+          ->getType()->getPointerElementType());
 
-      Val = HLMatrixLower::VecMatrixRegToMem(Val, matTy, LocalBuilder);
+      Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
 
       if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
-        for (unsigned c = 0; c < col; c++) {
+        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
           Constant *constColIdx = LocalBuilder.getInt32(c);
           Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
 
-          for (unsigned r = 0; r < row; r++) {
-            unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
+          for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+            unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
             Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
             LocalBuilder.CreateCall(ldStFunc,
               { OpArg, ID, colIdx, columnConsts[r], Elt });
           }
         }
       } else {
-        for (unsigned r = 0; r < row; r++) {
+        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
           Constant *constRowIdx = LocalBuilder.getInt32(r);
           Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned c = 0; c < col; c++) {
-            unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
+          for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+            unsigned matIdx = MatTy.getRowMajorIndex(r, c);
             Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
             LocalBuilder.CreateCall(ldStFunc,
               { OpArg, ID, rowIdx, columnConsts[c], Elt });

+ 7 - 1
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -236,12 +236,18 @@ static void addHLSLPasses(bool HLSLHighLevel, unsigned OptLevel, hlsl::HLSLExten
     // Do this before change vector to array.
     MPM.add(createDxilLegalizeEvalOperationsPass());
   }
+  else {
+    // This should go between matrix lower and dynamic indexing vector to array,
+    // because matrix lower may create dynamically indexed global vectors,
+    // which should become locals. If they are turned into arrays first,
+    // this pass will ignore them as it only works on scalars and vectors.
+    MPM.add(createLowerStaticGlobalIntoAlloca());
+  }
 
   // Change dynamic indexing vector to array.
   MPM.add(createDynamicIndexingVectorToArrayPass(NoOpt));
 
   if (!NoOpt) {
-    MPM.add(createLowerStaticGlobalIntoAlloca());
     // mem2reg
     MPM.add(createPromoteMemoryToRegisterPass());
 

+ 36 - 25
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -56,6 +56,7 @@
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/DXIL/DxilTypeSystem.h"
 #include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/DXIL/DxilOperations.h"
 #include <deque>
 #include <unordered_map>
@@ -2568,6 +2569,24 @@ static void DeleteMemcpy(MemCpyInst *MI) {
   }
 }
 
+// If user is function call, return param annotation to get matrix major.
+static DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
+  DxilTypeSystem &typeSys) {
+  for (User *U : Mat->users()) {
+    if (CallInst *CI = dyn_cast<CallInst>(U)) {
+      Function *F = CI->getCalledFunction();
+      if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
+        for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
+          if (CI->getArgOperand(i) == Mat) {
+            return &Anno->GetParameterAnnotation(i);
+          }
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
 void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
                                  DxilFieldAnnotation *fieldAnnotation,
                                  DxilTypeSystem &typeSys, const bool bEltMemCpy) {
@@ -2597,7 +2616,7 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   if (!fieldAnnotation) {
     Type *EltTy = dxilutil::GetArrayEltTy(DestTy);
     if (dxilutil::IsHLSLMatrixType(EltTy)) {
-      fieldAnnotation = HLMatrixLower::FindAnnotationFromMatUser(Dest, typeSys);
+      fieldAnnotation = FindAnnotationFromMatUser(Dest, typeSys);
     }
   }
 
@@ -4872,23 +4891,19 @@ static void CopyEltsPtrToVectorPtr(ArrayRef<Value *> elts, Value *VecPtr,
 static void CopyMatToArrayPtr(Value *Mat, Value *ArrayPtr,
                               unsigned arrayBaseIdx, HLModule &HLM,
                               IRBuilder<> &Builder, bool bRowMajor) {
-  Type *Ty = Mat->getType();
   // Mat val is row major.
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(Mat->getType(), col, row);
-  Type *VecTy = HLMatrixLower::LowerMatrixType(Ty);
+  HLMatrixType MatTy = HLMatrixType::cast(Mat->getType());
+  Type *VecTy = MatTy.getLoweredVectorTypeForReg();
   Value *Vec =
       HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
                               (unsigned)HLCastOpcode::RowMatrixToVecCast, VecTy,
                               {Mat}, *HLM.GetModule());
   Value *zero = Builder.getInt32(0);
 
-  for (unsigned r = 0; r < row; r++) {
-    for (unsigned c = 0; c < col; c++) {
-      unsigned rowMatIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
-      Value *Elt = Builder.CreateExtractElement(Vec, rowMatIdx);
-      unsigned matIdx =
-          bRowMajor ? rowMatIdx :  HLMatrixLower::GetColMajorIdx(r, c, row);
+  for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
+      Value *Elt = Builder.CreateExtractElement(Vec, matIdx);
       Value *Ptr = Builder.CreateInBoundsGEP(
           ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
       Builder.CreateStore(Elt, Ptr);
@@ -4919,15 +4934,15 @@ static void CopyMatPtrToArrayPtr(Value *MatPtr, Value *ArrayPtr,
 static Value *LoadArrayPtrToMat(Value *ArrayPtr, unsigned arrayBaseIdx,
                                 Type *Ty, HLModule &HLM, IRBuilder<> &Builder,
                                 bool bRowMajor) {
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(Ty, col, row);
+  HLMatrixType MatTy = HLMatrixType::cast(Ty);
   // HLInit operands are in row major.
   SmallVector<Value *, 16> Elts;
   Value *zero = Builder.getInt32(0);
-  for (unsigned r = 0; r < row; r++) {
-    for (unsigned c = 0; c < col; c++) {
-      unsigned matIdx = bRowMajor ? HLMatrixLower::GetRowMajorIdx(r, c, col)
-                                  : HLMatrixLower::GetColMajorIdx(r, c, row);
+  for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      unsigned matIdx = bRowMajor
+        ? MatTy.getRowMajorIndex(r, c)
+        : MatTy.getColumnMajorIndex(r, c);
       Value *Ptr = Builder.CreateInBoundsGEP(
           ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
       Value *Elt = Builder.CreateLoad(Ptr);
@@ -4981,12 +4996,10 @@ CastCopyArrayMultiDimTo1Dim(Value *FromArray, Value *ToArray, Type *CurFromTy,
       Value *Elt = Builder.CreateExtractElement(V, i);
       Builder.CreateStore(Elt, ToPtr);
     }
-  } else if (dxilutil::IsHLSLMatrixType(CurFromTy)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(CurFromTy)) {
     // Copy matrix to array.
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(CurFromTy, col, row);
     // Calculate the offset.
-    unsigned offset = calcIdx * col * row;
+    unsigned offset = calcIdx * MatTy.getNumElements();
     Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
     CopyMatPtrToArrayPtr(FromPtr, ToArray, offset, HLM, Builder, bRowMajor);
   } else if (!CurFromTy->isArrayTy()) {
@@ -5028,12 +5041,10 @@ CastCopyArray1DimToMultiDim(Value *FromArray, Value *ToArray, Type *CurToTy,
       V = Builder.CreateInsertElement(V, Elt, i);
     }
     Builder.CreateStore(V, ToPtr);
-  } else if (dxilutil::IsHLSLMatrixType(CurToTy)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::cast(CurToTy)) {
     // Copy array to matrix.
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(CurToTy, col, row);
     // Calculate the offset.
-    unsigned offset = calcIdx * col * row;
+    unsigned offset = calcIdx * MatTy.getNumElements();
     Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
     CopyArrayPtrToMatPtr(FromArray, offset, ToPtr, HLM, Builder, bRowMajor);
   } else if (!CurToTy->isArrayTy()) {

+ 20 - 26
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -15,7 +15,7 @@
 #include "CodeGenModule.h"
 #include "CGRecordLayout.h"
 #include "dxc/HlslIntrinsicOp.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/HLSL/HLOperations.h"
@@ -981,8 +981,7 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
   unsigned size = dataLayout.getTypeAllocSize(Type);
 
   if (IsHLSLMatType(Ty)) {
-    unsigned col, row;
-    llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Type, col, row);
+    llvm::Type *EltTy = HLMatrixType::cast(Type).getElementTypeForReg();
     bool b64Bit = dataLayout.getTypeAllocSize(EltTy) == 8;
     size = GetMatrixSizeInCB(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor,
                              b64Bit);
@@ -5331,13 +5330,12 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
       }
     }
   } else {
-    if (dxilutil::IsHLSLMatrixType(valTy)) {
-      unsigned col, row;
-      llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(valTy, col, row);
+    if (HLMatrixType MatTy = HLMatrixType::dyn_cast(valTy)) {
+      llvm::Type *EltTy = MatTy.getElementTypeForReg();
       // All matrix Value should be row major.
       // Init list is row major in scalar.
       // So the order is match here, just cast to vector.
-      unsigned matSize = col * row;
+      unsigned matSize = MatTy.getNumElements();
       bool isRowMajor = hlsl::IsHLSLMatRowMajor(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
 
       HLCastOpcode opcode = isRowMajor ? HLCastOpcode::RowMatrixToVecCast
@@ -5477,18 +5475,16 @@ static void StoreInitListToDestPtr(Value *DestPtr,
     Result = CGF.EmitToMemory(Result, Type);
     Builder.CreateStore(Result, DestPtr);
     idx += Ty->getVectorNumElements();
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
     bool isRowMajor = hlsl::IsHLSLMatRowMajor(Type, bDefaultRowMajor);
-    unsigned row, col;
-    HLMatrixLower::GetMatrixInfo(Ty, col, row);
-    std::vector<Value *> matInitList(col * row);
-    for (unsigned i = 0; i < col; i++) {
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = i * row + r;
+    std::vector<Value *> matInitList(MatTy.getNumElements());
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+        unsigned matIdx = c * MatTy.getNumRows() + r;
         matInitList[matIdx] = elts[idx + matIdx];
       }
     }
-    idx += row * col;
+    idx += MatTy.getNumElements();
     Value *matVal =
         EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLInit,
                                        /*opcode*/ 0, Ty, matInitList, M);
@@ -6518,10 +6514,9 @@ void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
                                  GepList, EltTyList);
 
     idxList.pop_back();
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
     // Use matLd/St for matrix.
-    unsigned col, row;
-    llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
+    llvm::Type *EltTy = MatTy.getElementTypeForReg();
     llvm::PointerType *EltPtrTy =
         llvm::PointerType::get(EltTy, Ptr->getType()->getPointerAddressSpace());
     QualType EltQualTy = hlsl::GetHLSLMatElementType(Type);
@@ -6529,8 +6524,8 @@ void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
     Value *matPtr = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
 
     // Flatten matrix to elements.
-    for (unsigned r = 0; r < row; r++) {
-      for (unsigned c = 0; c < col; c++) {
+    for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
         ConstantInt *cRow = CGF.Builder.getInt32(r);
         ConstantInt *cCol = CGF.Builder.getInt32(c);
         Constant *CV = llvm::ConstantVector::get({cRow, cCol});
@@ -6694,7 +6689,7 @@ void CGMSHLSLRuntime::EmitHLSLAggregateCopy(
     // Memcpy struct.
     CGF.Builder.CreateMemCpy(dstGEP, srcGEP, size, 1);
   } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
-    if (!HLMatrixLower::IsMatrixArrayPointer(llvm::PointerType::get(Ty,0))) {
+    if (!HLMatrixType::isMatrixArray(Ty)) {
       Value *srcGEP = CGF.Builder.CreateInBoundsGEP(SrcPtr, idxList);
       Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
       unsigned size = this->TheModule.getDataLayout().getTypeAllocSize(AT);
@@ -6775,7 +6770,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF,
   bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
   if (SrcPtrTy == DestPtrTy) {
     bool bMatArrayRotate = false;
-    if (HLMatrixLower::IsMatrixArrayPointer(SrcPtr->getType())) {
+    if (HLMatrixType::isMatrixArrayPtr(SrcPtr->getType())) {
       QualType SrcEltTy = GetArrayEltType(SrcTy);
       QualType DestEltTy = GetArrayEltType(DestTy);
       if (GetMatrixMajor(SrcEltTy, bDefaultRowMajor) !=
@@ -6871,11 +6866,10 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
                                       SrcType, PT->getElementType());
 
     idxList.pop_back();
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
     // Use matLd/St for matrix.
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
-    unsigned row, col;
-    llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
+    llvm::Type *EltTy = MatTy.getElementTypeForReg();
 
     llvm::VectorType *VT1 = llvm::VectorType::get(EltTy, 1);
     SrcVal = ConvertScalarOrVector(CGF, SrcVal, SrcType, hlsl::GetHLSLMatElementType(Type));
@@ -6883,7 +6877,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
     // Splat the value
     Value *V1 = CGF.Builder.CreateInsertElement(UndefValue::get(VT1), SrcVal,
                                                 (uint64_t)0);
-    std::vector<int> shufIdx(col * row, 0);
+    std::vector<int> shufIdx(MatTy.getNumElements(), 0);
     Value *VecMat = CGF.Builder.CreateShuffleVector(V1, V1, shufIdx);
     Value *MatInit = EmitHLSLMatrixOperationCallImp(
         CGF.Builder, HLOpcodeGroup::HLInit, 0, Ty, {VecMat}, TheModule);

+ 2 - 1
tools/clang/test/CodeGenHLSL/RValSubscript.hlsl

@@ -1,5 +1,7 @@
 // RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
 
+// CHECK: alloca [16 x i32]
+
 // For b4[2]
 // CHECK: cbufferLoadLegacy
 // CHECK: i32 5)
@@ -47,7 +49,6 @@
 // CHECK: fcmp fast oeq
 // CHECK: fcmp fast oeq
 // CHECK: fcmp fast oeq
-// CHECK: alloca [16 x i32]
 
 
 float4x4 xt;

+ 1 - 1
tools/clang/test/CodeGenHLSL/staticGlobals.hlsl

@@ -6,9 +6,9 @@
 // CHECK: [3 x float] [float 6.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 7.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 8.000000e+00, float 0.000000e+00, float 0.000000e+00]
+// CHECK: [4 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
 // CHECK: [16 x float] [float 1.500000e+01, float 1.500000e+01, float 1.500000e+01, float 1.500000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.700000e+01, float 1.700000e+01, float 1.700000e+01, float 1.700000e+01, float 1.800000e+01, float 1.800000e+01, float 1.800000e+01, float 1.800000e+01]
 // CHECK: [16 x float] [float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00]
-// CHECK: [4 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
 // CHECK: [16 x float] [float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01]
 
 static float4 f0 = {5,6,7,8};

+ 4 - 1
tools/clang/test/CodeGenHLSL/static_matrix.hlsl

@@ -1,6 +1,9 @@
 // RUN: %dxc -E not_main -T ps_6_0 %s | FileCheck %s
 
-// Make sure internal global is removed.
+// Tests that the static-global-to-alloca pass
+// will turn a static matrix lowered into a static vector
+// into a local variable, even in the absence of the GVN pass.
+
 // CHECK-NOT:  = internal
 
 static float2x2 a;

+ 4 - 5
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -17,7 +17,7 @@
 #include "dxc/DXIL/DxilShaderModel.h"
 #include "dxc/DXIL/DxilModule.h"
 #include "dxc/DXIL/DxilResource.h"
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/DXIL/DxilConstants.h"
 #include "dxc/DXIL/DxilOperations.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -808,10 +808,9 @@ void PrintFieldLayout(llvm::Type *Ty, DxilFieldAnnotation &annotation,
       }
       if (EltTy->isVectorTy()) {
         EltTy = EltTy->getVectorElementType();
-      } else if (EltTy->isStructTy()) {
-        unsigned col, row;
-        EltTy = HLMatrixLower::GetMatrixInfo(EltTy, col, row);
-      }
+      } else if (EltTy->isStructTy())
+        EltTy = HLMatrixType::cast(EltTy).getElementTypeForReg();
+
       if (arrayLevel == 1)
         arraySize = 0;
     }