فهرست منبع

Enable sroa for hlsl to remove static indexing array. (#1893)

Xiang Li 6 سال پیش
والد
کامیت
c2e744dcf4

+ 1 - 0
include/dxc/DXIL/DxilUtil.h

@@ -105,6 +105,7 @@ namespace dxilutil {
   // Returns true if type contains HLSL Object type (resource)
   // Returns true if type contains HLSL Object type (resource)
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
+  bool IsHLSLMatrixType(llvm::Type *Ty);
   bool IsSplat(llvm::ConstantDataVector *cdv);
   bool IsSplat(llvm::ConstantDataVector *cdv);
 }
 }
 
 

+ 0 - 1
include/dxc/HLSL/HLMatrixLowerHelper.h

@@ -27,7 +27,6 @@ class DxilTypeSystem;
 
 
 namespace HLMatrixLower {
 namespace HLMatrixLower {
 // TODO: use type annotation.
 // TODO: use type annotation.
-bool IsMatrixType(llvm::Type *Ty);
 DxilFieldAnnotation *FindAnnotationFromMatUser(llvm::Value *Mat,
 DxilFieldAnnotation *FindAnnotationFromMatUser(llvm::Value *Mat,
                                                DxilTypeSystem &typeSys);
                                                DxilTypeSystem &typeSys);
 // Translate matrix type to vector type.
 // Translate matrix type to vector type.

+ 2 - 1
include/llvm/Transforms/Scalar.h

@@ -94,7 +94,8 @@ FunctionPass *createBitTrackingDCEPass();
 //
 //
 // SROA - Replace aggregates or pieces of aggregates with scalar SSA values.
 // SROA - Replace aggregates or pieces of aggregates with scalar SSA values.
 //
 //
-FunctionPass *createSROAPass(bool RequiresDomTree = true);
+FunctionPass *createSROAPass(bool RequiresDomTree = true,
+                             bool SkipHLSLMat = true);
 
 
 //===----------------------------------------------------------------------===//
 //===----------------------------------------------------------------------===//
 //
 //

+ 14 - 0
lib/DXIL/DxilUtil.cpp

@@ -463,6 +463,20 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
   return false;
   return false;
 }
 }
 
 
+bool IsHLSLMatrixType(Type *Ty) {
+  if (StructType *ST = dyn_cast<StructType>(Ty)) {
+    Type *EltTy = ST->getElementType(0);
+    if (!ST->getName().startswith("class.matrix"))
+      return false;
+
+    bool isVecArray =
+        EltTy->isArrayTy() && EltTy->getArrayElementType()->isVectorTy();
+
+    return isVecArray && EltTy->getArrayNumElements() <= 4;
+  }
+  return false;
+}
+
 bool ContainsHLSLObjectType(llvm::Type *Ty) {
 bool ContainsHLSLObjectType(llvm::Type *Ty) {
   // Unwrap pointer/array
   // Unwrap pointer/array
   while (llvm::isa<llvm::PointerType>(Ty))
   while (llvm::isa<llvm::PointerType>(Ty))

+ 3 - 2
lib/HLSL/DxcOptimizer.cpp

@@ -204,7 +204,7 @@ static ArrayRef<LPCSTR> GetPassArgNames(LPCSTR passName) {
   static const LPCSTR LowerExpectIntrinsicArgs[] = { "likely-branch-weight", "unlikely-branch-weight" };
   static const LPCSTR LowerExpectIntrinsicArgs[] = { "likely-branch-weight", "unlikely-branch-weight" };
   static const LPCSTR MergeFunctionsArgs[] = { "mergefunc-sanity" };
   static const LPCSTR MergeFunctionsArgs[] = { "mergefunc-sanity" };
   static const LPCSTR RewriteSymbolsArgs[] = { "DL", "rewrite-map-file" };
   static const LPCSTR RewriteSymbolsArgs[] = { "DL", "rewrite-map-file" };
-  static const LPCSTR SROAArgs[] = { "RequiresDomTree", "force-ssa-updater", "sroa-random-shuffle-slices", "sroa-strict-inbounds" };
+  static const LPCSTR SROAArgs[] = { "RequiresDomTree", "SkipHLSLMat", "force-ssa-updater", "sroa-random-shuffle-slices", "sroa-strict-inbounds" };
   static const LPCSTR SROA_DTArgs[] = { "Threshold", "StructMemberThreshold", "ArrayElementThreshold", "ScalarLoadThreshold" };
   static const LPCSTR SROA_DTArgs[] = { "Threshold", "StructMemberThreshold", "ArrayElementThreshold", "ScalarLoadThreshold" };
   static const LPCSTR SROA_SSAUpArgs[] = { "Threshold", "StructMemberThreshold", "ArrayElementThreshold", "ScalarLoadThreshold" };
   static const LPCSTR SROA_SSAUpArgs[] = { "Threshold", "StructMemberThreshold", "ArrayElementThreshold", "ScalarLoadThreshold" };
   static const LPCSTR SampleProfileLoaderArgs[] = { "sample-profile-file", "sample-profile-max-propagate-iterations" };
   static const LPCSTR SampleProfileLoaderArgs[] = { "sample-profile-file", "sample-profile-max-propagate-iterations" };
@@ -277,7 +277,7 @@ static ArrayRef<LPCSTR> GetPassArgDescriptions(LPCSTR passName) {
   static const LPCSTR LowerExpectIntrinsicArgs[] = { "Weight of the branch likely to be taken (default = 64)", "Weight of the branch unlikely to be taken (default = 4)" };
   static const LPCSTR LowerExpectIntrinsicArgs[] = { "Weight of the branch likely to be taken (default = 64)", "Weight of the branch unlikely to be taken (default = 4)" };
   static const LPCSTR MergeFunctionsArgs[] = { "How many functions in module could be used for MergeFunctions pass sanity check. '0' disables this check. Works only with '-debug' key." };
   static const LPCSTR MergeFunctionsArgs[] = { "How many functions in module could be used for MergeFunctions pass sanity check. '0' disables this check. Works only with '-debug' key." };
   static const LPCSTR RewriteSymbolsArgs[] = { "None", "None" };
   static const LPCSTR RewriteSymbolsArgs[] = { "None", "None" };
-  static const LPCSTR SROAArgs[] = { "None", "Force the pass to not use DomTree and mem2reg, insteadforming SSA values through the SSAUpdater infrastructure.", "Enable randomly shuffling the slices to help uncover instability in their order.", "Experiment with completely strict handling of inbounds GEPs." };
+  static const LPCSTR SROAArgs[] = { "None", "None", "Force the pass to not use DomTree and mem2reg, insteadforming SSA values through the SSAUpdater infrastructure.", "Enable randomly shuffling the slices to help uncover instability in their order.", "Experiment with completely strict handling of inbounds GEPs." };
   static const LPCSTR SROA_DTArgs[] = { "None", "None", "None", "None" };
   static const LPCSTR SROA_DTArgs[] = { "None", "None", "None", "None" };
   static const LPCSTR SROA_SSAUpArgs[] = { "None", "None", "None", "None" };
   static const LPCSTR SROA_SSAUpArgs[] = { "None", "None", "None", "None" };
   static const LPCSTR SampleProfileLoaderArgs[] = { "None", "None" };
   static const LPCSTR SampleProfileLoaderArgs[] = { "None", "None" };
@@ -342,6 +342,7 @@ static bool IsPassOptionName(StringRef S) {
     ||  S.equals("RequiresDomTree")
     ||  S.equals("RequiresDomTree")
     ||  S.equals("Runtime")
     ||  S.equals("Runtime")
     ||  S.equals("ScalarLoadThreshold")
     ||  S.equals("ScalarLoadThreshold")
+    ||  S.equals("SkipHLSLMat")
     ||  S.equals("StructMemberThreshold")
     ||  S.equals("StructMemberThreshold")
     ||  S.equals("TIRA")
     ||  S.equals("TIRA")
     ||  S.equals("TLIImpl")
     ||  S.equals("TLIImpl")

+ 1 - 1
lib/HLSL/DxilCondenseResources.cpp

@@ -1534,7 +1534,7 @@ Type *UpdateFieldTypeForLegacyLayout(Type *Ty, bool IsCBuf,
       return Ty;
       return Ty;
     else
     else
       return ArrayType::get(UpdatedTy, Ty->getArrayNumElements());
       return ArrayType::get(UpdatedTy, Ty->getArrayNumElements());
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     DXASSERT(annotation.HasMatrixAnnotation(), "must a matrix");
     DXASSERT(annotation.HasMatrixAnnotation(), "must a matrix");
     unsigned rows, cols;
     unsigned rows, cols;
     Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, cols, rows);
     Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, cols, rows);

+ 1 - 1
lib/HLSL/DxilContainerReflection.cpp

@@ -765,7 +765,7 @@ HRESULT CShaderReflectionType::Initialize(
     llvm::Type* elementType = type->getArrayElementType();
     llvm::Type* elementType = type->getArrayElementType();
 
 
     // Note: At this point an HLSL matrix type may appear as an ordinary
     // Note: At this point an HLSL matrix type may appear as an ordinary
-    // array (not wrapped in a `struct`), so `HLMatrixLower::IsMatrixType()`
+    // array (not wrapped in a `struct`), so `dxilutil::IsHLSLMatrixType()`
     // is not sufficient. Instead we need to check the field annotation.
     // is not sufficient. Instead we need to check the field annotation.
     //
     //
     // We might have an array of matrices, though, so we only exit if
     // We might have an array of matrices, though, so we only exit if

+ 1 - 1
lib/HLSL/DxilLinker.cpp

@@ -1027,7 +1027,7 @@ void DxilLinkJob::RunPreparePass(Module &M) {
   PM.add(createDxilDeadFunctionEliminationPass());
   PM.add(createDxilDeadFunctionEliminationPass());
 
 
   // SROA
   // SROA
-  PM.add(createSROAPass(/*RequiresDomTree*/false));
+  PM.add(createSROAPass(/*RequiresDomTree*/false, /*SkipHLSLMat*/false));
 
 
   // Remove MultiDimArray from function call arg.
   // Remove MultiDimArray from function call arg.
   PM.add(createMultiDimArrayToOneDimArrayPass());
   PM.add(createMultiDimArrayToOneDimArrayPass());

+ 28 - 39
lib/HLSL/HLMatrixLowerPass.cpp

@@ -36,20 +36,6 @@ using namespace hlsl::HLMatrixLower;
 namespace hlsl {
 namespace hlsl {
 namespace HLMatrixLower {
 namespace HLMatrixLower {
 
 
-bool IsMatrixType(Type *Ty) {
-  if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    Type *EltTy = ST->getElementType(0);
-    if (!ST->getName().startswith("class.matrix"))
-      return false;
-
-    bool isVecArray = EltTy->isArrayTy() &&
-           EltTy->getArrayElementType()->isVectorTy();
-
-    return isVecArray && EltTy->getArrayNumElements() <= 4;
-  }
-  return false;
-}
-
 // If user is function call, return param annotation to get matrix major.
 // If user is function call, return param annotation to get matrix major.
 DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
 DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
                                                DxilTypeSystem &typeSys) {
                                                DxilTypeSystem &typeSys) {
@@ -81,7 +67,7 @@ Type *LowerMatrixType(Type *Ty, bool forMem) {
       params.emplace_back(LowerMatrixType(param));
       params.emplace_back(LowerMatrixType(param));
     }
     }
     return FunctionType::get(RetTy, params, false);
     return FunctionType::get(RetTy, params, false);
-  } else if (IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     unsigned row, col;
     unsigned row, col;
     Type *EltTy = GetMatrixInfo(Ty, col, row);
     Type *EltTy = GetMatrixInfo(Ty, col, row);
     if (forMem && EltTy->isIntegerTy(1))
     if (forMem && EltTy->isIntegerTy(1))
@@ -94,7 +80,7 @@ Type *LowerMatrixType(Type *Ty, bool forMem) {
 
 
 // Translate matrix type to array type.
 // Translate matrix type to array type.
 Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
 Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
-  if (IsMatrixType(Ty)) {
+  if (dxilutil::IsHLSLMatrixType(Ty)) {
     unsigned row, col;
     unsigned row, col;
     Type *EltTy = GetMatrixInfo(Ty, col, row);
     Type *EltTy = GetMatrixInfo(Ty, col, row);
     return ArrayType::get(EltTy, row * col);
     return ArrayType::get(EltTy, row * col);
@@ -105,7 +91,7 @@ Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
 
 
 
 
 Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
 Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
-  DXASSERT(IsMatrixType(Ty), "not matrix type");
+  DXASSERT(dxilutil::IsHLSLMatrixType(Ty), "not matrix type");
   StructType *ST = cast<StructType>(Ty);
   StructType *ST = cast<StructType>(Ty);
   Type *EltTy = ST->getElementType(0);
   Type *EltTy = ST->getElementType(0);
   Type *RowTy = EltTy->getArrayElementType();
   Type *RowTy = EltTy->getArrayElementType();
@@ -122,7 +108,7 @@ bool IsMatrixArrayPointer(llvm::Type *Ty) {
     return false;
     return false;
   while (Ty->isArrayTy())
   while (Ty->isArrayTy())
     Ty = Ty->getArrayElementType();
     Ty = Ty->getArrayElementType();
-  return IsMatrixType(Ty);
+  return dxilutil::IsHLSLMatrixType(Ty);
 }
 }
 Type *LowerMatrixArrayPointer(Type *Ty, bool forMem) {
 Type *LowerMatrixArrayPointer(Type *Ty, bool forMem) {
   unsigned addrSpace = Ty->getPointerAddressSpace();
   unsigned addrSpace = Ty->getPointerAddressSpace();
@@ -441,8 +427,8 @@ Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
   Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
   Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
   HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
   HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
 
 
-  bool ToMat = IsMatrixType(CI->getType());
-  bool FromMat = IsMatrixType(op->getType());
+  bool ToMat = dxilutil::IsHLSLMatrixType(CI->getType());
+  bool FromMat = dxilutil::IsHLSLMatrixType(op->getType());
   if (ToMat && !FromMat) {
   if (ToMat && !FromMat) {
     // Translate OtherToMat here.
     // Translate OtherToMat here.
     // Rest will translated when replace.
     // Rest will translated when replace.
@@ -514,11 +500,11 @@ Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
 // UDT alloca must be there for library function args
 // UDT alloca must be there for library function args
 static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
 static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
-    if (IsMatrixType(GEP->getResultElementType())) {
+    if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
       Value *ptr = GEP->getPointerOperand();
       Value *ptr = GEP->getPointerOperand();
       if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
       if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
         Type *ATy = AI->getAllocatedType();
         Type *ATy = AI->getAllocatedType();
-        if (ATy->isStructTy() && !IsMatrixType(ATy)) {
+        if (ATy->isStructTy() && !dxilutil::IsHLSLMatrixType(ATy)) {
           return GEP;
           return GEP;
         }
         }
       }
       }
@@ -531,7 +517,7 @@ static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
 // none-graphics functions.
 // none-graphics functions.
 static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
 static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
-    if (IsMatrixType(GEP->getResultElementType())) {
+    if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
       Value *ptr = GEP->getPointerOperand();
       Value *ptr = GEP->getPointerOperand();
       if (Argument *Arg = dyn_cast<Argument>(ptr)) {
       if (Argument *Arg = dyn_cast<Argument>(ptr)) {
         if (!HM.IsGraphicsShader(Arg->getParent()))
         if (!HM.IsGraphicsShader(Arg->getParent()))
@@ -654,7 +640,7 @@ Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
   SmallVector<Value *, 4> argList;
   SmallVector<Value *, 4> argList;
   for (Value *arg : CI->arg_operands()) {
   for (Value *arg : CI->arg_operands()) {
     Type *Ty = arg->getType();
     Type *Ty = arg->getType();
-    if (IsMatrixType(Ty)) {
+    if (dxilutil::IsHLSLMatrixType(Ty)) {
       argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
       argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
     } else
     } else
       argList.emplace_back(arg);
       argList.emplace_back(arg);
@@ -779,12 +765,14 @@ Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
     break;
     break;
   case HLBinaryOpcode::Shl: {
   case HLBinaryOpcode::Shl: {
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
+    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
+                      "must be matrix type here");
     Result = BinaryOperator::CreateShl(tmp, tmp);
     Result = BinaryOperator::CreateShl(tmp, tmp);
   } break;
   } break;
   case HLBinaryOpcode::Shr: {
   case HLBinaryOpcode::Shr: {
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
+    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
+                      "must be matrix type here");
     Result = BinaryOperator::CreateAShr(tmp, tmp);
     Result = BinaryOperator::CreateAShr(tmp, tmp);
   } break;
   } break;
   case HLBinaryOpcode::LT:
   case HLBinaryOpcode::LT:
@@ -831,7 +819,8 @@ Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
     break;
     break;
   case HLBinaryOpcode::UShr: {
   case HLBinaryOpcode::UShr: {
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
     Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
-    DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
+    DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
+                      "must be matrix type here");
     Result = BinaryOperator::CreateLShr(tmp, tmp);
     Result = BinaryOperator::CreateLShr(tmp, tmp);
   } break;
   } break;
   case HLBinaryOpcode::ULT:
   case HLBinaryOpcode::ULT:
@@ -1229,8 +1218,8 @@ void HLMatrixLowerPass::TranslateMul(Value *matVal, Value *vecVal,
   Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
   Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
   Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
   Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
 
 
-  bool LMat = IsMatrixType(LVal->getType());
-  bool RMat = IsMatrixType(RVal->getType());
+  bool LMat = dxilutil::IsHLSLMatrixType(LVal->getType());
+  bool RMat = dxilutil::IsHLSLMatrixType(RVal->getType());
   if (LMat && RMat) {
   if (LMat && RMat) {
     TranslateMatMatMul(matVal, vecVal, mulInst, isSigned);
     TranslateMatMatMul(matVal, vecVal, mulInst, isSigned);
   } else if (LMat) {
   } else if (LMat) {
@@ -1501,8 +1490,8 @@ void HLMatrixLowerPass::TranslateMatCast(Value *matVal,
                           opcode == HLCastOpcode::RowMatrixToColMatrix,
                           opcode == HLCastOpcode::RowMatrixToColMatrix,
                           /*bTranspose*/false);
                           /*bTranspose*/false);
   } else {
   } else {
-    bool ToMat = IsMatrixType(castInst->getType());
-    bool FromMat = IsMatrixType(matVal->getType());
+    bool ToMat = dxilutil::IsHLSLMatrixType(castInst->getType());
+    bool FromMat = dxilutil::IsHLSLMatrixType(matVal->getType());
     if (ToMat && FromMat) {
     if (ToMat && FromMat) {
       TranslateMatMatCast(matVal, vecVal, castInst);
       TranslateMatMatCast(matVal, vecVal, castInst);
     } else if (FromMat)
     } else if (FromMat)
@@ -1946,7 +1935,7 @@ static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
       }
       }
     }
     }
     Type *valEltTy = val->getType()->getPointerElementType();
     Type *valEltTy = val->getType()->getPointerElementType();
-    if (valEltTy->isVectorTy() || HLMatrixLower::IsMatrixType(valEltTy) ||
+    if (valEltTy->isVectorTy() || dxilutil::IsHLSLMatrixType(valEltTy) ||
         valEltTy->isSingleValueType()) {
         valEltTy->isSingleValueType()) {
       Value *ldVal = Builder.CreateLoad(val);
       Value *ldVal = Builder.CreateLoad(val);
       IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
       IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
@@ -1969,7 +1958,7 @@ static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
         }
         }
       }
       }
     }
     }
-  } else if (HLMatrixLower::IsMatrixType(valTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(valTy)) {
     unsigned col, row;
     unsigned col, row;
     HLMatrixLower::GetMatrixInfo(valTy, col, row);
     HLMatrixLower::GetMatrixInfo(valTy, col, row);
     unsigned matSize = col * row;
     unsigned matSize = col * row;
@@ -2489,7 +2478,7 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   }
   }
 
 
   Type *Ty = GV->getType()->getPointerElementType();
   Type *Ty = GV->getType()->getPointerElementType();
-  if (!HLMatrixLower::IsMatrixType(Ty))
+  if (!dxilutil::IsHLSLMatrixType(Ty))
     return;
     return;
 
 
   bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
   bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
@@ -2585,11 +2574,11 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
     BasicBlock *BB = BBI;
     BasicBlock *BB = BBI;
     for (auto II = BB->begin(); II != BB->end(); ) {
     for (auto II = BB->begin(); II != BB->end(); ) {
       Instruction &I = *(II++);
       Instruction &I = *(II++);
-      if (IsMatrixType(I.getType())) {
+      if (dxilutil::IsHLSLMatrixType(I.getType())) {
         lowerToVec(&I);
         lowerToVec(&I);
       } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
       } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
         Type *Ty = AI->getAllocatedType();
         Type *Ty = AI->getAllocatedType();
-        if (HLMatrixLower::IsMatrixType(Ty)) {
+        if (dxilutil::IsHLSLMatrixType(Ty)) {
           lowerToVec(&I);
           lowerToVec(&I);
         } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
         } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
           lowerToVec(&I);
           lowerToVec(&I);
@@ -2665,7 +2654,7 @@ Type *TryLowerMatTy(Type *Ty) {
   if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
   if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
     VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
     VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
   } else if (isa<PointerType>(Ty) &&
   } else if (isa<PointerType>(Ty) &&
-             HLMatrixLower::IsMatrixType(Ty->getPointerElementType())) {
+             dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
     VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
     VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
         Ty->getPointerElementType());
         Ty->getPointerElementType());
     VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
     VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
@@ -2727,7 +2716,7 @@ bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
     User *U = *(it++);
     User *U = *(it++);
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
       Type *EltTy = GEP->getType()->getPointerElementType();
       Type *EltTy = GEP->getType()->getPointerElementType();
-      if (HLMatrixLower::IsMatrixType(EltTy)) {
+      if (dxilutil::IsHLSLMatrixType(EltTy)) {
         if (hasCallUser(GEP))
         if (hasCallUser(GEP))
           return true;
           return true;
       } else {
       } else {
@@ -2782,7 +2771,7 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
     User *U = *(it++);
     User *U = *(it++);
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
       Type *EltTy = GEP->getType()->getPointerElementType();
       Type *EltTy = GEP->getType()->getPointerElementType();
-      if (HLMatrixLower::IsMatrixType(EltTy)) {
+      if (dxilutil::IsHLSLMatrixType(EltTy)) {
         // Change gep matrixArray, 0, index
         // Change gep matrixArray, 0, index
         // into
         // into
         //   gep oneDimArray, 0, index * matSize
         //   gep oneDimArray, 0, index * matSize

+ 3 - 3
lib/HLSL/HLSignatureLower.cpp

@@ -624,14 +624,14 @@ void replaceDirectInputParameter(Value *param, Function *loadInput,
       newVec = Builder.CreateInsertElement(newVec, input, col);
       newVec = Builder.CreateInsertElement(newVec, input, col);
     }
     }
     param->replaceAllUsesWith(newVec);
     param->replaceAllUsesWith(newVec);
-  } else if (!Ty->isArrayTy() && !HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (!Ty->isArrayTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
     DXASSERT(cols == 1, "only support scalar here");
     DXASSERT(cols == 1, "only support scalar here");
     Value *colIdx = hlslOP->GetU8Const(0);
     Value *colIdx = hlslOP->GetU8Const(0);
     args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
     args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
     Value *input =
     Value *input =
         GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
         GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
     param->replaceAllUsesWith(input);
     param->replaceAllUsesWith(input);
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     Value *colIdx = hlslOP->GetU8Const(0);
     Value *colIdx = hlslOP->GetU8Const(0);
     (void)colIdx;
     (void)colIdx;
     DXASSERT(param->hasOneUse(),
     DXASSERT(param->hasOneUse(),
@@ -784,7 +784,7 @@ void collectInputOutputAccessInfo(
               vectorIdx = GEPIt.getOperand();
               vectorIdx = GEPIt.getOperand();
             }
             }
           }
           }
-          if (HLMatrixLower::IsMatrixType(*GEPIt)) {
+          if (dxilutil::IsHLSLMatrixType(*GEPIt)) {
             unsigned row, col;
             unsigned row, col;
             HLMatrixLower::GetMatrixInfo(*GEPIt, col, row);
             HLMatrixLower::GetMatrixInfo(*GEPIt, col, row);
             Constant *arraySize = ConstantInt::get(idxTy, col);
             Constant *arraySize = ConstantInt::get(idxTy, col);

+ 1 - 4
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -373,14 +373,11 @@ void PassManagerBuilder::populateModulePassManager(
 
 
   // Start of function pass.
   // Start of function pass.
   // Break up aggregate allocas, using SSAUpdater.
   // Break up aggregate allocas, using SSAUpdater.
-  // HLSL Change - don't run SROA. 
-  // HLSL uses special SROA added in addHLSLPasses.
-  if (HLSLHighLevel) { // HLSL Change
   if (UseNewSROA)
   if (UseNewSROA)
     MPM.add(createSROAPass(/*RequiresDomTree*/ false));
     MPM.add(createSROAPass(/*RequiresDomTree*/ false));
   else
   else
     MPM.add(createScalarReplAggregatesPass(-1, false));
     MPM.add(createScalarReplAggregatesPass(-1, false));
-  }
+
   // HLSL Change. MPM.add(createEarlyCSEPass());              // Catch trivial redundancies
   // HLSL Change. MPM.add(createEarlyCSEPass());              // Catch trivial redundancies
   // HLSL Change. MPM.add(createJumpThreadingPass());         // Thread jumps.
   // HLSL Change. MPM.add(createJumpThreadingPass());         // Thread jumps.
   MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals
   MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals

+ 83 - 15
lib/Transforms/Scalar/SROA.cpp

@@ -222,7 +222,8 @@ namespace {
 class AllocaSlices {
 class AllocaSlices {
 public:
 public:
   /// \brief Construct the slices of a particular alloca.
   /// \brief Construct the slices of a particular alloca.
-  AllocaSlices(const DataLayout &DL, AllocaInst &AI);
+  AllocaSlices(const DataLayout &DL, AllocaInst &AI,
+               const bool SkipHLSLMat); // HLSL Change - not sroa matrix type.
 
 
   /// \brief Test whether a pointer to the allocation escapes our analysis.
   /// \brief Test whether a pointer to the allocation escapes our analysis.
   ///
   ///
@@ -633,6 +634,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
   friend class InstVisitor<SliceBuilder>;
   friend class InstVisitor<SliceBuilder>;
   typedef PtrUseVisitor<SliceBuilder> Base;
   typedef PtrUseVisitor<SliceBuilder> Base;
 
 
+  const bool SkipHLSLMat; // HLSL Change - not sroa matrix type.
   const uint64_t AllocSize;
   const uint64_t AllocSize;
   AllocaSlices &AS;
   AllocaSlices &AS;
 
 
@@ -643,8 +645,10 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
   SmallPtrSet<Instruction *, 4> VisitedDeadInsts;
   SmallPtrSet<Instruction *, 4> VisitedDeadInsts;
 
 
 public:
 public:
-  SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS)
+  SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS,
+               const bool SkipHLSLMat)
       : PtrUseVisitor<SliceBuilder>(DL),
       : PtrUseVisitor<SliceBuilder>(DL),
+        SkipHLSLMat(SkipHLSLMat), // HLSL Change - not sroa matrix type.
         AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {}
         AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {}
 
 
 private:
 private:
@@ -690,7 +694,24 @@ private:
   void visitBitCastInst(BitCastInst &BC) {
   void visitBitCastInst(BitCastInst &BC) {
     if (BC.use_empty())
     if (BC.use_empty())
       return markAsDead(BC);
       return markAsDead(BC);
-
+    // HLSL Change Begin - not sroa matrix type.
+    if (PointerType *PT = dyn_cast<PointerType>(BC.getType())) {
+      Type *EltTy = PT->getElementType();
+      if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(EltTy)) ||
+          hlsl::dxilutil::IsHLSLObjectType(EltTy)) {
+        AS.PointerEscapingInstr = &BC;
+        return;
+      }
+      if (PointerType *SrcPT = dyn_cast<PointerType>(BC.getSrcTy())) {
+        Type *SrcEltTy = SrcPT->getElementType();
+        if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(SrcEltTy)) ||
+            hlsl::dxilutil::IsHLSLObjectType(SrcEltTy)) {
+          AS.PointerEscapingInstr = &BC;
+          return;
+        }
+      }
+    }
+    // HLSL Change End.
     return Base::visitBitCastInst(BC);
     return Base::visitBitCastInst(BC);
   }
   }
 
 
@@ -751,9 +772,15 @@ private:
   }
   }
 
 
   void visitLoadInst(LoadInst &LI) {
   void visitLoadInst(LoadInst &LI) {
+    // HLSL Change Begin - not sroa matrix type.
+    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(LI.getType())) ||
+        hlsl::dxilutil::IsHLSLObjectType(LI.getType()))
+      return PI.setEscapedAndAborted(&LI);
+    // HLSL Change End.
     assert((!LI.isSimple() || LI.getType()->isSingleValueType()) &&
     assert((!LI.isSimple() || LI.getType()->isSingleValueType()) &&
            "All simple FCA loads should have been pre-split");
            "All simple FCA loads should have been pre-split");
 
 
+
     if (!IsOffsetKnown)
     if (!IsOffsetKnown)
       return PI.setAborted(&LI);
       return PI.setAborted(&LI);
 
 
@@ -766,6 +793,12 @@ private:
     Value *ValOp = SI.getValueOperand();
     Value *ValOp = SI.getValueOperand();
     if (ValOp == *U)
     if (ValOp == *U)
       return PI.setEscapedAndAborted(&SI);
       return PI.setEscapedAndAborted(&SI);
+    // HLSL Change Begin - not sroa matrix type.
+    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(ValOp->getType())) ||
+        hlsl::dxilutil::IsHLSLObjectType(ValOp->getType()))
+      return PI.setEscapedAndAborted(&SI);
+    // HLSL Change End.
+
     if (!IsOffsetKnown)
     if (!IsOffsetKnown)
       return PI.setAborted(&SI);
       return PI.setAborted(&SI);
 
 
@@ -1002,13 +1035,15 @@ private:
   void visitInstruction(Instruction &I) { PI.setAborted(&I); }
   void visitInstruction(Instruction &I) { PI.setAborted(&I); }
 };
 };
 
 
-AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI)
+AllocaSlices::AllocaSlices(
+    const DataLayout &DL, AllocaInst &AI,
+    const bool SkipHLSLMat) // HLSL Change - not sroa matrix type.
     :
     :
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
       AI(AI),
       AI(AI),
 #endif
 #endif
       PointerEscapingInstr(nullptr) {
       PointerEscapingInstr(nullptr) {
-  SliceBuilder PB(DL, AI, *this);
+  SliceBuilder PB(DL, AI, *this, SkipHLSLMat);
   SliceBuilder::PtrInfo PtrI = PB.visitPtr(AI);
   SliceBuilder::PtrInfo PtrI = PB.visitPtr(AI);
   if (PtrI.isEscaped() || PtrI.isAborted()) {
   if (PtrI.isEscaped() || PtrI.isAborted()) {
     // FIXME: We should sink the escape vs. abort info into the caller nicely,
     // FIXME: We should sink the escape vs. abort info into the caller nicely,
@@ -1204,6 +1239,7 @@ namespace {
 ///    SSA vector values.
 ///    SSA vector values.
 class SROA : public FunctionPass {
 class SROA : public FunctionPass {
   const bool RequiresDomTree;
   const bool RequiresDomTree;
+  const bool SkipHLSLMat; // HLSL Change - not sroa matrix type.
 
 
   LLVMContext *C;
   LLVMContext *C;
   DominatorTree *DT;
   DominatorTree *DT;
@@ -1252,9 +1288,10 @@ class SROA : public FunctionPass {
   SetVector<SelectInst *, SmallVector<SelectInst *, 2>> SpeculatableSelects;
   SetVector<SelectInst *, SmallVector<SelectInst *, 2>> SpeculatableSelects;
 
 
 public:
 public:
-  SROA(bool RequiresDomTree = true)
-      : FunctionPass(ID), RequiresDomTree(RequiresDomTree), C(nullptr),
-        DT(nullptr) {
+  SROA(bool RequiresDomTree = true, bool SkipHLSLMat = true)
+      : FunctionPass(ID), RequiresDomTree(RequiresDomTree),
+        SkipHLSLMat(SkipHLSLMat), // HLSL Change - not sroa matrix type.
+        C(nullptr), DT(nullptr) {
     initializeSROAPass(*PassRegistry::getPassRegistry());
     initializeSROAPass(*PassRegistry::getPassRegistry());
   }
   }
   bool runOnFunction(Function &F) override;
   bool runOnFunction(Function &F) override;
@@ -1280,8 +1317,8 @@ private:
 
 
 char SROA::ID = 0;
 char SROA::ID = 0;
 
 
-FunctionPass *llvm::createSROAPass(bool RequiresDomTree) {
-  return new SROA(RequiresDomTree);
+FunctionPass *llvm::createSROAPass(bool RequiresDomTree, bool SkipHLSLMat) {
+  return new SROA(RequiresDomTree, SkipHLSLMat);
 }
 }
 
 
 INITIALIZE_PASS_BEGIN(SROA, "sroa", "Scalar Replacement Of Aggregates", false,
 INITIALIZE_PASS_BEGIN(SROA, "sroa", "Scalar Replacement Of Aggregates", false,
@@ -3191,6 +3228,7 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
   friend class llvm::InstVisitor<AggLoadStoreRewriter, bool>;
   friend class llvm::InstVisitor<AggLoadStoreRewriter, bool>;
 
 
   const DataLayout &DL;
   const DataLayout &DL;
+  const bool SkipHLSLMat; // HLSL Change - not sroa matrix type.
 
 
   /// Queue of pointer uses to analyze and potentially rewrite.
   /// Queue of pointer uses to analyze and potentially rewrite.
   SmallVector<Use *, 8> Queue;
   SmallVector<Use *, 8> Queue;
@@ -3203,7 +3241,9 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
   Use *U;
   Use *U;
 
 
 public:
 public:
-  AggLoadStoreRewriter(const DataLayout &DL) : DL(DL) {}
+  AggLoadStoreRewriter(const DataLayout &DL, const bool SkipHLSLMat)
+      // HLSL Change - not sroa matrix type.
+      : DL(DL), SkipHLSLMat(SkipHLSLMat) {}
 
 
   /// Rewrite loads and stores through a pointer and all pointers derived from
   /// Rewrite loads and stores through a pointer and all pointers derived from
   /// it.
   /// it.
@@ -3323,6 +3363,11 @@ private:
     assert(LI.getPointerOperand() == *U);
     assert(LI.getPointerOperand() == *U);
     if (!LI.isSimple() || LI.getType()->isSingleValueType())
     if (!LI.isSimple() || LI.getType()->isSingleValueType())
       return false;
       return false;
+    // HLSL Change Begin - not sroa matrix type.
+    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(LI.getType())) ||
+        hlsl::dxilutil::IsHLSLObjectType(LI.getType()))
+      return false;
+    // HLSL Change End.
 
 
     // We have an aggregate being loaded, split it apart.
     // We have an aggregate being loaded, split it apart.
     DEBUG(dbgs() << "    original: " << LI << "\n");
     DEBUG(dbgs() << "    original: " << LI << "\n");
@@ -3357,7 +3402,11 @@ private:
     Value *V = SI.getValueOperand();
     Value *V = SI.getValueOperand();
     if (V->getType()->isSingleValueType())
     if (V->getType()->isSingleValueType())
       return false;
       return false;
-
+    // HLSL Change Begin - not sroa matrix type.
+    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(V->getType())) ||
+        hlsl::dxilutil::IsHLSLObjectType(V->getType()))
+      return false;
+    // HLSL Change End.
     // We have an aggregate being stored, split it apart.
     // We have an aggregate being stored, split it apart.
     DEBUG(dbgs() << "    original: " << SI << "\n");
     DEBUG(dbgs() << "    original: " << SI << "\n");
     StoreOpSplitter Splitter(&SI, *U);
     StoreOpSplitter Splitter(&SI, *U);
@@ -3367,6 +3416,20 @@ private:
   }
   }
 
 
   bool visitBitCastInst(BitCastInst &BC) {
   bool visitBitCastInst(BitCastInst &BC) {
+    // HLSL Change Begin - not sroa matrix type.
+    if (PointerType *PT = dyn_cast<PointerType>(BC.getType())) {
+      Type *EltTy = PT->getElementType();
+      if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(EltTy)) ||
+          hlsl::dxilutil::IsHLSLObjectType(EltTy))
+        return false;
+      if (PointerType *SrcPT = dyn_cast<PointerType>(BC.getSrcTy())) {
+        Type *SrcEltTy = SrcPT->getElementType();
+        if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(SrcEltTy)) ||
+            hlsl::dxilutil::IsHLSLObjectType(SrcEltTy))
+          return false;
+      }
+    }
+    // HLSL Change End.
     enqueueUsers(BC);
     enqueueUsers(BC);
     return false;
     return false;
   }
   }
@@ -4310,7 +4373,12 @@ bool SROA::runOnAlloca(AllocaInst &AI) {
 
 
   // Skip alloca forms that this analysis can't handle.
   // Skip alloca forms that this analysis can't handle.
   if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() ||
   if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() ||
-      hlsl::dxilutil::IsHLSLObjectType(AI.getAllocatedType()) || // HLSL Change - not sroa resource type.
+      hlsl::dxilutil::IsHLSLObjectType(
+          AI.getAllocatedType()) || // HLSL Change - not sroa resource type.
+      // HLSL Change Begin - not sroa matrix type.
+      (SkipHLSLMat &&
+       hlsl::dxilutil::IsHLSLMatrixType(AI.getAllocatedType())) ||
+      // HLSL Change End.
       DL.getTypeAllocSize(AI.getAllocatedType()) == 0)
       DL.getTypeAllocSize(AI.getAllocatedType()) == 0)
     return false;
     return false;
 
 
@@ -4318,11 +4386,11 @@ bool SROA::runOnAlloca(AllocaInst &AI) {
 
 
   // First, split any FCA loads and stores touching this alloca to promote
   // First, split any FCA loads and stores touching this alloca to promote
   // better splitting and promotion opportunities.
   // better splitting and promotion opportunities.
-  AggLoadStoreRewriter AggRewriter(DL);
+  AggLoadStoreRewriter AggRewriter(DL, SkipHLSLMat);
   Changed |= AggRewriter.rewrite(AI);
   Changed |= AggRewriter.rewrite(AI);
 
 
   // Build the slices using a recursive instruction-visiting builder.
   // Build the slices using a recursive instruction-visiting builder.
-  AllocaSlices AS(DL, AI);
+  AllocaSlices AS(DL, AI, SkipHLSLMat);
   DEBUG(AS.print(dbgs()));
   DEBUG(AS.print(dbgs()));
   if (AS.isEscaped())
   if (AS.isEscaped())
     return Changed;
     return Changed;

+ 26 - 26
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1639,7 +1639,7 @@ bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
         Type *Ty = AI->getAllocatedType();
         Type *Ty = AI->getAllocatedType();
         // Skip empty struct parameters.
         // Skip empty struct parameters.
         if (StructType *ST = dyn_cast<StructType>(Ty)) {
         if (StructType *ST = dyn_cast<StructType>(Ty)) {
-          if (!HLMatrixLower::IsMatrixType(Ty)) {
+          if (!dxilutil::IsHLSLMatrixType(Ty)) {
             DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
             DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
             if (SA && SA->IsEmptyStruct()) {
             if (SA && SA->IsEmptyStruct()) {
               for (User *U : AI->users()) {
               for (User *U : AI->users()) {
@@ -1884,7 +1884,7 @@ void SROA_HLSL::isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset,
 
 
   for (;GEPIt != E; ++GEPIt) {
   for (;GEPIt != E; ++GEPIt) {
     Type *Ty = *GEPIt;
     Type *Ty = *GEPIt;
-    if (Ty->isStructTy() && !HLMatrixLower::IsMatrixType(Ty)) {
+    if (Ty->isStructTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
       // Don't go inside struct when mark hasArrayIndexing and hasVectorIndexing.
       // Don't go inside struct when mark hasArrayIndexing and hasVectorIndexing.
       // The following level won't affect scalar repl on the struct.
       // The following level won't affect scalar repl on the struct.
       break;
       break;
@@ -2250,7 +2250,7 @@ static void EltMemCpy(Type *Ty, Value *Dest, Value *Src,
 static bool IsMemCpyTy(Type *Ty, DxilTypeSystem &typeSys) {
 static bool IsMemCpyTy(Type *Ty, DxilTypeSystem &typeSys) {
   if (!Ty->isAggregateType())
   if (!Ty->isAggregateType())
     return false;
     return false;
-  if (HLMatrixLower::IsMatrixType(Ty))
+  if (dxilutil::IsHLSLMatrixType(Ty))
     return false;
     return false;
   if (dxilutil::IsHLSLObjectType(Ty))
   if (dxilutil::IsHLSLObjectType(Ty))
     return false;
     return false;
@@ -2282,7 +2282,7 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
              fieldAnnotation, bEltMemCpy);
              fieldAnnotation, bEltMemCpy);
 
 
     idxList.pop_back();
     idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     // If no fieldAnnotation, use row major as default.
     // If no fieldAnnotation, use row major as default.
     // Only load then store immediately should be fine.
     // Only load then store immediately should be fine.
     bool bRowMajor = true;
     bool bRowMajor = true;
@@ -2389,7 +2389,7 @@ static void SplitPtr(Value *Ptr, // The root value pointer
   }
   }
   
   
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (!HLMatrixLower::IsMatrixType(Ty) && !dxilutil::IsHLSLObjectType(ST)) {
+    if (!dxilutil::IsHLSLMatrixType(Ty) && !dxilutil::IsHLSLObjectType(ST)) {
       const DxilStructAnnotation* SA = TypeSys.GetStructAnnotation(ST);
       const DxilStructAnnotation* SA = TypeSys.GetStructAnnotation(ST);
 
 
       for (uint32_t i = 0; i < ST->getNumElements(); i++) {
       for (uint32_t i = 0; i < ST->getNumElements(); i++) {
@@ -2425,7 +2425,7 @@ static void SplitPtr(Value *Ptr, // The root value pointer
       ElTy = ElAT->getElementType();
       ElTy = ElAT->getElementType();
     }
     }
 
 
-    if (ElTy->isStructTy() && !HLMatrixLower::IsMatrixType(ElTy)) {
+    if (ElTy->isStructTy() && !dxilutil::IsHLSLMatrixType(ElTy)) {
       DXASSERT(0, "Not support array of struct when split pointers.");
       DXASSERT(0, "Not support array of struct when split pointers.");
       return;
       return;
     }
     }
@@ -2443,7 +2443,7 @@ static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsi
   // Size match, return current level.
   // Size match, return current level.
   if (ptrSize == size) {
   if (ptrSize == size) {
     // Not go deeper for matrix.
     // Not go deeper for matrix.
-    if (HLMatrixLower::IsMatrixType(Ty))
+    if (dxilutil::IsHLSLMatrixType(Ty))
       return level;
       return level;
     // For struct, go deeper if size not change.
     // For struct, go deeper if size not change.
     // This will leave memcpy to deeper level when flatten.
     // This will leave memcpy to deeper level when flatten.
@@ -2596,7 +2596,7 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   // Try to find fieldAnnotation from user of Dest/Src.
   // Try to find fieldAnnotation from user of Dest/Src.
   if (!fieldAnnotation) {
   if (!fieldAnnotation) {
     Type *EltTy = dxilutil::GetArrayEltTy(DestTy);
     Type *EltTy = dxilutil::GetArrayEltTy(DestTy);
-    if (HLMatrixLower::IsMatrixType(EltTy)) {
+    if (dxilutil::IsHLSLMatrixType(EltTy)) {
       fieldAnnotation = HLMatrixLower::FindAnnotationFromMatUser(Dest, typeSys);
       fieldAnnotation = HLMatrixLower::FindAnnotationFromMatUser(Dest, typeSys);
     }
     }
   }
   }
@@ -2845,7 +2845,7 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
         Value *Ptr = NewElts[i];
         Value *Ptr = NewElts[i];
         Type *Ty = Ptr->getType()->getPointerElementType();
         Type *Ty = Ptr->getType()->getPointerElementType();
         Value *Load = nullptr;
         Value *Load = nullptr;
-        if (!HLMatrixLower::IsMatrixType(Ty))
+        if (!dxilutil::IsHLSLMatrixType(Ty))
           Load = Builder.CreateLoad(Ptr, "load");
           Load = Builder.CreateLoad(Ptr, "load");
         else {
         else {
           // Generate Matrix Load.
           // Generate Matrix Load.
@@ -2930,7 +2930,7 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
       Module *M = SI->getModule();
       Module *M = SI->getModule();
       for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
       for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
         Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
         Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
-        if (!HLMatrixLower::IsMatrixType(Extract->getType())) {
+        if (!dxilutil::IsHLSLMatrixType(Extract->getType())) {
           Builder.CreateStore(Extract, NewElts[i]);
           Builder.CreateStore(Extract, NewElts[i]);
         } else {
         } else {
           // Generate Matrix Store.
           // Generate Matrix Store.
@@ -3393,7 +3393,7 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
   if (!Ty->isAggregateType())
   if (!Ty->isAggregateType())
     return false;
     return false;
   // Skip matrix types.
   // Skip matrix types.
-  if (HLMatrixLower::IsMatrixType(Ty))
+  if (dxilutil::IsHLSLMatrixType(Ty))
     return false;
     return false;
 
 
   IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
   IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
@@ -3440,7 +3440,7 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
 
 
     if (ElTy->isStructTy() &&
     if (ElTy->isStructTy() &&
         // Skip Matrix type.
         // Skip Matrix type.
-        !HLMatrixLower::IsMatrixType(ElTy)) {
+        !dxilutil::IsHLSLMatrixType(ElTy)) {
       if (!dxilutil::IsHLSLObjectType(ElTy)) {
       if (!dxilutil::IsHLSLObjectType(ElTy)) {
         // for array of struct
         // for array of struct
         // split into arrays of struct elements
         // split into arrays of struct elements
@@ -3569,7 +3569,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
   if (Ty->isSingleValueType() && !Ty->isVectorTy())
   if (Ty->isSingleValueType() && !Ty->isVectorTy())
     return false;
     return false;
   // Skip matrix types.
   // Skip matrix types.
-  if (HLMatrixLower::IsMatrixType(Ty))
+  if (dxilutil::IsHLSLMatrixType(Ty))
     return false;
     return false;
 
 
   Module *M = GV->getParent();
   Module *M = GV->getParent();
@@ -3638,7 +3638,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
 
 
     if (ElTy->isStructTy() &&
     if (ElTy->isStructTy() &&
         // Skip Matrix type.
         // Skip Matrix type.
-        !HLMatrixLower::IsMatrixType(ElTy)) {
+        !dxilutil::IsHLSLMatrixType(ElTy)) {
       // for array of struct
       // for array of struct
       // split into arrays of struct elements
       // split into arrays of struct elements
       StructType *ElST = cast<StructType>(ElTy);
       StructType *ElST = cast<StructType>(ElTy);
@@ -4202,7 +4202,7 @@ bool SROA_Helper::IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys) {
     Ty = Ty->getArrayElementType();
     Ty = Ty->getArrayElementType();
 
 
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (!HLMatrixLower::IsMatrixType(Ty)) {
+    if (!dxilutil::IsHLSLMatrixType(Ty)) {
       DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
       DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
       if (SA && SA->IsEmptyStruct())
       if (SA && SA->IsEmptyStruct())
         return true;
         return true;
@@ -4360,7 +4360,7 @@ public:
           continue;
           continue;
 
 
         // Check matrix store.
         // Check matrix store.
-        if (HLMatrixLower::IsMatrixType(
+        if (dxilutil::IsHLSLMatrixType(
                 GV->getType()->getPointerElementType())) {
                 GV->getType()->getPointerElementType())) {
           if (CallInst *CI = dyn_cast<CallInst>(user)) {
           if (CallInst *CI = dyn_cast<CallInst>(user)) {
             if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
             if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
@@ -4667,7 +4667,7 @@ static DxilFieldAnnotation &GetEltAnnotation(Type *Ty, unsigned idx, DxilFieldAn
   while (Ty->isArrayTy())
   while (Ty->isArrayTy())
     Ty = Ty->getArrayElementType();
     Ty = Ty->getArrayElementType();
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (HLMatrixLower::IsMatrixType(Ty))
+    if (dxilutil::IsHLSLMatrixType(Ty))
       return annotation;
       return annotation;
     DxilStructAnnotation *SA = dxilTypeSys.GetStructAnnotation(ST);
     DxilStructAnnotation *SA = dxilTypeSys.GetStructAnnotation(ST);
     if (SA) {
     if (SA) {
@@ -4735,13 +4735,13 @@ static unsigned AllocateSemanticIndex(
                                             FlatAnnotationList);
                                             FlatAnnotationList);
     }
     }
     return updatedArgIdx;
     return updatedArgIdx;
-  } else if (Ty->isStructTy() && !HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (Ty->isStructTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
     unsigned fieldsCount = Ty->getStructNumElements();
     unsigned fieldsCount = Ty->getStructNumElements();
     for (unsigned i = 0; i < fieldsCount; i++) {
     for (unsigned i = 0; i < fieldsCount; i++) {
       Type *EltTy = Ty->getStructElementType(i);
       Type *EltTy = Ty->getStructElementType(i);
       argIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
       argIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
                                      FlatAnnotationList);
                                      FlatAnnotationList);
-      if (!(EltTy->isStructTy() && !HLMatrixLower::IsMatrixType(EltTy))) {
+      if (!(EltTy->isStructTy() && !dxilutil::IsHLSLMatrixType(EltTy))) {
         // Update argIdx only when it is a leaf node.
         // Update argIdx only when it is a leaf node.
         argIdx++;
         argIdx++;
       }
       }
@@ -4981,7 +4981,7 @@ CastCopyArrayMultiDimTo1Dim(Value *FromArray, Value *ToArray, Type *CurFromTy,
       Value *Elt = Builder.CreateExtractElement(V, i);
       Value *Elt = Builder.CreateExtractElement(V, i);
       Builder.CreateStore(Elt, ToPtr);
       Builder.CreateStore(Elt, ToPtr);
     }
     }
-  } else if (HLMatrixLower::IsMatrixType(CurFromTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(CurFromTy)) {
     // Copy matrix to array.
     // Copy matrix to array.
     unsigned col, row;
     unsigned col, row;
     HLMatrixLower::GetMatrixInfo(CurFromTy, col, row);
     HLMatrixLower::GetMatrixInfo(CurFromTy, col, row);
@@ -5028,7 +5028,7 @@ CastCopyArray1DimToMultiDim(Value *FromArray, Value *ToArray, Type *CurToTy,
       V = Builder.CreateInsertElement(V, Elt, i);
       V = Builder.CreateInsertElement(V, Elt, i);
     }
     }
     Builder.CreateStore(V, ToPtr);
     Builder.CreateStore(V, ToPtr);
-  } else if (HLMatrixLower::IsMatrixType(CurToTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(CurToTy)) {
     // Copy array to matrix.
     // Copy array to matrix.
     unsigned col, row;
     unsigned col, row;
     HLMatrixLower::GetMatrixInfo(CurToTy, col, row);
     HLMatrixLower::GetMatrixInfo(CurToTy, col, row);
@@ -5072,7 +5072,7 @@ static void CastCopyOldPtrToNewPtr(Value *OldPtr, Value *NewPtr, HLModule &HLM,
       Value *Elt = Builder.CreateExtractElement(V, i);
       Value *Elt = Builder.CreateExtractElement(V, i);
       Builder.CreateStore(Elt, EltPtr);
       Builder.CreateStore(Elt, EltPtr);
     }
     }
-  } else if (HLMatrixLower::IsMatrixType(OldTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
     CopyMatPtrToArrayPtr(OldPtr, NewPtr, /*arrayBaseIdx*/ 0, HLM, Builder,
     CopyMatPtrToArrayPtr(OldPtr, NewPtr, /*arrayBaseIdx*/ 0, HLM, Builder,
                          bRowMajor);
                          bRowMajor);
   } else if (OldTy->isArrayTy()) {
   } else if (OldTy->isArrayTy()) {
@@ -5102,7 +5102,7 @@ static void CastCopyNewPtrToOldPtr(Value *NewPtr, Value *OldPtr, HLModule &HLM,
       V = Builder.CreateInsertElement(V, Elt, i);
       V = Builder.CreateInsertElement(V, Elt, i);
     }
     }
     Builder.CreateStore(V, OldPtr);
     Builder.CreateStore(V, OldPtr);
-  } else if (HLMatrixLower::IsMatrixType(OldTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
     CopyArrayPtrToMatPtr(NewPtr, /*arrayBaseIdx*/ 0, OldPtr, HLM, Builder,
     CopyArrayPtrToMatPtr(NewPtr, /*arrayBaseIdx*/ 0, OldPtr, HLM, Builder,
                          bRowMajor);
                          bRowMajor);
   } else if (OldTy->isArrayTy()) {
   } else if (OldTy->isArrayTy()) {
@@ -5200,7 +5200,7 @@ void SROA_Parameter_HLSL::replaceCastParameter(
     // Must be in param.
     // Must be in param.
     // Store NewParam to OldParam at entry.
     // Store NewParam to OldParam at entry.
     Builder.CreateStore(NewParam, OldParam);
     Builder.CreateStore(NewParam, OldParam);
-  } else if (HLMatrixLower::IsMatrixType(OldTy)) {
+  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
     bool bRowMajor = castRowMajorParamMap.count(NewParam);
     bool bRowMajor = castRowMajorParamMap.count(NewParam);
     Value *Mat = LoadArrayPtrToMat(NewParam, /*arrayBaseIdx*/ 0, OldTy,
     Value *Mat = LoadArrayPtrToMat(NewParam, /*arrayBaseIdx*/ 0, OldTy,
                                    *m_pHLModule, Builder, bRowMajor);
                                    *m_pHLModule, Builder, bRowMajor);
@@ -5923,7 +5923,7 @@ static void LegalizeDxilInputOutputs(Function *F,
 
 
     // Skip arg which is not a pointer.
     // Skip arg which is not a pointer.
     if (!Ty->isPointerTy()) {
     if (!Ty->isPointerTy()) {
-      if (HLMatrixLower::IsMatrixType(Ty)) {
+      if (dxilutil::IsHLSLMatrixType(Ty)) {
         // Replace matrix arg with cast to vec. It will be lowered in
         // Replace matrix arg with cast to vec. It will be lowered in
         // DxilGenerationPass.
         // DxilGenerationPass.
         isColMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
         isColMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
@@ -5976,7 +5976,7 @@ static void LegalizeDxilInputOutputs(Function *F,
       bStoreInputToTemp = true;
       bStoreInputToTemp = true;
     }
     }
 
 
-    if (HLMatrixLower::IsMatrixType(Ty)) {
+    if (dxilutil::IsHLSLMatrixType(Ty)) {
       if (qual == DxilParamInputQual::In)
       if (qual == DxilParamInputQual::In)
         bStoreInputToTemp = bLoad;
         bStoreInputToTemp = bLoad;
       else if (qual == DxilParamInputQual::Out)
       else if (qual == DxilParamInputQual::Out)

+ 12 - 12
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -4289,7 +4289,7 @@ static void SimpleTransformForHLDXIR(Instruction *I,
   } break;
   } break;
   case Instruction::Load: {
   case Instruction::Load: {
     LoadInst *ldInst = cast<LoadInst>(I);
     LoadInst *ldInst = cast<LoadInst>(I);
-    DXASSERT(!HLMatrixLower::IsMatrixType(ldInst->getType()),
+    DXASSERT(!dxilutil::IsHLSLMatrixType(ldInst->getType()),
                       "matrix load should use HL LdStMatrix");
                       "matrix load should use HL LdStMatrix");
     Value *Ptr = ldInst->getPointerOperand();
     Value *Ptr = ldInst->getPointerOperand();
     if (ConstantExpr *CE = dyn_cast_or_null<ConstantExpr>(Ptr)) {
     if (ConstantExpr *CE = dyn_cast_or_null<ConstantExpr>(Ptr)) {
@@ -4301,7 +4301,7 @@ static void SimpleTransformForHLDXIR(Instruction *I,
   case Instruction::Store: {
   case Instruction::Store: {
     StoreInst *stInst = cast<StoreInst>(I);
     StoreInst *stInst = cast<StoreInst>(I);
     Value *V = stInst->getValueOperand();
     Value *V = stInst->getValueOperand();
-    DXASSERT_LOCALVAR(V, !HLMatrixLower::IsMatrixType(V->getType()),
+    DXASSERT_LOCALVAR(V, !dxilutil::IsHLSLMatrixType(V->getType()),
                       "matrix store should use HL LdStMatrix");
                       "matrix store should use HL LdStMatrix");
     Value *Ptr = stInst->getPointerOperand();
     Value *Ptr = stInst->getPointerOperand();
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) {
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) {
@@ -5298,7 +5298,7 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
         valEltTy->isSingleValueType()) {
         valEltTy->isSingleValueType()) {
       Value *ldVal = Builder.CreateLoad(val);
       Value *ldVal = Builder.CreateLoad(val);
       FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
       FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
-    } else if (HLMatrixLower::IsMatrixType(valEltTy)) {
+    } else if (dxilutil::IsHLSLMatrixType(valEltTy)) {
       Value *ldVal = EmitHLSLMatrixLoad(Builder, val, Ty);
       Value *ldVal = EmitHLSLMatrixLoad(Builder, val, Ty);
       FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
       FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
     } else {
     } else {
@@ -5350,7 +5350,7 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
       }
       }
     }
     }
   } else {
   } else {
-    if (HLMatrixLower::IsMatrixType(valTy)) {
+    if (dxilutil::IsHLSLMatrixType(valTy)) {
       unsigned col, row;
       unsigned col, row;
       llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(valTy, col, row);
       llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(valTy, col, row);
       // All matrix Value should be row major.
       // All matrix Value should be row major.
@@ -5492,7 +5492,7 @@ static void StoreInitListToDestPtr(Value *DestPtr,
     Result = CGF.EmitToMemory(Result, Type);
     Result = CGF.EmitToMemory(Result, Type);
     Builder.CreateStore(Result, DestPtr);
     Builder.CreateStore(Result, DestPtr);
     idx += Ty->getVectorNumElements();
     idx += Ty->getVectorNumElements();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     bool isRowMajor =
     bool isRowMajor =
         IsRowMajorMatrix(Type, bDefaultRowMajor);
         IsRowMajorMatrix(Type, bDefaultRowMajor);
 
 
@@ -5783,7 +5783,7 @@ static void FlatConstToList(Constant *C, SmallVector<Constant *, 4> &EltValList,
       FlatConstToList(C->getAggregateElement(i), EltValList, Type, Types,
       FlatConstToList(C->getAggregateElement(i), EltValList, Type, Types,
                       bDefaultRowMajor);
                       bDefaultRowMajor);
     }
     }
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     bool isRowMajor = IsRowMajorMatrix(Type, bDefaultRowMajor);
     bool isRowMajor = IsRowMajorMatrix(Type, bDefaultRowMajor);
     // matrix type is struct { vector<Ty, row> [col] };
     // matrix type is struct { vector<Ty, row> [col] };
     // Strip the struct level here.
     // Strip the struct level here.
@@ -6005,7 +6005,7 @@ static Constant *BuildConstInitializer(QualType Type, unsigned &offset,
   } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
   } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
     return BuildConstArray(AT, offset, EltValList, Type, Types,
     return BuildConstArray(AT, offset, EltValList, Type, Types,
                            bDefaultRowMajor);
                            bDefaultRowMajor);
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     return BuildConstMatrix(Ty, offset, EltValList, Type, Types,
     return BuildConstMatrix(Ty, offset, EltValList, Type, Types,
                             bDefaultRowMajor);
                             bDefaultRowMajor);
   } else if (StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
   } else if (StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
@@ -6509,7 +6509,7 @@ void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
                                  GepList, EltTyList);
                                  GepList, EltTyList);
 
 
     idxList.pop_back();
     idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     // Use matLd/St for matrix.
     // Use matLd/St for matrix.
     unsigned col, row;
     unsigned col, row;
     llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
     llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
@@ -6671,7 +6671,7 @@ void CGMSHLSLRuntime::EmitHLSLAggregateCopy(
                           PT->getElementType());
                           PT->getElementType());
 
 
     idxList.pop_back();
     idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     // Use matLd/St for matrix.
     // Use matLd/St for matrix.
     Value *srcGEP = CGF.Builder.CreateInBoundsGEP(SrcPtr, idxList);
     Value *srcGEP = CGF.Builder.CreateInBoundsGEP(SrcPtr, idxList);
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
@@ -6871,7 +6871,7 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(
                                       SrcType, PT->getElementType());
                                       SrcType, PT->getElementType());
 
 
     idxList.pop_back();
     idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
     // Use matLd/St for matrix.
     // Use matLd/St for matrix.
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
     unsigned row, col;
     unsigned row, col;
@@ -7133,7 +7133,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
                 outVal->getType(), ToTy));
                 outVal->getType(), ToTy));
 
 
         Value *castVal = CGF.Builder.CreateCast(castOp, outVal, ToTy);
         Value *castVal = CGF.Builder.CreateCast(castOp, outVal, ToTy);
-        if (!HLMatrixLower::IsMatrixType(ToTy))
+        if (!dxilutil::IsHLSLMatrixType(ToTy))
           CGF.Builder.CreateStore(castVal, tmpArgAddr);
           CGF.Builder.CreateStore(castVal, tmpArgAddr);
         else
         else
           EmitHLSLMatrixStore(CGF, castVal, tmpArgAddr, ParamTy);
           EmitHLSLMatrixStore(CGF, castVal, tmpArgAddr, ParamTy);
@@ -7201,7 +7201,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
 
 
           castVal = CGF.Builder.CreateCast(castOp, outVal, ToTy);
           castVal = CGF.Builder.CreateCast(castOp, outVal, ToTy);
         }
         }
-        if (!HLMatrixLower::IsMatrixType(ToTy))
+        if (!dxilutil::IsHLSLMatrixType(ToTy))
           CGF.EmitStoreThroughLValue(RValue::get(castVal), argLV);
           CGF.EmitStoreThroughLValue(RValue::get(castVal), argLV);
         else {
         else {
           Value *destPtr = argLV.getAddress();
           Value *destPtr = argLV.getAddress();

+ 4 - 4
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -784,9 +784,9 @@ void PrintFieldLayout(llvm::Type *Ty, DxilFieldAnnotation &annotation,
     llvm::Type *EltTy = Ty;
     llvm::Type *EltTy = Ty;
     unsigned arraySize = 0;
     unsigned arraySize = 0;
     unsigned arrayLevel = 0;
     unsigned arrayLevel = 0;
-    if (!HLMatrixLower::IsMatrixType(EltTy) && EltTy->isArrayTy()) {
+    if (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isArrayTy()) {
       arraySize = 1;
       arraySize = 1;
-      while (!HLMatrixLower::IsMatrixType(EltTy) && EltTy->isArrayTy()) {
+      while (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isArrayTy()) {
         arraySize *= EltTy->getArrayNumElements();
         arraySize *= EltTy->getArrayNumElements();
         EltTy = EltTy->getArrayElementType();
         EltTy = EltTy->getArrayElementType();
         arrayLevel++;
         arrayLevel++;
@@ -817,7 +817,7 @@ void PrintFieldLayout(llvm::Type *Ty, DxilFieldAnnotation &annotation,
     }
     }
 
 
     std::string StreamStr;
     std::string StreamStr;
-    if (!HLMatrixLower::IsMatrixType(EltTy) && EltTy->isStructTy()) {
+    if (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isStructTy()) {
       std::string NameTypeStr = annotation.GetFieldName();
       std::string NameTypeStr = annotation.GetFieldName();
       raw_string_ostream Stream(NameTypeStr);
       raw_string_ostream Stream(NameTypeStr);
       if (arraySize)
       if (arraySize)
@@ -901,7 +901,7 @@ void PrintStructBufferDefinition(DxilResource *buf,
   OS << comment << "\n";
   OS << comment << "\n";
   llvm::Type *RetTy = buf->GetRetType();
   llvm::Type *RetTy = buf->GetRetType();
   // Skip none struct type.
   // Skip none struct type.
-  if (!RetTy->isStructTy() || HLMatrixLower::IsMatrixType(RetTy)) {
+  if (!RetTy->isStructTy() || dxilutil::IsHLSLMatrixType(RetTy)) {
     llvm::Type *Ty = buf->GetGlobalSymbol()->getType()->getPointerElementType();
     llvm::Type *Ty = buf->GetGlobalSymbol()->getType()->getPointerElementType();
     // For resource array, use element type.
     // For resource array, use element type.
     if (Ty->isArrayTy())
     if (Ty->isArrayTy())

+ 1 - 0
utils/hct/hctdb.py

@@ -1483,6 +1483,7 @@ class db_dxil(object):
         # UseNewSROA is used by PassManagerBuilder::populateFunctionPassManager, not a pass per se.
         # UseNewSROA is used by PassManagerBuilder::populateFunctionPassManager, not a pass per se.
         add_pass("sroa", "SROA", "Scalar Replacement Of Aggregates", [
         add_pass("sroa", "SROA", "Scalar Replacement Of Aggregates", [
             {'n':'RequiresDomTree', 't':'bool', 'c':1},
             {'n':'RequiresDomTree', 't':'bool', 'c':1},
+            {'n':'SkipHLSLMat', 't':'bool', 'c':1},
             {'n':'force-ssa-updater', 'i':'ForceSSAUpdater', 't':'bool', 'd':'Force the pass to not use DomTree and mem2reg, insteadforming SSA values through the SSAUpdater infrastructure.'},
             {'n':'force-ssa-updater', 'i':'ForceSSAUpdater', 't':'bool', 'd':'Force the pass to not use DomTree and mem2reg, insteadforming SSA values through the SSAUpdater infrastructure.'},
             {'n':'sroa-random-shuffle-slices', 'i':'SROARandomShuffleSlices', 't':'bool', 'd':'Enable randomly shuffling the slices to help uncover instability in their order.'},
             {'n':'sroa-random-shuffle-slices', 'i':'SROARandomShuffleSlices', 't':'bool', 'd':'Enable randomly shuffling the slices to help uncover instability in their order.'},
             {'n':'sroa-strict-inbounds', 'i':'SROAStrictInbounds', 't':'bool', 'd':'Experiment with completely strict handling of inbounds GEPs.'}])
             {'n':'sroa-strict-inbounds', 'i':'SROAStrictInbounds', 't':'bool', 'd':'Experiment with completely strict handling of inbounds GEPs.'}])