Browse Source

Merged PR 84: Implement MatrixBitcastLowerPass for link matrix type.

Implement MatrixBitcastLowerPass for link matrix type.
Xiang_Li (XBox) 7 years ago
parent
commit
ae93061ee8

+ 0 - 2
lib/HLSL/ComputeViewIdState.cpp

@@ -632,8 +632,6 @@ void DxilViewIdState::CollectReachingDeclsRec(Value *pValue, ValueSetType &Reach
     CollectReachingDeclsRec(SelI->getFalseValue(), ReachingDecls, Visited);
   } else if (Argument *pArg = dyn_cast<Argument>(pValue)) {
     ReachingDecls.emplace(pValue);
-  } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(pValue)) {
-    CollectReachingDeclsRec(BCI->getOperand(0), ReachingDecls, Visited);
   } else {
     IFT(DXC_E_GENERAL_INTERNAL_ERROR);
   }

+ 3 - 0
lib/HLSL/DxilLinker.cpp

@@ -1035,6 +1035,9 @@ void DxilLinkJob::RunPreparePass(Module &M) {
   // SROA
   PM.add(createSROAPass(/*RequiresDomTree*/false));
 
+  // Remove MultiDimArray from function call arg.
+  PM.add(createMultiDimArrayToOneDimArrayPass());
+
   // Lower matrix bitcast.
   PM.add(createMatrixBitcastLowerPass());
 

+ 210 - 13
lib/HLSL/HLMatrixLowerPass.cpp

@@ -17,7 +17,8 @@
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/Support/Global.h"
 #include "dxc/HLSL/DxilOperations.h"
-#include "dxc/hlsl/DxilTypeSystem.h"
+#include "dxc/HLSL/DxilTypeSystem.h"
+#include "dxc/HLSL/DxilModule.h"
 
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
@@ -89,6 +90,18 @@ Type *LowerMatrixType(Type *Ty) {
   }
 }
 
+// Translate matrix type to array type.
+Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
+  if (IsMatrixType(Ty)) {
+    unsigned row, col;
+    Type *EltTy = GetMatrixInfo(Ty, col, row);
+    return ArrayType::get(EltTy, row * col);
+  } else {
+    return Ty;
+  }
+}
+
+
 Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
   DXASSERT(IsMatrixType(Ty), "not matrix type");
   StructType *ST = cast<StructType>(Ty);
@@ -110,6 +123,7 @@ bool IsMatrixArrayPointer(llvm::Type *Ty) {
   return IsMatrixType(Ty);
 }
 Type *LowerMatrixArrayPointer(Type *Ty) {
+  unsigned addrSpace = Ty->getPointerAddressSpace();
   Ty = Ty->getPointerElementType();
   std::vector<unsigned> arraySizeList;
   while (Ty->isArrayTy()) {
@@ -121,9 +135,25 @@ Type *LowerMatrixArrayPointer(Type *Ty) {
   for (auto arraySize = arraySizeList.rbegin();
        arraySize != arraySizeList.rend(); arraySize++)
     Ty = ArrayType::get(Ty, *arraySize);
-  return PointerType::get(Ty, 0);
+  return PointerType::get(Ty, addrSpace);
 }
 
+Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
+  unsigned addrSpace = Ty->getPointerAddressSpace();
+  Ty = Ty->getPointerElementType();
+
+  unsigned arraySize = 1;
+  while (Ty->isArrayTy()) {
+    arraySize *= Ty->getArrayNumElements();
+    Ty = Ty->getArrayElementType();
+  }
+  unsigned row, col;
+  Type *EltTy = GetMatrixInfo(Ty, col, row);
+  arraySize *= row*col;
+
+  Ty = ArrayType::get(EltTy, arraySize);
+  return PointerType::get(Ty, addrSpace);
+}
 Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
                    IRBuilder<> &Builder) {
   Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
@@ -475,7 +505,8 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
   case HLMatLoadStoreOpcode::ColMatLoad:
   case HLMatLoadStoreOpcode::RowMatLoad: {
     Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
-    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
+    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
+        GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
       Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
       result = Builder.CreateLoad(vecPtr);
     } else
@@ -484,7 +515,8 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
   case HLMatLoadStoreOpcode::ColMatStore:
   case HLMatLoadStoreOpcode::RowMatStore: {
     Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
-    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
+    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
+        GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
       Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
       Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
       Value *vecVal =
@@ -2179,7 +2211,8 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
       case HLOpcodeGroup::HLMatLoadStore: {
         DXASSERT(matToVecMap.count(useCall), "must have vec version");
         Value *vecUser = matToVecMap[useCall];
-        if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal)) {
+        if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal) ||
+            GetIfMatrixGEPOfUDTArg(matVal, *m_pHLModule)) {
           // Load Already translated in lowerToVec.
           // Store val operand will be set by the val use.
           // Do nothing here.
@@ -2508,7 +2541,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
   // The matrix operands will be undefval for these instructions.
   for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
     BasicBlock *BB = BBI;
-    for (Instruction &I : BB->getInstList()) {
+    for (auto II = BB->begin(); II != BB->end(); ) {
+      Instruction &I = *(II++);
       if (IsMatrixType(I.getType())) {
         lowerToVec(&I);
       } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
@@ -2531,7 +2565,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
           // Lower it here to make sure it is ready before replace.
           lowerToVec(&I);
         }
-      } else if (GetIfMatrixGEPOfUDTAlloca(&I)) {
+      } else if (GetIfMatrixGEPOfUDTAlloca(&I) ||
+                 GetIfMatrixGEPOfUDTArg(&I, *m_pHLModule)) {
         lowerToVec(&I);
       }
     }
@@ -2582,6 +2617,20 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
 //  %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
 
 namespace {
+
+Type *TryLowerMatTy(Type *Ty) {
+  Type *VecTy = nullptr;
+  if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
+    VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
+  } else if (isa<PointerType>(Ty) &&
+             HLMatrixLower::IsMatrixType(Ty->getPointerElementType())) {
+    VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
+        Ty->getPointerElementType());
+    VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
+  }
+  return VecTy;
+}
+
 class MatrixBitcastLowerPass : public FunctionPass {
 
 public:
@@ -2590,18 +2639,166 @@ public:
 
   const char *getPassName() const override { return "Matrix Bitcast lower"; }
   bool runOnFunction(Function &F) override {
-    // TODO: remove bitcast on matrix.
-    return false;
+    bool bUpdated = false;
+    std::unordered_set<BitCastInst*> matCastSet;
+    for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
+      BasicBlock *BB = blkIt;
+      for (auto iIt = BB->begin(); iIt != BB->end(); ) {
+        Instruction *I = (iIt++);
+        if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
+          // Mutate mat to vec.
+          Type *ToTy = BCI->getType();
+          if (Type *ToVecTy = TryLowerMatTy(ToTy)) {
+            matCastSet.insert(BCI);
+            bUpdated = true;
+          }
+        }
+      }
+    }
+
+    DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
+    // Remove bitcast which has CallInst user.
+    if (DM.GetShaderModel()->IsLib()) {
+      for (auto it = matCastSet.begin(); it != matCastSet.end();) {
+        BitCastInst *BCI = *(it++);
+        if (hasCallUser(BCI)) {
+          matCastSet.erase(BCI);
+        }
+      }
+    }
+
+    // Lower matrix first.
+    for (BitCastInst *BCI : matCastSet) {
+      lowerMatrix(BCI, BCI->getOperand(0));
+    }
+    return bUpdated;
   }
 private:
-  void lowerMatrixBitcast(BitCastInst *BCI);
+  void lowerMatrix(Instruction *M, Value *A);
+  bool hasCallUser(Instruction *M);
 };
 
 }
 
-void MatrixBitcastLowerPass::lowerMatrixBitcast(BitCastInst *BCI) {
-  // to matrix.
-  // from matrix.
+bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
+  for (auto it = M->user_begin(); it != M->user_end();) {
+    User *U = *(it++);
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      Type *EltTy = GEP->getType()->getPointerElementType();
+      if (HLMatrixLower::IsMatrixType(EltTy)) {
+        if (hasCallUser(GEP))
+          return true;
+      } else {
+        DXASSERT(0, "invalid GEP for matrix");
+      }
+    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
+      if (hasCallUser(BCI))
+        return true;
+    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
+      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
+      Value *V = ST->getValueOperand();
+      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (isa<CallInst>(U)) {
+      return true;
+    } else {
+      DXASSERT(0, "invalid use of matrix");
+    }
+  }
+  return false;
+}
+
+namespace {
+Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
+                    IRBuilder<> &Builder) {
+  Value *GEP = nullptr;
+  if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
+    // A should be gep oneDimArray, 0, index * matSize
+    // Here add eltIdx to index * matSize foreach elt.
+    Instruction *EltGEP = GEPA->clone();
+    unsigned eltIdx = EltGEP->getNumOperands() - 1;
+    Value *NewIdx =
+        Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
+    EltGEP->setOperand(eltIdx, NewIdx);
+    Builder.Insert(EltGEP);
+    GEP = EltGEP;
+  } else {
+    GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
+  }
+  return GEP;
+}
+} // namespace
+
+void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
+  for (auto it = M->user_begin(); it != M->user_end();) {
+    User *U = *(it++);
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      Type *EltTy = GEP->getType()->getPointerElementType();
+      if (HLMatrixLower::IsMatrixType(EltTy)) {
+        // Change gep matrixArray, 0, index
+        // into
+        //   gep oneDimArray, 0, index * matSize
+        IRBuilder<> Builder(GEP);
+        SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
+        DXASSERT(idxList.size() == 2,
+                 "else not one dim matrix array index to matrix");
+        unsigned col = 0;
+        unsigned row = 0;
+        HLMatrixLower::GetMatrixInfo(EltTy, col, row);
+        Value *matSize = Builder.getInt32(col * row);
+        idxList.back() = Builder.CreateMul(idxList.back(), matSize);
+        Value *NewGEP = Builder.CreateGEP(A, idxList);
+        lowerMatrix(GEP, NewGEP);
+        DXASSERT(GEP->user_empty(), "else lower matrix fail");
+        GEP->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid GEP for matrix");
+      }
+    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
+      lowerMatrix(BCI, A);
+      DXASSERT(BCI->user_empty(), "else lower matrix fail");
+      BCI->eraseFromParent();
+    } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
+      if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
+        IRBuilder<> Builder(LI);
+        Value *zeroIdx = Builder.getInt32(0);
+        unsigned vecSize = Ty->getNumElements();
+        Value *NewVec = UndefValue::get(LI->getType());
+        for (unsigned i = 0; i < vecSize; i++) {
+          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
+          Value *Elt = Builder.CreateLoad(GEP);
+          NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
+        }
+        LI->replaceAllUsesWith(NewVec);
+        LI->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
+      Value *V = ST->getValueOperand();
+      if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
+        IRBuilder<> Builder(LI);
+        Value *zeroIdx = Builder.getInt32(0);
+        unsigned vecSize = Ty->getNumElements();
+        for (unsigned i = 0; i < vecSize; i++) {
+          Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
+          Value *Elt = Builder.CreateExtractElement(V, i);
+          Builder.CreateStore(Elt, GEP);
+        }
+        ST->eraseFromParent();
+      } else {
+        DXASSERT(0, "invalid load for matrix");
+      }
+    } else {
+      DXASSERT(0, "invalid use of matrix");
+    }
+  }
 }
 
 #include "dxc/HLSL/DxilGenerationPass.h"

+ 14 - 0
tools/clang/test/CodeGenHLSL/quick-test/lib_mat_cast.hlsl

@@ -0,0 +1,14 @@
+// RUN: %dxc -T lib_6_3  %s | FileCheck %s
+
+// CHECK: bitcast %class.matrix.float.4.3* {{.*}} to <12 x float>*
+
+float mat_array_test(in float4 in0,
+                                  in float4 in1,
+                                  float4x3 basisArray[2]) // cube map basis.
+{
+uint index = in1.w;
+
+float3 outputs = mul(in0.xyz,
+            basisArray[index]);
+return outputs.z + basisArray[in1.y][in1.z].y;
+}

+ 13 - 0
tools/clang/test/CodeGenHLSL/quick-test/lib_mat_cast2.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -T lib_6_3  %s | FileCheck %s
+
+// CHECK: bitcast %class.matrix.float.4.3* {{.*}} to <12 x float>*
+
+float3 mat_test(in float4 in0,
+                                  in float4 in1,
+                                  inout float4x3 m)
+{
+uint basisIndex = in1.w;
+float3 outputs = mul(in0.xyz,
+            (float3x3)m);
+return outputs + m[in1.z];
+}

+ 20 - 0
tools/clang/test/CodeGenHLSL/quick-test/lib_mat_entry2.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -T lib_6_3  %s | FileCheck %s
+
+// CHECK: [[BCI:%.*]] = bitcast <12 x float>* %1 to %class.matrix.float.4.3*
+// CHECK:call <3 x float> @"\01?mat_test@@YA?AV?$vector@M$02@@V?$vector@M$03@@0AIAV?$matrix@M$03$02@@@Z"(<4 x float> {{.*}}, <4 x float> {{.*}}, %class.matrix.float.4.3* {{.*}}[[BCI]])
+
+float3 mat_test(in float4 in0,
+                                  in float4 in1,
+                                  inout float4x3 m);
+
+cbuffer A {
+float4 g0;
+float4 g1;
+float4x3 M;
+};
+
+[shader("pixel")]
+float3 main() : SV_Target {
+  float4x3 m = M;
+  return mat_test( g0, g1, m);
+}

+ 59 - 0
tools/clang/unittests/HLSL/LinkerTest.cpp

@@ -44,6 +44,9 @@ public:
   TEST_METHOD(RunLinkFailReDefine);
   TEST_METHOD(RunLinkGlobalInit);
   TEST_METHOD(RunLinkNoAlloca);
+  TEST_METHOD(RunLinkMatArrayParam);
+  TEST_METHOD(RunLinkMatParam);
+  TEST_METHOD(RunLinkMatParamToLib);
   TEST_METHOD(RunLinkResRet);
   TEST_METHOD(RunLinkToLib);
   TEST_METHOD(RunLinkToLibExport);
@@ -300,6 +303,62 @@ TEST_F(LinkerTest, RunLinkNoAlloca) {
   Link(L"ps_main", L"ps_6_0", pLinker, {libName, libName2}, {}, {"alloca"});
 }
 
+TEST_F(LinkerTest, RunLinkMatArrayParam) {
+  CComPtr<IDxcBlob> pEntryLib;
+  CompileLib(L"..\\CodeGenHLSL\\quick-test\\lib_mat_entry.hlsl", &pEntryLib);
+  CComPtr<IDxcBlob> pLib;
+  CompileLib(L"..\\CodeGenHLSL\\quick-test\\lib_mat_cast.hlsl", &pLib);
+
+  CComPtr<IDxcLinker> pLinker;
+  CreateLinker(&pLinker);
+
+  LPCWSTR libName = L"ps_main";
+  RegisterDxcModule(libName, pEntryLib, pLinker);
+
+  LPCWSTR libName2 = L"test";
+  RegisterDxcModule(libName2, pLib, pLinker);
+
+  Link(L"main", L"ps_6_0", pLinker, {libName, libName2},
+       {"alloca [24 x float]", "getelementptr [12 x float], [12 x float]*"},
+       {});
+}
+
+TEST_F(LinkerTest, RunLinkMatParam) {
+  CComPtr<IDxcBlob> pEntryLib;
+  CompileLib(L"..\\CodeGenHLSL\\quick-test\\lib_mat_entry2.hlsl", &pEntryLib);
+  CComPtr<IDxcBlob> pLib;
+  CompileLib(L"..\\CodeGenHLSL\\quick-test\\lib_mat_cast2.hlsl", &pLib);
+
+  CComPtr<IDxcLinker> pLinker;
+  CreateLinker(&pLinker);
+
+  LPCWSTR libName = L"ps_main";
+  RegisterDxcModule(libName, pEntryLib, pLinker);
+
+  LPCWSTR libName2 = L"test";
+  RegisterDxcModule(libName2, pLib, pLinker);
+
+  Link(L"main", L"ps_6_0", pLinker, {libName, libName2},
+       {"alloca [12 x float]"},
+       {});
+}
+
+TEST_F(LinkerTest, RunLinkMatParamToLib) {
+  CComPtr<IDxcBlob> pEntryLib;
+  CompileLib(L"..\\CodeGenHLSL\\quick-test\\lib_mat_entry2.hlsl", &pEntryLib);
+
+  CComPtr<IDxcLinker> pLinker;
+  CreateLinker(&pLinker);
+
+  LPCWSTR libName = L"ps_main";
+  RegisterDxcModule(libName, pEntryLib, pLinker);
+
+  Link(L"", L"lib_6_3", pLinker, {libName},
+       // The bitcast cannot be removed because user function call use it as
+       // argument.
+       {"bitcast <12 x float>* %1 to %class.matrix.float.4.3*"}, {});
+}
+
 TEST_F(LinkerTest, RunLinkResRet) {
   CComPtr<IDxcBlob> pEntryLib;
   CompileLib(L"..\\CodeGenHLSL\\shader-compat-suite\\lib_out_param_res.hlsl", &pEntryLib);