ソースを参照

Update matrix code gen. (#234)

1. Keep major for matrix pointers.
2. Change matrix values to row major to match hlsl.
   Only ColMatLoad, RowMatrixToColMatrix and col matrix value parameter for entry function are col major matrix value.
   And should only used by ColMatStore and ColMatrixToRowMatrix.
Xiang Li 8 年 前
コミット
e14dc5ed1b

+ 13 - 4
include/dxc/HLSL/HLMatrixLowerHelper.h

@@ -34,10 +34,19 @@ bool IsMatrixArrayPointer(llvm::Type *Ty);
 // Translate matrix array pointer type to vector array pointer type.
 llvm::Type *LowerMatrixArrayPointer(llvm::Type *Ty);
 
-llvm::Value *BuildMatrix(llvm::Type *EltTy, unsigned col, unsigned row,
-                   bool colMajor, llvm::ArrayRef<llvm::Value *> elts,
-                   llvm::IRBuilder<> &Builder);
-
+llvm::Value *BuildVector(llvm::Type *EltTy, unsigned size,
+                         llvm::ArrayRef<llvm::Value *> elts,
+                         llvm::IRBuilder<> &Builder);
+// For case like mat[i][j].
+// IdxList is [i][0], [i][1], [i][2],[i][3].
+// Idx is j.
+// return [i][j] not mat[i][j] because resource ptr and temp ptr need different
+// code gen.
+llvm::Value *
+LowerGEPOnMatIndexListToIndex(llvm::GetElementPtrInst *GEP,
+                              llvm::ArrayRef<llvm::Value *> IdxList);
+unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row);
+unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col);
 } // namespace HLMatrixLower
 
 } // namespace hlsl

+ 2 - 0
include/dxc/HLSL/HLOperations.h

@@ -100,6 +100,8 @@ enum class HLCastOpcode {
   ToUnsignedCast,
   ColMatrixToVecCast,
   RowMatrixToVecCast,
+  ColMatrixToRowMatrix,
+  RowMatrixToColMatrix,
   HandleToResCast,
 };
 

+ 10 - 10
lib/HLSL/DxilGenerationPass.cpp

@@ -926,8 +926,8 @@ static void replaceDirectInputParameter(Value *param, Function *loadInput,
           matElts[matIdx] = input;
         }
       }
-      Value *newVec = HLMatrixLower::BuildMatrix(EltTy, col, row, false,
-                                                 matElts, LocalBuilder);
+      Value *newVec =
+          HLMatrixLower::BuildVector(EltTy, col * row, matElts, LocalBuilder);
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
     } break;
@@ -949,8 +949,8 @@ static void replaceDirectInputParameter(Value *param, Function *loadInput,
           matElts[matIdx] = input;
         }
       }
-      Value *newVec = HLMatrixLower::BuildMatrix(EltTy, col, row, false,
-                                                 matElts, LocalBuilder);
+      Value *newVec =
+          HLMatrixLower::BuildVector(EltTy, col * row, matElts, LocalBuilder);
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
     } break;
@@ -1235,8 +1235,8 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
           matElts[matIdx] = input;
         }
       }
-      Value *newVec = HLMatrixLower::BuildMatrix(EltTy, col, row, true, matElts,
-                                                 LocalBuilder);
+      Value *newVec =
+          HLMatrixLower::BuildVector(EltTy, col * row, matElts, LocalBuilder);
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
     } break;
@@ -1261,8 +1261,8 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
           matElts[matIdx] = input;
         }
       }
-      Value *newVec = HLMatrixLower::BuildMatrix(EltTy, col, row, false,
-                                                 matElts, LocalBuilder);
+      Value *newVec =
+          HLMatrixLower::BuildVector(EltTy, col * row, matElts, LocalBuilder);
       CI->replaceAllUsesWith(newVec);
       CI->eraseFromParent();
     } break;
@@ -1280,7 +1280,7 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
         Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
 
         for (unsigned r = 0; r < row; r++) {
-          unsigned matIdx = c * row + r;
+          unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
           Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
           LocalBuilder.CreateCall(ldStFunc,
                                   {OpArg, ID, colIdx, columnConsts[r], Elt});
@@ -1301,7 +1301,7 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
         Constant *constRowIdx = LocalBuilder.getInt32(r);
         Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
         for (unsigned c = 0; c < col; c++) {
-          unsigned matIdx = r * col + c;
+          unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
           Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
           LocalBuilder.CreateCall(ldStFunc,
                                   {OpArg, ID, rowIdx, columnConsts[c], Elt});

+ 226 - 380
lib/HLSL/HLMatrixLowerPass.cpp

@@ -73,10 +73,10 @@ Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
   DXASSERT(IsMatrixType(Ty), "not matrix type");
   StructType *ST = cast<StructType>(Ty);
   Type *EltTy = ST->getElementType(0);
-  Type *ColTy = EltTy->getArrayElementType();
-  col = EltTy->getArrayNumElements();
-  row = ColTy->getVectorNumElements();
-  return ColTy->getVectorElementType();
+  Type *RowTy = EltTy->getArrayElementType();
+  row = EltTy->getArrayNumElements();
+  col = RowTy->getVectorNumElements();
+  return RowTy->getVectorElementType();
 }
 
 bool IsMatrixArrayPointer(llvm::Type *Ty) {
@@ -104,23 +104,49 @@ Type *LowerMatrixArrayPointer(Type *Ty) {
   return PointerType::get(Ty, 0);
 }
 
-Value *BuildMatrix(Type *EltTy, unsigned col, unsigned row,
-                          bool colMajor, ArrayRef<Value *> elts,
-                          IRBuilder<> &Builder) {
-  Value *Result = UndefValue::get(VectorType::get(EltTy, col * row));
-  if (colMajor) {
-    for (unsigned i = 0; i < col * row; i++)
-      Result = Builder.CreateInsertElement(Result, elts[i], i);
+Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
+                   IRBuilder<> &Builder) {
+  Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
+  for (unsigned i = 0; i < size; i++)
+    Vec = Builder.CreateInsertElement(Vec, elts[i], i);
+  return Vec;
+}
+
+Value *LowerGEPOnMatIndexListToIndex(
+    llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
+  IRBuilder<> Builder(GEP);
+  Value *zero = Builder.getInt32(0);
+  DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
+  Value *baseIdx = (GEP->idx_begin())->get();
+  DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
+  Value *Idx = (GEP->idx_begin() + 1)->get();
+
+  if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
+    return IdxList[immIdx->getSExtValue()];
   } else {
-    for (unsigned r = 0; r < row; r++)
-      for (unsigned c = 0; c < col; c++) {
-        unsigned rowMajorIdx = r * col + c;
-        unsigned colMajorIdx = c * row + r;
-        Result =
-            Builder.CreateInsertElement(Result, elts[rowMajorIdx], colMajorIdx);
-      }
+    IRBuilder<> AllocaBuilder(
+        GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
+    unsigned size = IdxList.size();
+    // Store idxList to temp array.
+    ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
+    Value *tempArray = AllocaBuilder.CreateAlloca(AT);
+
+    for (unsigned i = 0; i < size; i++) {
+      Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
+      Builder.CreateStore(IdxList[i], EltPtr);
+    }
+    // Load the idx.
+    Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
+    return Builder.CreateLoad(GEPOffset);
   }
-  return Result;
+}
+
+
+unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
+  return c * row + r;
+}
+unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
+  return r * col + c;
 }
 
 } // namespace HLMatrixLower
@@ -222,6 +248,8 @@ private:
                                CallInst *castInst);
   void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
                         CallInst *castInst);
+  void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
+                        CallInst *castInst, bool rowToCol);
   // Replace matInst with vecInst in matSubscript
   void TranslateMatSubscript(Value *matInst, Value *vecInst,
                              CallInst *matSubInst);
@@ -236,8 +264,6 @@ private:
                              CallInst *matLdStInst);
   void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
                              CallInst *matLdStInst);
-  void TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
-                             CallInst *matSubInst);
   void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
   void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
 
@@ -263,11 +289,6 @@ ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
 
 INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
 
-// All calculation on col major.
-static unsigned GetMatIdx(unsigned r, unsigned c, unsigned rowSize) {
-  return (c * rowSize + r);
-}
-
 static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
                                    IRBuilder<> Builder) {
   // Cast to bool.
@@ -843,8 +864,8 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
   Value *rMat = matToVecMap[cast<Instruction>(RVal)];
 
   auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
-    unsigned lMatIdx = GetMatIdx(r, lc, row);
-    unsigned rMatIdx = GetMatIdx(lc, c, rRow);
+    unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
+    unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
     Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
     Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
     return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
@@ -859,8 +880,8 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
 
   auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
                              Value *acc) -> Value * {
-    unsigned lMatIdx = GetMatIdx(r, lc, row);
-    unsigned rMatIdx = GetMatIdx(lc, c, rRow);
+    unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
+    unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
     Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
     Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
     return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
@@ -874,7 +895,7 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
       for (lc = 1; lc < col; lc++) {
         tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
       }
-      unsigned matIdx = GetMatIdx(r, c, row);
+      unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
       retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
     }
   }
@@ -912,7 +933,7 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
 
   auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
     Value *vecElt = Builder.CreateExtractElement(vec, c);
-    uint32_t matIdx = GetMatIdx(r, c, row);
+    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
     Value *matElt = Builder.CreateExtractElement(mat, matIdx);
     return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
   };
@@ -920,7 +941,7 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
   for (unsigned r = 0; r < row; r++) {
     unsigned c = 0;
     Value *vecElt = Builder.CreateExtractElement(vec, c);
-    uint32_t matIdx = GetMatIdx(r, c, row);
+    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
     Value *matElt = Builder.CreateExtractElement(mat, matIdx);
 
     Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
@@ -964,7 +985,7 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
 
   auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
     Value *vecElt = Builder.CreateExtractElement(vec, r);
-    uint32_t matIdx = GetMatIdx(r, c, row);
+    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
     Value *matElt = Builder.CreateExtractElement(mat, matIdx);
     return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
   };
@@ -972,7 +993,7 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
   for (unsigned c = 0; c < col; c++) {
     unsigned r = 0;
     Value *vecElt = Builder.CreateExtractElement(vec, r);
-    uint32_t matIdx = GetMatIdx(r, c, row);
+    uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
     Value *matElt = Builder.CreateExtractElement(mat, matIdx);
 
     Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
@@ -1008,25 +1029,8 @@ void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
 void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
                                               Instruction *vecInst,
                                               CallInst *transposeInst) {
-  unsigned row, col;
-  GetMatrixInfo(transposeInst->getType(), col, row);
-  IRBuilder<> Builder(transposeInst);
-  std::vector<int> transposeMask(col * row);
-  unsigned idx = 0;
-  for (unsigned c = 0; c < col; c++)
-    for (unsigned r = 0; r < row; r++) {
-      // change to row major
-      unsigned matIdx = GetMatIdx(c, r, col);
-      transposeMask[idx++] = (matIdx);
-    }
-  Instruction *shuf = cast<Instruction>(
-      Builder.CreateShuffleVector(vecInst, vecInst, transposeMask));
-  // Replace vec transpose function call with shuf.
-  DXASSERT(matToVecMap.count(transposeInst), "must has vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[transposeInst]);
-  vecUseInst->replaceAllUsesWith(shuf);
-  AddToDeadInsts(vecUseInst);
-  matToVecMap[transposeInst] = shuf;
+  // Matrix value is row major, transpose is cast it to col major.
+  TranslateMatMajorCast(matInst, vecInst, transposeInst, /*bRowToCol*/ true);
 }
 
 static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
@@ -1144,6 +1148,44 @@ void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
     }
 }
 
+void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
+                                              Instruction *vecInst,
+                                              CallInst *castInst,
+                                              bool bRowToCol) {
+  unsigned col, row;
+  GetMatrixInfo(castInst->getType(), col, row);
+  DXASSERT(castInst->getType() == matInst->getType(), "type must match");
+
+  IRBuilder<> Builder(castInst);
+
+  // shuf to change major.
+  SmallVector<int, 16> castMask(col * row);
+  unsigned idx = 0;
+  if (bRowToCol) {
+    for (unsigned c = 0; c < col; c++)
+      for (unsigned r = 0; r < row; r++) {
+        unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
+        castMask[idx++] = matIdx;
+      }
+  } else {
+    for (unsigned r = 0; r < row; r++)
+      for (unsigned c = 0; c < col; c++) {
+        unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
+        castMask[idx++] = matIdx;
+      }
+  }
+
+  Instruction *vecCast = cast<Instruction>(
+      Builder.CreateShuffleVector(vecInst, vecInst, castMask));
+
+  // Replace vec cast function call with vecCast.
+  DXASSERT(matToVecMap.count(castInst), "must has vec version");
+  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
+  vecUseInst->replaceAllUsesWith(vecCast);
+  AddToDeadInsts(vecUseInst);
+  matToVecMap[castInst] = vecCast;
+}
+
 void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
                                             Instruction *vecInst,
                                             CallInst *castInst) {
@@ -1167,9 +1209,9 @@ void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
     // shuf first
     std::vector<int> castMask(toCol * toRow);
     unsigned idx = 0;
-    for (unsigned c = 0; c < toCol; c++)
-      for (unsigned r = 0; r < toRow; r++) {
-        unsigned matIdx = GetMatIdx(r, c, fromRow);
+    for (unsigned r = 0; r < toRow; r++)
+      for (unsigned c = 0; c < toCol; c++) {
+        unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
         castMask[idx++] = matIdx;
       }
 
@@ -1232,14 +1274,21 @@ void HLMatrixLowerPass::TranslateMatToOtherCast(CallInst *matInst,
 void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
                                          Instruction *vecInst,
                                          CallInst *castInst) {
-  bool ToMat = IsMatrixType(castInst->getType());
-  bool FromMat = IsMatrixType(matInst->getType());
-  if (ToMat && FromMat) {
-    TranslateMatMatCast(matInst, vecInst, castInst);
-  } else if (FromMat)
-    TranslateMatToOtherCast(matInst, vecInst, castInst);
-  else
-    DXASSERT(0, "Not translate as user of matInst");
+  HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
+  if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
+      opcode == HLCastOpcode::RowMatrixToColMatrix) {
+    TranslateMatMajorCast(matInst, vecInst, castInst,
+                          opcode == HLCastOpcode::RowMatrixToColMatrix);
+  } else {
+    bool ToMat = IsMatrixType(castInst->getType());
+    bool FromMat = IsMatrixType(matInst->getType());
+    if (ToMat && FromMat) {
+      TranslateMatMatCast(matInst, vecInst, castInst);
+    } else if (FromMat)
+      TranslateMatToOtherCast(matInst, vecInst, castInst);
+    else
+      DXASSERT(0, "Not translate as user of matInst");
+  }
 }
 
 void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
@@ -1290,7 +1339,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
                    (matOpcode == HLSubscriptOpcode::RowMatElement);
   Value *mask =
       matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
-  // For temp matrix, all use col major.
+
   if (isElement) {
     Type *resultType = matSubInst->getType()->getPointerElementType();
     unsigned resultSize = 1;
@@ -1298,20 +1347,10 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
       resultSize = resultType->getVectorNumElements();
 
     std::vector<int> shufMask(resultSize);
-    if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(mask)) {
-      unsigned count = elts->getNumElements();
-      for (unsigned i = 0; i < count; i += 2) {
-        unsigned rowIdx = elts->getElementAsInteger(i);
-        unsigned colIdx = elts->getElementAsInteger(i + 1);
-        unsigned matIdx = GetMatIdx(rowIdx, colIdx, row);
-        shufMask[i>>1] = matIdx;
-      }
-    }
-    else {
-      ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(mask);
-      unsigned size = zeros->getNumElements()>>1;
-      for (unsigned i=0;i<size;i++) 
-        shufMask[i] = 0;
+    Constant *EltIdxs = cast<Constant>(mask);
+    for (unsigned i = 0; i < resultSize; i++) {
+      shufMask[i] =
+          cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
     }
 
     for (Value::use_iterator CallUI = matSubInst->use_begin(),
@@ -1357,21 +1396,24 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
     Value *tempArray = AllocaBuilder.CreateAlloca(AT);
     Value *zero = AllocaBuilder.getInt32(0);
     bool isDynamicIndexing = !isa<ConstantInt>(mask);
+    SmallVector<Value *, 4> idxList;
+    for (unsigned i = 0; i < col; i++) {
+      idxList.emplace_back(
+          matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
+    }
 
     for (Value::use_iterator CallUI = matSubInst->use_begin(),
                              CallE = matSubInst->use_end();
          CallUI != CallE;) {
       Use &CallUse = *CallUI++;
       Instruction *CallUser = cast<Instruction>(CallUse.getUser());
-      Value *idx = mask;
       IRBuilder<> Builder(CallUser);
       Value *vecLd = Builder.CreateLoad(vecInst);
       if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
         Value *sub = UndefValue::get(ld->getType());
         if (!isDynamicIndexing) {
           for (unsigned i = 0; i < col; i++) {
-            // col major: matIdx = c * row + r;
-            Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
+            Value *matIdx = idxList[i];
             Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
             sub = Builder.CreateInsertElement(sub, valElt, i);
           }
@@ -1385,8 +1427,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
             Builder.CreateStore(Elt, Ptr);
           }
           for (unsigned i = 0; i < col; i++) {
-            // col major: matIdx = c * row + r;
-            Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
+            Value *matIdx = idxList[i];
             Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
             Value *valElt = Builder.CreateLoad(Ptr);
             sub = Builder.CreateInsertElement(sub, valElt, i);
@@ -1397,8 +1438,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
         Value *val = st->getValueOperand();
         if (!isDynamicIndexing) {
           for (unsigned i = 0; i < col; i++) {
-            // col major: matIdx = c * row + r;
-            Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
+            Value *matIdx = idxList[i];
             Value *valElt = Builder.CreateExtractElement(val, i);
             vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
           }
@@ -1413,8 +1453,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
           }
           // Update array.
           for (unsigned i = 0; i < col; i++) {
-            // col major: matIdx = c * row + r;
-            Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
+            Value *matIdx = idxList[i];
             Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
             Value *valElt = Builder.CreateExtractElement(val, i);
             Builder.CreateStore(valElt, Ptr);
@@ -1430,17 +1469,8 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
         Builder.CreateStore(vecLd, vecInst);
       } else if (GetElementPtrInst *GEP =
                      dyn_cast<GetElementPtrInst>(CallUser)) {
-        // Must be for subscript on vector
-        auto idxIter = GEP->idx_begin();
-        // skip the zero
-        idxIter++;
-        Value *gepIdx = *idxIter;
-        // Col major matIdx = r + c * row;  r is idx, c is gepIdx
-        Value *iMulRow = Builder.CreateMul(gepIdx, Builder.getInt32(row));
-        Value *vecIdx = Builder.CreateAdd(iMulRow, idx);
-
-        llvm::Constant *zero = llvm::ConstantInt::get(vecIdx->getType(), 0);
-        Value *NewGEP = Builder.CreateGEP(vecInst, {zero, vecIdx});
+        Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
+        Value *NewGEP = Builder.CreateGEP(vecInst, {zero, GEPOffset});
         GEP->replaceAllUsesWith(NewGEP);
       } else
         DXASSERT(0, "matrix subscript should only used by load/store.");
@@ -1458,7 +1488,7 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
     Value *matGlobal, ArrayRef<Value *> vecGlobals,
     CallInst *matLdStInst) {
   // No dynamic indexing on matrix, flatten matrix to scalars.
-  // vecGlobals already in col major.
+  // vecGlobals already in correct major.
   Type *matType = matGlobal->getType()->getPointerElementType();
   unsigned col, row;
   HLMatrixLower::GetMatrixInfo(matType, col, row);
@@ -1491,7 +1521,7 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
 void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
                                                       GlobalVariable *scalarArrayGlobal,
                                                       CallInst *matLdStInst) {
-  // vecGlobals already in col major.
+  // vecGlobals already in correct major.
   const bool bColMajor = true;
   HLMatLoadStoreOpcode opcode =
       static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
@@ -1513,7 +1543,7 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
     }
 
     Value *newVec =
-        HLMatrixLower::BuildMatrix(EltTy, col, row, bColMajor, matElts, Builder);
+        HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
     matLdStInst->replaceAllUsesWith(newVec);
     matLdStInst->eraseFromParent();
   } break;
@@ -1540,9 +1570,10 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
   } break;
   }
 }
-void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
-                                                      GlobalVariable *scalarArrayGlobal,
-                                                      CallInst *matSubInst) {
+void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
+    CallInst *matSubInst, Value *vecPtr) {
+  Value *basePtr =
+      matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
   Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
   IRBuilder<> subBuilder(matSubInst);
   Value *zeroIdx = subBuilder.getInt32(0);
@@ -1550,46 +1581,32 @@ void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
   HLSubscriptOpcode opcode =
       static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
 
-  Type *matTy = matGlobal->getType()->getPointerElementType();
+  Type *matTy = basePtr->getType()->getPointerElementType();
   unsigned col, row;
   HLMatrixLower::GetMatrixInfo(matTy, col, row);
 
-  std::vector<Value *> Ptrs;
+  std::vector<Value *> idxList;
   switch (opcode) {
   case HLSubscriptOpcode::ColMatSubscript:
   case HLSubscriptOpcode::RowMatSubscript: {
-    // Use col major for internal matrix.
-    // And subscripts will return a row.
+    // Just use index created in EmitHLSLMatrixSubscript.
     for (unsigned c = 0; c < col; c++) {
-      Value *colIdxBase = subBuilder.getInt32(c * row);
-      Value *matIdx = subBuilder.CreateAdd(colIdxBase, idx);
-      Value *Ptr =
-          subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, matIdx});
-      Ptrs.emplace_back(Ptr);
+      Value *matIdx =
+          matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
+      idxList.emplace_back(matIdx);
     }
   } break;
   case HLSubscriptOpcode::RowMatElement:
   case HLSubscriptOpcode::ColMatElement: {
-    // Use col major for internal matrix.
-    if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(idx)) {
-      unsigned count = elts->getNumElements();
-
-      for (unsigned i = 0; i < count; i += 2) {
-        unsigned rowIdx = elts->getElementAsInteger(i);
-        unsigned colIdx = elts->getElementAsInteger(i + 1);
-        Value *matIdx = subBuilder.getInt32(colIdx * row + rowIdx);
-        Value *Ptr =
-            subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, matIdx});
-        Ptrs.emplace_back(Ptr);
-      }
-    } else {
-      ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
-      unsigned size = zeros->getNumElements() >> 1;
-      for (unsigned i = 0; i < size; i++) {
-        Value *Ptr =
-            subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, zeroIdx});
-        Ptrs.emplace_back(Ptr);
-      }
+    Type *resultType = matSubInst->getType()->getPointerElementType();
+    unsigned resultSize = 1;
+    if (resultType->isVectorTy())
+      resultSize = resultType->getVectorNumElements();
+    // Just use index created in EmitHLSLMatrixElement.
+    Constant *EltIdxs = cast<Constant>(idx);
+    for (unsigned i = 0; i < resultSize; i++) {
+      Value *matIdx = EltIdxs->getAggregateElement(i);
+      idxList.emplace_back(matIdx);
     }
   } break;
   default:
@@ -1599,90 +1616,54 @@ void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
 
   // Cannot generate vector pointer
   // Replace all uses with scalar pointers.
-  if (Ptrs.size() == 1)
-    matSubInst->replaceAllUsesWith(Ptrs[0]);
-  else {
+  if (idxList.size() == 1) {
+    Value *Ptr =
+        subBuilder.CreateInBoundsGEP(vecPtr, {zeroIdx, idxList[0]});
+    matSubInst->replaceAllUsesWith(Ptr);
+  } else {
     // Split the use of CI with Ptrs.
     for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
       Instruction *subsUser = cast<Instruction>(*(U++));
       IRBuilder<> userBuilder(subsUser);
       if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-        DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
-        Value *baseIdx = (GEP->idx_begin())->get();
-        DXASSERT_LOCALVAR(baseIdx, baseIdx == zeroIdx, "base index must be 0");
-        Value *idx = (GEP->idx_begin() + 1)->get();
+        Value *IndexPtr =
+            HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
+        Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
+                                                   {zeroIdx, IndexPtr});
         for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
           Instruction *gepUser = cast<Instruction>(*(gepU++));
           IRBuilder<> gepUserBuilder(gepUser);
           if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
             Value *subData = stUser->getValueOperand();
-            if (ConstantInt *immIdx = dyn_cast<ConstantInt>(idx)) {
-              Value *Ptr = Ptrs[immIdx->getSExtValue()];
-              gepUserBuilder.CreateStore(subData, Ptr);
-            } else {
-              // Create a temp array.
-              IRBuilder<> allocaBuilder(stUser->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
-              Value *tempArray = allocaBuilder.CreateAlloca(
-                  ArrayType::get(subData->getType(), Ptrs.size()));
-              // Store value to temp array.
-              for (unsigned i = 0; i < Ptrs.size(); i++) {
-                Value *Elt = gepUserBuilder.CreateLoad(Ptrs[i]);
-                Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
-                gepUserBuilder.CreateStore(Elt, EltGEP);
-              }
-              // Dynamic indexing.
-              Value *subGEP =
-                  gepUserBuilder.CreateInBoundsGEP(tempArray, {zeroIdx, idx});
-              gepUserBuilder.CreateStore(subData, subGEP);
-              // Store temp array to value.
-              for (unsigned i = 0; i < Ptrs.size(); i++) {
-                Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
-                Value *Elt = gepUserBuilder.CreateLoad(EltGEP);
-                gepUserBuilder.CreateStore(Elt, Ptrs[i]);
-              }
-            }
+            gepUserBuilder.CreateStore(subData, Ptr);
             stUser->eraseFromParent();
-          } else {
-            // Must be load here;
-            LoadInst *ldUser = cast<LoadInst>(gepUser);
-            Value *subData = nullptr;
-            if (ConstantInt *immIdx = dyn_cast<ConstantInt>(idx)) {
-              Value *Ptr = Ptrs[immIdx->getSExtValue()];
-              subData = gepUserBuilder.CreateLoad(Ptr);
-            } else {
-              // Create a temp array.
-              IRBuilder<> allocaBuilder(ldUser->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
-              Value *tempArray = allocaBuilder.CreateAlloca(
-                  ArrayType::get(ldUser->getType(), Ptrs.size()));
-              // Store value to temp array.
-              for (unsigned i = 0; i < Ptrs.size(); i++) {
-                Value *Elt = gepUserBuilder.CreateLoad(Ptrs[i]);
-                Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
-                gepUserBuilder.CreateStore(Elt, EltGEP);
-              }
-              // Dynamic indexing.
-              Value *subGEP =
-                  gepUserBuilder.CreateInBoundsGEP(tempArray, {zeroIdx, idx});
-              subData = gepUserBuilder.CreateLoad(subGEP);
-            }
+          } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
+            Value *subData = gepUserBuilder.CreateLoad(Ptr);
             ldUser->replaceAllUsesWith(subData);
             ldUser->eraseFromParent();
+          } else {
+            AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
+            Cast->setOperand(0, Ptr);
           }
         }
         GEP->eraseFromParent();
       } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
         Value *val = stUser->getValueOperand();
-        for (unsigned i = 0; i < Ptrs.size(); i++) {
+        for (unsigned i = 0; i < idxList.size(); i++) {
           Value *Elt = userBuilder.CreateExtractElement(val, i);
-          userBuilder.CreateStore(Elt, Ptrs[i]);
+          Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
+                                                     {zeroIdx, idxList[i]});
+          userBuilder.CreateStore(Elt, Ptr);
         }
         stUser->eraseFromParent();
       } else {
 
         Value *ldVal =
             UndefValue::get(matSubInst->getType()->getPointerElementType());
-        for (unsigned i = 0; i < Ptrs.size(); i++) {
-          Value *Elt = userBuilder.CreateLoad(Ptrs[i]);
+        for (unsigned i = 0; i < idxList.size(); i++) {
+          Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
+                                                     {zeroIdx, idxList[i]});
+          Value *Elt = userBuilder.CreateLoad(Ptr);
           ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
         }
         // Must be load here.
@@ -1695,128 +1676,9 @@ void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
   matSubInst->eraseFromParent();
 }
 
-void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr) {
-  // Just translate into vec array here.
-  // DynamicIndexingVectorToArray will change it to scalar array.
-  Value *basePtr = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
-  IRBuilder<> subBuilder(matSubInst);
-  unsigned opcode = hlsl::GetHLOpcode(matSubInst);
-  HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
-
-  // Vector array is inside struct.
-  Value *zeroIdx = subBuilder.getInt32(0);
-  Value *vecArrayGep = vecPtr;
-
-  Type *matType = basePtr->getType()->getPointerElementType();
-  unsigned col, row;
-  HLMatrixLower::GetMatrixInfo(matType, col, row);
-
-  Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
-
-  std::vector<Value *> Ptrs;
-  switch (subOp) {
-  case HLSubscriptOpcode::ColMatSubscript:
-  case HLSubscriptOpcode::RowMatSubscript: {
-    // Vector array is row major.
-    // And subscripts will return a row.
-    Value *rowIdx = idx;
-    Value *subPtr = subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, rowIdx});
-    matSubInst->replaceAllUsesWith(subPtr);
-    matSubInst->eraseFromParent();
-    return;
-  } break;
-  case HLSubscriptOpcode::RowMatElement:
-  case HLSubscriptOpcode::ColMatElement: {
-    // Vector array is row major.
-    if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(idx)) {
-      unsigned count = elts->getNumElements();
-
-      for (unsigned i = 0; i < count; i += 2) {
-        Value *rowIdx = subBuilder.getInt32(elts->getElementAsInteger(i));
-        Value *colIdx = subBuilder.getInt32(elts->getElementAsInteger(i + 1));
-        Value *Ptr =
-            subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, rowIdx, colIdx});
-        Ptrs.emplace_back(Ptr);
-      }
-    } else {
-      ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
-      unsigned size = zeros->getNumElements() >> 1;
-      for (unsigned i = 0; i < size; i++) {
-        Value *Ptr =
-            subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, zeroIdx, zeroIdx});
-        Ptrs.emplace_back(Ptr);
-      }
-    }
-  } break;
-  default:
-    DXASSERT(0, "invalid operation for TranslateMatSubscriptOnGlobalPtr");
-    break;
-  }
-
-  if (Ptrs.size() == 1)
-    matSubInst->replaceAllUsesWith(Ptrs[0]);
-  else {
-    // Split the use of CI with Ptrs.
-    for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
-      Instruction *subsUser = cast<Instruction>(*(U++));
-      IRBuilder<> userBuilder(subsUser);
-      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-        DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
-        Value *baseIdx = (GEP->idx_begin())->get();
-        DXASSERT_LOCALVAR(baseIdx, baseIdx == zeroIdx, "base index must be 0");
-        Value *idx = (GEP->idx_begin() + 1)->get();
-        for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
-          Instruction *gepUser = cast<Instruction>(*(gepU++));
-          IRBuilder<> gepUserBuilder(gepUser);
-          if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
-            Value *subData = stUser->getValueOperand();
-            // Only element can reach here.
-            // So index must be imm.
-            ConstantInt *immIdx = cast<ConstantInt>(idx);
-            Value *Ptr = Ptrs[immIdx->getSExtValue()];
-            gepUserBuilder.CreateStore(subData, Ptr);
-            stUser->eraseFromParent();
-          } else {
-            // Must be load here;
-            LoadInst *ldUser = cast<LoadInst>(gepUser);
-            Value *subData = nullptr;
-            // Only element can reach here.
-            // So index must be imm.
-            ConstantInt *immIdx = cast<ConstantInt>(idx);
-            Value *Ptr = Ptrs[immIdx->getSExtValue()];
-            subData = gepUserBuilder.CreateLoad(Ptr);
-            ldUser->replaceAllUsesWith(subData);
-            ldUser->eraseFromParent();
-          }
-        }
-        GEP->eraseFromParent();
-      } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
-        Value *val = stUser->getValueOperand();
-        for (unsigned i = 0; i < Ptrs.size(); i++) {
-          Value *Elt = userBuilder.CreateExtractElement(val, i);
-          userBuilder.CreateStore(Elt, Ptrs[i]);
-        }
-        stUser->eraseFromParent();
-      } else {
-        // Must be load here.
-        LoadInst *ldUser = cast<LoadInst>(subsUser);
-        // reload the value.
-        Value *ldVal =
-            UndefValue::get(matSubInst->getType()->getPointerElementType());
-        for (unsigned i = 0; i < Ptrs.size(); i++) {
-          Value *Elt = userBuilder.CreateLoad(Ptrs[i]);
-          ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
-        }
-        ldUser->replaceAllUsesWith(ldVal);
-        ldUser->eraseFromParent();
-      }
-    }
-  }
-  matSubInst->eraseFromParent();
-}
 void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
     CallInst *matLdStInst, Value *vecPtr) {
-  // Just translate into vec array here.
+  // Just translate into vector here.
   // DynamicIndexingVectorToArray will change it to scalar array.
   IRBuilder<> Builder(matLdStInst);
   unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
@@ -1824,56 +1686,19 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
   switch (matLdStOp) {
   case HLMatLoadStoreOpcode::ColMatLoad:
   case HLMatLoadStoreOpcode::RowMatLoad: {
-    // Load as vector array.
+    // Load as vector.
     Value *newLoad = Builder.CreateLoad(vecPtr);
-    // Then change to vector.
-    // Use col major.
-    Value *Ptr = matLdStInst->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
-    Type *matTy = Ptr->getType()->getPointerElementType();
-    unsigned col, row;
-    HLMatrixLower::GetMatrixInfo(matTy, col, row);
-    Value *NewVal = UndefValue::get(matLdStInst->getType());
-    // Vector array is row major.
-    for (unsigned r = 0; r < row; r++) {
-      Value *eltRow = Builder.CreateExtractValue(newLoad, r);
-      for (unsigned c = 0; c < col; c++) {
-        Value *elt = Builder.CreateExtractElement(eltRow, c);
-        // Vector is col major.
-        unsigned matIdx = c * row + r;
-        NewVal = Builder.CreateInsertElement(NewVal, elt, matIdx);
-      }
-    }
-    matLdStInst->replaceAllUsesWith(NewVal);
+
+    matLdStInst->replaceAllUsesWith(newLoad);
     matLdStInst->eraseFromParent();
   } break;
   case HLMatLoadStoreOpcode::ColMatStore:
   case HLMatLoadStoreOpcode::RowMatStore: {
     // Change value to vector array, then store.
     Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-    Value *Ptr = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
-
-    Type *matTy = Ptr->getType()->getPointerElementType();
-    unsigned col, row;
-    Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-    Type *rowTy = VectorType::get(EltTy, row);
 
     Value *vecArrayGep = vecPtr;
-
-    Value *NewVal =
-        UndefValue::get(vecArrayGep->getType()->getPointerElementType());
-
-    // Vector array val is row major.
-    for (unsigned r = 0; r < row; r++) {
-      Value *NewElt = UndefValue::get(rowTy);
-      for (unsigned c = 0; c < col; c++) {
-        // Vector val is col major.
-        unsigned matIdx = c * row + r;
-        Value *elt = Builder.CreateExtractElement(Val, matIdx);
-        NewElt = Builder.CreateInsertElement(NewElt, elt, c);
-      }
-      NewVal = Builder.CreateInsertValue(NewVal, NewElt, r);
-    }
-    Builder.CreateStore(NewVal, vecArrayGep);
+    Builder.CreateStore(Val, vecArrayGep);
     matLdStInst->eraseFromParent();
   } break;
   default:
@@ -1928,7 +1753,7 @@ static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
     HLMatrixLower::GetMatrixInfo(valTy, col, row);
     unsigned matSize = col * row;
     val = matToVecMap[cast<Instruction>(val)];
-    // temp matrix all col major
+    // temp matrix all row major
     for (unsigned i = 0; i < matSize; i++) {
       Value *Elt = Builder.CreateExtractElement(val, i);
       elts[idx + i] = Elt;
@@ -2005,15 +1830,13 @@ void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
   }
 
   Value *newInit = UndefValue::get(vecTy);
-  // InitList is row major, the result is col major.
-  for (unsigned c = 0; c < col; c++)
-    for (unsigned r = 0; r < row; r++) {
-      unsigned rowMajorIdx = r * col + c;
-      unsigned colMajorIdx = c * row + r;
-      Constant *vecIdx = Builder.getInt32(colMajorIdx);
-      newInit = InsertElementInst::Create(newInit, elts[rowMajorIdx], vecIdx);
+  // InitList is row major, the result is row major too.
+  for (unsigned i=0;i< col * row;i++) {
+      Constant *vecIdx = Builder.getInt32(i);
+      newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
       Builder.Insert(cast<Instruction>(newInit));
-    }
+  }
+
   // Replace matInit function call with matInitInst.
   DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
   Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
@@ -2247,25 +2070,36 @@ static bool OnlyUsedByMatrixLdSt(Value *V) {
   return onlyLdSt;
 }
 
-static Constant *LowerMatrixArrayConst(Constant *MA, ArrayType *ResultTy) {
-  if (ArrayType *AT = dyn_cast<ArrayType>(MA->getType())) {
+static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
+  if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
     std::vector<Constant *> Elts;
-    ArrayType *EltResultTy = cast<ArrayType>(ResultTy->getElementType());
+    Type *EltResultTy = AT->getElementType();
     for (unsigned i = 0; i < AT->getNumElements(); i++) {
       Constant *Elt =
           LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
       Elts.emplace_back(Elt);
     }
-    return ConstantArray::get(ResultTy, Elts);
+    return ConstantArray::get(AT, Elts);
   } else {
+    // Cast float[row][col] -> float< row * col>.
     // Get float[row][col] from the struct.
-    return MA->getAggregateElement((unsigned)0);
+    Constant *rows = MA->getAggregateElement((unsigned)0);
+    ArrayType *RowAT = cast<ArrayType>(rows->getType());
+    std::vector<Constant *> Elts;
+    for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
+      Constant *row = rows->getAggregateElement(r);
+      VectorType *VT = cast<VectorType>(row->getType());
+      for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
+        Elts.emplace_back(row->getAggregateElement(c));
+      }
+    }
+    return ConstantVector::get(Elts);
   }
 }
 
 void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
-  // Lower to array of vector array like float[row][col].
-  // It's row major.
+  // Lower to array of vector array like float[row * col].
+  // It's follow the major of decl.
   // DynamicIndexingVectorToArray will change it to scalar array.
   Type *Ty = GV->getType()->getPointerElementType();
   std::vector<unsigned> arraySizeList;
@@ -2275,8 +2109,7 @@ void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
   }
   unsigned row, col;
   Type *EltTy = GetMatrixInfo(Ty, col, row);
-  Ty = VectorType::get(EltTy, col);
-  Ty = ArrayType::get(Ty, row);
+  Ty = VectorType::get(EltTy, col * row);
 
   for (auto arraySize = arraySizeList.rbegin();
        arraySize != arraySizeList.rend(); arraySize++)
@@ -2348,12 +2181,13 @@ static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
       Elts.emplace_back(Elt);
   } else {
     M = M->getAggregateElement((unsigned)0);
-    // Initializer is row major.
-    // Make it col major to match temp matrix.
-    for (unsigned c = 0; c < col; c++) {
-      for (unsigned r = 0; r < row; r++) {
-        Constant *R = M->getAggregateElement(r);
-        Elts.emplace_back(R->getAggregateElement(c));
+    // Initializer is already in correct major.
+    // Just read it here.
+    // The type is vector<element, col>[row].
+    for (unsigned r = 0; r < row; r++) {
+      Constant *C = M->getAggregateElement(r);
+      for (unsigned c = 0; c < col; c++) {
+        Elts.emplace_back(C->getAggregateElement(c));
       }
     }
   }
@@ -2441,7 +2275,7 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
       }
       else {
         DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
-        TranslateMatSubscriptOnGlobal(GV, arrayMat, CI);
+        TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
       }
     }
     GV->removeDeadConstantUsers();
@@ -2464,6 +2298,18 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
         } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
           lowerToVec(&I);
         }
+      } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
+        HLOpcodeGroup group =
+            hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
+        if (group == HLOpcodeGroup::HLMatLoadStore) {
+          HLMatLoadStoreOpcode opcode =
+              static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
+          DXASSERT(opcode == HLMatLoadStoreOpcode::ColMatStore ||
+                       opcode == HLMatLoadStoreOpcode::RowMatStore,
+                   "Must MatStore here, load will go IsMatrixType path");
+          // Lower it here to make sure it is ready before replace.
+          lowerToVec(&I);
+        }
       }
     }
   }

+ 66 - 167
lib/HLSL/HLOperationLower.cpp

@@ -4305,38 +4305,17 @@ Value *TranslateConstBufMatLd(Type *matType, Value *handle, Value *offset,
   Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
   unsigned matSize = col * row;
   std::vector<Value *> elts(matSize);
-  unsigned EltByteSize = GetEltTypeByteSizeForConstBuf(EltTy, DL);
-  if (colMajor) {
-    // TODO: use real size after change constant buffer into linear layout.
-    unsigned colByteSize = 4 * EltByteSize;
-    for (unsigned c = 0; c < col; c++) {
-      Value *baseOffset = offset;
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = c * row + r;
-        elts[matIdx] = GenerateCBLoad(handle, baseOffset, EltTy, OP, Builder);
-        baseOffset =
-            Builder.CreateAdd(baseOffset, OP->GetU32Const(EltByteSize));
-      }
-      // Update offset for a column.
-      offset = Builder.CreateAdd(offset, OP->GetU32Const(colByteSize));
-    }
-  } else {
-    // TODO: use real size after change constant buffer into linear layout.
-    unsigned rowByteSize = 4 * EltByteSize;
-    for (unsigned r = 0; r < row; r++) {
-      Value *baseOffset = offset;
-      for (unsigned c = 0; c < col; c++) {
-        unsigned matIdx = r * col + c;
-        elts[matIdx] = GenerateCBLoad(handle, baseOffset, EltTy, OP, Builder);
-        baseOffset =
-            Builder.CreateAdd(baseOffset, OP->GetU32Const(EltByteSize));
-      }
-      // Update offset for a row.
-      offset = Builder.CreateAdd(offset, OP->GetU32Const(rowByteSize));
-    }
+  Value *EltByteSize = ConstantInt::get(
+      offset->getType(), GetEltTypeByteSizeForConstBuf(EltTy, DL));
+
+  // TODO: use real size after change constant buffer into linear layout.
+  Value *baseOffset = offset;
+  for (unsigned i = 0; i < matSize; i++) {
+    elts[i] = GenerateCBLoad(handle, baseOffset, EltTy, OP, Builder);
+    baseOffset = Builder.CreateAdd(baseOffset, EltByteSize);
   }
 
-  return HLMatrixLower::BuildMatrix(EltTy, col, row, colMajor, elts, Builder);
+  return HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
 }
 
 void TranslateCBGep(GetElementPtrInst *GEP, Value *handle, Value *baseOffset,
@@ -4405,7 +4384,8 @@ void TranslateCBAddressUser(Instruction *user, Value *handle, Value *baseOffset,
       unsigned col, row;
       Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
 
-      unsigned EltByteSize = GetEltTypeByteSizeForConstBuf(EltTy, DL);
+      Value *EltByteSize = ConstantInt::get(
+          baseOffset->getType(), GetEltTypeByteSizeForConstBuf(EltTy, DL));
 
       Value *idx = CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
 
@@ -4418,48 +4398,23 @@ void TranslateCBAddressUser(Instruction *user, Value *handle, Value *baseOffset,
       Value *idxList[16];
 
       switch (subOp) {
-      case HLSubscriptOpcode::ColMatSubscript: {
-        for (unsigned i = 0; i < resultSize; i++) {
-          Value *colBase = Builder.CreateAdd(
-              baseOffset, hlslOP->GetU32Const(i * col * EltByteSize));
-          idxList[i] = Builder.CreateAdd(colBase, idx);
-        }
-      } break;
+      case HLSubscriptOpcode::ColMatSubscript:
       case HLSubscriptOpcode::RowMatSubscript: {
-        // TODO: use real size after change constant buffer into linear layout.
-        col = 4;
-        unsigned rowSize = col * EltByteSize;
-        idx = Builder.CreateMul(idx, hlslOP->GetU32Const(rowSize));
-        idx = Builder.CreateAdd(idx, baseOffset);
         for (unsigned i = 0; i < resultSize; i++) {
-          idxList[i] = idx;
-          idx = Builder.CreateAdd(idx, hlslOP->GetU32Const(EltByteSize));
+          Value *idx =
+              CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i);
+          Value *offset = Builder.CreateMul(idx, EltByteSize);
+          idxList[i] = Builder.CreateAdd(baseOffset, offset);
         }
+
       } break;
       case HLSubscriptOpcode::RowMatElement:
       case HLSubscriptOpcode::ColMatElement: {
-        bool isRowMajor = subOp == HLSubscriptOpcode::RowMatElement;
-        // TODO: use real size after change constant buffer into linear layout.
-        col = 4;
-        row = 4;
-        if (ConstantDataSequential *elts =
-                dyn_cast<ConstantDataSequential>(idx)) {
-          unsigned count = elts->getNumElements();
-          for (unsigned i = 0; i < count; i += 2) {
-            unsigned rowIdx = elts->getElementAsInteger(i);
-            unsigned colIdx = elts->getElementAsInteger(i + 1);
-            unsigned matIdx =
-                isRowMajor ? (rowIdx * col + colIdx) : (colIdx * row + rowIdx);
-            Value *offset = hlslOP->GetU32Const(matIdx * EltByteSize);
-            idxList[i >> 1] = Builder.CreateAdd(baseOffset, offset);
-          }
-        } else {
-          ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
-          unsigned size = zeros->getNumElements() >> 1;
-          DXASSERT(size <= 16, "up to 4x4 elements in vector or matrix");
-          _Analysis_assume_(size <= 16);
-          for (unsigned i = 0; i < size; i++)
-            idxList[i] = baseOffset;
+        Constant *EltIdxs = cast<Constant>(idx);
+        for (unsigned i = 0; i < resultSize; i++) {
+          Value *offset =
+              Builder.CreateMul(EltIdxs->getAggregateElement(i), EltByteSize);
+          idxList[i] = Builder.CreateAdd(baseOffset, offset);
         }
       } break;
       default:
@@ -4771,7 +4726,7 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
                                         EltTy, row, OP, Builder);
 
       for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = c * row + r;
+        unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
         elts[matIdx] = Builder.CreateExtractElement(col, r);
       }
       // Update offset for a column.
@@ -4784,7 +4739,7 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
       Value *row = GenerateCBLoadLegacy(handle, legacyIdx, /*channelOffset*/ 0,
                                         EltTy, col, OP, Builder);
       for (unsigned c = 0; c < col; c++) {
-        unsigned matIdx = r * col + c;
+        unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
         elts[matIdx] = Builder.CreateExtractElement(row, c);
       }
       // Update offset for a row.
@@ -4792,7 +4747,7 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
     }
   }
 
-  return HLMatrixLower::BuildMatrix(EltTy, col, row, colMajor, elts, Builder);
+  return HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
 }
 
 void TranslateCBGepLegacy(GetElementPtrInst *GEP, Value *handle,
@@ -4831,8 +4786,6 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
                                   DxilTypeSystem &dxilTypeSys,
                                   const DataLayout &DL,
                                   HLObjectOperationLowerHelper *pObjHelper) {
-  Value *zeroIdx = hlslOP->GetU32Const(0);
-
   IRBuilder<> Builder(user);
   if (CallInst *CI = dyn_cast<CallInst>(user)) {
     HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
@@ -4856,7 +4809,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       Type *matType = basePtr->getType()->getPointerElementType();
       unsigned col, row;
       Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
-      
+
       Value *idx = CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
 
       Type *resultType = CI->getType()->getPointerElementType();
@@ -4867,7 +4820,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       _Analysis_assume_(resultSize <= 16);
       Value *idxList[16];
       bool colMajor = subOp == HLSubscriptOpcode::ColMatSubscript ||
-          subOp == HLSubscriptOpcode::ColMatElement;
+                      subOp == HLSubscriptOpcode::ColMatElement;
       bool dynamicIndexing = !isa<ConstantInt>(idx) &&
                              !isa<ConstantAggregateZero>(idx) &&
                              !isa<ConstantDataSequential>(idx);
@@ -4876,36 +4829,21 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       if (!dynamicIndexing) {
         Value *matLd = TranslateConstBufMatLdLegacy(
             matType, handle, legacyIdx, colMajor, hlslOP, DL, Builder);
-        // The matLd is col major, so use col major when calc index.
+        // The matLd is keep original layout, just use the idx calc in
+        // EmitHLSLMatrixElement and EmitHLSLMatrixSubscript.
         switch (subOp) {
         case HLSubscriptOpcode::RowMatSubscript:
         case HLSubscriptOpcode::ColMatSubscript: {
-          // matIdx = idx + i*row;
           for (unsigned i = 0; i < resultSize; i++) {
-            Value *colBase = hlslOP->GetU32Const(i * row);
-            idxList[i] = Builder.CreateAdd(colBase, idx);
+            idxList[i] =
+                CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i);
           }
         } break;
         case HLSubscriptOpcode::RowMatElement:
         case HLSubscriptOpcode::ColMatElement: {
-          if (ConstantDataSequential *elts =
-                  dyn_cast<ConstantDataSequential>(idx)) {
-            unsigned count = elts->getNumElements();
-            DXASSERT(count <= 16, "up to 4x4 elements in vector or matrix");
-
-            for (unsigned i = 0; i < count; i += 2) {
-              unsigned rowIdx = elts->getElementAsInteger(i);
-              unsigned colIdx = elts->getElementAsInteger(i + 1);
-              unsigned matIdx = (colIdx * row + rowIdx);
-              Value *offset = hlslOP->GetU32Const(matIdx);
-              idxList[i >> 1] = offset;
-            }
-          } else {
-            ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
-            unsigned size = zeros->getNumElements() >> 1;
-            DXASSERT(size <= 16, "up to 4x4 elements in vector or matrix");
-            for (unsigned i = 0; i < size; i++)
-              idxList[i] = zeroIdx;
+          Constant *EltIdxs = cast<Constant>(idx);
+          for (unsigned i = 0; i < resultSize; i++) {
+            idxList[i] = EltIdxs->getAggregateElement(i);
           }
         } break;
         default:
@@ -4923,8 +4861,12 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
           ldData = eltData;
         }
       } else {
+        // Must be matSub here.
         Value *idx = CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
+
         if (colMajor) {
+          // idx is c * row + r.
+          // For first col, c is 0, so idx is r.
           Value *one = Builder.getInt32(1);
           // row.x = c[0].[idx]
           // row.y = c[1].[idx]
@@ -4944,18 +4886,17 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
           for (unsigned int c = 0; c < col; c++) {
             Value *ColVal =
                 GenerateCBLoadLegacy(handle, cbufIdx, /*channelOffset*/ 0,
-                                     EltTy, col, hlslOP, Builder);
+                                     EltTy, row, hlslOP, Builder);
             // Convert ColVal to array for indexing.
-            for (unsigned int c = 0; c < col; c++) {
+            for (unsigned int r = 0; r < row; r++) {
               Value *Elt =
-                  Builder.CreateExtractElement(ColVal, Builder.getInt32(c));
+                  Builder.CreateExtractElement(ColVal, Builder.getInt32(r));
               Value *Ptr = Builder.CreateInBoundsGEP(
-                  tempArray, {zero, Builder.getInt32(c)});
+                  tempArray, {zero, Builder.getInt32(r)});
               Builder.CreateStore(Elt, Ptr);
             }
 
-            Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
-                                                   {zero, idx});
+            Value *Ptr = Builder.CreateInBoundsGEP(tempArray, {zero, idx});
             Elts[c] = Builder.CreateLoad(Ptr);
             // Update cbufIdx.
             cbufIdx = Builder.CreateAdd(cbufIdx, one);
@@ -4968,6 +4909,10 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
             ldData = Elts[0];
           }
         } else {
+          // idx is r * col + c;
+          // r = idx / col;
+          Value *cCol = ConstantInt::get(idx->getType(), col);
+          idx = Builder.CreateUDiv(idx, cCol);
           idx = Builder.CreateAdd(idx, legacyIdx);
           // Just return a row.
           ldData = GenerateCBLoadLegacy(handle, idx, /*channelOffset*/ 0, EltTy,
@@ -5014,7 +4959,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       return;
     }
     DXASSERT(!Ty->isAggregateType(), "should be flat in previous pass");
-    
+
     Value *newLd = nullptr;
 
     if (Ty->isVectorTy())
@@ -5023,7 +4968,6 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
     else
       newLd = GenerateCBLoadLegacy(handle, legacyIdx, channelOffset, EltTy,
                                    hlslOP, Builder);
-    
 
     ldInst->replaceAllUsesWith(newLd);
     ldInst->eraseFromParent();
@@ -5460,7 +5404,7 @@ Value *TranslateStructBufMatLd(Type *matType, IRBuilder<> &Builder,
     offset = Builder.CreateAdd(offset, OP->GetU32Const(4 * 4));
   }
 
-  return HLMatrixLower::BuildMatrix(EltTy, col, row, colMajor, elts, Builder);
+  return HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
 }
 
 void TranslateStructBufMatSt(Type *matType, IRBuilder<> &Builder, Value *handle,
@@ -5569,7 +5513,8 @@ void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
   unsigned col, row;
   Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
 
-  unsigned EltByteSize = GetEltTypeByteSizeForConstBuf(EltTy, DL);
+  Value *EltByteSize = ConstantInt::get(
+      baseOffset->getType(), GetEltTypeByteSizeForConstBuf(EltTy, DL));
 
   Value *idx = CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
 
@@ -5582,45 +5527,22 @@ void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
   Value *idxList[16];
 
   switch (subOp) {
-  case HLSubscriptOpcode::ColMatSubscript: {
-    idx = subBuilder.CreateMul(idx, hlslOP->GetU32Const(EltByteSize));
-    for (unsigned i = 0; i < resultSize; i++) {
-      Value *colBase = subBuilder.CreateAdd(
-          baseOffset, hlslOP->GetU32Const(i * row * EltByteSize));
-      idxList[i] = subBuilder.CreateAdd(colBase, idx);
-    }
-  } break;
+  case HLSubscriptOpcode::ColMatSubscript:
   case HLSubscriptOpcode::RowMatSubscript: {
-    unsigned rowSize = col * EltByteSize;
-    idx = subBuilder.CreateMul(idx, hlslOP->GetU32Const(rowSize));
-    idx = subBuilder.CreateAdd(idx, baseOffset);
     for (unsigned i = 0; i < resultSize; i++) {
-      idxList[i] = idx;
-      idx = subBuilder.CreateAdd(idx, hlslOP->GetU32Const(EltByteSize));
+      Value *offset =
+          CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i);
+      offset = subBuilder.CreateMul(offset, EltByteSize);
+      idxList[i] = subBuilder.CreateAdd(baseOffset, offset);
     }
   } break;
   case HLSubscriptOpcode::RowMatElement:
   case HLSubscriptOpcode::ColMatElement: {
-    bool isRowMajor = subOp == HLSubscriptOpcode::RowMatElement;
-    if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(idx)) {
-      unsigned count = elts->getNumElements();
-      DXASSERT(count <= 16, "up to 4x4 elements in vector or matrix");
-      _Analysis_assume_(count <= 16);
-      for (unsigned i = 0; i < count; i += 2) {
-        unsigned rowIdx = elts->getElementAsInteger(i);
-        unsigned colIdx = elts->getElementAsInteger(i + 1);
-        unsigned matIdx =
-            isRowMajor ? (rowIdx * col + colIdx) : (colIdx * row + rowIdx);
-        Value *offset = hlslOP->GetU32Const(matIdx * EltByteSize);
-        idxList[i >> 1] = subBuilder.CreateAdd(baseOffset, offset);
-      }
-    } else {
-      ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
-      unsigned size = zeros->getNumElements() >> 1;
-      DXASSERT(size <= 16, "up to 4x4 elements in vector or matrix");
-      _Analysis_assume_(size <= 16);
-      for (unsigned i = 0; i < size; i++)
-        idxList[i] = baseOffset;
+    Constant *EltIdxs = cast<Constant>(idx);
+    for (unsigned i = 0; i < resultSize; i++) {
+      Value *offset =
+          subBuilder.CreateMul(EltIdxs->getAggregateElement(i), EltByteSize);
+      idxList[i] = subBuilder.CreateAdd(baseOffset, offset);
     }
   } break;
   default:
@@ -5638,32 +5560,8 @@ void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
       continue;
     }
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-      IRBuilder<> GEPBuilder(GEP);
-      DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
-      Value *baseIdx = (GEP->idx_begin())->get();
-      DXASSERT_LOCALVAR(baseIdx, baseIdx == zeroIdx, "base index must be 0");
-      Value *idx = (GEP->idx_begin() + 1)->get();
-      Value *GEPOffset = nullptr;
-      if (ConstantInt *immIdx = dyn_cast<ConstantInt>(idx)) {
-        GEPOffset = idxList[immIdx->getSExtValue()];
-      } else {
-        IRBuilder<> allocaBuilder(GEP->getParent()
-                                      ->getParent()
-                                      ->getEntryBlock()
-                                      .getFirstInsertionPt());
-        // Store idxList to temp array.
-        ArrayType *AT = ArrayType::get(allocaBuilder.getInt32Ty(), resultSize);
-        Value *tempArray = allocaBuilder.CreateAlloca(AT);
-
-        for (unsigned i = 0; i < resultSize; i++) {
-          Value *EltPtr = GEPBuilder.CreateGEP(
-              tempArray, {zeroIdx, GEPBuilder.getInt32(i)});
-          GEPBuilder.CreateStore(idxList[i], EltPtr);
-        }
-        // Load the idx.
-        GEPOffset = GEPBuilder.CreateGEP(tempArray, {zeroIdx, idx});
-        GEPOffset = GEPBuilder.CreateLoad(GEPOffset);
-      }
+      Value *GEPOffset =
+          HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
 
       for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
         Instruction *gepUserInst = cast<Instruction>(*(gepU++));
@@ -5700,12 +5598,13 @@ void TranslateStructBufMatSubscript(CallInst *CI, Value *handle,
         for (unsigned i = 0; i < resultSize; i++) {
           Value *ResultElt;
           GenerateStructBufLd(handle, bufIdx, idxList[i],
-                                  /*status*/ nullptr, EltTy, ResultElt, hlslOP, ldBuilder);
+                              /*status*/ nullptr, EltTy, ResultElt, hlslOP,
+                              ldBuilder);
           ldData = ldBuilder.CreateInsertElement(ldData, ResultElt, i);
         }
       } else {
         GenerateStructBufLd(handle, bufIdx, idxList[0], /*status*/ nullptr,
-                                EltTy, ldData, hlslOP, ldBuilder);
+                            EltTy, ldData, hlslOP, ldBuilder);
       }
       ldUser->replaceAllUsesWith(ldData);
       ldUser->eraseFromParent();

+ 4 - 0
lib/HLSL/HLOperations.cpp

@@ -243,6 +243,10 @@ llvm::StringRef GetHLOpcodeName(HLCastOpcode Op) {
     return "colMatToVec";
   case HLCastOpcode::RowMatrixToVecCast:
     return "rowMatToVec";
+  case HLCastOpcode::ColMatrixToRowMatrix:
+    return "colMatToRowMat";
+  case HLCastOpcode::RowMatrixToColMatrix:
+    return "rowMatToColMat";
   case HLCastOpcode::HandleToResCast:
     return "handleToRes";
   }

+ 217 - 155
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -122,7 +122,7 @@ struct SROA_HLSL : public FunctionPass {
 
   bool runOnFunction(Function &F) override;
 
-  bool performScalarRepl(Function &F);
+  bool performScalarRepl(Function &F, DxilTypeSystem &typeSys);
   bool performPromotion(Function &F);
   bool markPrecise(Function &F);
 
@@ -248,8 +248,10 @@ public:
 // Simple struct to split memcpy into ld/st
 struct MemcpySplitter {
   llvm::LLVMContext &m_context;
+  DxilTypeSystem &m_typeSys;
 public:
-  MemcpySplitter(llvm::LLVMContext &context) : m_context(context) {}
+  MemcpySplitter(llvm::LLVMContext &context, DxilTypeSystem &typeSys)
+      : m_context(context), m_typeSys(typeSys) {}
   void Split(llvm::Function &F);
 };
 
@@ -1047,11 +1049,15 @@ Value *ConvertToScalarInfo::ConvertScalar_InsertValue(Value *SV, Value *Old,
 //===----------------------------------------------------------------------===//
 
 bool SROA_HLSL::runOnFunction(Function &F) {
+  Module *M = F.getParent();
+  HLModule &HLM = M->GetOrCreateHLModule();
+  DxilTypeSystem &typeSys = HLM.GetTypeSystem();
+
   // change memcpy into ld/st first
-  MemcpySplitter splitter(F.getContext());
+  MemcpySplitter splitter(F.getContext(), typeSys);
   splitter.Split(F);
 
-  bool Changed = performScalarRepl(F);
+  bool Changed = performScalarRepl(F, typeSys);
   Changed |= markPrecise(F);
   Changed |= performPromotion(F);
 
@@ -1502,12 +1508,9 @@ bool SROA_HLSL::ShouldAttemptScalarRepl(AllocaInst *AI) {
 // which runs on all of the alloca instructions in the entry block, removing
 // them if they are only used by getelementptr instructions.
 //
-bool SROA_HLSL::performScalarRepl(Function &F) {
+bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
   std::vector<AllocaInst *> AllocaList;
   const DataLayout &DL = F.getParent()->getDataLayout();
-  Module *M = F.getParent();
-  HLModule &HM = M->GetOrCreateHLModule();
-  DxilTypeSystem &typeSys = HM.GetTypeSystem();
 
   // Scan the entry basic block, adding allocas to the worklist.
   BasicBlock &BB = F.getEntryBlock();
@@ -2162,47 +2165,67 @@ static void SimpleCopy(Value *Dest, Value *Src,
 }
 // Split copy into ld/st.
 static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
-                     SmallVector<Value *, 16> &idxList,
-                     bool bAllowReplace,
-                     IRBuilder<> &Builder) {
+                     SmallVector<Value *, 16> &idxList, bool bAllowReplace,
+                     IRBuilder<> &Builder, DxilTypeSystem &typeSys,
+                     DxilFieldAnnotation *fieldAnnotation) {
   if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
     Constant *idx = Constant::getIntegerValue(
         IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
     idxList.emplace_back(idx);
 
-    SplitCpy(PT->getElementType(), Dest, Src, idxList, bAllowReplace, Builder);
+    SplitCpy(PT->getElementType(), Dest, Src, idxList, bAllowReplace, Builder,
+             typeSys, fieldAnnotation);
 
     idxList.pop_back();
   } else if (HLMatrixLower::IsMatrixType(Ty)) {
+    DXASSERT(fieldAnnotation, "require fieldAnnotation here");
+    DXASSERT(fieldAnnotation->HasMatrixAnnotation(),
+             "must has matrix annotation");
     Module *M = Builder.GetInsertPoint()->getModule();
     Value *DestGEP = Builder.CreateInBoundsGEP(Dest, idxList);
     Value *SrcGEP = Builder.CreateInBoundsGEP(Src, idxList);
+    bool bRowMajor = fieldAnnotation->GetMatrixAnnotation().Orientation ==
+                     MatrixOrientation::RowMajor;
+    if (bRowMajor) {
+      Value *Load = HLModule::EmitHLOperationCall(
+          Builder, HLOpcodeGroup::HLMatLoadStore,
+          static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad), Ty, {SrcGEP},
+          *M);
+
+      // Generate Matrix Store.
+      HLModule::EmitHLOperationCall(
+          Builder, HLOpcodeGroup::HLMatLoadStore,
+          static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore), Ty,
+          {DestGEP, Load}, *M);
+    } else {
+      Value *Load = HLModule::EmitHLOperationCall(
+          Builder, HLOpcodeGroup::HLMatLoadStore,
+          static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad), Ty, {SrcGEP},
+          *M);
 
-    Value *Load = HLModule::EmitHLOperationCall(
-        Builder, HLOpcodeGroup::HLMatLoadStore,
-        static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad), Ty, {SrcGEP},
-        *M);
-
-    // Generate Matrix Store.
-    HLModule::EmitHLOperationCall(
-        Builder, HLOpcodeGroup::HLMatLoadStore,
-        static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
-        {DestGEP, Load}, *M);
-
+      // Generate Matrix Store.
+      HLModule::EmitHLOperationCall(
+          Builder, HLOpcodeGroup::HLMatLoadStore,
+          static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
+          {DestGEP, Load}, *M);
+    }
   } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
     if (HLModule::IsHLSLObjectType(ST)) {
       // Avoid split HLSL object.
       SimpleCopy(Dest, Src, idxList, bAllowReplace, Builder);
       return;
     }
+    DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
+    DXASSERT(STA, "require annotation here");
     for (uint32_t i = 0; i < ST->getNumElements(); i++) {
       llvm::Type *ET = ST->getElementType(i);
 
       Constant *idx = llvm::Constant::getIntegerValue(
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
       idxList.emplace_back(idx);
-
-      SplitCpy(ET, Dest, Src, idxList, bAllowReplace, Builder);
+      DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
+      SplitCpy(ET, Dest, Src, idxList, bAllowReplace, Builder, typeSys,
+               &EltAnnotation);
 
       idxList.pop_back();
     }
@@ -2214,7 +2237,8 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       Constant *idx = Constant::getIntegerValue(
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
       idxList.emplace_back(idx);
-      SplitCpy(ET, Dest, Src, idxList, bAllowReplace, Builder);
+      SplitCpy(ET, Dest, Src, idxList, bAllowReplace, Builder, typeSys,
+               fieldAnnotation);
 
       idxList.pop_back();
     }
@@ -2328,7 +2352,10 @@ void MemcpySplitter::Split(llvm::Function &F) {
         }
         llvm::SmallVector<llvm::Value *, 16> idxList;
         // split
-        SplitCpy(Dest->getType(), Dest, Src, idxList, /*bAllowReplace*/true, Builder);
+        // Matrix is treated as scalar type, will not use memcpy.
+        // So use nullptr for fieldAnnotation should be safe here.
+        SplitCpy(Dest->getType(), Dest, Src, idxList, /*bAllowReplace*/ true,
+                 Builder, m_typeSys, /*fieldAnnotation*/ nullptr);
         // delete memcpy
         I->eraseFromParent();
         if (Instruction *op0 = dyn_cast<Instruction>(Op0)) {
@@ -2448,11 +2475,13 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
   }
 }
 
-static Type *getArrayEltType(Type *T) {
-  while (isa<ArrayType>(T)) {
-    T = T->getArrayElementType();
+static Type *GetArrayEltTy(Type *Ty) {
+  if (isa<PointerType>(Ty))
+    Ty = Ty->getPointerElementType();
+  while (isa<ArrayType>(Ty)) {
+    Ty = Ty->getArrayElementType();
   }
-  return T;
+  return Ty;
 }
 
 /// isVectorOrStructArray - Check if T is array of vector or struct.
@@ -2460,7 +2489,7 @@ static bool isVectorOrStructArray(Type *T) {
   if (!T->isArrayTy())
     return false;
 
-  T = getArrayEltType(T);
+  T = GetArrayEltTy(T);
 
   return T->isStructTy() || T->isVectorTy();
 }
@@ -2557,7 +2586,7 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
           // Generate Matrix Load.
           Load = HLModule::EmitHLOperationCall(
               Builder, HLOpcodeGroup::HLMatLoadStore,
-              static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad), Ty,
+              static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad), Ty,
               {Ptr}, *M);
         }
         LdElts[i] = Load;
@@ -2642,7 +2671,7 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
           // Generate Matrix Store.
           HLModule::EmitHLOperationCall(
               Builder, HLOpcodeGroup::HLMatLoadStore,
-              static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore),
+              static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore),
               Extract->getType(), {NewElts[i], Extract}, *M);
         }
       }
@@ -3621,15 +3650,6 @@ static DxilFieldAnnotation &GetEltAnnotation(Type *Ty, unsigned idx, DxilFieldAn
   return annotation;  
 }
 
-static Type *GetArrayEltTy(Type *Ty) {
-  if (isa<PointerType>(Ty))
-    Ty = Ty->getPointerElementType();
-  while (isa<ArrayType>(Ty)) {
-    Ty = Ty->getArrayElementType();
-  }
-  return Ty;
-}
-
 // Note: Semantic index allocation.
 // Semantic index is allocated base on linear layout.
 // For following code
@@ -3803,6 +3823,7 @@ void SROA_Parameter_HLSL::flattenArgument(
   DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
   unsigned debugOffset = 0;
   const DataLayout &DL = F->getParent()->getDataLayout();
+  Module &M = *m_pHLModule->GetModule();
 
   // Process the worklist
   while (!WorkList.empty()) {
@@ -3882,11 +3903,14 @@ void SROA_Parameter_HLSL::flattenArgument(
         // Just save parent semantic here, allocate later.
         annotation.SetSemanticString(semantic);
       }
+      Type *Ty = V->getType();
+      if (Ty->isPointerTy())
+        Ty = Ty->getPointerElementType();
 
       // Flatten array of SV_Target.
       StringRef semanticStr = annotation.GetSemanticString();
       if (semanticStr.upper().find("SV_TARGET") == 0 &&
-          V->getType()->getPointerElementType()->isArrayTy()) {
+          Ty->isArrayTy()) {
         Type *Ty = cast<ArrayType>(V->getType()->getPointerElementType());
         StringRef targetStr;
         unsigned  targetIndex;
@@ -3971,27 +3995,62 @@ void SROA_Parameter_HLSL::flattenArgument(
       flatParamAnnotation.SetMatrixAnnotation(annotation.GetMatrixAnnotation());
       flatParamAnnotation.SetPrecise(annotation.IsPrecise());
 
-      bool updateToRowMajor = annotation.HasMatrixAnnotation() &&
+      bool updateToColMajor = annotation.HasMatrixAnnotation() &&
                               hasShaderInputOutput &&
                               annotation.GetMatrixAnnotation().Orientation ==
-                                  MatrixOrientation::RowMajor;
-
-      if (updateToRowMajor) {
-        for (User *user : V->users()) {
-          CallInst *CI = dyn_cast<CallInst>(user);
-          if (!CI)
-            continue;
-
-          HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-          if (group == HLOpcodeGroup::NotHL)
-            continue;
-          unsigned opcode = GetHLOpcode(CI);
-          unsigned rowOpcode = GetRowMajorOpcode(group, opcode);
-          if (opcode == rowOpcode)
-            continue;
-          // Update matrix function opcode to row major version.
-          Value *rowOpArg = ConstantInt::get(opcodeTy, rowOpcode);
-          CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
+                                  MatrixOrientation::ColumnMajor;
+
+      if (updateToColMajor) {
+        if (V->getType()->isPointerTy()) {
+          for (User *user : V->users()) {
+            CallInst *CI = dyn_cast<CallInst>(user);
+            if (!CI)
+              continue;
+
+            HLOpcodeGroup group =
+                GetHLOpcodeGroupByName(CI->getCalledFunction());
+            if (group != HLOpcodeGroup::HLMatLoadStore)
+              continue;
+            HLMatLoadStoreOpcode opcode =
+                static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
+            switch (opcode) {
+            case HLMatLoadStoreOpcode::RowMatLoad: {
+              // Update matrix function opcode to col major version.
+              Value *rowOpArg = ConstantInt::get(
+                  opcodeTy,
+                  static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad));
+              CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
+              // Cast result to row major.
+              CallInst *RowMat = HLModule::EmitHLOperationCall(
+                  Builder, HLOpcodeGroup::HLCast,
+                  (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {CI}, M);
+              CI->replaceAllUsesWith(RowMat);
+              // Set arg to CI again.
+              RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, CI);
+            } break;
+            case HLMatLoadStoreOpcode::RowMatStore:
+              // Update matrix function opcode to col major version.
+              Value *rowOpArg = ConstantInt::get(
+                  opcodeTy,
+                  static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore));
+              CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
+              Value *Mat = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+              // Cast value to col major.
+              CallInst *RowMat = HLModule::EmitHLOperationCall(
+                  Builder, HLOpcodeGroup::HLCast,
+                  (unsigned)HLCastOpcode::RowMatrixToColMatrix, Ty, {Mat}, M);
+              CI->setArgOperand(HLOperandIndex::kMatStoreValOpIdx, RowMat);
+              break;
+            }
+          }
+        } else {
+          // Cast V into row major.
+          CallInst *RowMat = HLModule::EmitHLOperationCall(
+              Builder, HLOpcodeGroup::HLCast,
+              (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {V}, M);
+          V->replaceAllUsesWith(RowMat);
+          // Set arg to V again.
+          RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, V);
         }
       }
 
@@ -4030,7 +4089,8 @@ void SROA_Parameter_HLSL::flattenArgument(
 
                 llvm::SmallVector<llvm::Value *, 16> idxList;
                 SplitCpy(data->getType(), outputVal, data, idxList,
-                         /*bAllowReplace*/ false, Builder);
+                         /*bAllowReplace*/ false, Builder, dxilTypeSys,
+                         &flatParamAnnotation);
 
                 CI->setArgOperand(HLOperandIndex::kStreamAppendDataOpIndex, outputVal);
               }
@@ -4056,7 +4116,8 @@ void SROA_Parameter_HLSL::flattenArgument(
 
                   llvm::SmallVector<llvm::Value *, 16> idxList;
                   SplitCpy(DataPtr->getType(), EltPtr, DataPtr, idxList,
-                           /*bAllowReplace*/ false, Builder);
+                           /*bAllowReplace*/ false, Builder, dxilTypeSys,
+                           &flatParamAnnotation);
                   CI->setArgOperand(i, EltPtr);
                 }
               }
@@ -4115,7 +4176,8 @@ void SROA_Parameter_HLSL::moveFunctionBody(Function *F, Function *flatF) {
   }
 }
 
-static void SplitArrayCopy(Value *V) {
+static void SplitArrayCopy(Value *V, DxilTypeSystem &typeSys,
+                           DxilFieldAnnotation *fieldAnnotation) {
   for (auto U = V->user_begin(); U != V->user_end();) {
     User *user = *(U++);
     if (StoreInst *ST = dyn_cast<StoreInst>(user)) {
@@ -4123,7 +4185,8 @@ static void SplitArrayCopy(Value *V) {
       Value *val = ST->getValueOperand();
       IRBuilder<> Builder(ST);
       SmallVector<Value *, 16> idxList;
-      SplitCpy(ptr->getType(), ptr, val, idxList, /*bAllowReplace*/ true, Builder);
+      SplitCpy(ptr->getType(), ptr, val, idxList, /*bAllowReplace*/ true,
+               Builder, typeSys, fieldAnnotation);
       ST->eraseFromParent();
     }
   }
@@ -4163,7 +4226,9 @@ static void CheckArgUsage(Value *V, bool &bLoad, bool &bStore) {
   }
 }
 // Support store to input and load from output.
-static void LegalizeDxilInputOutputs(Function *F, DxilFunctionAnnotation *EntryAnnotation) {
+static void LegalizeDxilInputOutputs(Function *F,
+                                     DxilFunctionAnnotation *EntryAnnotation,
+                                     DxilTypeSystem &typeSys) {
   BasicBlock &EntryBlk = F->getEntryBlock();
   Module *M = F->getParent();
   // Map from output to the temp created for it.
@@ -4174,19 +4239,19 @@ static void LegalizeDxilInputOutputs(Function *F, DxilFunctionAnnotation *EntryA
     DxilParameterAnnotation &paramAnnotation = EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
     DxilParamInputQual qual = paramAnnotation.GetParamInputQual();
 
-    bool isRowMajor = false;
+    bool isColMajor = false;
 
     // Skip arg which is not a pointer.
     if (!Ty->isPointerTy()) {
       if (HLMatrixLower::IsMatrixType(Ty)) {
         // Replace matrix arg with cast to vec. It will be lowered in
         // DxilGenerationPass.
-        isRowMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
-                     MatrixOrientation::RowMajor;
+        isColMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
+                     MatrixOrientation::ColumnMajor;
         IRBuilder<> Builder(EntryBlk.getFirstInsertionPt());
 
-        HLCastOpcode opcode = isRowMajor ? HLCastOpcode::RowMatrixToVecCast
-                                         : HLCastOpcode::ColMatrixToVecCast;
+        HLCastOpcode opcode = isColMajor ? HLCastOpcode::ColMatrixToVecCast
+                                         : HLCastOpcode::RowMatrixToVecCast;
         Value *undefVal = UndefValue::get(Ty);
 
         Value *Cast = HLModule::EmitHLOperationCall(
@@ -4216,26 +4281,27 @@ static void LegalizeDxilInputOutputs(Function *F, DxilFunctionAnnotation *EntryA
     } else if (qual == DxilParamInputQual::Out && bLoad) {
       bNeedTemp = true;
       bLoadOutputFromTemp = true;
-    } else if (qual == DxilParamInputQual::Inout) {
-      bNeedTemp = true;
-      bLoadOutputFromTemp = true;
-      bStoreInputToTemp = true;
     } else if (bLoad && bStore) {
-      bNeedTemp = true;
       switch (qual) {
       case DxilParamInputQual::InputPrimitive:
       case DxilParamInputQual::InputPatch:
-      case DxilParamInputQual::OutputPatch:
+      case DxilParamInputQual::OutputPatch: {
+        bNeedTemp = true;
         bStoreInputToTemp = true;
+      } break;
+      case DxilParamInputQual::Inout:
         break;
       default:
         DXASSERT(0, "invalid input qual here");
       }
+    } else if (qual == DxilParamInputQual::Inout) {
+      // Only replace inout when (bLoad && bStore) == false.
+      bNeedTemp = true;
+      bLoadOutputFromTemp = true;
+      bStoreInputToTemp = true;
     }
 
     if (HLMatrixLower::IsMatrixType(Ty)) {
-      isRowMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
-                   MatrixOrientation::RowMajor;
       bNeedTemp = true;
       if (qual == DxilParamInputQual::In)
         bStoreInputToTemp = bLoad;
@@ -4258,34 +4324,8 @@ static void LegalizeDxilInputOutputs(Function *F, DxilFunctionAnnotation *EntryA
       if (bStoreInputToTemp) {
         llvm::SmallVector<llvm::Value *, 16> idxList;
         // split copy.
-        SplitCpy(temp->getType(), temp, &arg, idxList, /*bAllowReplace*/ false, Builder);
-        if (isRowMajor) {
-          auto Iter = Builder.GetInsertPoint();
-          Iter--;
-          while (cast<Instruction>(Iter) != temp) {
-            if (CallInst *CI = dyn_cast<CallInst>(Iter--)) {
-              HLOpcodeGroup group =
-                  GetHLOpcodeGroupByName(CI->getCalledFunction());
-              if (group == HLOpcodeGroup::HLMatLoadStore) {
-                HLMatLoadStoreOpcode opcode =
-                    static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
-                switch (opcode) {
-                case HLMatLoadStoreOpcode::ColMatLoad: {
-                  CI->setArgOperand(HLOperandIndex::kOpcodeIdx,
-                                    Builder.getInt32(static_cast<unsigned>(
-                                        HLMatLoadStoreOpcode::RowMatLoad)));
-                } break;
-                case HLMatLoadStoreOpcode::ColMatStore:
-                case HLMatLoadStoreOpcode::RowMatStore:
-                  CI->setArgOperand(HLOperandIndex::kOpcodeIdx,
-                                    Builder.getInt32(static_cast<unsigned>(
-                                        HLMatLoadStoreOpcode::RowMatStore)));
-                  break;
-                }
-              }
-            }
-          }
-        }
+        SplitCpy(temp->getType(), temp, &arg, idxList, /*bAllowReplace*/ false,
+                 Builder, typeSys, &paramAnnotation);
       }
 
       // Generate store output, temp later.
@@ -4306,12 +4346,6 @@ static void LegalizeDxilInputOutputs(Function *F, DxilFunctionAnnotation *EntryA
 
         DxilParameterAnnotation &paramAnnotation =
             EntryAnnotation->GetParameterAnnotation(output->getArgNo());
-        bool hasMatrix = paramAnnotation.HasMatrixAnnotation();
-        bool isRowMajor = false;
-        if (hasMatrix) {
-          isRowMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
-                       MatrixOrientation::RowMajor;
-        }
 
         auto Iter = Builder.GetInsertPoint();
         bool onlyRetBlk = false;
@@ -4321,35 +4355,7 @@ static void LegalizeDxilInputOutputs(Function *F, DxilFunctionAnnotation *EntryA
           onlyRetBlk = true;
         // split copy.
         SplitCpy(output->getType(), output, temp, idxList,
-                 /*bAllowReplace*/ false, Builder);
-        if (isRowMajor) {
-          if (onlyRetBlk)
-            Iter = BB.begin();
-
-          while (cast<Instruction>(Iter) != RI) {
-            if (CallInst *CI = dyn_cast<CallInst>(++Iter)) {
-              HLOpcodeGroup group =
-                  GetHLOpcodeGroupByName(CI->getCalledFunction());
-              if (group == HLOpcodeGroup::HLMatLoadStore) {
-                HLMatLoadStoreOpcode opcode =
-                    static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
-                switch (opcode) {
-                case HLMatLoadStoreOpcode::ColMatLoad: {
-                  CI->setArgOperand(HLOperandIndex::kOpcodeIdx,
-                                    Builder.getInt32(static_cast<unsigned>(
-                                        HLMatLoadStoreOpcode::RowMatLoad)));
-                } break;
-                case HLMatLoadStoreOpcode::ColMatStore:
-                case HLMatLoadStoreOpcode::RowMatStore:
-                  CI->setArgOperand(HLOperandIndex::kOpcodeIdx,
-                                    Builder.getInt32(static_cast<unsigned>(
-                                        HLMatLoadStoreOpcode::RowMatStore)));
-                  break;
-                }
-              }
-            }
-          }
-        }
+                 /*bAllowReplace*/ false, Builder, typeSys, &paramAnnotation);
       }
       // Clone the return.
       Builder.CreateRet(RI->getReturnValue());
@@ -4359,8 +4365,9 @@ static void LegalizeDxilInputOutputs(Function *F, DxilFunctionAnnotation *EntryA
 }
 
 void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
+  DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
   // Change memcpy into ld/st first
-  MemcpySplitter splitter(F->getContext());
+  MemcpySplitter splitter(F->getContext(), typeSys);
   splitter.Split(*F);
 
   // Skip void (void) function.
@@ -4396,6 +4403,10 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
     Instruction *InsertPt = F->getEntryBlock().getFirstInsertionPt();
     IRBuilder<> Builder(InsertPt);
     Value *retValAddr = Builder.CreateAlloca(retType);
+    DxilParameterAnnotation &retAnnotation =
+        funcAnnotation->GetRetTypeAnnotation();
+    Module &M = *m_pHLModule->GetModule();
+    Type *voidTy = Type::getVoidTy(m_pHLModule->GetCtx());
     // Create DbgDecl for the ret value.
     if (DISubprogram *funcDI = getDISubprogram(F)) {
        DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
@@ -4413,7 +4424,27 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
       if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
         // Create store for return.
         IRBuilder<> RetBuilder(RI);
-        RetBuilder.CreateStore(RI->getReturnValue(), retValAddr);
+        if (!retAnnotation.HasMatrixAnnotation()) {
+          RetBuilder.CreateStore(RI->getReturnValue(), retValAddr);
+        } else {
+          bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
+                            MatrixOrientation::RowMajor;
+          Value *RetVal = RI->getReturnValue();
+          if (!isRowMajor) {
+            // Matrix value is row major. ColMatStore require col major.
+            // Cast before store.
+            RetVal = HLModule::EmitHLOperationCall(
+                RetBuilder, HLOpcodeGroup::HLCast,
+                static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix),
+                RetVal->getType(), {RetVal}, M);
+          }
+          unsigned opcode = static_cast<unsigned>(
+              isRowMajor ? HLMatLoadStoreOpcode::RowMatStore
+                         : HLMatLoadStoreOpcode::ColMatStore);
+          HLModule::EmitHLOperationCall(RetBuilder,
+                                        HLOpcodeGroup::HLMatLoadStore, opcode,
+                                        voidTy, {retValAddr, RetVal}, M);
+        }
         // Clone the return.
         ReturnInst *NewRet = RetBuilder.CreateRet(RI->getReturnValue());
         if (RI == InsertPt) {
@@ -4487,7 +4518,7 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
       }
     }
     // Support store to input and load from output.
-    LegalizeDxilInputOutputs(F, funcAnnotation);
+    LegalizeDxilInputOutputs(F, funcAnnotation, typeSys);
     return;
   }
 
@@ -4567,17 +4598,21 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
     if (Arg->getType()->isPointerTy()) {
       Type *Ty = Arg->getType()->getPointerElementType();
       if (Ty->isArrayTy())
-        SplitArrayCopy(Arg);
+        SplitArrayCopy(
+            Arg, typeSys,
+            &flatFuncAnnotation->GetParameterAnnotation(Arg->getArgNo()));
     }
   }
   // Support store to input and load from output.
-  LegalizeDxilInputOutputs(flatF, flatFuncAnnotation);
+  LegalizeDxilInputOutputs(flatF, flatFuncAnnotation, typeSys);
 }
 
 void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *flatF, CallInst *CI) {
   DxilFunctionAnnotation *funcAnnotation = m_pHLModule->GetFunctionAnnotation(F);
   DXASSERT(funcAnnotation, "must find annotation for function");
 
+  DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
+
   std::vector<Value *> FlatParamList;
   std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
 
@@ -4606,8 +4641,30 @@ void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *fla
        DIB.insertDeclare(retValAddr, RetVar, Expr, DL, CI);
     }
 
+    DxilParameterAnnotation &retAnnotation = funcAnnotation->GetRetTypeAnnotation();
     // Load ret value and replace CI.
-    Value *newRetVal = RetBuilder.CreateLoad(retValAddr);
+    Value *newRetVal = nullptr;
+    if (!retAnnotation.HasMatrixAnnotation()) {
+      newRetVal = RetBuilder.CreateLoad(retValAddr);
+    } else {
+      bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
+                        MatrixOrientation::RowMajor;
+      unsigned opcode =
+          static_cast<unsigned>(isRowMajor ? HLMatLoadStoreOpcode::RowMatLoad
+                                           : HLMatLoadStoreOpcode::ColMatLoad);
+      newRetVal = HLModule::EmitHLOperationCall(RetBuilder, HLOpcodeGroup::HLMatLoadStore,
+                                    opcode, retType, {retValAddr},
+                                    *m_pHLModule->GetModule());
+      if (!isRowMajor) {
+        // ColMatLoad will return a col major.
+        // Matrix value should be row major.
+        // Cast it here.
+        newRetVal = HLModule::EmitHLOperationCall(
+            RetBuilder, HLOpcodeGroup::HLCast,
+            static_cast<unsigned>(HLCastOpcode::ColMatrixToRowMatrix), retType,
+            {newRetVal}, *m_pHLModule->GetModule());
+      }
+    }
     CI->replaceAllUsesWith(newRetVal);
     // Flat ret val
     flattenArgument(flatF, retValAddr, funcAnnotation->GetRetTypeAnnotation(),
@@ -4627,7 +4684,8 @@ void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *fla
     DxilParameterAnnotation &paramAnnotation =
         funcAnnotation->GetParameterAnnotation(i);
     Value *arg = args[i];
-    if (arg->getType()->isPointerTy()) {
+    Type *Ty = arg->getType();
+    if (Ty->isPointerTy()) {
       // For pointer, alloca another pointer, replace in CI.
       Value *tempArg =
           AllocaBuilder.CreateAlloca(arg->getType()->getPointerElementType());
@@ -4637,15 +4695,19 @@ void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *fla
       if (inputQual == DxilParamInputQual::In ||
           inputQual == DxilParamInputQual::Inout) {
         // Copy in param.
-        Value *v = CallBuilder.CreateLoad(arg);
-        CallBuilder.CreateStore(v, tempArg);
+        llvm::SmallVector<llvm::Value *, 16> idxList;
+        // split copy to avoid load of struct.
+        SplitCpy(Ty, tempArg, arg, idxList, /*bAllowReplace*/ false, CallBuilder,
+                 typeSys, &paramAnnotation);
       }
 
       if (inputQual == DxilParamInputQual::Out ||
           inputQual == DxilParamInputQual::Inout) {
         // Copy out param.
-        Value *v = RetBuilder.CreateLoad(tempArg);
-        RetBuilder.CreateStore(v, arg);
+        llvm::SmallVector<llvm::Value *, 16> idxList;
+        // split copy to avoid load of struct.
+        SplitCpy(Ty, arg, tempArg, idxList, /*bAllowReplace*/ false, RetBuilder,
+                 typeSys, &paramAnnotation);
       }
       arg = tempArg;
       flattenArgument(flatF, arg, paramAnnotation, FlatParamList,

+ 4 - 4
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -190,15 +190,15 @@ void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorT
   QualType elementType = context.getTemplateTypeParmType(
       /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
   Expr *sizeExpr = DeclRefExpr::Create(
-      context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
+      context, NestedNameSpecifierLoc(), NoLoc, rowCountTemplateParamDecl,
       false,
-      DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
+      DeclarationNameInfo(rowCountTemplateParamDecl->getDeclName(), NoLoc),
       intType, ExprValueKind::VK_RValue);
 
   Expr *rowSizeExpr = DeclRefExpr::Create(
-      context, NestedNameSpecifierLoc(), NoLoc, rowCountTemplateParamDecl,
+      context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
       false,
-      DeclarationNameInfo(rowCountTemplateParamDecl->getDeclName(), NoLoc),
+      DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
       intType, ExprValueKind::VK_RValue);
 
   QualType vectorType = context.getDependentSizedExtVectorType(

+ 296 - 66
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -4205,6 +4205,9 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
     if (HLMatrixLower::IsMatrixType(valTy)) {
       unsigned col, row;
       llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(valTy, col, row);
+      // All matrix Value should be row major.
+      // Init list is row major in scalar.
+      // So the order is match here, just cast to vector.
       unsigned matSize = col * row;
       bool isRowMajor = IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
 
@@ -4236,7 +4239,7 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
 
 // Cast elements in initlist if not match the target type.
 // idx is current element index in initlist, Ty is target type.
-static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVector<QualType, 4> eltTys, unsigned &idx, QualType Ty, CodeGenFunction &CGF) {
+static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVector<QualType, 4> &eltTys, unsigned &idx, QualType Ty, CodeGenFunction &CGF) {
   if (Ty->isArrayType()) {
     const clang::ArrayType *AT = Ty->getAsArrayTypeUnsafe();
     // Must be ConstantArrayType here.
@@ -4298,57 +4301,109 @@ static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVect
   }
 }
 
-static void StoreInitListToDestPtr(Value *DestPtr, SmallVector<Value *, 4> &elts, unsigned &idx, CGBuilderTy &Builder, llvm::Module &M) {
+static void StoreInitListToDestPtr(Value *DestPtr,
+                                   SmallVector<Value *, 4> &elts, unsigned &idx,
+                                   QualType Type, CodeGenTypes &Types, bool bDefaultRowMajor,
+                                   CGBuilderTy &Builder, llvm::Module &M) {
   llvm::Type *Ty = DestPtr->getType()->getPointerElementType();
   llvm::Type *i32Ty = llvm::Type::getInt32Ty(Ty->getContext());
 
   if (Ty->isVectorTy()) {
     Value *Result = UndefValue::get(Ty);
     for (unsigned i = 0; i < Ty->getVectorNumElements(); i++)
-      Result = Builder.CreateInsertElement(Result, elts[idx+i], i);
+      Result = Builder.CreateInsertElement(Result, elts[idx + i], i);
     Builder.CreateStore(Result, DestPtr);
     idx += Ty->getVectorNumElements();
   } else if (HLMatrixLower::IsMatrixType(Ty)) {
+    bool isRowMajor =
+        IsRowMajorMatrix(Type, bDefaultRowMajor);
+
     unsigned row, col;
     HLMatrixLower::GetMatrixInfo(Ty, col, row);
-    std::vector<Value*> matInitList(col*row);
+    std::vector<Value *> matInitList(col * row);
     for (unsigned i = 0; i < col; i++) {
       for (unsigned r = 0; r < row; r++) {
         unsigned matIdx = i * row + r;
-        matInitList[matIdx] = elts[idx+matIdx];
+        matInitList[matIdx] = elts[idx + matIdx];
       }
     }
-    idx += row*col;
-    
-    Value *matVal = EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLInit,
-        /*opcode*/0, Ty, matInitList, M);
-    EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLMatLoadStore,
-        static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
-        {DestPtr, matVal}, M);
+    idx += row * col;
+    Value *matVal =
+        EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLInit,
+                                       /*opcode*/ 0, Ty, matInitList, M);
+    // matVal return from HLInit is row major.
+    // If DestPtr is row major, just store it directly.
+    if (!isRowMajor) {
+      // ColMatStore need a col major value.
+      // Cast row major matrix into col major.
+      // Then store it.
+      Value *colMatVal = EmitHLSLMatrixOperationCallImp(
+          Builder, HLOpcodeGroup::HLCast,
+          static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix), Ty,
+          {matVal}, M);
+      EmitHLSLMatrixOperationCallImp(
+          Builder, HLOpcodeGroup::HLMatLoadStore,
+          static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
+          {DestPtr, colMatVal}, M);
+    } else {
+      EmitHLSLMatrixOperationCallImp(
+          Builder, HLOpcodeGroup::HLMatLoadStore,
+          static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore), Ty,
+          {DestPtr, matVal}, M);
+    }
   } else if (Ty->isStructTy()) {
     if (HLModule::IsHLSLObjectType(Ty)) {
       Builder.CreateStore(elts[idx], DestPtr);
       idx++;
     } else {
       Constant *zero = ConstantInt::get(i32Ty, 0);
-      for (unsigned i = 0; i < Ty->getStructNumElements(); i++) {
+
+      const RecordType *RT = Type->getAsStructureType();
+      // For CXXRecord.
+      if (!RT)
+        RT = Type->getAs<RecordType>();
+      RecordDecl *RD = RT->getDecl();
+      const CGRecordLayout &RL = Types.getCGRecordLayout(RD);
+      // Take care base.
+      if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases()) {
+          for (const auto &I : CXXRD->bases()) {
+            const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(
+                I.getType()->castAs<RecordType>()->getDecl());
+            if (BaseDecl->field_empty())
+              continue;
+            QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
+            unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
+            Constant *gepIdx = ConstantInt::get(i32Ty, i);
+            Value *GEP = Builder.CreateInBoundsGEP(DestPtr, {zero, gepIdx});
+            StoreInitListToDestPtr(GEP, elts, idx, parentTy, Types,
+                                   bDefaultRowMajor, Builder, M);
+          }
+        }
+      }
+      for (FieldDecl *field : RD->fields()) {
+        unsigned i = RL.getLLVMFieldNo(field);
         Constant *gepIdx = ConstantInt::get(i32Ty, i);
         Value *GEP = Builder.CreateInBoundsGEP(DestPtr, {zero, gepIdx});
-        StoreInitListToDestPtr(GEP, elts, idx, Builder, M);
+        StoreInitListToDestPtr(GEP, elts, idx, field->getType(), Types,
+                               bDefaultRowMajor, Builder, M);
       }
     }
   } else if (Ty->isArrayTy()) {
     Constant *zero = ConstantInt::get(i32Ty, 0);
+    QualType EltType = Type->getAsArrayTypeUnsafe()->getElementType();
     for (unsigned i = 0; i < Ty->getArrayNumElements(); i++) {
       Constant *gepIdx = ConstantInt::get(i32Ty, i);
       Value *GEP = Builder.CreateInBoundsGEP(DestPtr, {zero, gepIdx});
-      StoreInitListToDestPtr(GEP, elts, idx, Builder, M);
+      StoreInitListToDestPtr(GEP, elts, idx, EltType, Types, bDefaultRowMajor,
+                             Builder, M);
     }
   } else {
     DXASSERT(Ty->isSingleValueType(), "invalid type");
     llvm::Type *i1Ty = Builder.getInt1Ty();
     Value *V = elts[idx];
-    if (V->getType() == i1Ty && DestPtr->getType()->getPointerElementType() != i1Ty) {
+    if (V->getType() == i1Ty &&
+        DestPtr->getType()->getPointerElementType() != i1Ty) {
       V = Builder.CreateZExt(V, DestPtr->getType()->getPointerElementType());
     }
     Builder.CreateStore(V, DestPtr);
@@ -4400,7 +4455,9 @@ Value *CGMSHLSLRuntime::EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr
     ParamList.emplace_back(DestPtr);
     ParamList.append(EltValList.begin(), EltValList.end());
     idx = 0;
-    StoreInitListToDestPtr(DestPtr, EltValList, idx, CGF.Builder, TheModule);
+    bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
+    StoreInitListToDestPtr(DestPtr, EltValList, idx, ResultTy, CGF.getTypes(),
+                           bDefaultRowMajor, CGF.Builder, TheModule);
     return nullptr;
   }
 
@@ -4418,20 +4475,72 @@ Value *CGMSHLSLRuntime::EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr
   }
 }
 
-static void FlatConstToList(Constant *C,
-                            SmallVector<Constant *, 4> &EltValList) {
+static void FlatConstToList(Constant *C, SmallVector<Constant *, 4> &EltValList,
+                            QualType Type, CodeGenTypes &Types,
+                            bool bDefaultRowMajor) {
   llvm::Type *Ty = C->getType();
   if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
+    // Type is only for matrix. Keep use Type to next level.
     for (unsigned i = 0; i < VT->getNumElements(); i++) {
-      FlatConstToList(C->getAggregateElement(i), EltValList);
+      FlatConstToList(C->getAggregateElement(i), EltValList, Type, Types,
+                      bDefaultRowMajor);
+    }
+  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+    bool isRowMajor = IsRowMajorMatrix(Type, bDefaultRowMajor);
+    // matrix type is struct { vector<Ty, row> [col] };
+    // Strip the struct level here.
+    Constant *matVal = C->getAggregateElement((unsigned)0);
+    const RecordType *RT = Type->getAs<RecordType>();
+    RecordDecl *RD = RT->getDecl();
+    QualType EltTy = RD->field_begin()->getType();
+    // When scan, init list scalars is row major.
+    if (isRowMajor) {
+      // Don't change the major for row major value.
+      FlatConstToList(matVal, EltValList, EltTy, Types, bDefaultRowMajor);
+    } else {
+      // Save to tmp list.
+      SmallVector<Constant *, 4> matEltList;
+      FlatConstToList(matVal, matEltList, EltTy, Types, bDefaultRowMajor);
+      unsigned row, col;
+      HLMatrixLower::GetMatrixInfo(Ty, col, row);
+      // Change col major value to row major.
+      for (unsigned r = 0; r < row; r++)
+        for (unsigned c = 0; c < col; c++) {
+          unsigned colMajorIdx = c * row + r;
+          EltValList.emplace_back(matEltList[colMajorIdx]);
+        }
     }
   } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
+    QualType EltTy = Type->getAsArrayTypeUnsafe()->getElementType();
     for (unsigned i = 0; i < AT->getNumElements(); i++) {
-      FlatConstToList(C->getAggregateElement(i), EltValList);
+      FlatConstToList(C->getAggregateElement(i), EltValList, EltTy, Types,
+                      bDefaultRowMajor);
     }
   } else if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
-    for (unsigned i = 0; i < ST->getNumElements(); i++) {
-      FlatConstToList(C->getAggregateElement(i), EltValList);
+    RecordDecl *RD = Type->getAsStructureType()->getDecl();
+    const CGRecordLayout &RL = Types.getCGRecordLayout(RD);
+    // Take care base.
+    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+      if (CXXRD->getNumBases()) {
+        for (const auto &I : CXXRD->bases()) {
+          const CXXRecordDecl *BaseDecl =
+              cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
+          if (BaseDecl->field_empty())
+            continue;
+          QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
+          unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
+          FlatConstToList(C->getAggregateElement(i), EltValList, parentTy,
+                          Types, bDefaultRowMajor);
+        }
+      }
+    }
+
+    for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
+         fieldIter != fieldEnd; ++fieldIter) {
+      unsigned i = RL.getLLVMFieldNo(*fieldIter);
+
+      FlatConstToList(C->getAggregateElement(i), EltValList,
+                      fieldIter->getType(), Types, bDefaultRowMajor);
     }
   } else {
     EltValList.emplace_back(C);
@@ -4439,20 +4548,22 @@ static void FlatConstToList(Constant *C,
 }
 
 static bool ScanConstInitList(CodeGenModule &CGM, InitListExpr *E,
-                              SmallVector<Constant *, 4> &EltValList) {
+                              SmallVector<Constant *, 4> &EltValList,
+                              CodeGenTypes &Types, bool bDefaultRowMajor) {
   unsigned NumInitElements = E->getNumInits();
   for (unsigned i = 0; i != NumInitElements; ++i) {
     Expr *init = E->getInit(i);
     QualType iType = init->getType();
     if (InitListExpr *initList = dyn_cast<InitListExpr>(init)) {
-      if (!ScanConstInitList(CGM, initList, EltValList))
+      if (!ScanConstInitList(CGM, initList, EltValList, Types,
+                             bDefaultRowMajor))
         return false;
     } else if (DeclRefExpr *ref = dyn_cast<DeclRefExpr>(init)) {
       if (VarDecl *D = dyn_cast<VarDecl>(ref->getDecl())) {
         if (!D->hasInit())
           return false;
         if (Constant *initVal = CGM.EmitConstantInit(*D)) {
-          FlatConstToList(initVal, EltValList);
+          FlatConstToList(initVal, EltValList, iType, Types, bDefaultRowMajor);
         } else {
           return false;
         }
@@ -4463,7 +4574,7 @@ static bool ScanConstInitList(CodeGenModule &CGM, InitListExpr *E,
       return false;
     } else if (CodeGenFunction::hasScalarEvaluationKind(iType)) {
       if (Constant *initVal = CGM.EmitConstantExpr(init, iType)) {
-        FlatConstToList(initVal, EltValList);
+        FlatConstToList(initVal, EltValList, iType, Types, bDefaultRowMajor);
       } else {
         return false;
       }
@@ -4476,7 +4587,8 @@ static bool ScanConstInitList(CodeGenModule &CGM, InitListExpr *E,
 
 static Constant *BuildConstInitializer(QualType Type, unsigned &offset,
                                        SmallVector<Constant *, 4> &EltValList,
-                                       CodeGenTypes &Types);
+                                       CodeGenTypes &Types,
+                                       bool bDefaultRowMajor);
 
 static Constant *BuildConstVector(llvm::VectorType *VT, unsigned &offset,
                                   SmallVector<Constant *, 4> &EltValList,
@@ -4484,26 +4596,50 @@ static Constant *BuildConstVector(llvm::VectorType *VT, unsigned &offset,
   SmallVector<Constant *, 4> Elts;
   QualType EltTy = hlsl::GetHLSLVecElementType(Type);
   for (unsigned i = 0; i < VT->getNumElements(); i++) {
-    Elts.emplace_back(BuildConstInitializer(EltTy, offset, EltValList, Types));
+    Elts.emplace_back(BuildConstInitializer(EltTy, offset, EltValList, Types,
+                                            // Vector don't need major.
+                                            /*bDefaultRowMajor*/ false));
   }
   return llvm::ConstantVector::get(Elts);
 }
 
 static Constant *BuildConstMatrix(llvm::Type *Ty, unsigned &offset,
                                   SmallVector<Constant *, 4> &EltValList,
-                                  QualType Type, CodeGenTypes &Types) {
+                                  QualType Type, CodeGenTypes &Types,
+                                  bool bDefaultRowMajor) {
   QualType EltTy = hlsl::GetHLSLMatElementType(Type);
   unsigned col, row;
   HLMatrixLower::GetMatrixInfo(Ty, col, row);
   llvm::ArrayType *AT = cast<llvm::ArrayType>(Ty->getStructElementType(0));
+  // Save initializer elements first.
   // Matrix initializer is row major.
+  SmallVector<Constant *, 16> elts;
+  for (unsigned i = 0; i < col * row; i++) {
+    elts.emplace_back(BuildConstInitializer(EltTy, offset, EltValList, Types,
+                                            bDefaultRowMajor));
+  }
+
+  bool isRowMajor = IsRowMajorMatrix(Type, bDefaultRowMajor);
+
+  SmallVector<Constant *, 16> majorElts(elts.begin(), elts.end());
+  if (!isRowMajor) {
+    // cast row major to col major.
+    for (unsigned c = 0; c < col; c++) {
+      SmallVector<Constant *, 4> rows;
+      for (unsigned r = 0; r < row; r++) {
+        unsigned rowMajorIdx = r * col + c;
+        unsigned colMajorIdx = c * row + r;
+        majorElts[colMajorIdx] = elts[rowMajorIdx];
+      }
+    }
+  }
   // The type is vector<element, col>[row].
   SmallVector<Constant *, 4> rows;
+  unsigned idx = 0;
   for (unsigned r = 0; r < row; r++) {
     SmallVector<Constant *, 4> cols;
     for (unsigned c = 0; c < col; c++) {
-      cols.emplace_back(
-          BuildConstInitializer(EltTy, offset, EltValList, Types));
+      cols.emplace_back(majorElts[idx++]);
     }
     rows.emplace_back(llvm::ConstantVector::get(cols));
   }
@@ -4513,19 +4649,21 @@ static Constant *BuildConstMatrix(llvm::Type *Ty, unsigned &offset,
 
 static Constant *BuildConstArray(llvm::ArrayType *AT, unsigned &offset,
                                  SmallVector<Constant *, 4> &EltValList,
-                                 QualType Type, CodeGenTypes &Types) {
+                                 QualType Type, CodeGenTypes &Types,
+                                 bool bDefaultRowMajor) {
   SmallVector<Constant *, 4> Elts;
   QualType EltType = QualType(Type->getArrayElementTypeNoTypeQual(), 0);
   for (unsigned i = 0; i < AT->getNumElements(); i++) {
-    Elts.emplace_back(
-        BuildConstInitializer(EltType, offset, EltValList, Types));
+    Elts.emplace_back(BuildConstInitializer(EltType, offset, EltValList, Types,
+                                            bDefaultRowMajor));
   }
   return llvm::ConstantArray::get(AT, Elts);
 }
 
 static Constant *BuildConstStruct(llvm::StructType *ST, unsigned &offset,
                                   SmallVector<Constant *, 4> &EltValList,
-                                  QualType Type, CodeGenTypes &Types) {
+                                  QualType Type, CodeGenTypes &Types,
+                                  bool bDefaultRowMajor) {
   SmallVector<Constant *, 4> Elts;
 
   const RecordType *RT = Type->getAsStructureType();
@@ -4544,16 +4682,16 @@ static Constant *BuildConstStruct(llvm::StructType *ST, unsigned &offset,
           continue;
 
         // Add base as a whole constant. Not as element.
-        Elts.emplace_back(
-            BuildConstInitializer(I.getType(), offset, EltValList, Types));
+        Elts.emplace_back(BuildConstInitializer(I.getType(), offset, EltValList,
+                                                Types, bDefaultRowMajor));
       }
     }
   }
 
   for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
        fieldIter != fieldEnd; ++fieldIter) {
-    Elts.emplace_back(
-        BuildConstInitializer(fieldIter->getType(), offset, EltValList, Types));
+    Elts.emplace_back(BuildConstInitializer(
+        fieldIter->getType(), offset, EltValList, Types, bDefaultRowMajor));
   }
 
   return llvm::ConstantStruct::get(ST, Elts);
@@ -4561,16 +4699,20 @@ static Constant *BuildConstStruct(llvm::StructType *ST, unsigned &offset,
 
 static Constant *BuildConstInitializer(QualType Type, unsigned &offset,
                                        SmallVector<Constant *, 4> &EltValList,
-                                       CodeGenTypes &Types) {
+                                       CodeGenTypes &Types,
+                                       bool bDefaultRowMajor) {
   llvm::Type *Ty = Types.ConvertType(Type);
   if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
     return BuildConstVector(VT, offset, EltValList, Type, Types);
   } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
-    return BuildConstArray(AT, offset, EltValList, Type, Types);
+    return BuildConstArray(AT, offset, EltValList, Type, Types,
+                           bDefaultRowMajor);
   } else if (HLMatrixLower::IsMatrixType(Ty)) {
-    return BuildConstMatrix(Ty, offset, EltValList, Type, Types);
+    return BuildConstMatrix(Ty, offset, EltValList, Type, Types,
+                            bDefaultRowMajor);
   } else if (StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
-    return BuildConstStruct(ST, offset, EltValList, Type, Types);
+    return BuildConstStruct(ST, offset, EltValList, Type, Types,
+                            bDefaultRowMajor);
   } else {
     // Scalar basic types.
     Constant *Val = EltValList[offset++];
@@ -4591,13 +4733,15 @@ static Constant *BuildConstInitializer(QualType Type, unsigned &offset,
 
 Constant *CGMSHLSLRuntime::EmitHLSLConstInitListExpr(CodeGenModule &CGM,
                                                      InitListExpr *E) {
+  bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
   SmallVector<Constant *, 4> EltValList;
-  if (!ScanConstInitList(CGM, E, EltValList))
+  if (!ScanConstInitList(CGM, E, EltValList, CGM.getTypes(), bDefaultRowMajor))
     return nullptr;
 
   QualType Type = E->getType();
   unsigned offset = 0;
-  return BuildConstInitializer(Type, offset, EltValList, CGM.getTypes());
+  return BuildConstInitializer(Type, offset, EltValList, CGM.getTypes(),
+                               bDefaultRowMajor);
 }
 
 Value *CGMSHLSLRuntime::EmitHLSLMatrixOperationCall(
@@ -4742,58 +4886,144 @@ Value *CGMSHLSLRuntime::EmitHLSLMatrixSubscript(CodeGenFunction &CGF,
                                                 llvm::Value *Ptr,
                                                 llvm::Value *Idx,
                                                 clang::QualType Ty) {
+  bool isRowMajor =
+      IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
   unsigned opcode =
-      IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor)
-          ? static_cast<unsigned>(HLSubscriptOpcode::RowMatSubscript)
-          : static_cast<unsigned>(HLSubscriptOpcode::ColMatSubscript);
+      isRowMajor ? static_cast<unsigned>(HLSubscriptOpcode::RowMatSubscript)
+                 : static_cast<unsigned>(HLSubscriptOpcode::ColMatSubscript);
   Value *matBase = Ptr;
-  if (matBase->getType()->isPointerTy()) {
-    RetType =
-        llvm::PointerType::get(RetType->getPointerElementType(),
-                               matBase->getType()->getPointerAddressSpace());
+  DXASSERT(matBase->getType()->isPointerTy(),
+           "matrix subscript should return pointer");
+
+  RetType =
+      llvm::PointerType::get(RetType->getPointerElementType(),
+                             matBase->getType()->getPointerAddressSpace());
+
+  // Lower mat[Idx] into real idx.
+  SmallVector<Value *, 8> args;
+  args.emplace_back(Ptr);
+  unsigned row, col;
+  hlsl::GetHLSLMatRowColCount(Ty, row, col);
+  if (isRowMajor) {
+    Value *cCol = ConstantInt::get(Idx->getType(), col);
+    Value *Base = CGF.Builder.CreateMul(cCol, Idx);
+    for (unsigned i = 0; i < col; i++) {
+      Value *c = ConstantInt::get(Idx->getType(), i);
+      // r * col + c
+      Value *matIdx = CGF.Builder.CreateAdd(Base, c);
+      args.emplace_back(matIdx);
+    }
+  } else {
+    for (unsigned i = 0; i < col; i++) {
+      Value *cMulRow = ConstantInt::get(Idx->getType(), i * row);
+      // c * row + r
+      Value *matIdx = CGF.Builder.CreateAdd(cMulRow, Idx);
+      args.emplace_back(matIdx);
+    }
   }
-  return EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLSubscript,
-                                        opcode, RetType, {Ptr, Idx}, TheModule);
+
+  Value *matSub =
+      EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLSubscript,
+                                     opcode, RetType, args, TheModule);
+  return matSub;
 }
 
 Value *CGMSHLSLRuntime::EmitHLSLMatrixElement(CodeGenFunction &CGF,
                                               llvm::Type *RetType,
                                               ArrayRef<Value *> paramList,
                                               QualType Ty) {
+  bool isRowMajor =
+      IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
   unsigned opcode =
-      IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor)
-          ? static_cast<unsigned>(HLSubscriptOpcode::RowMatElement)
-          : static_cast<unsigned>(HLSubscriptOpcode::ColMatElement);
+      isRowMajor ? static_cast<unsigned>(HLSubscriptOpcode::RowMatElement)
+                 : static_cast<unsigned>(HLSubscriptOpcode::ColMatElement);
 
   Value *matBase = paramList[0];
-  if (matBase->getType()->isPointerTy()) {
-    RetType =
-        llvm::PointerType::get(RetType->getPointerElementType(),
-                               matBase->getType()->getPointerAddressSpace());
+  DXASSERT(matBase->getType()->isPointerTy(),
+           "matrix element should return pointer");
+
+  RetType =
+      llvm::PointerType::get(RetType->getPointerElementType(),
+                             matBase->getType()->getPointerAddressSpace());
+
+  Value *idx = paramList[HLOperandIndex::kMatSubscriptSubOpIdx-1];
+
+  // Lower _m00 into real idx.
+
+  // -1 to avoid opcode param which is added in EmitHLSLMatrixOperationCallImp.
+  Value *args[] = {paramList[HLOperandIndex::kMatSubscriptMatOpIdx - 1],
+                   paramList[HLOperandIndex::kMatSubscriptSubOpIdx - 1]};
+  // For all zero idx. Still all zero idx.
+  if (ConstantAggregateZero *zeros = dyn_cast<ConstantAggregateZero>(idx)) {
+    Constant *zero = zeros->getAggregateElement((unsigned)0);
+    std::vector<Constant *> elts(zeros->getNumElements() >> 1, zero);
+    args[HLOperandIndex::kMatSubscriptSubOpIdx - 1] = ConstantVector::get(elts);
+  } else {
+    ConstantDataSequential *elts = cast<ConstantDataSequential>(idx);
+    unsigned count = elts->getNumElements();
+    unsigned row, col;
+    hlsl::GetHLSLMatRowColCount(Ty, row, col);
+    std::vector<Constant *> idxs(count >> 1);
+    for (unsigned i = 0; i < count; i += 2) {
+      unsigned rowIdx = elts->getElementAsInteger(i);
+      unsigned colIdx = elts->getElementAsInteger(i + 1);
+      unsigned matIdx = 0;
+      if (isRowMajor) {
+        matIdx = rowIdx * col + colIdx;
+      } else {
+        matIdx = colIdx * row + rowIdx;
+      }
+      idxs[i >> 1] = CGF.Builder.getInt32(matIdx);
+    }
+    args[HLOperandIndex::kMatSubscriptSubOpIdx - 1] = ConstantVector::get(idxs);
   }
 
   return EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLSubscript,
-                                        opcode, RetType, paramList, TheModule);
+                                        opcode, RetType, args, TheModule);
 }
 
 Value *CGMSHLSLRuntime::EmitHLSLMatrixLoad(CGBuilderTy &Builder, Value *Ptr,
                                            QualType Ty) {
+  bool isRowMajor =
+      IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
   unsigned opcode =
-      IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor)
+      isRowMajor
           ? static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad)
           : static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad);
 
-  return EmitHLSLMatrixOperationCallImp(
+  Value *matVal = EmitHLSLMatrixOperationCallImp(
       Builder, HLOpcodeGroup::HLMatLoadStore, opcode,
       Ptr->getType()->getPointerElementType(), {Ptr}, TheModule);
+  if (!isRowMajor) {
+    // ColMatLoad will return a col major matrix.
+    // All matrix Value should be row major.
+    // Cast it to row major.
+    matVal = EmitHLSLMatrixOperationCallImp(
+        Builder, HLOpcodeGroup::HLCast,
+        static_cast<unsigned>(HLCastOpcode::ColMatrixToRowMatrix),
+        matVal->getType(), {matVal}, TheModule);
+  }
+  return matVal;
 }
 void CGMSHLSLRuntime::EmitHLSLMatrixStore(CGBuilderTy &Builder, Value *Val,
                                           Value *DestPtr, QualType Ty) {
+  bool isRowMajor =
+      IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
   unsigned opcode =
-      IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor)
+      isRowMajor
           ? static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore)
           : static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore);
 
+  if (!isRowMajor) {
+    // All matrix Value should be row major.
+    // ColMatStore need a col major value.
+    // Cast it to row major.
+    Val = EmitHLSLMatrixOperationCallImp(
+        Builder, HLOpcodeGroup::HLCast,
+        static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix),
+        Val->getType(), {Val}, TheModule);
+  }
+
   EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLMatLoadStore, opcode,
                                  Val->getType(), {DestPtr, Val}, TheModule);
 }

+ 18 - 0
tools/clang/test/CodeGenHLSL/matrixIn1.hlsl

@@ -0,0 +1,18 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0
+// CHECK: dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1
+// CHECK: dx.op.loadInput.f32(i32 4, i32 0, i32 1, i8 0
+// CHECK: dx.op.loadInput.f32(i32 4, i32 0, i32 1, i8 1
+
+// fxc will generate v0.x v1.x v0.y v1.y
+
+// CHECK: float %0)
+// CHECK: float %2)
+// CHECK: float %1)
+// CHECK: float %3)
+
+float4 main(float2x2 m : M) : SV_Target
+{
+  return m;
+}

+ 18 - 0
tools/clang/test/CodeGenHLSL/matrixIn2.hlsl

@@ -0,0 +1,18 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0
+// CHECK: dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1
+// CHECK: dx.op.loadInput.f32(i32 4, i32 0, i32 1, i8 0
+// CHECK: dx.op.loadInput.f32(i32 4, i32 0, i32 1, i8 1
+
+// fxc will generate v0.xy v1.xy
+
+// CHECK: float %0)
+// CHECK: float %1)
+// CHECK: float %2)
+// CHECK: float %3)
+
+float4 main(row_major float2x2 m : M) : SV_Target
+{
+  return m;
+}

+ 15 - 0
tools/clang/test/CodeGenHLSL/matrixOut1.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// fxc o1.xy = 5, 7 o2.xy = 6, 8
+
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 5.000000e+00)
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float 7.000000e+00)
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 0, i32 1, i8 0, float 6.000000e+00)
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 0, i32 1, i8 1, float 8.000000e+00)
+
+float4 main(out float2x2 a : A, int4 b : B) : SV_Position
+{
+  float2x2 m = { 5, 6, 7, 8};
+  a = m;
+  return b;
+}

+ 15 - 0
tools/clang/test/CodeGenHLSL/matrixOut2.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T vs_6_0 %s
+
+// fxc o1.xy = 5, 6 o2.xy = 7, 8
+
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 5.000000e+00)
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float 6.000000e+00)
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 0, i32 1, i8 0, float 7.000000e+00)
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 0, i32 1, i8 1, float 8.000000e+00)
+
+float4 main(out row_major float2x2 a : A, int4 b : B) : SV_Position
+{
+  float2x2 m = { 5, 6, 7, 8};
+  a = m;
+  return b;
+}

+ 1 - 2
tools/clang/test/CodeGenHLSL/staticGlobals.hlsl

@@ -9,8 +9,7 @@
 // CHECK: [16 x float] [float 1.500000e+01, float 1.500000e+01, float 1.500000e+01, float 1.500000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.700000e+01, float 1.700000e+01, float 1.700000e+01, float 1.700000e+01, float 1.800000e+01, float 1.800000e+01, float 1.800000e+01, float 1.800000e+01]
 // CHECK: [16 x float] [float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00]
 // CHECK: [4 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
-// CHECK: [16 x float] [float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01]
-
+// CHECK: [16 x float] [float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01]
 
 static float4 f0 = {5,6,7,8};
 static float4 f1 = 0;

+ 5 - 2
tools/clang/test/CodeGenHLSL/staticGlobals3.hlsl

@@ -10,9 +10,12 @@
 // t3.c.y
 // CHECK: [3 x i32] [i32 0, i32 28, i32 0]
 // t3.a
-// CHECK: [12 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00]
+
+// CHECK: [12 x float] [float 5.000000e+00, float 7.000000e+00, float 6.000000e+00, float 8.000000e+00, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00]
+
 // t3.t
-// CHECK: [24 x float] [float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
+
+// CHECK: [24 x float] [float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 2.500000e+01, float 2.700000e+01, float 2.600000e+01, float 2.800000e+01, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 5.000000e+00, float 7.000000e+00, float 6.000000e+00, float 8.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 5.000000e+00, float 7.000000e+00, float 6.000000e+00, float 8.000000e+00]
 
 
 

+ 1 - 1
tools/clang/test/CodeGenHLSL/swizzleAtomic.hlsl

@@ -15,7 +15,7 @@
 // CHECK: i32 12,
 
 
-groupshared column_major uint2x2 dataC[8*8];
+groupshared row_major uint2x2 dataC[8*8];
 groupshared uint4 a;
 
 RWStructuredBuffer<uint2x2> mats;

+ 20 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -429,7 +429,11 @@ public:
   TEST_METHOD(CodeGenMatInStruct)
   TEST_METHOD(CodeGenMatInStructRet)
   TEST_METHOD(CodeGenMatIn)
+  TEST_METHOD(CodeGenMatIn1)
+  TEST_METHOD(CodeGenMatIn2)
   TEST_METHOD(CodeGenMatOut)
+  TEST_METHOD(CodeGenMatOut1)
+  TEST_METHOD(CodeGenMatOut2)
   TEST_METHOD(CodeGenMatSubscript)
   TEST_METHOD(CodeGenMatSubscript2)
   TEST_METHOD(CodeGenMatSubscript3)
@@ -2532,10 +2536,26 @@ TEST_F(CompilerTest, CodeGenMatIn) {
   CodeGenTest(L"..\\CodeGenHLSL\\matrixIn.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenMatIn1) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\matrixIn1.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenMatIn2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\matrixIn2.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenMatOut) {
   CodeGenTest(L"..\\CodeGenHLSL\\matrixOut.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenMatOut1) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\matrixOut1.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenMatOut2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\matrixOut2.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenMatSubscript) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\matSubscript.hlsl");
 }

+ 24 - 24
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -1407,52 +1407,52 @@ TEST_F(ValidationTest, LocalRes6Dbg) {
 
 TEST_F(ValidationTest, AddrSpaceCast) {
   RewriteAssemblyCheckMsg(L"..\\CodeGenHLSL\\staticGlobals.hlsl", "ps_6_0",
-                          "%([0-9]+) = getelementptr \\[4 x float\\], \\[4 x float\\]\\* %([0-9]+), i32 0, i32 0\n"
-                          "  store float %([0-9]+), float\\* %\\1, align 4",
-                          "%\\1 = getelementptr [4 x float], [4 x float]* %\\2, i32 0, i32 0\n"
-                          "  %X = addrspacecast float* %\\1 to float addrspace(1)*    \n"
-                          "  store float %\\3, float addrspace(1)* %X, align 4",
+                          "%([0-9]+) = getelementptr \\[4 x i32\\], \\[4 x i32\\]\\* %([0-9]+), i32 0, i32 0\n"
+                          "  store i32 %([0-9]+), i32\\* %\\1, align 4",
+                          "%\\1 = getelementptr [4 x i32], [4 x i32]* %\\2, i32 0, i32 0\n"
+                          "  %X = addrspacecast i32* %\\1 to i32 addrspace(1)*    \n"
+                          "  store i32 %\\3, i32 addrspace(1)* %X, align 4",
                           "generic address space",
                           /*bRegex*/true);
 }
 
 TEST_F(ValidationTest, PtrBitCast) {
   RewriteAssemblyCheckMsg(L"..\\CodeGenHLSL\\staticGlobals.hlsl", "ps_6_0",
-                          "%([0-9]+) = getelementptr \\[4 x float\\], \\[4 x float\\]\\* %([0-9]+), i32 0, i32 0\n"
-                          "  store float %([0-9]+), float\\* %\\1, align 4",
-                          "%\\1 = getelementptr [4 x float], [4 x float]* %\\2, i32 0, i32 0\n"
-                          "  %X = bitcast float* %\\1 to double*    \n"
-                          "  store float %\\3, float* %\\1, align 4",
+                          "%([0-9]+) = getelementptr \\[4 x i32\\], \\[4 x i32\\]\\* %([0-9]+), i32 0, i32 0\n"
+                          "  store i32 %([0-9]+), i32\\* %\\1, align 4",
+                          "%\\1 = getelementptr [4 x i32], [4 x i32]* %\\2, i32 0, i32 0\n"
+                          "  %X = bitcast i32* %\\1 to double*    \n"
+                          "  store i32 %\\3, i32* %\\1, align 4",
                           "Pointer type bitcast must be have same size",
                           /*bRegex*/true);
 }
 
 TEST_F(ValidationTest, MinPrecisionBitCast) {
   RewriteAssemblyCheckMsg(L"..\\CodeGenHLSL\\staticGlobals.hlsl", "ps_6_0",
-                          "%([0-9]+) = getelementptr \\[4 x float\\], \\[4 x float\\]\\* %([0-9]+), i32 0, i32 0\n"
-                          "  store float %([0-9]+), float\\* %\\1, align 4",
-                          "%\\1 = getelementptr [4 x float], [4 x float]* %\\2, i32 0, i32 0\n"
-                          "  %X = bitcast float* %\\1 to [2 x half]*    \n"
-                          "  store float %\\3, float* %\\1, align 4",
+                          "%([0-9]+) = getelementptr \\[4 x i32\\], \\[4 x i32\\]\\* %([0-9]+), i32 0, i32 0\n"
+                          "  store i32 %([0-9]+), i32\\* %\\1, align 4",
+                          "%\\1 = getelementptr [4 x i32], [4 x i32]* %\\2, i32 0, i32 0\n"
+                          "  %X = bitcast i32* %\\1 to [2 x half]*    \n"
+                          "  store i32 %\\3, i32* %\\1, align 4",
                           "Bitcast on minprecison types is not allowed",
                           /*bRegex*/true);
 }
 
 TEST_F(ValidationTest, StructBitCast) {
   RewriteAssemblyCheckMsg(L"..\\CodeGenHLSL\\staticGlobals.hlsl", "ps_6_0",
-                          "%([0-9]+) = getelementptr \\[4 x float\\], \\[4 x float\\]\\* %([0-9]+), i32 0, i32 0\n"
-                          "  store float %([0-9]+), float\\* %\\1, align 4",
-                          "%\\1 = getelementptr [4 x float], [4 x float]* %\\2, i32 0, i32 0\n"
-                          "  %X = bitcast float* %\\1 to %dx.types.Handle*    \n"
-                          "  store float %\\3, float* %\\1, align 4",
+                          "%([0-9]+) = getelementptr \\[4 x i32\\], \\[4 x i32\\]\\* %([0-9]+), i32 0, i32 0\n"
+                          "  store i32 %([0-9]+), i32\\* %\\1, align 4",
+                          "%\\1 = getelementptr [4 x i32], [4 x i32]* %\\2, i32 0, i32 0\n"
+                          "  %X = bitcast i32* %\\1 to %dx.types.Handle*    \n"
+                          "  store i32 %\\3, i32* %\\1, align 4",
                           "Bitcast on struct types is not allowed",
                           /*bRegex*/true);
 }
 
 TEST_F(ValidationTest, MultiDimArray) {
   RewriteAssemblyCheckMsg(L"..\\CodeGenHLSL\\staticGlobals.hlsl", "ps_6_0",
-                          "= alloca [4 x float]",
-                          "= alloca [4 x float]\n"
+                          "= alloca [4 x i32]",
+                          "= alloca [4 x i32]\n"
                           "  %md = alloca [2 x [4 x float]]",
                           "Only one dimension allowed for array type");
 }
@@ -1487,8 +1487,8 @@ TEST_F(ValidationTest, NoFunctionParam) {
 
 TEST_F(ValidationTest, I8Type) {
   RewriteAssemblyCheckMsg(L"..\\CodeGenHLSL\\staticGlobals.hlsl", "ps_6_0",
-                          "%([0-9]+) = alloca \\[4 x float\\]",
-                          "%\\1 = alloca [4 x float]\n"
+                          "%([0-9]+) = alloca \\[4 x i32\\]",
+                          "%\\1 = alloca [4 x i32]\n"
                           "  %m8 = alloca i8",
                           "I8 can only used as immediate value for intrinsic",
     /*bRegex*/true);