|  | @@ -17,7 +17,8 @@
 | 
	
		
			
				|  |  |  #include "dxc/HlslIntrinsicOp.h"
 | 
	
		
			
				|  |  |  #include "dxc/Support/Global.h"
 | 
	
		
			
				|  |  |  #include "dxc/HLSL/DxilOperations.h"
 | 
	
		
			
				|  |  | -#include "dxc/hlsl/DxilTypeSystem.h"
 | 
	
		
			
				|  |  | +#include "dxc/HLSL/DxilTypeSystem.h"
 | 
	
		
			
				|  |  | +#include "dxc/HLSL/DxilModule.h"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  #include "llvm/IR/IRBuilder.h"
 | 
	
		
			
				|  |  |  #include "llvm/IR/Module.h"
 | 
	
	
		
			
				|  | @@ -89,6 +90,18 @@ Type *LowerMatrixType(Type *Ty) {
 | 
	
		
			
				|  |  |    }
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +// Translate matrix type to array type.
 | 
	
		
			
				|  |  | +Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
 | 
	
		
			
				|  |  | +  if (IsMatrixType(Ty)) {
 | 
	
		
			
				|  |  | +    unsigned row, col;
 | 
	
		
			
				|  |  | +    Type *EltTy = GetMatrixInfo(Ty, col, row);
 | 
	
		
			
				|  |  | +    return ArrayType::get(EltTy, row * col);
 | 
	
		
			
				|  |  | +  } else {
 | 
	
		
			
				|  |  | +    return Ty;
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
 | 
	
		
			
				|  |  |    DXASSERT(IsMatrixType(Ty), "not matrix type");
 | 
	
		
			
				|  |  |    StructType *ST = cast<StructType>(Ty);
 | 
	
	
		
			
				|  | @@ -110,6 +123,7 @@ bool IsMatrixArrayPointer(llvm::Type *Ty) {
 | 
	
		
			
				|  |  |    return IsMatrixType(Ty);
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  Type *LowerMatrixArrayPointer(Type *Ty) {
 | 
	
		
			
				|  |  | +  unsigned addrSpace = Ty->getPointerAddressSpace();
 | 
	
		
			
				|  |  |    Ty = Ty->getPointerElementType();
 | 
	
		
			
				|  |  |    std::vector<unsigned> arraySizeList;
 | 
	
		
			
				|  |  |    while (Ty->isArrayTy()) {
 | 
	
	
		
			
				|  | @@ -121,9 +135,25 @@ Type *LowerMatrixArrayPointer(Type *Ty) {
 | 
	
		
			
				|  |  |    for (auto arraySize = arraySizeList.rbegin();
 | 
	
		
			
				|  |  |         arraySize != arraySizeList.rend(); arraySize++)
 | 
	
		
			
				|  |  |      Ty = ArrayType::get(Ty, *arraySize);
 | 
	
		
			
				|  |  | -  return PointerType::get(Ty, 0);
 | 
	
		
			
				|  |  | +  return PointerType::get(Ty, addrSpace);
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
 | 
	
		
			
				|  |  | +  unsigned addrSpace = Ty->getPointerAddressSpace();
 | 
	
		
			
				|  |  | +  Ty = Ty->getPointerElementType();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  unsigned arraySize = 1;
 | 
	
		
			
				|  |  | +  while (Ty->isArrayTy()) {
 | 
	
		
			
				|  |  | +    arraySize *= Ty->getArrayNumElements();
 | 
	
		
			
				|  |  | +    Ty = Ty->getArrayElementType();
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +  unsigned row, col;
 | 
	
		
			
				|  |  | +  Type *EltTy = GetMatrixInfo(Ty, col, row);
 | 
	
		
			
				|  |  | +  arraySize *= row*col;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +  Ty = ArrayType::get(EltTy, arraySize);
 | 
	
		
			
				|  |  | +  return PointerType::get(Ty, addrSpace);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  |  Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
 | 
	
		
			
				|  |  |                     IRBuilder<> &Builder) {
 | 
	
		
			
				|  |  |    Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
 | 
	
	
		
			
				|  | @@ -475,7 +505,8 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
 | 
	
		
			
				|  |  |    case HLMatLoadStoreOpcode::ColMatLoad:
 | 
	
		
			
				|  |  |    case HLMatLoadStoreOpcode::RowMatLoad: {
 | 
	
		
			
				|  |  |      Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
 | 
	
		
			
				|  |  | -    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
 | 
	
		
			
				|  |  | +    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
 | 
	
		
			
				|  |  | +        GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
 | 
	
		
			
				|  |  |        Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
 | 
	
		
			
				|  |  |        result = Builder.CreateLoad(vecPtr);
 | 
	
		
			
				|  |  |      } else
 | 
	
	
		
			
				|  | @@ -484,7 +515,8 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
 | 
	
		
			
				|  |  |    case HLMatLoadStoreOpcode::ColMatStore:
 | 
	
		
			
				|  |  |    case HLMatLoadStoreOpcode::RowMatStore: {
 | 
	
		
			
				|  |  |      Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
 | 
	
		
			
				|  |  | -    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
 | 
	
		
			
				|  |  | +    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
 | 
	
		
			
				|  |  | +        GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
 | 
	
		
			
				|  |  |        Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
 | 
	
		
			
				|  |  |        Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
 | 
	
		
			
				|  |  |        Value *vecVal =
 | 
	
	
		
			
				|  | @@ -2179,7 +2211,8 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
 | 
	
		
			
				|  |  |        case HLOpcodeGroup::HLMatLoadStore: {
 | 
	
		
			
				|  |  |          DXASSERT(matToVecMap.count(useCall), "must have vec version");
 | 
	
		
			
				|  |  |          Value *vecUser = matToVecMap[useCall];
 | 
	
		
			
				|  |  | -        if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal)) {
 | 
	
		
			
				|  |  | +        if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal) ||
 | 
	
		
			
				|  |  | +            GetIfMatrixGEPOfUDTArg(matVal, *m_pHLModule)) {
 | 
	
		
			
				|  |  |            // Load Already translated in lowerToVec.
 | 
	
		
			
				|  |  |            // Store val operand will be set by the val use.
 | 
	
		
			
				|  |  |            // Do nothing here.
 | 
	
	
		
			
				|  | @@ -2508,7 +2541,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
 | 
	
		
			
				|  |  |    // The matrix operands will be undefval for these instructions.
 | 
	
		
			
				|  |  |    for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
 | 
	
		
			
				|  |  |      BasicBlock *BB = BBI;
 | 
	
		
			
				|  |  | -    for (Instruction &I : BB->getInstList()) {
 | 
	
		
			
				|  |  | +    for (auto II = BB->begin(); II != BB->end(); ) {
 | 
	
		
			
				|  |  | +      Instruction &I = *(II++);
 | 
	
		
			
				|  |  |        if (IsMatrixType(I.getType())) {
 | 
	
		
			
				|  |  |          lowerToVec(&I);
 | 
	
		
			
				|  |  |        } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
 | 
	
	
		
			
				|  | @@ -2531,7 +2565,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
 | 
	
		
			
				|  |  |            // Lower it here to make sure it is ready before replace.
 | 
	
		
			
				|  |  |            lowerToVec(&I);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -      } else if (GetIfMatrixGEPOfUDTAlloca(&I)) {
 | 
	
		
			
				|  |  | +      } else if (GetIfMatrixGEPOfUDTAlloca(&I) ||
 | 
	
		
			
				|  |  | +                 GetIfMatrixGEPOfUDTArg(&I, *m_pHLModule)) {
 | 
	
		
			
				|  |  |          lowerToVec(&I);
 | 
	
		
			
				|  |  |        }
 | 
	
		
			
				|  |  |      }
 | 
	
	
		
			
				|  | @@ -2582,6 +2617,20 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
 | 
	
		
			
				|  |  |  //  %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  namespace {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +Type *TryLowerMatTy(Type *Ty) {
 | 
	
		
			
				|  |  | +  Type *VecTy = nullptr;
 | 
	
		
			
				|  |  | +  if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
 | 
	
		
			
				|  |  | +    VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
 | 
	
		
			
				|  |  | +  } else if (isa<PointerType>(Ty) &&
 | 
	
		
			
				|  |  | +             HLMatrixLower::IsMatrixType(Ty->getPointerElementType())) {
 | 
	
		
			
				|  |  | +    VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
 | 
	
		
			
				|  |  | +        Ty->getPointerElementType());
 | 
	
		
			
				|  |  | +    VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +  return VecTy;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class MatrixBitcastLowerPass : public FunctionPass {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  public:
 | 
	
	
		
			
				|  | @@ -2590,18 +2639,166 @@ public:
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |    const char *getPassName() const override { return "Matrix Bitcast lower"; }
 | 
	
		
			
				|  |  |    bool runOnFunction(Function &F) override {
 | 
	
		
			
				|  |  | -    // TODO: remove bitcast on matrix.
 | 
	
		
			
				|  |  | -    return false;
 | 
	
		
			
				|  |  | +    bool bUpdated = false;
 | 
	
		
			
				|  |  | +    std::unordered_set<BitCastInst*> matCastSet;
 | 
	
		
			
				|  |  | +    for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
 | 
	
		
			
				|  |  | +      BasicBlock *BB = blkIt;
 | 
	
		
			
				|  |  | +      for (auto iIt = BB->begin(); iIt != BB->end(); ) {
 | 
	
		
			
				|  |  | +        Instruction *I = (iIt++);
 | 
	
		
			
				|  |  | +        if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
 | 
	
		
			
				|  |  | +          // Mutate mat to vec.
 | 
	
		
			
				|  |  | +          Type *ToTy = BCI->getType();
 | 
	
		
			
				|  |  | +          if (Type *ToVecTy = TryLowerMatTy(ToTy)) {
 | 
	
		
			
				|  |  | +            matCastSet.insert(BCI);
 | 
	
		
			
				|  |  | +            bUpdated = true;
 | 
	
		
			
				|  |  | +          }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
 | 
	
		
			
				|  |  | +    // Remove bitcast which has CallInst user.
 | 
	
		
			
				|  |  | +    if (DM.GetShaderModel()->IsLib()) {
 | 
	
		
			
				|  |  | +      for (auto it = matCastSet.begin(); it != matCastSet.end();) {
 | 
	
		
			
				|  |  | +        BitCastInst *BCI = *(it++);
 | 
	
		
			
				|  |  | +        if (hasCallUser(BCI)) {
 | 
	
		
			
				|  |  | +          matCastSet.erase(BCI);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    // Lower matrix first.
 | 
	
		
			
				|  |  | +    for (BitCastInst *BCI : matCastSet) {
 | 
	
		
			
				|  |  | +      lowerMatrix(BCI, BCI->getOperand(0));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    return bUpdated;
 | 
	
		
			
				|  |  |    }
 | 
	
		
			
				|  |  |  private:
 | 
	
		
			
				|  |  | -  void lowerMatrixBitcast(BitCastInst *BCI);
 | 
	
		
			
				|  |  | +  void lowerMatrix(Instruction *M, Value *A);
 | 
	
		
			
				|  |  | +  bool hasCallUser(Instruction *M);
 | 
	
		
			
				|  |  |  };
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -void MatrixBitcastLowerPass::lowerMatrixBitcast(BitCastInst *BCI) {
 | 
	
		
			
				|  |  | -  // to matrix.
 | 
	
		
			
				|  |  | -  // from matrix.
 | 
	
		
			
				|  |  | +bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
 | 
	
		
			
				|  |  | +  for (auto it = M->user_begin(); it != M->user_end();) {
 | 
	
		
			
				|  |  | +    User *U = *(it++);
 | 
	
		
			
				|  |  | +    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
 | 
	
		
			
				|  |  | +      Type *EltTy = GEP->getType()->getPointerElementType();
 | 
	
		
			
				|  |  | +      if (HLMatrixLower::IsMatrixType(EltTy)) {
 | 
	
		
			
				|  |  | +        if (hasCallUser(GEP))
 | 
	
		
			
				|  |  | +          return true;
 | 
	
		
			
				|  |  | +      } else {
 | 
	
		
			
				|  |  | +        DXASSERT(0, "invalid GEP for matrix");
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
 | 
	
		
			
				|  |  | +      if (hasCallUser(BCI))
 | 
	
		
			
				|  |  | +        return true;
 | 
	
		
			
				|  |  | +    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
 | 
	
		
			
				|  |  | +      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
 | 
	
		
			
				|  |  | +      } else {
 | 
	
		
			
				|  |  | +        DXASSERT(0, "invalid load for matrix");
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
 | 
	
		
			
				|  |  | +      Value *V = ST->getValueOperand();
 | 
	
		
			
				|  |  | +      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
 | 
	
		
			
				|  |  | +      } else {
 | 
	
		
			
				|  |  | +        DXASSERT(0, "invalid load for matrix");
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    } else if (isa<CallInst>(U)) {
 | 
	
		
			
				|  |  | +      return true;
 | 
	
		
			
				|  |  | +    } else {
 | 
	
		
			
				|  |  | +      DXASSERT(0, "invalid use of matrix");
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +  return false;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +namespace {
 | 
	
		
			
				|  |  | +Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
 | 
	
		
			
				|  |  | +                    IRBuilder<> &Builder) {
 | 
	
		
			
				|  |  | +  Value *GEP = nullptr;
 | 
	
		
			
				|  |  | +  if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
 | 
	
		
			
				|  |  | +    // A should be gep oneDimArray, 0, index * matSize
 | 
	
		
			
				|  |  | +    // Here add eltIdx to index * matSize foreach elt.
 | 
	
		
			
				|  |  | +    Instruction *EltGEP = GEPA->clone();
 | 
	
		
			
				|  |  | +    unsigned eltIdx = EltGEP->getNumOperands() - 1;
 | 
	
		
			
				|  |  | +    Value *NewIdx =
 | 
	
		
			
				|  |  | +        Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
 | 
	
		
			
				|  |  | +    EltGEP->setOperand(eltIdx, NewIdx);
 | 
	
		
			
				|  |  | +    Builder.Insert(EltGEP);
 | 
	
		
			
				|  |  | +    GEP = EltGEP;
 | 
	
		
			
				|  |  | +  } else {
 | 
	
		
			
				|  |  | +    GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  | +  return GEP;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +} // namespace
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
 | 
	
		
			
				|  |  | +  for (auto it = M->user_begin(); it != M->user_end();) {
 | 
	
		
			
				|  |  | +    User *U = *(it++);
 | 
	
		
			
				|  |  | +    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
 | 
	
		
			
				|  |  | +      Type *EltTy = GEP->getType()->getPointerElementType();
 | 
	
		
			
				|  |  | +      if (HLMatrixLower::IsMatrixType(EltTy)) {
 | 
	
		
			
				|  |  | +        // Change gep matrixArray, 0, index
 | 
	
		
			
				|  |  | +        // into
 | 
	
		
			
				|  |  | +        //   gep oneDimArray, 0, index * matSize
 | 
	
		
			
				|  |  | +        IRBuilder<> Builder(GEP);
 | 
	
		
			
				|  |  | +        SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
 | 
	
		
			
				|  |  | +        DXASSERT(idxList.size() == 2,
 | 
	
		
			
				|  |  | +                 "else not one dim matrix array index to matrix");
 | 
	
		
			
				|  |  | +        unsigned col = 0;
 | 
	
		
			
				|  |  | +        unsigned row = 0;
 | 
	
		
			
				|  |  | +        HLMatrixLower::GetMatrixInfo(EltTy, col, row);
 | 
	
		
			
				|  |  | +        Value *matSize = Builder.getInt32(col * row);
 | 
	
		
			
				|  |  | +        idxList.back() = Builder.CreateMul(idxList.back(), matSize);
 | 
	
		
			
				|  |  | +        Value *NewGEP = Builder.CreateGEP(A, idxList);
 | 
	
		
			
				|  |  | +        lowerMatrix(GEP, NewGEP);
 | 
	
		
			
				|  |  | +        DXASSERT(GEP->user_empty(), "else lower matrix fail");
 | 
	
		
			
				|  |  | +        GEP->eraseFromParent();
 | 
	
		
			
				|  |  | +      } else {
 | 
	
		
			
				|  |  | +        DXASSERT(0, "invalid GEP for matrix");
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
 | 
	
		
			
				|  |  | +      lowerMatrix(BCI, A);
 | 
	
		
			
				|  |  | +      DXASSERT(BCI->user_empty(), "else lower matrix fail");
 | 
	
		
			
				|  |  | +      BCI->eraseFromParent();
 | 
	
		
			
				|  |  | +    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
 | 
	
		
			
				|  |  | +      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
 | 
	
		
			
				|  |  | +        IRBuilder<> Builder(LI);
 | 
	
		
			
				|  |  | +        Value *zeroIdx = Builder.getInt32(0);
 | 
	
		
			
				|  |  | +        unsigned vecSize = Ty->getNumElements();
 | 
	
		
			
				|  |  | +        Value *NewVec = UndefValue::get(LI->getType());
 | 
	
		
			
				|  |  | +        for (unsigned i = 0; i < vecSize; i++) {
 | 
	
		
			
				|  |  | +          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
 | 
	
		
			
				|  |  | +          Value *Elt = Builder.CreateLoad(GEP);
 | 
	
		
			
				|  |  | +          NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        LI->replaceAllUsesWith(NewVec);
 | 
	
		
			
				|  |  | +        LI->eraseFromParent();
 | 
	
		
			
				|  |  | +      } else {
 | 
	
		
			
				|  |  | +        DXASSERT(0, "invalid load for matrix");
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
 | 
	
		
			
				|  |  | +      Value *V = ST->getValueOperand();
 | 
	
		
			
				|  |  | +      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
 | 
	
		
			
				|  |  | +        IRBuilder<> Builder(LI);
 | 
	
		
			
				|  |  | +        Value *zeroIdx = Builder.getInt32(0);
 | 
	
		
			
				|  |  | +        unsigned vecSize = Ty->getNumElements();
 | 
	
		
			
				|  |  | +        for (unsigned i = 0; i < vecSize; i++) {
 | 
	
		
			
				|  |  | +          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
 | 
	
		
			
				|  |  | +          Value *Elt = Builder.CreateExtractElement(V, i);
 | 
	
		
			
				|  |  | +          Builder.CreateStore(Elt, GEP);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        ST->eraseFromParent();
 | 
	
		
			
				|  |  | +      } else {
 | 
	
		
			
				|  |  | +        DXASSERT(0, "invalid load for matrix");
 | 
	
		
			
				|  |  | +      }
 | 
	
		
			
				|  |  | +    } else {
 | 
	
		
			
				|  |  | +      DXASSERT(0, "invalid use of matrix");
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +  }
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  #include "dxc/HLSL/DxilGenerationPass.h"
 |