Ver código fonte

Merged PR 82: Fix matrix array as parameter for external function.

Fix matrix array as parameter for external function.
Xiang_Li (XBox) 7 anos atrás
pai
commit
4491b6230b

+ 2 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -70,6 +70,7 @@ ModulePass *createDxilTranslateRawBuffer();
 ModulePass *createNoPausePassesPass();
 ModulePass *createPausePassesPass();
 ModulePass *createResumePassesPass();
+FunctionPass *createMatrixBitcastLowerPass();
 
 void initializeDxilCondenseResourcesPass(llvm::PassRegistry&);
 void initializeDxilLowerCreateHandleForLibPass(llvm::PassRegistry&);
@@ -98,6 +99,7 @@ void initializeDxilTranslateRawBufferPass(llvm::PassRegistry&);
 void initializeNoPausePassesPass(llvm::PassRegistry&);
 void initializePausePassesPass(llvm::PassRegistry&);
 void initializeResumePassesPass(llvm::PassRegistry&);
+void initializeMatrixBitcastLowerPassPass(llvm::PassRegistry&);
 
 bool AreDxilResourcesDense(llvm::Module *M, hlsl::DxilResourceBase **ppNonDense);
 

+ 5 - 0
include/dxc/HLSL/HLMatrixLowerHelper.h

@@ -22,9 +22,14 @@ namespace llvm {
 
 namespace hlsl {
 
+class DxilFieldAnnotation;
+class DxilTypeSystem;
+
 namespace HLMatrixLower {
 // TODO: use type annotation.
 bool IsMatrixType(llvm::Type *Ty);
+DxilFieldAnnotation *FindAnnotationFromMatUser(llvm::Value *Mat,
+                                               DxilTypeSystem &typeSys);
 // Translate matrix type to vector type.
 llvm::Type *LowerMatrixType(llvm::Type *Ty);
 // TODO: use type annotation.

+ 2 - 0
lib/HLSL/ComputeViewIdState.cpp

@@ -632,6 +632,8 @@ void DxilViewIdState::CollectReachingDeclsRec(Value *pValue, ValueSetType &Reach
     CollectReachingDeclsRec(SelI->getFalseValue(), ReachingDecls, Visited);
   } else if (Argument *pArg = dyn_cast<Argument>(pValue)) {
     ReachingDecls.emplace(pValue);
+  } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(pValue)) {
+    CollectReachingDeclsRec(BCI->getOperand(0), ReachingDecls, Visited);
   } else {
     IFT(DXC_E_GENERAL_INTERNAL_ERROR);
   }

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -142,6 +142,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeLowerBitSetsPass(Registry);
     initializeLowerExpectIntrinsicPass(Registry);
     initializeLowerStaticGlobalIntoAllocaPass(Registry);
+    initializeMatrixBitcastLowerPassPass(Registry);
     initializeMergeFunctionsPass(Registry);
     initializeMergedLoadStoreMotionPass(Registry);
     initializeMultiDimArrayToOneDimArrayPass(Registry);

+ 3 - 0
lib/HLSL/DxilLinker.cpp

@@ -1035,6 +1035,9 @@ void DxilLinkJob::RunPreparePass(Module &M) {
   // SROA
   PM.add(createSROAPass(/*RequiresDomTree*/false));
 
+  // Lower matrix bitcast.
+  PM.add(createMatrixBitcastLowerPass());
+
   // mem2reg.
   PM.add(createPromoteMemoryToRegisterPass());
 

+ 85 - 15
lib/HLSL/HLMatrixLowerPass.cpp

@@ -17,6 +17,7 @@
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/Support/Global.h"
 #include "dxc/HLSL/DxilOperations.h"
+#include "dxc/hlsl/DxilTypeSystem.h"
 
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
@@ -48,6 +49,24 @@ bool IsMatrixType(Type *Ty) {
   return false;
 }
 
+// If user is function call, return param annotation to get matrix major.
+DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
+                                               DxilTypeSystem &typeSys) {
+  for (User *U : Mat->users()) {
+    if (CallInst *CI = dyn_cast<CallInst>(U)) {
+      Function *F = CI->getCalledFunction();
+      if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
+        for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
+          if (CI->getArgOperand(i) == Mat) {
+            return &Anno->GetParameterAnnotation(i);
+          }
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
 // Translate matrix type to vector type.
 Type *LowerMatrixType(Type *Ty) {
   // Only translate matrix type and function type which use matrix type.
@@ -284,7 +303,7 @@ private:
   // Lower users of a matrix type instruction.
   void replaceMatWithVec(Value *matVal, Value *vecVal);
   // Translate user library function call arguments
-  void castMatrixArgs(Instruction *I);
+  void castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI);
   // Translate mat inst which need all operands ready.
   void finalMatTranslation(Value *matVal);
   // Delete dead insts in m_deadInsts.
@@ -432,6 +451,21 @@ static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
   return nullptr;
 }
 
+// Return GEP if value is Matrix resulting GEP from UDT argument of
+// none-graphics functions.
+static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
+  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
+    if (IsMatrixType(GEP->getResultElementType())) {
+      Value *ptr = GEP->getPointerOperand();
+      if (Argument *Arg = dyn_cast<Argument>(ptr)) {
+        if (!HM.IsGraphicsShader(Arg->getParent()))
+          return GEP;
+      }
+    }
+  }
+  return nullptr;
+}
+
 Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
   IRBuilder<> Builder(CI);
   unsigned opcode = GetHLOpcode(CI);
@@ -876,7 +910,8 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
     if (HLModule::HasPreciseAttributeWithMetadata(AI))
       HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
 
-  } else if (GetIfMatrixGEPOfUDTAlloca(matInst)) {
+  } else if (GetIfMatrixGEPOfUDTAlloca(matInst) ||
+             GetIfMatrixGEPOfUDTArg(matInst, *m_pHLModule)) {
     // If GEP from alloca of non-matrix UDT, bitcast
     IRBuilder<> Builder(matInst->getNextNode());
     vecVal = Builder.CreateBitCast(matInst,
@@ -2157,6 +2192,8 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
       case HLOpcodeGroup::HLSubscript: {
         if (AllocaInst *AI = dyn_cast<AllocaInst>(vecVal))
           TranslateMatSubscript(matVal, vecVal, useCall);
+        else if (BitCastInst *BCI = dyn_cast<BitCastInst>(vecVal))
+          TranslateMatSubscript(matVal, vecVal, useCall);
         else
           TrivialMatReplace(matVal, vecVal, useCall);
 
@@ -2166,7 +2203,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
         TranslateMatInit(useCall);
       } break;
       case HLOpcodeGroup::NotHL: {
-        castMatrixArgs(useCall);
+        castMatrixArgs(matVal, vecVal, useCall);
       } break;
       }
     } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
@@ -2187,19 +2224,16 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
   }
 }
 
-void HLMatrixLowerPass::castMatrixArgs(Instruction *I) {
+void HLMatrixLowerPass::castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI) {
   // translate user function parameters as necessary
-  for (unsigned i = 0; i < I->getNumOperands(); i++) {
-    Value *argVal = I->getOperand(i);
-    Type *argTy = argVal->getType();
-    if (argTy->isPointerTy())
-      argTy = argTy->getPointerElementType();
-    if (argTy->isStructTy() && IsMatrixType(argTy)) {
-      Value *vecVal = matToVecMap[argVal];
-      Value *newMatVal = GetMatrixForVec(vecVal, argVal->getType());
-      if (argVal != newMatVal)
-        I->setOperand(i, newMatVal);
-    }
+  Type *Ty = matVal->getType();
+  if (Ty->isPointerTy()) {
+    IRBuilder<> Builder(CI);
+    Value *newMatVal = Builder.CreateBitCast(vecVal, Ty);
+    CI->replaceUsesOfWith(matVal, newMatVal);
+  } else {
+    Value *newMatVal = GetMatrixForVec(vecVal, Ty);
+    CI->replaceUsesOfWith(matVal, newMatVal);
   }
 }
 
@@ -2539,3 +2573,39 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
 
   return;
 }
+
+// Matrix Bitcast lower.
+// After linking Lower matrix bitcast patterns like:
+//  %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
+//  %conv.i = fptoui float %164 to i32
+//  %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
+//  %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
+
+namespace {
+class MatrixBitcastLowerPass : public FunctionPass {
+
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
+
+  const char *getPassName() const override { return "Matrix Bitcast lower"; }
+  bool runOnFunction(Function &F) override {
+    // TODO: remove bitcast on matrix.
+    return false;
+  }
+private:
+  void lowerMatrixBitcast(BitCastInst *BCI);
+};
+
+}
+
+void MatrixBitcastLowerPass::lowerMatrixBitcast(BitCastInst *BCI) {
+  // to matrix.
+  // from matrix.
+}
+
+#include "dxc/HLSL/DxilGenerationPass.h"
+char MatrixBitcastLowerPass::ID = 0;
+FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
+
+INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)

+ 13 - 0
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -2566,6 +2566,13 @@ void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
   if (DestTy != SrcTy) {
     return;
   }
+  // Try to find fieldAnnotation from user of Dest/Src.
+  if (!fieldAnnotation) {
+    Type *EltTy = dxilutil::GetArrayEltTy(DestTy);
+    if (HLMatrixLower::IsMatrixType(EltTy)) {
+      fieldAnnotation = HLMatrixLower::FindAnnotationFromMatUser(Dest, typeSys);
+    }
+  }
 
   llvm::SmallVector<llvm::Value *, 16> idxList;
   // split
@@ -6653,6 +6660,8 @@ void DynamicIndexingVectorToArray::ReplaceVectorArrayWithArray(Value *VA, Value
       IRBuilder<> Builder(GEPOp->getContext());
       SmallVector<Value *, 4> idxList(GEPOp->idx_begin(), GEPOp->idx_end());
       ReplaceVecArrayGEP(GEPOp, idxList, A, Builder);
+    } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User)) {
+      BCI->setOperand(0, A);
     } else {
       DXASSERT(0, "Array pointer should only used by GEP");
     }
@@ -6799,6 +6808,10 @@ void MultiDimArrayToOneDimArray::lowerUseWithNewValue(Value *MultiDim, Value *On
     User *U = *(it++);
     if (U->user_empty())
       continue;
+    if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
+      BCI->setOperand(0, OneDim);
+      continue;
+    }
     // Must be GEP.
     GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U);
 

+ 29 - 0
tools/clang/test/CodeGenHLSL/quick-test/lib_mat_entry.hlsl

@@ -0,0 +1,29 @@
+// RUN: %dxc -T lib_6_3  %s | FileCheck %s
+
+
+// CHECK: @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %A, i32 2)
+// CHECK: @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %A, i32 3)
+// CHECK: @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %A, i32 4)
+
+// CHECK: @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %A, i32 5)
+// CHECK: @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %A, i32 6)
+// CHECK: @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %A, i32 7)
+
+
+// CHECK: [[BCI:%.*]] = bitcast [24 x float]* %1 to [2 x %class.matrix.float.4.3]*
+// CHECK: call float @"\01?mat_array_test@@YAMV?$vector@M$03@@0Y01V?$matrix@M$03$02@@@Z"(<4 x float> {{.*}}, <4 x float> {{.*}}, [2 x %class.matrix.float.4.3]* [[BCI]]
+
+float mat_array_test(in float4 inGBuffer0,
+                                  in float4 inGBuffer1,
+                                  float4x3 basisArray[2]);
+
+cbuffer A {
+float4 g0;
+float4 g1;
+float4x3 m[2];
+};
+
+[shader("pixel")]
+float main() : SV_Target {
+  return mat_array_test( g0, g1, m);
+}

+ 1 - 0
utils/hct/hctdb.py

@@ -1466,6 +1466,7 @@ class db_dxil(object):
         add_pass('scalarreplhlsl-ssa', 'SROA_SSAUp_HLSL', 'Scalar Replacement of Aggregates HLSL (SSAUp)', [])
         add_pass('static-global-to-alloca', 'LowerStaticGlobalIntoAlloca', 'Lower static global into Alloca', [])
         add_pass('hlmatrixlower', 'HLMatrixLowerPass', 'HLSL High-Level Matrix Lower', [])
+        add_pass('matrixbitcastlower', 'MatrixBitcastLowerPass', 'Matrix Bitcast lower', [])
         add_pass('dce', 'DCE', 'Dead Code Elimination', [])
         add_pass('die', 'DeadInstElimination', 'Dead Instruction Elimination', [])
         add_pass('globaldce', 'GlobalDCE', 'Dead Global Elimination', [])