فهرست منبع

Enable sroa for hlsl to remove static indexing array. (#1893)

Xiang Li 6 سال پیش
والد
کامیت
c2e744dcf4

+ 1 - 0
include/dxc/DXIL/DxilUtil.h

@@ -105,6 +105,7 @@ namespace dxilutil {
   // Returns true if type contains HLSL Object type (resource)
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
+  bool IsHLSLMatrixType(llvm::Type *Ty);
   bool IsSplat(llvm::ConstantDataVector *cdv);
 }
 

+ 0 - 1
include/dxc/HLSL/HLMatrixLowerHelper.h

@@ -27,7 +27,6 @@ class DxilTypeSystem;
 
 namespace HLMatrixLower {
 // TODO: use type annotation.
-bool IsMatrixType(llvm::Type *Ty);
 DxilFieldAnnotation *FindAnnotationFromMatUser(llvm::Value *Mat,
                                                DxilTypeSystem &typeSys);
 // Translate matrix type to vector type.

+ 2 - 1
include/llvm/Transforms/Scalar.h

@@ -94,7 +94,8 @@ FunctionPass *createBitTrackingDCEPass();
 //
 // SROA - Replace aggregates or pieces of aggregates with scalar SSA values.
 //
-FunctionPass *createSROAPass(bool RequiresDomTree = true);
+FunctionPass *createSROAPass(bool RequiresDomTree = true,
+                             bool SkipHLSLMat = true);
 
 //===----------------------------------------------------------------------===//
 //

+ 14 - 0
lib/DXIL/DxilUtil.cpp

@@ -463,6 +463,20 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
   return false;
 }
 
+bool IsHLSLMatrixType(Type *Ty) {
+  if (StructType *ST = dyn_cast<StructType>(Ty)) {
+    Type *EltTy = ST->getElementType(0);
+    if (!ST->getName().startswith("class.matrix"))
+      return false;
+
+    bool isVecArray =
+        EltTy->isArrayTy() && EltTy->getArrayElementType()->isVectorTy();
+
+    return isVecArray && EltTy->getArrayNumElements() <= 4;
+  }
+  return false;
+}
+
 bool ContainsHLSLObjectType(llvm::Type *Ty) {
   // Unwrap pointer/array
   while (llvm::isa<llvm::PointerType>(Ty))

+ 3 - 2
lib/HLSL/DxcOptimizer.cpp

@@ -204,7 +204,7 @@ static ArrayRef<LPCSTR> GetPassArgNames(LPCSTR passName) {
   static const LPCSTR LowerExpectIntrinsicArgs[] = { "likely-branch-weight", "unlikely-branch-weight" };
   static const LPCSTR MergeFunctionsArgs[] = { "mergefunc-sanity" };
   static const LPCSTR RewriteSymbolsArgs[] = { "DL", "rewrite-map-file" };
-  static const LPCSTR SROAArgs[] = { "RequiresDomTree", "force-ssa-updater", "sroa-random-shuffle-slices", "sroa-strict-inbounds" };
+  static const LPCSTR SROAArgs[] = { "RequiresDomTree", "SkipHLSLMat", "force-ssa-updater", "sroa-random-shuffle-slices", "sroa-strict-inbounds" };
   static const LPCSTR SROA_DTArgs[] = { "Threshold", "StructMemberThreshold", "ArrayElementThreshold", "ScalarLoadThreshold" };
   static const LPCSTR SROA_SSAUpArgs[] = { "Threshold", "StructMemberThreshold", "ArrayElementThreshold", "ScalarLoadThreshold" };
   static const LPCSTR SampleProfileLoaderArgs[] = { "sample-profile-file", "sample-profile-max-propagate-iterations" };
@@ -277,7 +277,7 @@ static ArrayRef<LPCSTR> GetPassArgDescriptions(LPCSTR passName) {
   static const LPCSTR LowerExpectIntrinsicArgs[] = { "Weight of the branch likely to be taken (default = 64)", "Weight of the branch unlikely to be taken (default = 4)" };
   static const LPCSTR MergeFunctionsArgs[] = { "How many functions in module could be used for MergeFunctions pass sanity check. '0' disables this check. Works only with '-debug' key." };
   static const LPCSTR RewriteSymbolsArgs[] = { "None", "None" };
-  static const LPCSTR SROAArgs[] = { "None", "Force the pass to not use DomTree and mem2reg, insteadforming SSA values through the SSAUpdater infrastructure.", "Enable randomly shuffling the slices to help uncover instability in their order.", "Experiment with completely strict handling of inbounds GEPs." };
+  static const LPCSTR SROAArgs[] = { "None", "None", "Force the pass to not use DomTree and mem2reg, insteadforming SSA values through the SSAUpdater infrastructure.", "Enable randomly shuffling the slices to help uncover instability in their order.", "Experiment with completely strict handling of inbounds GEPs." };
   static const LPCSTR SROA_DTArgs[] = { "None", "None", "None", "None" };
   static const LPCSTR SROA_SSAUpArgs[] = { "None", "None", "None", "None" };
   static const LPCSTR SampleProfileLoaderArgs[] = { "None", "None" };
@@ -342,6 +342,7 @@ static bool IsPassOptionName(StringRef S) {
     ||  S.equals("RequiresDomTree")
     ||  S.equals("Runtime")
     ||  S.equals("ScalarLoadThreshold")
+    ||  S.equals("SkipHLSLMat")
     ||  S.equals("StructMemberThreshold")
     ||  S.equals("TIRA")
     ||  S.equals("TLIImpl")

+ 1 - 1
lib/HLSL/DxilCondenseResources.cpp

@@ -1534,7 +1534,7 @@ Type *UpdateFieldTypeForLegacyLayout(Type *Ty, bool IsCBuf,
       return Ty;
     else
       return ArrayType::get(UpdatedTy, Ty->getArrayNumElements());
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     DXASSERT(annotation.HasMatrixAnnotation(), "must a matrix");
     unsigned rows, cols;
     Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, cols, rows);

+ 1 - 1
lib/HLSL/DxilContainerReflection.cpp

@@ -765,7 +765,7 @@ HRESULT CShaderReflectionType::Initialize(
     llvm::Type* elementType = type->getArrayElementType();
 
     // Note: At this point an HLSL matrix type may appear as an ordinary
-    // array (not wrapped in a `struct`), so `HLMatrixLower::IsMatrixType()`
+    // array (not wrapped in a `struct`), so `dxilutil::IsHLSLMatrixType()`
     // is not sufficient. Instead we need to check the field annotation.
     //
     // We might have an array of matrices, though, so we only exit if

+ 1 - 1
lib/HLSL/DxilLinker.cpp

@@ -1027,7 +1027,7 @@ void DxilLinkJob::RunPreparePass(Module &M) {
   PM.add(createDxilDeadFunctionEliminationPass());
 
   // SROA
-  PM.add(createSROAPass(/*RequiresDomTree*/false));
+  PM.add(createSROAPass(/*RequiresDomTree*/false, /*SkipHLSLMat*/false));
 
   // Remove MultiDimArray from function call arg.
   PM.add(createMultiDimArrayToOneDimArrayPass());

+ 28 - 39
lib/HLSL/HLMatrixLowerPass.cpp

@@ -36,20 +36,6 @@ using namespace hlsl::HLMatrixLower;
 namespace hlsl {
 namespace HLMatrixLower {
 
-bool IsMatrixType(Type *Ty) {
-  if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    Type *EltTy = ST->getElementType(0);
-    if (!ST->getName().startswith("class.matrix"))
-      return false;
-
-    bool isVecArray = EltTy->isArrayTy() &&
-           EltTy->getArrayElementType()->isVectorTy();
-
-    return isVecArray && EltTy->getArrayNumElements() <= 4;
-  }
-  return false;
-}
-
 // If user is function call, return param annotation to get matrix major.
 DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
                                                DxilTypeSystem &typeSys) {
@@ -81,7 +67,7 @@ Type *LowerMatrixType(Type *Ty, bool forMem) {
       params.emplace_back(LowerMatrixType(param));
     }
     return FunctionType::get(RetTy, params, false);
-  } else if (IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     unsigned row, col;
     Type *EltTy = GetMatrixInfo(Ty, col, row);
     if (forMem && EltTy->isIntegerTy(1))
@@ -94,7 +80,7 @@ Type *LowerMatrixType(Type *Ty, bool forMem) {
 
 // Translate matrix type to array type.
 Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
-  if (IsMatrixType(Ty)) {
+  if (dxilutil::IsHLSLMatrixType(Ty)) {
     unsigned row, col;
     Type *EltTy = GetMatrixInfo(Ty, col, row);
     return ArrayType::get(EltTy, row * col);
@@ -105,7 +91,7 @@ Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
 
 
 Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
-  DXASSERT(IsMatrixType(Ty), "not matrix type");
+  DXASSERT(dxilutil::IsHLSLMatrixType(Ty), "not matrix type");
   StructType *ST = cast<StructType>(Ty);
   Type *EltTy = ST->getElementType(0);
   Type *RowTy = EltTy->getArrayElementType();
@@ -122,7 +108,7 @@ bool IsMatrixArrayPointer(llvm::Type *Ty) {
     return false;
   while (Ty->isArrayTy())
     Ty = Ty->getArrayElementType();
-  return IsMatrixType(Ty);
+  return dxilutil::IsHLSLMatrixType(Ty);
 }
 Type *LowerMatrixArrayPointer(Type *Ty, bool forMem) {
   unsigned addrSpace = Ty->getPointerAddressSpace();
@@ -441,8 +427,8 @@ Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
   Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
   HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
 
-  bool ToMat = IsMatrixType(CI->getType());
-  bool FromMat = IsMatrixType(op->getType());
+  bool ToMat = dxilutil::IsHLSLMatrixType(CI->getType());
+  bool FromMat = dxilutil::IsHLSLMatrixType(op->getType());
   if (ToMat && !FromMat) {
     // Translate OtherToMat here.
     // Rest will translated when replace.
@@ -514,11 +500,11 @@ Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
 // UDT alloca must be there for library function args
 static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
-    if (IsMatrixType(GEP->getResultElementType())) {
+    if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
       Value *ptr = GEP->getPointerOperand();
       if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
         Type *ATy = AI->getAllocatedType();
-        if (ATy->isStructTy() && !IsMatrixType(ATy)) {
+        if (ATy->isStructTy() && !dxilutil::IsHLSLMatrixType(ATy)) {
           return GEP;
         }
       }
@@ -531,7 +517,7 @@ static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
 // none-graphics functions.
 static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
-    if (IsMatrixType(GEP->getResultElementType())) {
+    if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
       Value *ptr = GEP->getPointerOperand();
       if (Argument *Arg = dyn_cast<Argument>(ptr)) {
         if (!HM.IsGraphicsShader(Arg->getParent()))
@@ -654,7 +640,7 @@ Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
   SmallVector<Value *, 4> argList;
   for (Value *arg : CI->arg_operands()) {
     Type *Ty = arg->getType();
-    if (IsMatrixType(Ty)) {
+    if (dxilutil::IsHLSLMatrixType(Ty)) {
       argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
     } else
       argList.emplace_back(arg);
@@ -779,12 +765,14 @@ Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
     break;
   case HLBinaryOpcode::Shl: {
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
+    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
+                      "must be matrix type here");
     Result = BinaryOperator::CreateShl(tmp, tmp);
   } break;
   case HLBinaryOpcode::Shr: {
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
+    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
+                      "must be matrix type here");
     Result = BinaryOperator::CreateAShr(tmp, tmp);
   } break;
   case HLBinaryOpcode::LT:
@@ -831,7 +819,8 @@ Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
     break;
   case HLBinaryOpcode::UShr: {
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
+    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
+                      "must be matrix type here");
     Result = BinaryOperator::CreateLShr(tmp, tmp);
   } break;
   case HLBinaryOpcode::ULT:
@@ -1229,8 +1218,8 @@ void HLMatrixLowerPass::TranslateMul(Value *matVal, Value *vecVal,
   Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
   Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
 
-  bool LMat = IsMatrixType(LVal->getType());
-  bool RMat = IsMatrixType(RVal->getType());
+  bool LMat = dxilutil::IsHLSLMatrixType(LVal->getType());
+  bool RMat = dxilutil::IsHLSLMatrixType(RVal->getType());
   if (LMat && RMat) {
     TranslateMatMatMul(matVal, vecVal, mulInst, isSigned);
   } else if (LMat) {
@@ -1501,8 +1490,8 @@ void HLMatrixLowerPass::TranslateMatCast(Value *matVal,
                           opcode == HLCastOpcode::RowMatrixToColMatrix,
                           /*bTranspose*/false);
   } else {
-    bool ToMat = IsMatrixType(castInst->getType());
-    bool FromMat = IsMatrixType(matVal->getType());
+    bool ToMat = dxilutil::IsHLSLMatrixType(castInst->getType());
+    bool FromMat = dxilutil::IsHLSLMatrixType(matVal->getType());
     if (ToMat && FromMat) {
       TranslateMatMatCast(matVal, vecVal, castInst);
     } else if (FromMat)
@@ -1946,7 +1935,7 @@ static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
       }
     }
     Type *valEltTy = val->getType()->getPointerElementType();
-    if (valEltTy->isVectorTy() || HLMatrixLower::IsMatrixType(valEltTy) ||
+    if (valEltTy->isVectorTy() || dxilutil::IsHLSLMatrixType(valEltTy) ||
         valEltTy->isSingleValueType()) {
       Value *ldVal = Builder.CreateLoad(val);
       IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
@@ -1969,7 +1958,7 @@ static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
         }
       }
     }
-  } else if (HLMatrixLower::IsMatrixType(valTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(valTy)) {
     unsigned col, row;
     HLMatrixLower::GetMatrixInfo(valTy, col, row);
     unsigned matSize = col * row;
@@ -2489,7 +2478,7 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   }
 
   Type *Ty = GV->getType()->getPointerElementType();
-  if (!HLMatrixLower::IsMatrixType(Ty))
+  if (!dxilutil::IsHLSLMatrixType(Ty))
     return;
 
   bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
@@ -2585,11 +2574,11 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
     BasicBlock *BB = BBI;
     for (auto II = BB->begin(); II != BB->end(); ) {
       Instruction &I = *(II++);
-      if (IsMatrixType(I.getType())) {
+      if (dxilutil::IsHLSLMatrixType(I.getType())) {
         lowerToVec(&I);
       } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
         Type *Ty = AI->getAllocatedType();
-        if (HLMatrixLower::IsMatrixType(Ty)) {
+        if (dxilutil::IsHLSLMatrixType(Ty)) {
           lowerToVec(&I);
         } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
           lowerToVec(&I);
@@ -2665,7 +2654,7 @@ Type *TryLowerMatTy(Type *Ty) {
   if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
     VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
   } else if (isa<PointerType>(Ty) &&
-             HLMatrixLower::IsMatrixType(Ty->getPointerElementType())) {
+             dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
     VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
         Ty->getPointerElementType());
     VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
@@ -2727,7 +2716,7 @@ bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
     User *U = *(it++);
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
       Type *EltTy = GEP->getType()->getPointerElementType();
-      if (HLMatrixLower::IsMatrixType(EltTy)) {
+      if (dxilutil::IsHLSLMatrixType(EltTy)) {
         if (hasCallUser(GEP))
           return true;
       } else {
@@ -2782,7 +2771,7 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
     User *U = *(it++);
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
       Type *EltTy = GEP->getType()->getPointerElementType();
-      if (HLMatrixLower::IsMatrixType(EltTy)) {
+      if (dxilutil::IsHLSLMatrixType(EltTy)) {
         // Change gep matrixArray, 0, index
         // into
         //   gep oneDimArray, 0, index * matSize

+ 3 - 3
lib/HLSL/HLSignatureLower.cpp

@@ -624,14 +624,14 @@ void replaceDirectInputParameter(Value *param, Function *loadInput,
       newVec = Builder.CreateInsertElement(newVec, input, col);
     }
     param->replaceAllUsesWith(newVec);
-  } else if (!Ty->isArrayTy() && !HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (!Ty->isArrayTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
     DXASSERT(cols == 1, "only support scalar here");
     Value *colIdx = hlslOP->GetU8Const(0);
     args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
     Value *input =
         GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
     param->replaceAllUsesWith(input);
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     Value *colIdx = hlslOP->GetU8Const(0);
     (void)colIdx;
     DXASSERT(param->hasOneUse(),
@@ -784,7 +784,7 @@ void collectInputOutputAccessInfo(
               vectorIdx = GEPIt.getOperand();
             }
           }
-          if (HLMatrixLower::IsMatrixType(*GEPIt)) {
+          if (dxilutil::IsHLSLMatrixType(*GEPIt)) {
             unsigned row, col;
             HLMatrixLower::GetMatrixInfo(*GEPIt, col, row);
             Constant *arraySize = ConstantInt::get(idxTy, col);

+ 1 - 4
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -373,14 +373,11 @@ void PassManagerBuilder::populateModulePassManager(
 
   // Start of function pass.
   // Break up aggregate allocas, using SSAUpdater.
-  // HLSL Change - don't run SROA. 
-  // HLSL uses special SROA added in addHLSLPasses.
-  if (HLSLHighLevel) { // HLSL Change
   if (UseNewSROA)
     MPM.add(createSROAPass(/*RequiresDomTree*/ false));
   else
     MPM.add(createScalarReplAggregatesPass(-1, false));
-  }
+
   // HLSL Change. MPM.add(createEarlyCSEPass());              // Catch trivial redundancies
   // HLSL Change. MPM.add(createJumpThreadingPass());         // Thread jumps.
   MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals

+ 83 - 15
lib/Transforms/Scalar/SROA.cpp

@@ -222,7 +222,8 @@ namespace {
 class AllocaSlices {
 public:
   /// \brief Construct the slices of a particular alloca.
-  AllocaSlices(const DataLayout &DL, AllocaInst &AI);
+  AllocaSlices(const DataLayout &DL, AllocaInst &AI,
+               const bool SkipHLSLMat); // HLSL Change - not sroa matrix type.
 
   /// \brief Test whether a pointer to the allocation escapes our analysis.
   ///
@@ -633,6 +634,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
   friend class InstVisitor<SliceBuilder>;
   typedef PtrUseVisitor<SliceBuilder> Base;
 
+  const bool SkipHLSLMat; // HLSL Change - not sroa matrix type.
   const uint64_t AllocSize;
   AllocaSlices &AS;
 
@@ -643,8 +645,10 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
   SmallPtrSet<Instruction *, 4> VisitedDeadInsts;
 
 public:
-  SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS)
+  SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS,
+               const bool SkipHLSLMat)
       : PtrUseVisitor<SliceBuilder>(DL),
+        SkipHLSLMat(SkipHLSLMat), // HLSL Change - not sroa matrix type.
         AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {}
 
 private:
@@ -690,7 +694,24 @@ private:
   void visitBitCastInst(BitCastInst &BC) {
     if (BC.use_empty())
       return markAsDead(BC);
-
+    // HLSL Change Begin - not sroa matrix type.
+    if (PointerType *PT = dyn_cast<PointerType>(BC.getType())) {
+      Type *EltTy = PT->getElementType();
+      if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(EltTy)) ||
+          hlsl::dxilutil::IsHLSLObjectType(EltTy)) {
+        AS.PointerEscapingInstr = &BC;
+        return;
+      }
+      if (PointerType *SrcPT = dyn_cast<PointerType>(BC.getSrcTy())) {
+        Type *SrcEltTy = SrcPT->getElementType();
+        if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(SrcEltTy)) ||
+            hlsl::dxilutil::IsHLSLObjectType(SrcEltTy)) {
+          AS.PointerEscapingInstr = &BC;
+          return;
+        }
+      }
+    }
+    // HLSL Change End.
     return Base::visitBitCastInst(BC);
   }
 
@@ -751,9 +772,15 @@ private:
   }
 
   void visitLoadInst(LoadInst &LI) {
+    // HLSL Change Begin - not sroa matrix type.
+    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(LI.getType())) ||
+        hlsl::dxilutil::IsHLSLObjectType(LI.getType()))
+      return PI.setEscapedAndAborted(&LI);
+    // HLSL Change End.
     assert((!LI.isSimple() || LI.getType()->isSingleValueType()) &&
            "All simple FCA loads should have been pre-split");
 
+
     if (!IsOffsetKnown)
       return PI.setAborted(&LI);
 
@@ -766,6 +793,12 @@ private:
     Value *ValOp = SI.getValueOperand();
     if (ValOp == *U)
       return PI.setEscapedAndAborted(&SI);
+    // HLSL Change Begin - not sroa matrix type.
+    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(ValOp->getType())) ||
+        hlsl::dxilutil::IsHLSLObjectType(ValOp->getType()))
+      return PI.setEscapedAndAborted(&SI);
+    // HLSL Change End.
+
     if (!IsOffsetKnown)
       return PI.setAborted(&SI);
 
@@ -1002,13 +1035,15 @@ private:
   void visitInstruction(Instruction &I) { PI.setAborted(&I); }
 };
 
-AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI)
+AllocaSlices::AllocaSlices(
+    const DataLayout &DL, AllocaInst &AI,
+    const bool SkipHLSLMat) // HLSL Change - not sroa matrix type.
     :
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
       AI(AI),
 #endif
       PointerEscapingInstr(nullptr) {
-  SliceBuilder PB(DL, AI, *this);
+  SliceBuilder PB(DL, AI, *this, SkipHLSLMat);
   SliceBuilder::PtrInfo PtrI = PB.visitPtr(AI);
   if (PtrI.isEscaped() || PtrI.isAborted()) {
     // FIXME: We should sink the escape vs. abort info into the caller nicely,
@@ -1204,6 +1239,7 @@ namespace {
 ///    SSA vector values.
 class SROA : public FunctionPass {
   const bool RequiresDomTree;
+  const bool SkipHLSLMat; // HLSL Change - not sroa matrix type.
 
   LLVMContext *C;
   DominatorTree *DT;
@@ -1252,9 +1288,10 @@ class SROA : public FunctionPass {
   SetVector<SelectInst *, SmallVector<SelectInst *, 2>> SpeculatableSelects;
 
 public:
-  SROA(bool RequiresDomTree = true)
-      : FunctionPass(ID), RequiresDomTree(RequiresDomTree), C(nullptr),
-        DT(nullptr) {
+  SROA(bool RequiresDomTree = true, bool SkipHLSLMat = true)
+      : FunctionPass(ID), RequiresDomTree(RequiresDomTree),
+        SkipHLSLMat(SkipHLSLMat), // HLSL Change - not sroa matrix type.
+        C(nullptr), DT(nullptr) {
     initializeSROAPass(*PassRegistry::getPassRegistry());
   }
   bool runOnFunction(Function &F) override;
@@ -1280,8 +1317,8 @@ private:
 
 char SROA::ID = 0;
 
-FunctionPass *llvm::createSROAPass(bool RequiresDomTree) {
-  return new SROA(RequiresDomTree);
+FunctionPass *llvm::createSROAPass(bool RequiresDomTree, bool SkipHLSLMat) {
+  return new SROA(RequiresDomTree, SkipHLSLMat);
 }
 
 INITIALIZE_PASS_BEGIN(SROA, "sroa", "Scalar Replacement Of Aggregates", false,
@@ -3191,6 +3228,7 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
   friend class llvm::InstVisitor<AggLoadStoreRewriter, bool>;
 
   const DataLayout &DL;
+  const bool SkipHLSLMat; // HLSL Change - not sroa matrix type.
 
   /// Queue of pointer uses to analyze and potentially rewrite.
   SmallVector<Use *, 8> Queue;
@@ -3203,7 +3241,9 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
   Use *U;
 
 public:
-  AggLoadStoreRewriter(const DataLayout &DL) : DL(DL) {}
+  AggLoadStoreRewriter(const DataLayout &DL, const bool SkipHLSLMat)
+      // HLSL Change - not sroa matrix type.
+      : DL(DL), SkipHLSLMat(SkipHLSLMat) {}
 
   /// Rewrite loads and stores through a pointer and all pointers derived from
   /// it.
@@ -3323,6 +3363,11 @@ private:
     assert(LI.getPointerOperand() == *U);
     if (!LI.isSimple() || LI.getType()->isSingleValueType())
       return false;
+    // HLSL Change Begin - not sroa matrix type.
+    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(LI.getType())) ||
+        hlsl::dxilutil::IsHLSLObjectType(LI.getType()))
+      return false;
+    // HLSL Change End.
 
     // We have an aggregate being loaded, split it apart.
     DEBUG(dbgs() << "    original: " << LI << "\n");
@@ -3357,7 +3402,11 @@ private:
     Value *V = SI.getValueOperand();
     if (V->getType()->isSingleValueType())
       return false;
-
+    // HLSL Change Begin - not sroa matrix type.
+    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(V->getType())) ||
+        hlsl::dxilutil::IsHLSLObjectType(V->getType()))
+      return false;
+    // HLSL Change End.
     // We have an aggregate being stored, split it apart.
     DEBUG(dbgs() << "    original: " << SI << "\n");
     StoreOpSplitter Splitter(&SI, *U);
@@ -3367,6 +3416,20 @@ private:
   }
 
   bool visitBitCastInst(BitCastInst &BC) {
+    // HLSL Change Begin - not sroa matrix type.
+    if (PointerType *PT = dyn_cast<PointerType>(BC.getType())) {
+      Type *EltTy = PT->getElementType();
+      if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(EltTy)) ||
+          hlsl::dxilutil::IsHLSLObjectType(EltTy))
+        return false;
+      if (PointerType *SrcPT = dyn_cast<PointerType>(BC.getSrcTy())) {
+        Type *SrcEltTy = SrcPT->getElementType();
+        if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(SrcEltTy)) ||
+            hlsl::dxilutil::IsHLSLObjectType(SrcEltTy))
+          return false;
+      }
+    }
+    // HLSL Change End.
     enqueueUsers(BC);
     return false;
   }
@@ -4310,7 +4373,12 @@ bool SROA::runOnAlloca(AllocaInst &AI) {
 
   // Skip alloca forms that this analysis can't handle.
   if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() ||
-      hlsl::dxilutil::IsHLSLObjectType(AI.getAllocatedType()) || // HLSL Change - not sroa resource type.
+      hlsl::dxilutil::IsHLSLObjectType(
+          AI.getAllocatedType()) || // HLSL Change - not sroa resource type.
+      // HLSL Change Begin - not sroa matrix type.
+      (SkipHLSLMat &&
+       hlsl::dxilutil::IsHLSLMatrixType(AI.getAllocatedType())) ||
+      // HLSL Change End.
       DL.getTypeAllocSize(AI.getAllocatedType()) == 0)
     return false;
 
@@ -4318,11 +4386,11 @@ bool SROA::runOnAlloca(AllocaInst &AI) {
 
   // First, split any FCA loads and stores touching this alloca to promote
   // better splitting and promotion opportunities.
-  AggLoadStoreRewriter AggRewriter(DL);
+  AggLoadStoreRewriter AggRewriter(DL, SkipHLSLMat);
   Changed |= AggRewriter.rewrite(AI);
 
   // Build the slices using a recursive instruction-visiting builder.
-  AllocaSlices AS(DL, AI);
+  AllocaSlices AS(DL, AI, SkipHLSLMat);
   DEBUG(AS.print(dbgs()));
   if (AS.isEscaped())
     return Changed;

+ 26 - 26
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1639,7 +1639,7 @@ bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
         Type *Ty = AI->getAllocatedType();
         // Skip empty struct parameters.
         if (StructType *ST = dyn_cast<StructType>(Ty)) {
-          if (!HLMatrixLower::IsMatrixType(Ty)) {
+          if (!dxilutil::IsHLSLMatrixType(Ty)) {
             DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
             if (SA && SA->IsEmptyStruct()) {
               for (User *U : AI->users()) {
@@ -1884,7 +1884,7 @@ void SROA_HLSL::isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset,
 
   for (;GEPIt != E; ++GEPIt) {
     Type *Ty = *GEPIt;
-    if (Ty->isStructTy() && !HLMatrixLower::IsMatrixType(Ty)) {
+    if (Ty->isStructTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
       // Don't go inside struct when mark hasArrayIndexing and hasVectorIndexing.
       // The following level won't affect scalar repl on the struct.
       break;
@@ -2250,7 +2250,7 @@ static void EltMemCpy(Type *Ty, Value *Dest, Value *Src,
 static bool IsMemCpyTy(Type *Ty, DxilTypeSystem &typeSys) {
   if (!Ty->isAggregateType())
     return false;
-  if (HLMatrixLower::IsMatrixType(Ty))
+  if (dxilutil::IsHLSLMatrixType(Ty))
     return false;
   if (dxilutil::IsHLSLObjectType(Ty))
     return false;
@@ -2282,7 +2282,7 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
              fieldAnnotation, bEltMemCpy);
 
     idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     // If no fieldAnnotation, use row major as default.
     // Only load then store immediately should be fine.
     bool bRowMajor = true;
@@ -2389,7 +2389,7 @@ static void SplitPtr(Value *Ptr, // The root value pointer
   }
   
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (!HLMatrixLower::IsMatrixType(Ty) && !dxilutil::IsHLSLObjectType(ST)) {
+    if (!dxilutil::IsHLSLMatrixType(Ty) && !dxilutil::IsHLSLObjectType(ST)) {
       const DxilStructAnnotation* SA = TypeSys.GetStructAnnotation(ST);
 
       for (uint32_t i = 0; i < ST->getNumElements(); i++) {
@@ -2425,7 +2425,7 @@ static void SplitPtr(Value *Ptr, // The root value pointer
       ElTy = ElAT->getElementType();
     }
 
-    if (ElTy->isStructTy() && !HLMatrixLower::IsMatrixType(ElTy)) {
+    if (ElTy->isStructTy() && !dxilutil::IsHLSLMatrixType(ElTy)) {
       DXASSERT(0, "Not support array of struct when split pointers.");
       return;
     }
@@ -2443,7 +2443,7 @@ static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsi
   // Size match, return current level.
   if (ptrSize == size) {
     // Not go deeper for matrix.
-    if (HLMatrixLower::IsMatrixType(Ty))
+    if (dxilutil::IsHLSLMatrixType(Ty))
       return level;
     // For struct, go deeper if size not change.
     // This will leave memcpy to deeper level when flatten.
@@ -2596,7 +2596,7 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   // Try to find fieldAnnotation from user of Dest/Src.
   if (!fieldAnnotation) {
     Type *EltTy = dxilutil::GetArrayEltTy(DestTy);
-    if (HLMatrixLower::IsMatrixType(EltTy)) {
+    if (dxilutil::IsHLSLMatrixType(EltTy)) {
       fieldAnnotation = HLMatrixLower::FindAnnotationFromMatUser(Dest, typeSys);
     }
   }
@@ -2845,7 +2845,7 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
         Value *Ptr = NewElts[i];
         Type *Ty = Ptr->getType()->getPointerElementType();
         Value *Load = nullptr;
-        if (!HLMatrixLower::IsMatrixType(Ty))
+        if (!dxilutil::IsHLSLMatrixType(Ty))
           Load = Builder.CreateLoad(Ptr, "load");
         else {
           // Generate Matrix Load.
@@ -2930,7 +2930,7 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
       Module *M = SI->getModule();
       for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
         Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
-        if (!HLMatrixLower::IsMatrixType(Extract->getType())) {
+        if (!dxilutil::IsHLSLMatrixType(Extract->getType())) {
           Builder.CreateStore(Extract, NewElts[i]);
         } else {
           // Generate Matrix Store.
@@ -3393,7 +3393,7 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
   if (!Ty->isAggregateType())
     return false;
   // Skip matrix types.
-  if (HLMatrixLower::IsMatrixType(Ty))
+  if (dxilutil::IsHLSLMatrixType(Ty))
     return false;
 
   IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
@@ -3440,7 +3440,7 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
 
     if (ElTy->isStructTy() &&
         // Skip Matrix type.
-        !HLMatrixLower::IsMatrixType(ElTy)) {
+        !dxilutil::IsHLSLMatrixType(ElTy)) {
       if (!dxilutil::IsHLSLObjectType(ElTy)) {
         // for array of struct
         // split into arrays of struct elements
@@ -3569,7 +3569,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
   if (Ty->isSingleValueType() && !Ty->isVectorTy())
     return false;
   // Skip matrix types.
-  if (HLMatrixLower::IsMatrixType(Ty))
+  if (dxilutil::IsHLSLMatrixType(Ty))
     return false;
 
   Module *M = GV->getParent();
@@ -3638,7 +3638,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
 
     if (ElTy->isStructTy() &&
         // Skip Matrix type.
-        !HLMatrixLower::IsMatrixType(ElTy)) {
+        !dxilutil::IsHLSLMatrixType(ElTy)) {
       // for array of struct
       // split into arrays of struct elements
       StructType *ElST = cast<StructType>(ElTy);
@@ -4202,7 +4202,7 @@ bool SROA_Helper::IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys) {
     Ty = Ty->getArrayElementType();
 
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (!HLMatrixLower::IsMatrixType(Ty)) {
+    if (!dxilutil::IsHLSLMatrixType(Ty)) {
       DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
       if (SA && SA->IsEmptyStruct())
         return true;
@@ -4360,7 +4360,7 @@ public:
           continue;
 
         // Check matrix store.
-        if (HLMatrixLower::IsMatrixType(
+        if (dxilutil::IsHLSLMatrixType(
                 GV->getType()->getPointerElementType())) {
           if (CallInst *CI = dyn_cast<CallInst>(user)) {
             if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
@@ -4667,7 +4667,7 @@ static DxilFieldAnnotation &GetEltAnnotation(Type *Ty, unsigned idx, DxilFieldAn
   while (Ty->isArrayTy())
     Ty = Ty->getArrayElementType();
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (HLMatrixLower::IsMatrixType(Ty))
+    if (dxilutil::IsHLSLMatrixType(Ty))
       return annotation;
     DxilStructAnnotation *SA = dxilTypeSys.GetStructAnnotation(ST);
     if (SA) {
@@ -4735,13 +4735,13 @@ static unsigned AllocateSemanticIndex(
                                             FlatAnnotationList);
     }
     return updatedArgIdx;
-  } else if (Ty->isStructTy() && !HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (Ty->isStructTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
     unsigned fieldsCount = Ty->getStructNumElements();
     for (unsigned i = 0; i < fieldsCount; i++) {
       Type *EltTy = Ty->getStructElementType(i);
       argIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
                                      FlatAnnotationList);
-      if (!(EltTy->isStructTy() && !HLMatrixLower::IsMatrixType(EltTy))) {
+      if (!(EltTy->isStructTy() && !dxilutil::IsHLSLMatrixType(EltTy))) {
         // Update argIdx only when it is a leaf node.
         argIdx++;
       }
@@ -4981,7 +4981,7 @@ CastCopyArrayMultiDimTo1Dim(Value *FromArray, Value *ToArray, Type *CurFromTy,
       Value *Elt = Builder.CreateExtractElement(V, i);
       Builder.CreateStore(Elt, ToPtr);
     }
-  } else if (HLMatrixLower::IsMatrixType(CurFromTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(CurFromTy)) {
     // Copy matrix to array.
     unsigned col, row;
     HLMatrixLower::GetMatrixInfo(CurFromTy, col, row);
@@ -5028,7 +5028,7 @@ CastCopyArray1DimToMultiDim(Value *FromArray, Value *ToArray, Type *CurToTy,
       V = Builder.CreateInsertElement(V, Elt, i);
     }
     Builder.CreateStore(V, ToPtr);
-  } else if (HLMatrixLower::IsMatrixType(CurToTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(CurToTy)) {
     // Copy array to matrix.
     unsigned col, row;
     HLMatrixLower::GetMatrixInfo(CurToTy, col, row);
@@ -5072,7 +5072,7 @@ static void CastCopyOldPtrToNewPtr(Value *OldPtr, Value *NewPtr, HLModule &HLM,
       Value *Elt = Builder.CreateExtractElement(V, i);
       Builder.CreateStore(Elt, EltPtr);
     }
-  } else if (HLMatrixLower::IsMatrixType(OldTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
     CopyMatPtrToArrayPtr(OldPtr, NewPtr, /*arrayBaseIdx*/ 0, HLM, Builder,
                          bRowMajor);
   } else if (OldTy->isArrayTy()) {
@@ -5102,7 +5102,7 @@ static void CastCopyNewPtrToOldPtr(Value *NewPtr, Value *OldPtr, HLModule &HLM,
       V = Builder.CreateInsertElement(V, Elt, i);
     }
     Builder.CreateStore(V, OldPtr);
-  } else if (HLMatrixLower::IsMatrixType(OldTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
     CopyArrayPtrToMatPtr(NewPtr, /*arrayBaseIdx*/ 0, OldPtr, HLM, Builder,
                          bRowMajor);
   } else if (OldTy->isArrayTy()) {
@@ -5200,7 +5200,7 @@ void SROA_Parameter_HLSL::replaceCastParameter(
     // Must be in param.
     // Store NewParam to OldParam at entry.
     Builder.CreateStore(NewParam, OldParam);
-  } else if (HLMatrixLower::IsMatrixType(OldTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
     bool bRowMajor = castRowMajorParamMap.count(NewParam);
     Value *Mat = LoadArrayPtrToMat(NewParam, /*arrayBaseIdx*/ 0, OldTy,
                                    *m_pHLModule, Builder, bRowMajor);
@@ -5923,7 +5923,7 @@ static void LegalizeDxilInputOutputs(Function *F,
 
     // Skip arg which is not a pointer.
     if (!Ty->isPointerTy()) {
-      if (HLMatrixLower::IsMatrixType(Ty)) {
+      if (dxilutil::IsHLSLMatrixType(Ty)) {
         // Replace matrix arg with cast to vec. It will be lowered in
         // DxilGenerationPass.
         isColMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
@@ -5976,7 +5976,7 @@ static void LegalizeDxilInputOutputs(Function *F,
       bStoreInputToTemp = true;
     }
 
-    if (HLMatrixLower::IsMatrixType(Ty)) {
+    if (dxilutil::IsHLSLMatrixType(Ty)) {
       if (qual == DxilParamInputQual::In)
         bStoreInputToTemp = bLoad;
       else if (qual == DxilParamInputQual::Out)

+ 12 - 12
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -4289,7 +4289,7 @@ static void SimpleTransformForHLDXIR(Instruction *I,
   } break;
   case Instruction::Load: {
     LoadInst *ldInst = cast<LoadInst>(I);
-    DXASSERT(!HLMatrixLower::IsMatrixType(ldInst->getType()),
+    DXASSERT(!dxilutil::IsHLSLMatrixType(ldInst->getType()),
                       "matrix load should use HL LdStMatrix");
     Value *Ptr = ldInst->getPointerOperand();
     if (ConstantExpr *CE = dyn_cast_or_null<ConstantExpr>(Ptr)) {
@@ -4301,7 +4301,7 @@ static void SimpleTransformForHLDXIR(Instruction *I,
   case Instruction::Store: {
     StoreInst *stInst = cast<StoreInst>(I);
     Value *V = stInst->getValueOperand();
-    DXASSERT_LOCALVAR(V, !HLMatrixLower::IsMatrixType(V->getType()),
+    DXASSERT_LOCALVAR(V, !dxilutil::IsHLSLMatrixType(V->getType()),
                       "matrix store should use HL LdStMatrix");
     Value *Ptr = stInst->getPointerOperand();
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) {
@@ -5298,7 +5298,7 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
         valEltTy->isSingleValueType()) {
       Value *ldVal = Builder.CreateLoad(val);
       FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
-    } else if (HLMatrixLower::IsMatrixType(valEltTy)) {
+    } else if (dxilutil::IsHLSLMatrixType(valEltTy)) {
       Value *ldVal = EmitHLSLMatrixLoad(Builder, val, Ty);
       FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
     } else {
@@ -5350,7 +5350,7 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
       }
     }
   } else {
-    if (HLMatrixLower::IsMatrixType(valTy)) {
+    if (dxilutil::IsHLSLMatrixType(valTy)) {
       unsigned col, row;
       llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(valTy, col, row);
       // All matrix Value should be row major.
@@ -5492,7 +5492,7 @@ static void StoreInitListToDestPtr(Value *DestPtr,
     Result = CGF.EmitToMemory(Result, Type);
     Builder.CreateStore(Result, DestPtr);
     idx += Ty->getVectorNumElements();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     bool isRowMajor =
         IsRowMajorMatrix(Type, bDefaultRowMajor);
 
@@ -5783,7 +5783,7 @@ static void FlatConstToList(Constant *C, SmallVector<Constant *, 4> &EltValList,
       FlatConstToList(C->getAggregateElement(i), EltValList, Type, Types,
                       bDefaultRowMajor);
     }
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     bool isRowMajor = IsRowMajorMatrix(Type, bDefaultRowMajor);
     // matrix type is struct { vector<Ty, row> [col] };
     // Strip the struct level here.
@@ -6005,7 +6005,7 @@ static Constant *BuildConstInitializer(QualType Type, unsigned &offset,
   } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
     return BuildConstArray(AT, offset, EltValList, Type, Types,
                            bDefaultRowMajor);
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     return BuildConstMatrix(Ty, offset, EltValList, Type, Types,
                             bDefaultRowMajor);
   } else if (StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
@@ -6509,7 +6509,7 @@ void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
                                  GepList, EltTyList);
 
     idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     // Use matLd/St for matrix.
     unsigned col, row;
     llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
@@ -6671,7 +6671,7 @@ void CGMSHLSLRuntime::EmitHLSLAggregateCopy(
                           PT->getElementType());
 
     idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     // Use matLd/St for matrix.
     Value *srcGEP = CGF.Builder.CreateInBoundsGEP(SrcPtr, idxList);
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
@@ -6871,7 +6871,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
                                       SrcType, PT->getElementType());
 
     idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     // Use matLd/St for matrix.
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
     unsigned row, col;
@@ -7133,7 +7133,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
                 outVal->getType(), ToTy));
 
         Value *castVal = CGF.Builder.CreateCast(castOp, outVal, ToTy);
-        if (!HLMatrixLower::IsMatrixType(ToTy))
+        if (!dxilutil::IsHLSLMatrixType(ToTy))
           CGF.Builder.CreateStore(castVal, tmpArgAddr);
         else
           EmitHLSLMatrixStore(CGF, castVal, tmpArgAddr, ParamTy);
@@ -7201,7 +7201,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
 
           castVal = CGF.Builder.CreateCast(castOp, outVal, ToTy);
         }
-        if (!HLMatrixLower::IsMatrixType(ToTy))
+        if (!dxilutil::IsHLSLMatrixType(ToTy))
           CGF.EmitStoreThroughLValue(RValue::get(castVal), argLV);
         else {
           Value *destPtr = argLV.getAddress();

+ 4 - 4
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -784,9 +784,9 @@ void PrintFieldLayout(llvm::Type *Ty, DxilFieldAnnotation &annotation,
     llvm::Type *EltTy = Ty;
     unsigned arraySize = 0;
     unsigned arrayLevel = 0;
-    if (!HLMatrixLower::IsMatrixType(EltTy) && EltTy->isArrayTy()) {
+    if (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isArrayTy()) {
       arraySize = 1;
-      while (!HLMatrixLower::IsMatrixType(EltTy) && EltTy->isArrayTy()) {
+      while (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isArrayTy()) {
         arraySize *= EltTy->getArrayNumElements();
         EltTy = EltTy->getArrayElementType();
         arrayLevel++;
@@ -817,7 +817,7 @@ void PrintFieldLayout(llvm::Type *Ty, DxilFieldAnnotation &annotation,
     }
 
     std::string StreamStr;
-    if (!HLMatrixLower::IsMatrixType(EltTy) && EltTy->isStructTy()) {
+    if (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isStructTy()) {
       std::string NameTypeStr = annotation.GetFieldName();
       raw_string_ostream Stream(NameTypeStr);
       if (arraySize)
@@ -901,7 +901,7 @@ void PrintStructBufferDefinition(DxilResource *buf,
   OS << comment << "\n";
   llvm::Type *RetTy = buf->GetRetType();
   // Skip none struct type.
-  if (!RetTy->isStructTy() || HLMatrixLower::IsMatrixType(RetTy)) {
+  if (!RetTy->isStructTy() || dxilutil::IsHLSLMatrixType(RetTy)) {
     llvm::Type *Ty = buf->GetGlobalSymbol()->getType()->getPointerElementType();
     // For resource array, use element type.
     if (Ty->isArrayTy())

+ 1 - 0
utils/hct/hctdb.py

@@ -1483,6 +1483,7 @@ class db_dxil(object):
         # UseNewSROA is used by PassManagerBuilder::populateFunctionPassManager, not a pass per se.
         add_pass("sroa", "SROA", "Scalar Replacement Of Aggregates", [
             {'n':'RequiresDomTree', 't':'bool', 'c':1},
+            {'n':'SkipHLSLMat', 't':'bool', 'c':1},
             {'n':'force-ssa-updater', 'i':'ForceSSAUpdater', 't':'bool', 'd':'Force the pass to not use DomTree and mem2reg, insteadforming SSA values through the SSAUpdater infrastructure.'},
             {'n':'sroa-random-shuffle-slices', 'i':'SROARandomShuffleSlices', 't':'bool', 'd':'Enable randomly shuffling the slices to help uncover instability in their order.'},
             {'n':'sroa-strict-inbounds', 'i':'SROAStrictInbounds', 't':'bool', 'd':'Experiment with completely strict handling of inbounds GEPs.'}])