|
@@ -17,6 +17,7 @@
|
|
#include "dxc/HlslIntrinsicOp.h"
|
|
#include "dxc/HlslIntrinsicOp.h"
|
|
#include "dxc/Support/Global.h"
|
|
#include "dxc/Support/Global.h"
|
|
#include "dxc/HLSL/DxilOperations.h"
|
|
#include "dxc/HLSL/DxilOperations.h"
|
|
|
|
+#include "dxc/hlsl/DxilTypeSystem.h"
|
|
|
|
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Module.h"
|
|
@@ -48,6 +49,24 @@ bool IsMatrixType(Type *Ty) {
|
|
return false;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// If user is function call, return param annotation to get matrix major.
|
|
|
|
+DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
|
|
|
|
+ DxilTypeSystem &typeSys) {
|
|
|
|
+ for (User *U : Mat->users()) {
|
|
|
|
+ if (CallInst *CI = dyn_cast<CallInst>(U)) {
|
|
|
|
+ Function *F = CI->getCalledFunction();
|
|
|
|
+ if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
|
|
|
|
+ for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
|
|
|
|
+ if (CI->getArgOperand(i) == Mat) {
|
|
|
|
+ return &Anno->GetParameterAnnotation(i);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return nullptr;
|
|
|
|
+}
|
|
|
|
+
|
|
// Translate matrix type to vector type.
|
|
// Translate matrix type to vector type.
|
|
Type *LowerMatrixType(Type *Ty) {
|
|
Type *LowerMatrixType(Type *Ty) {
|
|
// Only translate matrix type and function type which use matrix type.
|
|
// Only translate matrix type and function type which use matrix type.
|
|
@@ -284,7 +303,7 @@ private:
|
|
// Lower users of a matrix type instruction.
|
|
// Lower users of a matrix type instruction.
|
|
void replaceMatWithVec(Value *matVal, Value *vecVal);
|
|
void replaceMatWithVec(Value *matVal, Value *vecVal);
|
|
// Translate user library function call arguments
|
|
// Translate user library function call arguments
|
|
- void castMatrixArgs(Instruction *I);
|
|
|
|
|
|
+ void castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI);
|
|
// Translate mat inst which need all operands ready.
|
|
// Translate mat inst which need all operands ready.
|
|
void finalMatTranslation(Value *matVal);
|
|
void finalMatTranslation(Value *matVal);
|
|
// Delete dead insts in m_deadInsts.
|
|
// Delete dead insts in m_deadInsts.
|
|
@@ -432,6 +451,21 @@ static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
|
|
return nullptr;
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// Return GEP if value is Matrix resulting GEP from UDT argument of
|
|
|
|
+// none-graphics functions.
|
|
|
|
+static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
|
|
|
|
+ if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
|
|
|
|
+ if (IsMatrixType(GEP->getResultElementType())) {
|
|
|
|
+ Value *ptr = GEP->getPointerOperand();
|
|
|
|
+ if (Argument *Arg = dyn_cast<Argument>(ptr)) {
|
|
|
|
+ if (!HM.IsGraphicsShader(Arg->getParent()))
|
|
|
|
+ return GEP;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return nullptr;
|
|
|
|
+}
|
|
|
|
+
|
|
Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
|
|
Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
|
|
IRBuilder<> Builder(CI);
|
|
IRBuilder<> Builder(CI);
|
|
unsigned opcode = GetHLOpcode(CI);
|
|
unsigned opcode = GetHLOpcode(CI);
|
|
@@ -876,7 +910,8 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
|
|
if (HLModule::HasPreciseAttributeWithMetadata(AI))
|
|
if (HLModule::HasPreciseAttributeWithMetadata(AI))
|
|
HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
|
|
HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
|
|
|
|
|
|
- } else if (GetIfMatrixGEPOfUDTAlloca(matInst)) {
|
|
|
|
|
|
+ } else if (GetIfMatrixGEPOfUDTAlloca(matInst) ||
|
|
|
|
+ GetIfMatrixGEPOfUDTArg(matInst, *m_pHLModule)) {
|
|
// If GEP from alloca of non-matrix UDT, bitcast
|
|
// If GEP from alloca of non-matrix UDT, bitcast
|
|
IRBuilder<> Builder(matInst->getNextNode());
|
|
IRBuilder<> Builder(matInst->getNextNode());
|
|
vecVal = Builder.CreateBitCast(matInst,
|
|
vecVal = Builder.CreateBitCast(matInst,
|
|
@@ -2157,6 +2192,8 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
case HLOpcodeGroup::HLSubscript: {
|
|
case HLOpcodeGroup::HLSubscript: {
|
|
if (AllocaInst *AI = dyn_cast<AllocaInst>(vecVal))
|
|
if (AllocaInst *AI = dyn_cast<AllocaInst>(vecVal))
|
|
TranslateMatSubscript(matVal, vecVal, useCall);
|
|
TranslateMatSubscript(matVal, vecVal, useCall);
|
|
|
|
+ else if (BitCastInst *BCI = dyn_cast<BitCastInst>(vecVal))
|
|
|
|
+ TranslateMatSubscript(matVal, vecVal, useCall);
|
|
else
|
|
else
|
|
TrivialMatReplace(matVal, vecVal, useCall);
|
|
TrivialMatReplace(matVal, vecVal, useCall);
|
|
|
|
|
|
@@ -2166,7 +2203,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
TranslateMatInit(useCall);
|
|
TranslateMatInit(useCall);
|
|
} break;
|
|
} break;
|
|
case HLOpcodeGroup::NotHL: {
|
|
case HLOpcodeGroup::NotHL: {
|
|
- castMatrixArgs(useCall);
|
|
|
|
|
|
+ castMatrixArgs(matVal, vecVal, useCall);
|
|
} break;
|
|
} break;
|
|
}
|
|
}
|
|
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
|
|
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
|
|
@@ -2187,19 +2224,16 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::castMatrixArgs(Instruction *I) {
|
|
|
|
|
|
+void HLMatrixLowerPass::castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI) {
|
|
// translate user function parameters as necessary
|
|
// translate user function parameters as necessary
|
|
- for (unsigned i = 0; i < I->getNumOperands(); i++) {
|
|
|
|
- Value *argVal = I->getOperand(i);
|
|
|
|
- Type *argTy = argVal->getType();
|
|
|
|
- if (argTy->isPointerTy())
|
|
|
|
- argTy = argTy->getPointerElementType();
|
|
|
|
- if (argTy->isStructTy() && IsMatrixType(argTy)) {
|
|
|
|
- Value *vecVal = matToVecMap[argVal];
|
|
|
|
- Value *newMatVal = GetMatrixForVec(vecVal, argVal->getType());
|
|
|
|
- if (argVal != newMatVal)
|
|
|
|
- I->setOperand(i, newMatVal);
|
|
|
|
- }
|
|
|
|
|
|
+ Type *Ty = matVal->getType();
|
|
|
|
+ if (Ty->isPointerTy()) {
|
|
|
|
+ IRBuilder<> Builder(CI);
|
|
|
|
+ Value *newMatVal = Builder.CreateBitCast(vecVal, Ty);
|
|
|
|
+ CI->replaceUsesOfWith(matVal, newMatVal);
|
|
|
|
+ } else {
|
|
|
|
+ Value *newMatVal = GetMatrixForVec(vecVal, Ty);
|
|
|
|
+ CI->replaceUsesOfWith(matVal, newMatVal);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -2539,3 +2573,39 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
|
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+// Matrix Bitcast lower.
|
|
|
|
+// After linking Lower matrix bitcast patterns like:
|
|
|
|
+// %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
|
|
|
|
+// %conv.i = fptoui float %164 to i32
|
|
|
|
+// %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
|
|
|
|
+// %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
|
|
|
|
+
|
|
|
|
+namespace {
|
|
|
|
+class MatrixBitcastLowerPass : public FunctionPass {
|
|
|
|
+
|
|
|
|
+public:
|
|
|
|
+ static char ID; // Pass identification, replacement for typeid
|
|
|
|
+ explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
|
|
|
|
+
|
|
|
|
+ const char *getPassName() const override { return "Matrix Bitcast lower"; }
|
|
|
|
+ bool runOnFunction(Function &F) override {
|
|
|
|
+ // TODO: remove bitcast on matrix.
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+private:
|
|
|
|
+ void lowerMatrixBitcast(BitCastInst *BCI);
|
|
|
|
+};
|
|
|
|
+
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+void MatrixBitcastLowerPass::lowerMatrixBitcast(BitCastInst *BCI) {
|
|
|
|
+ // to matrix.
|
|
|
|
+ // from matrix.
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+#include "dxc/HLSL/DxilGenerationPass.h"
|
|
|
|
+char MatrixBitcastLowerPass::ID = 0;
|
|
|
|
+FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
|
|
|
|
+
|
|
|
|
+INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)
|