Преглед на файлове

Eliminate dxilutil::IsHLSLMatrixType in favor of HLMatrixType::isa (#1986)

Tristan Labelle преди 6 години
родител
ревизия
f56fb3a3d0

+ 0 - 1
include/dxc/DXIL/DxilUtil.h

@@ -107,7 +107,6 @@ namespace dxilutil {
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLResourceType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
-  bool IsHLSLMatrixType(llvm::Type *Ty);
   bool IsSplat(llvm::ConstantDataVector *cdv);
 }
 

+ 0 - 14
lib/DXIL/DxilUtil.cpp

@@ -479,20 +479,6 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
   return false;
 }
 
-bool IsHLSLMatrixType(Type *Ty) {
-  if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    Type *EltTy = ST->getElementType(0);
-    if (!ST->getName().startswith("class.matrix"))
-      return false;
-
-    bool isVecArray =
-        EltTy->isArrayTy() && EltTy->getArrayElementType()->isVectorTy();
-
-    return isVecArray && EltTy->getArrayNumElements() <= 4;
-  }
-  return false;
-}
-
 bool IsIntegerOrFloatingPointType(llvm::Type *Ty) {
   return Ty->isIntegerTy() || Ty->isFloatingPointTy();
 }

+ 2 - 1
lib/HLSL/DxilCondenseResources.cpp

@@ -19,6 +19,7 @@
 #include "dxc/HLSL/DxilSpanAllocator.h"
 #include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/DXIL/DxilUtil.h"
+#include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/HLSL/HLModule.h"
 
 #include "llvm/IR/Instructions.h"
@@ -1534,7 +1535,7 @@ Type *UpdateFieldTypeForLegacyLayout(Type *Ty, bool IsCBuf,
       return Ty;
     else
       return ArrayType::get(UpdatedTy, Ty->getArrayNumElements());
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (hlsl::HLMatrixType::isa(Ty)) {
     DXASSERT(annotation.HasMatrixAnnotation(), "must a matrix");
     HLMatrixType MatTy = HLMatrixType::cast(Ty);
     unsigned rows = MatTy.getNumRows();

+ 3 - 4
lib/HLSL/HLMatrixBitcastLowerPass.cpp

@@ -67,8 +67,7 @@ Type *TryLowerMatTy(Type *Ty) {
   Type *VecTy = nullptr;
   if (HLMatrixType::isMatrixArrayPtr(Ty)) {
     VecTy = LowerMatrixArrayPointerToOneDimArray(Ty);
-  } else if (isa<PointerType>(Ty) &&
-             dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
+  } else if (isa<PointerType>(Ty) && HLMatrixType::isa(Ty->getPointerElementType())) {
     VecTy = LowerMatrixTypeToOneDimArray(
         Ty->getPointerElementType());
     VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
@@ -130,7 +129,7 @@ bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
     User *U = *(it++);
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
       Type *EltTy = GEP->getType()->getPointerElementType();
-      if (dxilutil::IsHLSLMatrixType(EltTy)) {
+      if (HLMatrixType::isa(EltTy)) {
         if (hasCallUser(GEP))
           return true;
       } else {
@@ -185,7 +184,7 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
     User *U = *(it++);
     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
       Type *EltTy = GEP->getType()->getPointerElementType();
-      if (dxilutil::IsHLSLMatrixType(EltTy)) {
+      if (HLMatrixType::isa(EltTy)) {
         // Change gep matrixArray, 0, index
         // into
         //   gep oneDimArray, 0, index * matSize

+ 2 - 2
lib/HLSL/HLSignatureLower.cpp

@@ -625,14 +625,14 @@ void replaceDirectInputParameter(Value *param, Function *loadInput,
       newVec = Builder.CreateInsertElement(newVec, input, col);
     }
     param->replaceAllUsesWith(newVec);
-  } else if (!Ty->isArrayTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (!Ty->isArrayTy() && !HLMatrixType::isa(Ty)) {
     DXASSERT(cols == 1, "only support scalar here");
     Value *colIdx = hlslOP->GetU8Const(0);
     args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
     Value *input =
         GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
     param->replaceAllUsesWith(input);
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType::isa(Ty)) {
     if (param->use_empty()) return;
     DXASSERT(param->hasOneUse(),
              "matrix arg should only has one use as matrix to vec");

+ 11 - 10
lib/Transforms/Scalar/SROA.cpp

@@ -56,7 +56,8 @@
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/PromoteMemToReg.h"
 #include "llvm/Transforms/Utils/SSAUpdater.h"
-#include "dxc/DXIL/DxilUtil.h"  // HLSL Change - not sroa resource type.
+#include "dxc/DXIL/DxilUtil.h"  // HLSL Change - don't sroa resource type.
+#include "dxc/HLSL/HLMatrixType.h"  // HLSL Change - don't sroa matrix types.
 
 #if __cplusplus >= 201103L && !defined(NDEBUG)
 // We only use this for a debug check in C++11
@@ -697,14 +698,14 @@ private:
     // HLSL Change Begin - not sroa matrix type.
     if (PointerType *PT = dyn_cast<PointerType>(BC.getType())) {
       Type *EltTy = PT->getElementType();
-      if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(EltTy)) ||
+      if ((SkipHLSLMat && hlsl::HLMatrixType::isa(EltTy)) ||
           hlsl::dxilutil::IsHLSLObjectType(EltTy)) {
         AS.PointerEscapingInstr = &BC;
         return;
       }
       if (PointerType *SrcPT = dyn_cast<PointerType>(BC.getSrcTy())) {
         Type *SrcEltTy = SrcPT->getElementType();
-        if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(SrcEltTy)) ||
+        if ((SkipHLSLMat && hlsl::HLMatrixType::isa(SrcEltTy)) ||
             hlsl::dxilutil::IsHLSLObjectType(SrcEltTy)) {
           AS.PointerEscapingInstr = &BC;
           return;
@@ -773,7 +774,7 @@ private:
 
   void visitLoadInst(LoadInst &LI) {
     // HLSL Change Begin - not sroa matrix type.
-    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(LI.getType())) ||
+    if ((SkipHLSLMat && hlsl::HLMatrixType::isa(LI.getType())) ||
         hlsl::dxilutil::IsHLSLObjectType(LI.getType()))
       return PI.setEscapedAndAborted(&LI);
     // HLSL Change End.
@@ -794,7 +795,7 @@ private:
     if (ValOp == *U)
       return PI.setEscapedAndAborted(&SI);
     // HLSL Change Begin - not sroa matrix type.
-    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(ValOp->getType())) ||
+    if ((SkipHLSLMat && hlsl::HLMatrixType::isa(ValOp->getType())) ||
         hlsl::dxilutil::IsHLSLObjectType(ValOp->getType()))
       return PI.setEscapedAndAborted(&SI);
     // HLSL Change End.
@@ -3364,7 +3365,7 @@ private:
     if (!LI.isSimple() || LI.getType()->isSingleValueType())
       return false;
     // HLSL Change Begin - not sroa matrix type.
-    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(LI.getType())) ||
+    if ((SkipHLSLMat && hlsl::HLMatrixType::isa(LI.getType())) ||
         hlsl::dxilutil::IsHLSLObjectType(LI.getType()))
       return false;
     // HLSL Change End.
@@ -3403,7 +3404,7 @@ private:
     if (V->getType()->isSingleValueType())
       return false;
     // HLSL Change Begin - not sroa matrix type.
-    if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(V->getType())) ||
+    if ((SkipHLSLMat && hlsl::HLMatrixType::isa(V->getType())) ||
         hlsl::dxilutil::IsHLSLObjectType(V->getType()))
       return false;
     // HLSL Change End.
@@ -3419,12 +3420,12 @@ private:
     // HLSL Change Begin - not sroa matrix type.
     if (PointerType *PT = dyn_cast<PointerType>(BC.getType())) {
       Type *EltTy = PT->getElementType();
-      if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(EltTy)) ||
+      if ((SkipHLSLMat && hlsl::HLMatrixType::isa(EltTy)) ||
           hlsl::dxilutil::IsHLSLObjectType(EltTy))
         return false;
       if (PointerType *SrcPT = dyn_cast<PointerType>(BC.getSrcTy())) {
         Type *SrcEltTy = SrcPT->getElementType();
-        if ((SkipHLSLMat && hlsl::dxilutil::IsHLSLMatrixType(SrcEltTy)) ||
+        if ((SkipHLSLMat && hlsl::HLMatrixType::isa(SrcEltTy)) ||
             hlsl::dxilutil::IsHLSLObjectType(SrcEltTy))
           return false;
       }
@@ -4381,7 +4382,7 @@ bool SROA::runOnAlloca(AllocaInst &AI) {
           AI.getAllocatedType()) || // HLSL Change - not sroa resource type.
       // HLSL Change Begin - not sroa matrix type.
       (SkipHLSLMat &&
-       hlsl::dxilutil::IsHLSLMatrixType(AI.getAllocatedType())) ||
+       hlsl::HLMatrixType::isa(AI.getAllocatedType())) ||
       // HLSL Change End.
       DL.getTypeAllocSize(AI.getAllocatedType()) == 0)
     return false;

+ 24 - 24
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1644,7 +1644,7 @@ bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
         Type *Ty = AI->getAllocatedType();
         // Skip empty struct parameters.
         if (StructType *ST = dyn_cast<StructType>(Ty)) {
-          if (!dxilutil::IsHLSLMatrixType(Ty)) {
+          if (!HLMatrixType::isa(Ty)) {
             DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
             if (SA && SA->IsEmptyStruct()) {
               for (User *U : AI->users()) {
@@ -1908,7 +1908,7 @@ void SROA_HLSL::isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset,
 
   for (;GEPIt != E; ++GEPIt) {
     Type *Ty = *GEPIt;
-    if (Ty->isStructTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
+    if (Ty->isStructTy() && !HLMatrixType::isa(Ty)) {
       // Don't go inside struct when mark hasArrayIndexing and hasVectorIndexing.
       // The following level won't affect scalar repl on the struct.
       break;
@@ -2274,7 +2274,7 @@ static void EltMemCpy(Type *Ty, Value *Dest, Value *Src,
 static bool IsMemCpyTy(Type *Ty, DxilTypeSystem &typeSys) {
   if (!Ty->isAggregateType())
     return false;
-  if (dxilutil::IsHLSLMatrixType(Ty))
+  if (HLMatrixType::isa(Ty))
     return false;
   if (dxilutil::IsHLSLObjectType(Ty))
     return false;
@@ -2306,7 +2306,7 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
              fieldAnnotation, bEltMemCpy);
 
     idxList.pop_back();
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType::isa(Ty)) {
     // If no fieldAnnotation, use row major as default.
     // Only load then store immediately should be fine.
     bool bRowMajor = true;
@@ -2413,7 +2413,7 @@ static void SplitPtr(Value *Ptr, // The root value pointer
   }
   
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (!dxilutil::IsHLSLMatrixType(Ty) && !dxilutil::IsHLSLObjectType(ST)) {
+    if (!HLMatrixType::isa(Ty) && !dxilutil::IsHLSLObjectType(ST)) {
       const DxilStructAnnotation* SA = TypeSys.GetStructAnnotation(ST);
 
       for (uint32_t i = 0; i < ST->getNumElements(); i++) {
@@ -2449,7 +2449,7 @@ static void SplitPtr(Value *Ptr, // The root value pointer
       ElTy = ElAT->getElementType();
     }
 
-    if (ElTy->isStructTy() && !dxilutil::IsHLSLMatrixType(ElTy)) {
+    if (ElTy->isStructTy() && !HLMatrixType::isa(ElTy)) {
       DXASSERT(0, "Not support array of struct when split pointers.");
       return;
     }
@@ -2467,7 +2467,7 @@ static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsi
   // Size match, return current level.
   if (ptrSize == size) {
     // Do not go deeper for matrix or object.
-    if (dxilutil::IsHLSLMatrixType(Ty) || dxilutil::IsHLSLObjectType(Ty))
+    if (HLMatrixType::isa(Ty) || dxilutil::IsHLSLObjectType(Ty))
       return level;
     // For struct, go deeper if size not change.
     // This will leave memcpy to deeper level when flatten.
@@ -2638,7 +2638,7 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   // Try to find fieldAnnotation from user of Dest/Src.
   if (!fieldAnnotation) {
     Type *EltTy = dxilutil::GetArrayEltTy(DestTy);
-    if (dxilutil::IsHLSLMatrixType(EltTy)) {
+    if (HLMatrixType::isa(EltTy)) {
       fieldAnnotation = FindAnnotationFromMatUser(Dest, typeSys);
     }
   }
@@ -2885,7 +2885,7 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
         Value *Ptr = NewElts[i];
         Type *Ty = Ptr->getType()->getPointerElementType();
         Value *Load = nullptr;
-        if (!dxilutil::IsHLSLMatrixType(Ty))
+        if (!HLMatrixType::isa(Ty))
           Load = Builder.CreateLoad(Ptr, "load");
         else {
           // Generate Matrix Load.
@@ -2971,7 +2971,7 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
       Module *M = SI->getModule();
       for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
         Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
-        if (!dxilutil::IsHLSLMatrixType(Extract->getType())) {
+        if (!HLMatrixType::isa(Extract->getType())) {
           Builder.CreateStore(Extract, NewElts[i]);
         } else {
           // Generate Matrix Store.
@@ -3463,7 +3463,7 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
   if (!Ty->isAggregateType())
     return false;
   // Skip matrix types.
-  if (dxilutil::IsHLSLMatrixType(Ty))
+  if (HLMatrixType::isa(Ty))
     return false;
 
   IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
@@ -3510,7 +3510,7 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
 
     if (ElTy->isStructTy() &&
         // Skip Matrix type.
-        !dxilutil::IsHLSLMatrixType(ElTy)) {
+        !HLMatrixType::isa(ElTy)) {
       if (!dxilutil::IsHLSLObjectType(ElTy)) {
         // for array of struct
         // split into arrays of struct elements
@@ -3639,7 +3639,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
   if (Ty->isSingleValueType() && !Ty->isVectorTy())
     return false;
   // Skip matrix types.
-  if (dxilutil::IsHLSLMatrixType(Ty))
+  if (HLMatrixType::isa(Ty))
     return false;
 
   Module *M = GV->getParent();
@@ -3708,7 +3708,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
 
     if (ElTy->isStructTy() &&
         // Skip Matrix type.
-        !dxilutil::IsHLSLMatrixType(ElTy)) {
+        !HLMatrixType::isa(ElTy)) {
       // for array of struct
       // split into arrays of struct elements
       StructType *ElST = cast<StructType>(ElTy);
@@ -4291,7 +4291,7 @@ bool SROA_Helper::IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys) {
     Ty = Ty->getArrayElementType();
 
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (!dxilutil::IsHLSLMatrixType(Ty)) {
+    if (!HLMatrixType::isa(Ty)) {
       DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
       if (SA && SA->IsEmptyStruct())
         return true;
@@ -4449,7 +4449,7 @@ public:
           continue;
 
         // Check matrix store.
-        if (dxilutil::IsHLSLMatrixType(
+        if (HLMatrixType::isa(
                 GV->getType()->getPointerElementType())) {
           if (CallInst *CI = dyn_cast<CallInst>(user)) {
             if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
@@ -4762,7 +4762,7 @@ static DxilFieldAnnotation &GetEltAnnotation(Type *Ty, unsigned idx, DxilFieldAn
   while (Ty->isArrayTy())
     Ty = Ty->getArrayElementType();
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (dxilutil::IsHLSLMatrixType(Ty))
+    if (HLMatrixType::isa(Ty))
       return annotation;
     DxilStructAnnotation *SA = dxilTypeSys.GetStructAnnotation(ST);
     if (SA) {
@@ -4830,13 +4830,13 @@ static unsigned AllocateSemanticIndex(
                                             FlatAnnotationList);
     }
     return updatedArgIdx;
-  } else if (Ty->isStructTy() && !dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (Ty->isStructTy() && !HLMatrixType::isa(Ty)) {
     unsigned fieldsCount = Ty->getStructNumElements();
     for (unsigned i = 0; i < fieldsCount; i++) {
       Type *EltTy = Ty->getStructElementType(i);
       argIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
                                      FlatAnnotationList);
-      if (!(EltTy->isStructTy() && !dxilutil::IsHLSLMatrixType(EltTy))) {
+      if (!(EltTy->isStructTy() && !HLMatrixType::isa(EltTy))) {
         // Update argIdx only when it is a leaf node.
         argIdx++;
       }
@@ -5159,7 +5159,7 @@ static void CastCopyOldPtrToNewPtr(Value *OldPtr, Value *NewPtr, HLModule &HLM,
       Value *Elt = Builder.CreateExtractElement(V, i);
       Builder.CreateStore(Elt, EltPtr);
     }
-  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
+  } else if (HLMatrixType::isa(OldTy)) {
     CopyMatPtrToArrayPtr(OldPtr, NewPtr, /*arrayBaseIdx*/ 0, HLM, Builder,
                          bRowMajor);
   } else if (OldTy->isArrayTy()) {
@@ -5189,7 +5189,7 @@ static void CastCopyNewPtrToOldPtr(Value *NewPtr, Value *OldPtr, HLModule &HLM,
       V = Builder.CreateInsertElement(V, Elt, i);
     }
     Builder.CreateStore(V, OldPtr);
-  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
+  } else if (HLMatrixType::isa(OldTy)) {
     CopyArrayPtrToMatPtr(NewPtr, /*arrayBaseIdx*/ 0, OldPtr, HLM, Builder,
                          bRowMajor);
   } else if (OldTy->isArrayTy()) {
@@ -5287,7 +5287,7 @@ void SROA_Parameter_HLSL::replaceCastParameter(
     // Must be in param.
     // Store NewParam to OldParam at entry.
     Builder.CreateStore(NewParam, OldParam);
-  } else if (dxilutil::IsHLSLMatrixType(OldTy)) {
+  } else if (HLMatrixType::isa(OldTy)) {
     bool bRowMajor = castRowMajorParamMap.count(NewParam);
     Value *Mat = LoadArrayPtrToMat(NewParam, /*arrayBaseIdx*/ 0, OldTy,
                                    *m_pHLModule, Builder, bRowMajor);
@@ -6024,7 +6024,7 @@ static void LegalizeDxilInputOutputs(Function *F,
 
     // Skip arg which is not a pointer.
     if (!Ty->isPointerTy()) {
-      if (dxilutil::IsHLSLMatrixType(Ty)) {
+      if (HLMatrixType::isa(Ty)) {
         // Replace matrix arg with cast to vec. It will be lowered in
         // DxilGenerationPass.
         isColMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
@@ -6077,7 +6077,7 @@ static void LegalizeDxilInputOutputs(Function *F,
       bStoreInputToTemp = true;
     }
 
-    if (dxilutil::IsHLSLMatrixType(Ty)) {
+    if (HLMatrixType::isa(Ty)) {
       if (qual == DxilParamInputQual::In)
         bStoreInputToTemp = bLoad;
       else if (qual == DxilParamInputQual::Out)

+ 7 - 7
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -4289,7 +4289,7 @@ static void SimpleTransformForHLDXIR(Instruction *I,
   } break;
   case Instruction::Load: {
     LoadInst *ldInst = cast<LoadInst>(I);
-    DXASSERT(!dxilutil::IsHLSLMatrixType(ldInst->getType()),
+    DXASSERT(!HLMatrixType::isa(ldInst->getType()),
                       "matrix load should use HL LdStMatrix");
     Value *Ptr = ldInst->getPointerOperand();
     if (ConstantExpr *CE = dyn_cast_or_null<ConstantExpr>(Ptr)) {
@@ -4301,7 +4301,7 @@ static void SimpleTransformForHLDXIR(Instruction *I,
   case Instruction::Store: {
     StoreInst *stInst = cast<StoreInst>(I);
     Value *V = stInst->getValueOperand();
-    DXASSERT_LOCALVAR(V, !dxilutil::IsHLSLMatrixType(V->getType()),
+    DXASSERT_LOCALVAR(V, !HLMatrixType::isa(V->getType()),
                       "matrix store should use HL LdStMatrix");
     Value *Ptr = stInst->getPointerOperand();
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) {
@@ -5280,7 +5280,7 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
         valEltTy->isSingleValueType()) {
       Value *ldVal = Builder.CreateLoad(val);
       FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
-    } else if (dxilutil::IsHLSLMatrixType(valEltTy)) {
+    } else if (HLMatrixType::isa(valEltTy)) {
       Value *ldVal = EmitHLSLMatrixLoad(Builder, val, Ty);
       FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
     } else {
@@ -5767,7 +5767,7 @@ static void FlatConstToList(CodeGenTypes &Types, bool bDefaultRowMajor,
       EltVals.emplace_back(C->getAggregateElement(i));
       EltQualTys.emplace_back(VecElemQualTy);
     }
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType::isa(Ty)) {
     DXASSERT(hlsl::IsHLSLMatType(QualTy), "QualType/Type mismatch!");
     // matrix type is struct { [rowcount x <colcount x T>] };
     // Strip the struct level here.
@@ -6687,7 +6687,7 @@ void CGMSHLSLRuntime::EmitHLSLAggregateCopy(
                           PT->getElementType());
 
     idxList.pop_back();
-  } else if (dxilutil::IsHLSLMatrixType(Ty)) {
+  } else if (HLMatrixType::isa(Ty)) {
     // Use matLd/St for matrix.
     Value *srcGEP = CGF.Builder.CreateInBoundsGEP(SrcPtr, idxList);
     Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
@@ -7181,7 +7181,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
         }
 
         llvm::Type *ToTy = tmpArgAddr->getType()->getPointerElementType();
-        if (dxilutil::IsHLSLMatrixType(ToTy)) {
+        if (HLMatrixType::isa(ToTy)) {
           Value *castVal = CGF.Builder.CreateBitCast(outVal, ToTy);
           EmitHLSLMatrixStore(CGF, castVal, tmpArgAddr, ParamTy);
         }
@@ -7250,7 +7250,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
           castVal = ConvertScalarOrVector(CGF,
             outVal, tmpLV.getType(), argLV.getType());
         }
-        if (!dxilutil::IsHLSLMatrixType(ToTy))
+        if (!HLMatrixType::isa(ToTy))
           CGF.EmitStoreThroughLValue(RValue::get(castVal), argLV);
         else {
           Value *destPtr = argLV.getAddress();

+ 4 - 4
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -784,9 +784,9 @@ void PrintFieldLayout(llvm::Type *Ty, DxilFieldAnnotation &annotation,
     llvm::Type *EltTy = Ty;
     unsigned arraySize = 0;
     unsigned arrayLevel = 0;
-    if (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isArrayTy()) {
+    if (!HLMatrixType::isa(EltTy) && EltTy->isArrayTy()) {
       arraySize = 1;
-      while (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isArrayTy()) {
+      while (!HLMatrixType::isa(EltTy) && EltTy->isArrayTy()) {
         arraySize *= EltTy->getArrayNumElements();
         EltTy = EltTy->getArrayElementType();
         arrayLevel++;
@@ -816,7 +816,7 @@ void PrintFieldLayout(llvm::Type *Ty, DxilFieldAnnotation &annotation,
     }
 
     std::string StreamStr;
-    if (!dxilutil::IsHLSLMatrixType(EltTy) && EltTy->isStructTy()) {
+    if (!HLMatrixType::isa(EltTy) && EltTy->isStructTy()) {
       std::string NameTypeStr = annotation.GetFieldName();
       raw_string_ostream Stream(NameTypeStr);
       if (arraySize)
@@ -900,7 +900,7 @@ void PrintStructBufferDefinition(DxilResource *buf,
   OS << comment << "\n";
   llvm::Type *RetTy = buf->GetRetType();
   // Skip none struct type.
-  if (!RetTy->isStructTy() || dxilutil::IsHLSLMatrixType(RetTy)) {
+  if (!RetTy->isStructTy() || HLMatrixType::isa(RetTy)) {
     llvm::Type *Ty = buf->GetGlobalSymbol()->getType()->getPointerElementType();
     // For resource array, use element type.
     if (Ty->isArrayTy())