|
@@ -9,8 +9,9 @@
|
|
|
// //
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
-#include "dxc/HLSL/HLMatrixLowerHelper.h"
|
|
|
#include "dxc/HLSL/HLMatrixLowerPass.h"
|
|
|
+#include "dxc/HLSL/HLMatrixLowerHelper.h"
|
|
|
+#include "dxc/HLSL/HLMatrixType.h"
|
|
|
#include "dxc/HLSL/HLOperations.h"
|
|
|
#include "dxc/HLSL/HLModule.h"
|
|
|
#include "dxc/DXIL/DxilUtil.h"
|
|
@@ -19,6 +20,7 @@
|
|
|
#include "dxc/DXIL/DxilOperations.h"
|
|
|
#include "dxc/DXIL/DxilTypeSystem.h"
|
|
|
#include "dxc/DXIL/DxilModule.h"
|
|
|
+#include "HLMatrixSubscriptUseReplacer.h"
|
|
|
|
|
|
#include "llvm/IR/IRBuilder.h"
|
|
|
#include "llvm/IR/Module.h"
|
|
@@ -36,356 +38,146 @@ using namespace hlsl::HLMatrixLower;
|
|
|
namespace hlsl {
|
|
|
namespace HLMatrixLower {
|
|
|
|
|
|
-// If user is function call, return param annotation to get matrix major.
|
|
|
-DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
|
|
|
- DxilTypeSystem &typeSys) {
|
|
|
- for (User *U : Mat->users()) {
|
|
|
- if (CallInst *CI = dyn_cast<CallInst>(U)) {
|
|
|
- Function *F = CI->getCalledFunction();
|
|
|
- if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
|
|
|
- for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
|
|
|
- if (CI->getArgOperand(i) == Mat) {
|
|
|
- return &Anno->GetParameterAnnotation(i);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return nullptr;
|
|
|
-}
|
|
|
-
|
|
|
-// Translate matrix type to vector type.
|
|
|
-Type *LowerMatrixType(Type *Ty, bool forMem) {
|
|
|
- // Only translate matrix type and function type which use matrix type.
|
|
|
- // Not translate struct has matrix or matrix pointer.
|
|
|
- // Struct should be flattened before.
|
|
|
- // Pointer could cover by matldst which use vector as value type.
|
|
|
- if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
|
|
|
- Type *RetTy = LowerMatrixType(FT->getReturnType());
|
|
|
- SmallVector<Type *, 4> params;
|
|
|
- for (Type *param : FT->params()) {
|
|
|
- params.emplace_back(LowerMatrixType(param));
|
|
|
- }
|
|
|
- return FunctionType::get(RetTy, params, false);
|
|
|
- } else if (dxilutil::IsHLSLMatrixType(Ty)) {
|
|
|
- unsigned row, col;
|
|
|
- Type *EltTy = GetMatrixInfo(Ty, col, row);
|
|
|
- if (forMem && EltTy->isIntegerTy(1))
|
|
|
- EltTy = Type::getInt32Ty(Ty->getContext());
|
|
|
- return VectorType::get(EltTy, row * col);
|
|
|
- } else {
|
|
|
- return Ty;
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-// Translate matrix type to array type.
|
|
|
-Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
|
|
|
- if (dxilutil::IsHLSLMatrixType(Ty)) {
|
|
|
- unsigned row, col;
|
|
|
- Type *EltTy = GetMatrixInfo(Ty, col, row);
|
|
|
- return ArrayType::get(EltTy, row * col);
|
|
|
- } else {
|
|
|
- return Ty;
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-
|
|
|
-Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
|
|
|
- DXASSERT(dxilutil::IsHLSLMatrixType(Ty), "not matrix type");
|
|
|
- StructType *ST = cast<StructType>(Ty);
|
|
|
- Type *EltTy = ST->getElementType(0);
|
|
|
- Type *RowTy = EltTy->getArrayElementType();
|
|
|
- row = EltTy->getArrayNumElements();
|
|
|
- col = RowTy->getVectorNumElements();
|
|
|
- return RowTy->getVectorElementType();
|
|
|
-}
|
|
|
-
|
|
|
-bool IsMatrixArrayPointer(llvm::Type *Ty) {
|
|
|
- if (!Ty->isPointerTy())
|
|
|
- return false;
|
|
|
- Ty = Ty->getPointerElementType();
|
|
|
- if (!Ty->isArrayTy())
|
|
|
- return false;
|
|
|
- while (Ty->isArrayTy())
|
|
|
- Ty = Ty->getArrayElementType();
|
|
|
- return dxilutil::IsHLSLMatrixType(Ty);
|
|
|
-}
|
|
|
-Type *LowerMatrixArrayPointer(Type *Ty, bool forMem) {
|
|
|
- unsigned addrSpace = Ty->getPointerAddressSpace();
|
|
|
- Ty = Ty->getPointerElementType();
|
|
|
- std::vector<unsigned> arraySizeList;
|
|
|
- while (Ty->isArrayTy()) {
|
|
|
- arraySizeList.push_back(Ty->getArrayNumElements());
|
|
|
- Ty = Ty->getArrayElementType();
|
|
|
- }
|
|
|
- Ty = LowerMatrixType(Ty, forMem);
|
|
|
-
|
|
|
- for (auto arraySize = arraySizeList.rbegin();
|
|
|
- arraySize != arraySizeList.rend(); arraySize++)
|
|
|
- Ty = ArrayType::get(Ty, *arraySize);
|
|
|
- return PointerType::get(Ty, addrSpace);
|
|
|
-}
|
|
|
-
|
|
|
-Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
|
|
|
- unsigned addrSpace = Ty->getPointerAddressSpace();
|
|
|
- Ty = Ty->getPointerElementType();
|
|
|
-
|
|
|
- unsigned arraySize = 1;
|
|
|
- while (Ty->isArrayTy()) {
|
|
|
- arraySize *= Ty->getArrayNumElements();
|
|
|
- Ty = Ty->getArrayElementType();
|
|
|
- }
|
|
|
- unsigned row, col;
|
|
|
- Type *EltTy = GetMatrixInfo(Ty, col, row);
|
|
|
- arraySize *= row*col;
|
|
|
-
|
|
|
- Ty = ArrayType::get(EltTy, arraySize);
|
|
|
- return PointerType::get(Ty, addrSpace);
|
|
|
-}
|
|
|
-Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
|
|
|
- IRBuilder<> &Builder) {
|
|
|
- Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
|
|
|
- for (unsigned i = 0; i < size; i++)
|
|
|
+Value *BuildVector(Type *EltTy, ArrayRef<llvm::Value *> elts, IRBuilder<> &Builder) {
|
|
|
+ Value *Vec = UndefValue::get(VectorType::get(EltTy, static_cast<unsigned>(elts.size())));
|
|
|
+ for (unsigned i = 0; i < elts.size(); i++)
|
|
|
Vec = Builder.CreateInsertElement(Vec, elts[i], i);
|
|
|
return Vec;
|
|
|
}
|
|
|
|
|
|
-llvm::Value *VecMatrixMemToReg(llvm::Value *VecVal, llvm::Type *MatType,
|
|
|
- llvm::IRBuilder<> &Builder)
|
|
|
-{
|
|
|
- llvm::Type *VecMatRegTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/false);
|
|
|
- if (VecVal->getType() == VecMatRegTy) {
|
|
|
- return VecVal;
|
|
|
- }
|
|
|
-
|
|
|
- DXASSERT(VecMatRegTy->getVectorElementType()->isIntegerTy(1),
|
|
|
- "Vector matrix mem to reg type mismatch should only happen for bools.");
|
|
|
- llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
|
|
|
- return Builder.CreateICmpNE(VecVal, Constant::getNullValue(VecMatMemTy));
|
|
|
-}
|
|
|
-
|
|
|
-llvm::Value *VecMatrixRegToMem(llvm::Value* VecVal, llvm::Type *MatType,
|
|
|
- llvm::IRBuilder<> &Builder)
|
|
|
-{
|
|
|
- llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
|
|
|
- if (VecVal->getType() == VecMatMemTy) {
|
|
|
- return VecVal;
|
|
|
- }
|
|
|
-
|
|
|
- DXASSERT(VecVal->getType()->getVectorElementType()->isIntegerTy(1),
|
|
|
- "Vector matrix reg to mem type mismatch should only happen for bools.");
|
|
|
- return Builder.CreateZExt(VecVal, VecMatMemTy);
|
|
|
-}
|
|
|
-
|
|
|
-llvm::Instruction *CreateVecMatrixLoad(
|
|
|
- llvm::Value *VecPtr, llvm::Type *MatType, llvm::IRBuilder<> &Builder)
|
|
|
-{
|
|
|
- llvm::Instruction *VecVal = Builder.CreateLoad(VecPtr);
|
|
|
- return cast<llvm::Instruction>(VecMatrixMemToReg(VecVal, MatType, Builder));
|
|
|
-}
|
|
|
-
|
|
|
-llvm::Instruction *CreateVecMatrixStore(llvm::Value* VecVal, llvm::Value *VecPtr,
|
|
|
- llvm::Type *MatType, llvm::IRBuilder<> &Builder)
|
|
|
-{
|
|
|
- llvm::Type *VecMatMemTy = HLMatrixLower::LowerMatrixType(MatType, /*forMem*/true);
|
|
|
- if (VecVal->getType() == VecMatMemTy) {
|
|
|
- return Builder.CreateStore(VecVal, VecPtr);
|
|
|
- }
|
|
|
-
|
|
|
- // We need to convert to the memory representation, and we want to return
|
|
|
- // the conversion instruction rather than the store since that's what
|
|
|
- // accepts the register-typed i1 values.
|
|
|
-
|
|
|
- // Do not use VecMatrixRegToMem as it may constant fold the conversion
|
|
|
- // instruction, which is what we want to return.
|
|
|
- DXASSERT(VecVal->getType()->getVectorElementType()->isIntegerTy(1),
|
|
|
- "Vector matrix reg to mem type mismatch should only happen for bools.");
|
|
|
-
|
|
|
- llvm::Instruction *ConvInst = Builder.Insert(new ZExtInst(VecVal, VecMatMemTy));
|
|
|
- Builder.CreateStore(ConvInst, VecPtr);
|
|
|
- return ConvInst;
|
|
|
-}
|
|
|
-
|
|
|
-Value *LowerGEPOnMatIndexListToIndex(
|
|
|
- llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
|
|
|
- IRBuilder<> Builder(GEP);
|
|
|
- Value *zero = Builder.getInt32(0);
|
|
|
- DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
|
|
|
- Value *baseIdx = (GEP->idx_begin())->get();
|
|
|
- DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
|
|
|
- Value *Idx = (GEP->idx_begin() + 1)->get();
|
|
|
-
|
|
|
- if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
|
|
|
- return IdxList[immIdx->getSExtValue()];
|
|
|
- } else {
|
|
|
- IRBuilder<> AllocaBuilder(
|
|
|
- GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
|
|
|
- unsigned size = IdxList.size();
|
|
|
- // Store idxList to temp array.
|
|
|
- ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
|
|
|
- Value *tempArray = AllocaBuilder.CreateAlloca(AT);
|
|
|
-
|
|
|
- for (unsigned i = 0; i < size; i++) {
|
|
|
- Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
|
|
|
- Builder.CreateStore(IdxList[i], EltPtr);
|
|
|
- }
|
|
|
- // Load the idx.
|
|
|
- Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
|
|
|
- return Builder.CreateLoad(GEPOffset);
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-
|
|
|
-unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
|
|
|
- return c * row + r;
|
|
|
-}
|
|
|
-unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
|
|
|
- return r * col + c;
|
|
|
-}
|
|
|
-
|
|
|
} // namespace HLMatrixLower
|
|
|
} // namespace hlsl
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
-class HLMatrixLowerPass : public ModulePass {
|
|
|
+// Creates and manages a set of temporary overloaded functions keyed on the function type,
|
|
|
+// and which should be destroyed when the pool gets out of scope.
|
|
|
+class TempOverloadPool {
|
|
|
+public:
|
|
|
+ TempOverloadPool(llvm::Module &Module, const char* BaseName)
|
|
|
+ : Module(Module), BaseName(BaseName) {}
|
|
|
+ ~TempOverloadPool() { clear(); }
|
|
|
+
|
|
|
+ Function *get(FunctionType *Ty);
|
|
|
+ bool contains(FunctionType *Ty) const { return Funcs.count(Ty) != 0; }
|
|
|
+ bool contains(Function *Func) const;
|
|
|
+ void clear();
|
|
|
|
|
|
+private:
|
|
|
+ llvm::Module &Module;
|
|
|
+ const char* BaseName;
|
|
|
+ llvm::DenseMap<FunctionType*, Function*> Funcs;
|
|
|
+};
|
|
|
+
|
|
|
+Function *TempOverloadPool::get(FunctionType *Ty) {
|
|
|
+ auto It = Funcs.find(Ty);
|
|
|
+ if (It != Funcs.end()) return It->second;
|
|
|
+
|
|
|
+ std::string MangledName;
|
|
|
+ raw_string_ostream MangledNameStream(MangledName);
|
|
|
+ MangledNameStream << BaseName;
|
|
|
+ MangledNameStream << '.';
|
|
|
+ Ty->print(MangledNameStream);
|
|
|
+ MangledNameStream.flush();
|
|
|
+
|
|
|
+ Function* Func = cast<Function>(Module.getOrInsertFunction(MangledName, Ty));
|
|
|
+ Funcs.insert(std::make_pair(Ty, Func));
|
|
|
+ return Func;
|
|
|
+}
|
|
|
+
|
|
|
+bool TempOverloadPool::contains(Function *Func) const {
|
|
|
+ auto It = Funcs.find(Func->getFunctionType());
|
|
|
+ return It != Funcs.end() && It->second == Func;
|
|
|
+}
|
|
|
+
|
|
|
+void TempOverloadPool::clear() {
|
|
|
+ for (auto Entry : Funcs) {
|
|
|
+ DXASSERT(Entry.second->use_empty(), "Temporary function still used during pool destruction.");
|
|
|
+ Entry.second->removeFromParent();
|
|
|
+ }
|
|
|
+ Funcs.clear();
|
|
|
+}
|
|
|
+
|
|
|
+// High-level matrix lowering pass.
|
|
|
+//
|
|
|
+// This pass converts matrices to their lowered vector representations,
|
|
|
+// including global variables, local variables and operations,
|
|
|
+// but not function signatures (arguments and return types) - left to HLSignatureLower and HLMatrixBitcastLower,
|
|
|
+// nor matrices obtained from resources or constant - left to HLOperationLower.
|
|
|
+//
|
|
|
+// Algorithm overview:
|
|
|
+// 1. Find all matrix and matrix array global variables and lower them to vectors.
|
|
|
+// Walk any GEPs and insert vec-to-mat translation stubs so that consuming
|
|
|
+// instructions keep dealing with matrix types for the moment.
|
|
|
+// 2. For each function
|
|
|
+// 2a. Lower all matrix and matrix array allocas, just like global variables.
|
|
|
+// 2b. Lower all other instructions producing or consuming matrices
|
|
|
+//
|
|
|
+// Conversion stubs are used to allow converting instructions in isolation,
|
|
|
+// and in an order-independent manner:
|
|
|
+//
|
|
|
+// Initial: MatInst1(MatInst2(MatInst3))
|
|
|
+// After lowering MatInst2: MatInst1(VecToMat(VecInst2(MatToVec(MatInst3))))
|
|
|
+// After lowering MatInst1: VecInst1(VecInst2(MatToVec(MatInst3)))
|
|
|
+// After lowering MatInst3: VecInst1(VecInst2(VecInst3))
|
|
|
+class HLMatrixLowerPass : public ModulePass {
|
|
|
public:
|
|
|
static char ID; // Pass identification, replacement for typeid
|
|
|
explicit HLMatrixLowerPass() : ModulePass(ID) {}
|
|
|
|
|
|
const char *getPassName() const override { return "HL matrix lower"; }
|
|
|
+ bool runOnModule(Module &M) override;
|
|
|
|
|
|
- bool runOnModule(Module &M) override {
|
|
|
- m_pModule = &M;
|
|
|
- m_pHLModule = &m_pModule->GetOrCreateHLModule();
|
|
|
- // Load up debug information, to cross-reference values and the instructions
|
|
|
- // used to load them.
|
|
|
- m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
|
|
|
-
|
|
|
- for (Function &F : M.functions()) {
|
|
|
-
|
|
|
- if (F.isDeclaration())
|
|
|
- continue;
|
|
|
- runOnFunction(F);
|
|
|
- }
|
|
|
- std::vector<GlobalVariable*> staticGVs;
|
|
|
- for (GlobalVariable &GV : M.globals()) {
|
|
|
- if (dxilutil::IsStaticGlobal(&GV) ||
|
|
|
- dxilutil::IsSharedMemoryGlobal(&GV)) {
|
|
|
- staticGVs.emplace_back(&GV);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- for (GlobalVariable *GV : staticGVs)
|
|
|
- runOnGlobal(GV);
|
|
|
-
|
|
|
- return true;
|
|
|
- }
|
|
|
+private:
|
|
|
+ void runOnFunction(Function &Func);
|
|
|
+ void addToDeadInsts(Instruction *Inst) { m_deadInsts.emplace_back(Inst); }
|
|
|
+ void deleteDeadInsts();
|
|
|
+
|
|
|
+ void getMatrixAllocasAndOtherInsts(Function &Func,
|
|
|
+ std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts);
|
|
|
+ Value *getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub = false);
|
|
|
+ Value *tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub = false);
|
|
|
+ Value *bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder);
|
|
|
+ void replaceAllUsesByLoweredValue(Instruction *MatInst, Value *VecVal);
|
|
|
+ void replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr);
|
|
|
+ void replaceAllVariableUses(SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr);
|
|
|
+
|
|
|
+ void lowerGlobal(GlobalVariable *Global);
|
|
|
+ Constant *lowerConstInitVal(Constant *Val);
|
|
|
+ AllocaInst *lowerAlloca(AllocaInst *MatAlloca);
|
|
|
+ void lowerInstruction(Instruction* Inst);
|
|
|
+ void lowerReturn(ReturnInst* Return);
|
|
|
+ Value *lowerCall(CallInst *Call);
|
|
|
+ Value *lowerNonHLCall(CallInst *Call);
|
|
|
+ Value *lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup);
|
|
|
+ Value *lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode);
|
|
|
+ Value *lowerHLMulIntrinsic(Value* Lhs, Value *Rhs, bool Unsigned, IRBuilder<> &Builder);
|
|
|
+ Value *lowerHLTransposeIntrinsic(Value *MatVal, IRBuilder<> &Builder);
|
|
|
+ Value *lowerHLDeterminantIntrinsic(Value *MatVal, IRBuilder<> &Builder);
|
|
|
+ Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder);
|
|
|
+ Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder);
|
|
|
+ Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
|
|
|
+ Value *lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
|
|
|
+ Value *lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
|
|
|
+ Value *lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
|
|
|
+ Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
|
|
|
+ Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
|
|
|
+ Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
|
|
|
+ void lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices);
|
|
|
+ Value *lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
|
|
|
+ Value *lowerHLInit(CallInst *Call);
|
|
|
+ Value *lowerHLSelect(CallInst *Call);
|
|
|
|
|
|
private:
|
|
|
Module *m_pModule;
|
|
|
HLModule *m_pHLModule;
|
|
|
bool m_HasDbgInfo;
|
|
|
+
|
|
|
+ // Pools for the translation stubs
|
|
|
+ TempOverloadPool *m_matToVecStubs = nullptr;
|
|
|
+ TempOverloadPool *m_vecToMatStubs = nullptr;
|
|
|
+
|
|
|
std::vector<Instruction *> m_deadInsts;
|
|
|
- // For instruction like matrix array init.
|
|
|
- // May use more than 1 matrix alloca inst.
|
|
|
- // This set is here to avoid put it into deadInsts more than once.
|
|
|
- std::unordered_set<Instruction *> m_inDeadInstsSet;
|
|
|
- // For most matrix insturction users, it will only have one matrix use.
|
|
|
- // Use vector so save deadInsts because vector is cheap.
|
|
|
- void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
|
|
|
- // In case instruction has more than one matrix use.
|
|
|
- // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
|
|
|
- void AddToDeadInstsWithDups(Instruction *I) {
|
|
|
- if (m_inDeadInstsSet.count(I) == 0) {
|
|
|
- // Only add to deadInsts when it's not inside m_inDeadInstsSet.
|
|
|
- m_inDeadInstsSet.insert(I);
|
|
|
- AddToDeadInsts(I);
|
|
|
- }
|
|
|
- }
|
|
|
- void runOnFunction(Function &F);
|
|
|
- void runOnGlobal(GlobalVariable *GV);
|
|
|
- void runOnGlobalMatrixArray(GlobalVariable *GV);
|
|
|
- Instruction *MatCastToVec(CallInst *CI);
|
|
|
- Instruction *MatLdStToVec(CallInst *CI);
|
|
|
- Instruction *MatSubscriptToVec(CallInst *CI);
|
|
|
- Instruction *MatFrExpToVec(CallInst *CI);
|
|
|
- Instruction *MatIntrinsicToVec(CallInst *CI);
|
|
|
- Instruction *TrivialMatUnOpToVec(CallInst *CI);
|
|
|
- // Replace matVal with vecVal on matUseInst.
|
|
|
- void TrivialMatUnOpReplace(Value *matVal, Value *vecVal,
|
|
|
- CallInst *matUseInst);
|
|
|
- Instruction *TrivialMatBinOpToVec(CallInst *CI);
|
|
|
- // Replace matVal with vecVal on matUseInst.
|
|
|
- void TrivialMatBinOpReplace(Value *matVal, Value *vecVal,
|
|
|
- CallInst *matUseInst);
|
|
|
- // Replace matVal with vecVal on mulInst.
|
|
|
- void TranslateMatMatMul(Value *matVal, Value *vecVal,
|
|
|
- CallInst *mulInst, bool isSigned);
|
|
|
- void TranslateMatVecMul(Value *matVal, Value *vecVal,
|
|
|
- CallInst *mulInst, bool isSigned);
|
|
|
- void TranslateVecMatMul(Value *matVal, Value *vecVal,
|
|
|
- CallInst *mulInst, bool isSigned);
|
|
|
- void TranslateMul(Value *matVal, Value *vecVal, CallInst *mulInst,
|
|
|
- bool isSigned);
|
|
|
- // Replace matVal with vecVal on transposeInst.
|
|
|
- void TranslateMatTranspose(Value *matVal, Value *vecVal,
|
|
|
- CallInst *transposeInst);
|
|
|
- void TranslateMatDeterminant(Value *matVal, Value *vecVal,
|
|
|
- CallInst *determinantInst);
|
|
|
- void MatIntrinsicReplace(Value *matVal, Value *vecVal,
|
|
|
- CallInst *matUseInst);
|
|
|
- // Replace matVal with vecVal on castInst.
|
|
|
- void TranslateMatMatCast(Value *matVal, Value *vecVal,
|
|
|
- CallInst *castInst);
|
|
|
- void TranslateMatToOtherCast(Value *matVal, Value *vecVal,
|
|
|
- CallInst *castInst);
|
|
|
- void TranslateMatCast(Value *matVal, Value *vecVal,
|
|
|
- CallInst *castInst);
|
|
|
- void TranslateMatMajorCast(Value *matVal, Value *vecVal,
|
|
|
- CallInst *castInst, bool rowToCol, bool transpose);
|
|
|
- // Replace matVal with vecVal in matSubscript
|
|
|
- void TranslateMatSubscript(Value *matVal, Value *vecVal,
|
|
|
- CallInst *matSubInst);
|
|
|
- // Replace matInitInst using matToVecMap
|
|
|
- void TranslateMatInit(CallInst *matInitInst);
|
|
|
- // Replace matSelectInst using matToVecMap
|
|
|
- void TranslateMatSelect(CallInst *matSelectInst);
|
|
|
- // Replace matVal with vecVal on matInitInst.
|
|
|
- void TranslateMatArrayGEP(Value *matVal, Value *vecVal,
|
|
|
- GetElementPtrInst *matGEP);
|
|
|
- void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
|
|
|
- CallInst *matLdStInst);
|
|
|
- void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
|
|
|
- CallInst *matLdStInst);
|
|
|
- void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
|
|
|
- void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
|
|
|
-
|
|
|
- // Get new matrix value corresponding to vecVal
|
|
|
- Value *GetMatrixForVec(Value *vecVal, Type *matTy);
|
|
|
-
|
|
|
- // Translate library function input/output to preserve function signatures
|
|
|
- void TranslateArgForLibFunc(CallInst *CI);
|
|
|
- void TranslateArgsForLibFunc(Function &F);
|
|
|
-
|
|
|
- // Replace matVal with vecVal on matUseInst.
|
|
|
- void TrivialMatReplace(Value *matVal, Value *vecVal,
|
|
|
- CallInst *matUseInst);
|
|
|
- // Lower a matrix type instruction to a vector type instruction.
|
|
|
- void lowerToVec(Instruction *matInst);
|
|
|
- // Lower users of a matrix type instruction.
|
|
|
- void replaceMatWithVec(Value *matVal, Value *vecVal);
|
|
|
- // Translate user library function call arguments
|
|
|
- void castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI);
|
|
|
- // Translate mat inst which need all operands ready.
|
|
|
- void finalMatTranslation(Value *matVal);
|
|
|
- // Delete dead insts in m_deadInsts.
|
|
|
- void DeleteDeadInsts();
|
|
|
- // Map from matrix value to its vector version.
|
|
|
- DenseMap<Value *, Value *> matToVecMap;
|
|
|
- // Map from new vector version to matrix version needed by user call or return.
|
|
|
- DenseMap<Value *, Value *> vecToMatMap;
|
|
|
};
|
|
|
}
|
|
|
|
|
@@ -395,2445 +187,1349 @@ ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
|
|
|
|
|
|
INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
|
|
|
|
|
|
-static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
|
|
|
- IRBuilder<> Builder) {
|
|
|
- Type *srcTy = src->getType();
|
|
|
-
|
|
|
- // Conversions between equivalent types are no-ops,
|
|
|
- // even between signed/unsigned variants.
|
|
|
- if (srcTy == toTy) return cast<Instruction>(src);
|
|
|
+bool HLMatrixLowerPass::runOnModule(Module &M) {
|
|
|
+ TempOverloadPool matToVecStubs(M, "hlmatrixlower.mat2vec");
|
|
|
+ TempOverloadPool vecToMatStubs(M, "hlmatrixlower.vec2mat");
|
|
|
+
|
|
|
+ m_pModule = &M;
|
|
|
+ m_pHLModule = &m_pModule->GetOrCreateHLModule();
|
|
|
+ // Load up debug information, to cross-reference values and the instructions
|
|
|
+ // used to load them.
|
|
|
+ m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
|
|
|
+ m_matToVecStubs = &matToVecStubs;
|
|
|
+ m_vecToMatStubs = &vecToMatStubs;
|
|
|
+
|
|
|
+ // First, lower static global variables.
|
|
|
+ // We need to accumulate them locally because we'll be creating new ones as we lower them.
|
|
|
+ std::vector<GlobalVariable*> Globals;
|
|
|
+ for (GlobalVariable &Global : M.globals()) {
|
|
|
+ if ((dxilutil::IsStaticGlobal(&Global) || dxilutil::IsSharedMemoryGlobal(&Global))
|
|
|
+ && HLMatrixType::isMatrixPtrOrArrayPtr(Global.getType())) {
|
|
|
+ Globals.emplace_back(&Global);
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
|
|
|
- castOp == HLCastOpcode::UnsignedUnsignedCast;
|
|
|
- bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
|
|
|
- castOp == HLCastOpcode::UnsignedUnsignedCast;
|
|
|
+ for (GlobalVariable *Global : Globals)
|
|
|
+ lowerGlobal(Global);
|
|
|
|
|
|
- // Conversions to bools are comparisons
|
|
|
- if (toTy->getScalarSizeInBits() == 1) {
|
|
|
- // fcmp une is what regular clang uses in C++ for (bool)f;
|
|
|
- return cast<Instruction>(srcTy->isIntOrIntVectorTy()
|
|
|
- ? Builder.CreateICmpNE(src, llvm::Constant::getNullValue(srcTy), "tobool")
|
|
|
- : Builder.CreateFCmpUNE(src, llvm::Constant::getNullValue(srcTy), "tobool"));
|
|
|
+ for (Function &F : M.functions()) {
|
|
|
+ if (F.isDeclaration()) continue;
|
|
|
+ runOnFunction(F);
|
|
|
}
|
|
|
|
|
|
- // Cast necessary
|
|
|
- auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
|
|
|
- srcTy, fromUnsigned, toTy, toUnsigned));
|
|
|
- return cast<Instruction>(Builder.CreateCast(CastOp, src, toTy));
|
|
|
+ m_pModule = nullptr;
|
|
|
+ m_pHLModule = nullptr;
|
|
|
+ m_matToVecStubs = nullptr;
|
|
|
+ m_vecToMatStubs = nullptr;
|
|
|
+
|
|
|
+ // If you hit an assert during TempOverloadPool destruction,
|
|
|
+ // it means that either a matrix producer was lowered,
|
|
|
+ // causing a translation stub to be created,
|
|
|
+ // but the consumer of that matrix was never (properly) lowered.
|
|
|
+ // Or the opposite: a matrix consumer was lowered and not its producer.
|
|
|
+
|
|
|
+ return true;
|
|
|
}
|
|
|
|
|
|
-Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
|
|
|
- HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
|
|
|
-
|
|
|
- bool ToMat = dxilutil::IsHLSLMatrixType(CI->getType());
|
|
|
- bool FromMat = dxilutil::IsHLSLMatrixType(op->getType());
|
|
|
- if (ToMat && !FromMat) {
|
|
|
- // Translate OtherToMat here.
|
|
|
- // Rest will translated when replace.
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
|
|
|
- unsigned toSize = col * row;
|
|
|
- Instruction *sizeCast = nullptr;
|
|
|
- Type *FromTy = op->getType();
|
|
|
- Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
|
|
|
- if (FromTy->isVectorTy()) {
|
|
|
- std::vector<Constant *> MaskVec(toSize);
|
|
|
- for (size_t i = 0; i != toSize; ++i)
|
|
|
- MaskVec[i] = ConstantInt::get(I32Ty, i);
|
|
|
-
|
|
|
- Value *castMask = ConstantVector::get(MaskVec);
|
|
|
-
|
|
|
- sizeCast = new ShuffleVectorInst(op, op, castMask);
|
|
|
- Builder.Insert(sizeCast);
|
|
|
-
|
|
|
- } else {
|
|
|
- op = Builder.CreateInsertElement(
|
|
|
- UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
|
|
|
- Constant *zero = ConstantInt::get(I32Ty, 0);
|
|
|
- std::vector<Constant *> MaskVec(toSize, zero);
|
|
|
- Value *castMask = ConstantVector::get(MaskVec);
|
|
|
-
|
|
|
- sizeCast = new ShuffleVectorInst(op, op, castMask);
|
|
|
- Builder.Insert(sizeCast);
|
|
|
- }
|
|
|
- Instruction *typeCast = sizeCast;
|
|
|
- if (EltTy != FromTy->getScalarType()) {
|
|
|
- typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
|
|
|
- sizeCast, Builder);
|
|
|
- }
|
|
|
- return typeCast;
|
|
|
- } else if (FromMat && ToMat) {
|
|
|
- if (isa<Argument>(op)) {
|
|
|
- // Cast From mat to mat for arugment.
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
-
|
|
|
- // Here only lower the return type to vector.
|
|
|
- Type *RetTy = LowerMatrixType(CI->getType());
|
|
|
- SmallVector<Type *, 4> params;
|
|
|
- for (Value *operand : CI->arg_operands()) {
|
|
|
- params.emplace_back(operand->getType());
|
|
|
+void HLMatrixLowerPass::runOnFunction(Function &Func) {
|
|
|
+ // Skip hl function definition (like createhandle)
|
|
|
+ if (hlsl::GetHLOpcodeGroupByName(&Func) != HLOpcodeGroup::NotHL)
|
|
|
+ return;
|
|
|
+
|
|
|
+ // Save the matrix instructions first since the translation process
|
|
|
+ // will temporarily create other instructions consuming/producing matrix types.
|
|
|
+ std::vector<AllocaInst*> MatAllocas;
|
|
|
+ std::vector<Instruction*> MatInsts;
|
|
|
+ getMatrixAllocasAndOtherInsts(Func, MatAllocas, MatInsts);
|
|
|
+
|
|
|
+ // First lower all allocas and take care of their GEP chains
|
|
|
+ for (AllocaInst* MatAlloca : MatAllocas) {
|
|
|
+ AllocaInst* LoweredAlloca = lowerAlloca(MatAlloca);
|
|
|
+ replaceAllVariableUses(MatAlloca, LoweredAlloca);
|
|
|
+ addToDeadInsts(MatAlloca);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now lower all other matrix instructions
|
|
|
+ for (Instruction *MatInst : MatInsts)
|
|
|
+ lowerInstruction(MatInst);
|
|
|
+
|
|
|
+ deleteDeadInsts();
|
|
|
+}
|
|
|
+
|
|
|
+void HLMatrixLowerPass::deleteDeadInsts() {
|
|
|
+ while (!m_deadInsts.empty()) {
|
|
|
+ Instruction *Inst = m_deadInsts.back();
|
|
|
+ m_deadInsts.pop_back();
|
|
|
+
|
|
|
+ DXASSERT_NOMSG(Inst->use_empty());
|
|
|
+ for (Value *Operand : Inst->operand_values()) {
|
|
|
+ Instruction *OperandInst = dyn_cast<Instruction>(Operand);
|
|
|
+ if (OperandInst && ++OperandInst->user_begin() == OperandInst->user_end()) {
|
|
|
+ // We were its only user, erase recursively.
|
|
|
+ // This will get rid of translation stubs:
|
|
|
+ // Original: MatConsumer(MatProducer)
|
|
|
+ // Producer lowered: MatConsumer(VecToMat(VecProducer)), MatProducer dead
|
|
|
+ // Consumer lowered: VecConsumer(VecProducer)), MatConsumer(VecToMat) dead
|
|
|
+ // Only by recursing on MatConsumer's operand do we delete the VecToMat stub.
|
|
|
+ DXASSERT_NOMSG(*OperandInst->user_begin() == Inst);
|
|
|
+ m_deadInsts.emplace_back(OperandInst);
|
|
|
}
|
|
|
+ }
|
|
|
+
|
|
|
+ Inst->eraseFromParent();
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- Type *FT = FunctionType::get(RetTy, params, false);
|
|
|
+// Find all instructions consuming or producing matrices,
|
|
|
+// directly or through pointers/arrays.
|
|
|
+void HLMatrixLowerPass::getMatrixAllocasAndOtherInsts(Function &Func,
|
|
|
+ std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts){
|
|
|
+ for (BasicBlock &BasicBlock : Func) {
|
|
|
+ for (Instruction &Inst : BasicBlock) {
|
|
|
+ // Don't lower GEPs directly, we'll handle them as we lower the root pointer,
|
|
|
+ // typically a global variable or alloca.
|
|
|
+ if (isa<GetElementPtrInst>(&Inst)) continue;
|
|
|
|
|
|
- HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- unsigned opcode = GetHLOpcode(CI);
|
|
|
+ if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&Inst)) {
|
|
|
+ if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Alloca->getType())) {
|
|
|
+ MatAllocas.emplace_back(Alloca);
|
|
|
+ }
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (CallInst *Call = dyn_cast<CallInst>(&Inst)) {
|
|
|
+ // Lowering of global variables will have introduced
|
|
|
+ // vec-to-mat translation stubs, which we deal with indirectly,
|
|
|
+ // as we lower the instructions consuming them.
|
|
|
+ if (m_vecToMatStubs->contains(Call->getCalledFunction()))
|
|
|
+ continue;
|
|
|
+
|
|
|
+ // Mat-to-vec stubs should only be introduced during instruction lowering.
|
|
|
+ // Globals lowering won't introduce any because their only operand is
|
|
|
+ // their initializer, which we can fully lower without stubbing since it is constant.
|
|
|
+ DXASSERT(!m_matToVecStubs->contains(Call->getCalledFunction()),
|
|
|
+ "Unexpected mat-to-vec stubbing before function instruction lowering.");
|
|
|
+
|
|
|
+ // Match matrix producers
|
|
|
+ if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Inst.getType())) {
|
|
|
+ MatInsts.emplace_back(Call);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
|
|
|
- group, opcode);
|
|
|
+ // Match matrix consumers
|
|
|
+ for (Value *Operand : Inst.operand_values()) {
|
|
|
+ if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Operand->getType())) {
|
|
|
+ MatInsts.emplace_back(Call);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- SmallVector<Value *, 4> argList;
|
|
|
- for (Value *arg : CI->arg_operands()) {
|
|
|
- argList.emplace_back(arg);
|
|
|
+ if (ReturnInst *Return = dyn_cast<ReturnInst>(&Inst)) {
|
|
|
+ Value *ReturnValue = Return->getReturnValue();
|
|
|
+ if (ReturnValue != nullptr && HLMatrixType::isMatrixOrPtrOrArrayPtr(ReturnValue->getType()))
|
|
|
+ MatInsts.emplace_back(Return);
|
|
|
+ continue;
|
|
|
}
|
|
|
|
|
|
- return Builder.CreateCall(vecF, argList);
|
|
|
+ // Nothing else should produce or consume matrices
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- return MatIntrinsicToVec(CI);
|
|
|
}
|
|
|
|
|
|
-// Return GEP if value is Matrix resulting GEP from UDT alloca
|
|
|
-// UDT alloca must be there for library function args
|
|
|
-static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
|
|
|
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
|
|
|
- if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
|
|
|
- Value *ptr = GEP->getPointerOperand();
|
|
|
- if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
|
|
|
- Type *ATy = AI->getAllocatedType();
|
|
|
- if (ATy->isStructTy() && !dxilutil::IsHLSLMatrixType(ATy)) {
|
|
|
- return GEP;
|
|
|
- }
|
|
|
+// Gets the matrix-lowered representation of a value, potentially adding a translation stub.
|
|
|
+// DiscardStub causes any vec-to-mat translation stubs to be deleted,
|
|
|
+// it should be true only if the original instruction will be modified and kept alive.
|
|
|
+// If a new instruction is created and the original marked as dead,
|
|
|
+// then the remove dead instructions pass will take care of removing the stub.
|
|
|
+Value* HLMatrixLowerPass::getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub) {
|
|
|
+ Type *Ty = Val->getType();
|
|
|
+
|
|
|
+ // We're only lowering byval matrices.
|
|
|
+ // Since structs and arrays are always accessed by pointer,
|
|
|
+ // we do not need to worry about a matrix being hidden inside a more complex type.
|
|
|
+ DXASSERT(!Ty->isPointerTy(), "Value cannot be a pointer.");
|
|
|
+ HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty);
|
|
|
+ if (!MatTy) return Val;
|
|
|
+
|
|
|
+ Type *LoweredTy = MatTy.getLoweredVectorTypeForReg();
|
|
|
+
|
|
|
+ // Check if the value is already a vec-to-mat translation stub
|
|
|
+ if (CallInst *Call = dyn_cast<CallInst>(Val)) {
|
|
|
+ if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
|
|
|
+ if (DiscardStub && Call->getNumUses() == 1) {
|
|
|
+ Call->use_begin()->set(UndefValue::get(Call->getType()));
|
|
|
+ addToDeadInsts(Call);
|
|
|
}
|
|
|
+
|
|
|
+ Value *LoweredVal = Call->getArgOperand(0);
|
|
|
+ DXASSERT(LoweredVal->getType() == LoweredTy, "Unexpected already-lowered value type.");
|
|
|
+ return LoweredVal;
|
|
|
}
|
|
|
}
|
|
|
- return nullptr;
|
|
|
+
|
|
|
+ // Return a mat-to-vec translation stub
|
|
|
+ FunctionType *TranslationStubTy = FunctionType::get(LoweredTy, { Ty }, /* isVarArg */ false);
|
|
|
+ Function *TranslationStub = m_matToVecStubs->get(TranslationStubTy);
|
|
|
+ return Builder.CreateCall(TranslationStub, { Val });
|
|
|
}
|
|
|
|
|
|
-// Return GEP if value is Matrix resulting GEP from UDT argument of
|
|
|
-// none-graphics functions.
|
|
|
-static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
|
|
|
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
|
|
|
- if (dxilutil::IsHLSLMatrixType(GEP->getResultElementType())) {
|
|
|
- Value *ptr = GEP->getPointerOperand();
|
|
|
- if (Argument *Arg = dyn_cast<Argument>(ptr)) {
|
|
|
- if (!HM.IsGraphicsShader(Arg->getParent()))
|
|
|
- return GEP;
|
|
|
+// Attempts to retrieve the lowered vector pointer equivalent to a matrix pointer.
|
|
|
+// Returns nullptr if the pointed-to matrix lives in memory that cannot be lowered at this time,
|
|
|
+// for example a buffer or shader inputs/outputs, which are lowered during signature lowering.
|
|
|
+Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub) {
|
|
|
+ if (!HLMatrixType::isMatrixPtrOrArrayPtr(Ptr->getType()))
|
|
|
+ return nullptr;
|
|
|
+
|
|
|
+ // Matrix pointers can only be derived from Allocas, GlobalVariables or resource accesses.
|
|
|
+ // The first two cases are what this pass must be able to lower, and we should already
|
|
|
+ // have replaced their uses by vector to matrix pointer translation stubs.
|
|
|
+ if (CallInst *Call = dyn_cast<CallInst>(Ptr)) {
|
|
|
+ if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
|
|
|
+ if (DiscardStub && Call->getNumUses() == 1) {
|
|
|
+ Call->use_begin()->set(UndefValue::get(Call->getType()));
|
|
|
+ addToDeadInsts(Call);
|
|
|
}
|
|
|
+ return Call->getArgOperand(0);
|
|
|
}
|
|
|
}
|
|
|
- return nullptr;
|
|
|
-}
|
|
|
|
|
|
-Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- unsigned opcode = GetHLOpcode(CI);
|
|
|
- HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
|
|
|
- Instruction *result = nullptr;
|
|
|
- switch (matOpcode) {
|
|
|
- case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
- case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
- Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
|
|
|
- if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
|
|
|
- GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
|
|
|
- Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
|
|
|
- result = CreateVecMatrixLoad(vecPtr, matPtr->getType()->getPointerElementType(), Builder);
|
|
|
- } else
|
|
|
- result = MatIntrinsicToVec(CI);
|
|
|
- } break;
|
|
|
- case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
- case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
- Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
|
|
|
- if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
|
|
|
- GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
|
|
|
- Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
|
|
|
- Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
- Value *vecVal = UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
|
|
|
- result = CreateVecMatrixStore(vecVal, vecPtr, matVal->getType(), Builder);
|
|
|
- } else
|
|
|
- result = MatIntrinsicToVec(CI);
|
|
|
- } break;
|
|
|
+ // There's one more case to handle.
|
|
|
+ // When compiling shader libraries, signatures won't have been lowered yet.
|
|
|
+ // So we can have a matrix in a struct as an argument,
|
|
|
+ // or an alloca'd struct holding the return value of a call and containing a matrix.
|
|
|
+ Value *RootPtr = Ptr;
|
|
|
+ while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
|
|
|
+ RootPtr = GEP->getPointerOperand();
|
|
|
+
|
|
|
+ Argument *Arg = dyn_cast<Argument>(RootPtr);
|
|
|
+ bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsGraphicsShader(Arg->getParent());
|
|
|
+ if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
|
|
|
+ // Bitcast the matrix pointer to its lowered equivalent.
|
|
|
+ // The HLMatrixBitcast pass will take care of this later.
|
|
|
+ return Builder.CreateBitCast(Ptr, HLMatrixType::getLoweredType(Ptr->getType()));
|
|
|
}
|
|
|
- return result;
|
|
|
+
|
|
|
+ // The pointer must be derived from a resource, we don't handle it in this pass.
|
|
|
+ return nullptr;
|
|
|
}
|
|
|
|
|
|
-Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
|
|
|
- if (isa<AllocaInst>(matPtr)) {
|
|
|
- // Here just create a new matSub call which use vec ptr.
|
|
|
- // Later in TranslateMatSubscript will do the real translation.
|
|
|
- std::vector<Value *> args(CI->getNumArgOperands());
|
|
|
- for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
|
|
|
- args[i] = CI->getArgOperand(i);
|
|
|
- }
|
|
|
- // Change mat ptr into vec ptr.
|
|
|
- args[HLOperandIndex::kMatSubscriptMatOpIdx] =
|
|
|
- matToVecMap[cast<Instruction>(matPtr)];
|
|
|
- std::vector<Type *> paramTyList(CI->getNumArgOperands());
|
|
|
- for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
|
|
|
- paramTyList[i] = args[i]->getType();
|
|
|
- }
|
|
|
+// Bitcasts a value from matrix to vector or vice-versa.
|
|
|
+// This is used to convert to/from arguments/return values since we don't
|
|
|
+// lower signatures in this pass. The later HLMatrixBitcastLower pass fixes this.
|
|
|
+Value *HLMatrixLowerPass::bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder) {
|
|
|
+ Type *SrcTy = SrcVal->getType();
|
|
|
+ DXASSERT_NOMSG(!SrcTy->isPointerTy());
|
|
|
|
|
|
- FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
|
|
|
- unsigned opcode = GetHLOpcode(CI);
|
|
|
- Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
|
|
|
- return Builder.CreateCall(opFunc, args);
|
|
|
- } else
|
|
|
- return MatIntrinsicToVec(CI);
|
|
|
+ // We store and load from a temporary alloca, bitcasting either on the store pointer
|
|
|
+ // or on the load pointer.
|
|
|
+ IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
|
|
|
+ Value *Alloca = AllocaBuilder.CreateAlloca(DstTyAlloca ? DstTy : SrcTy);
|
|
|
+ Value *BitCastedAlloca = Builder.CreateBitCast(Alloca, (DstTyAlloca ? SrcTy : DstTy)->getPointerTo());
|
|
|
+ Builder.CreateStore(SrcVal, DstTyAlloca ? BitCastedAlloca : Alloca);
|
|
|
+ return Builder.CreateLoad(DstTyAlloca ? Alloca : BitCastedAlloca);
|
|
|
}
|
|
|
|
|
|
-Instruction *HLMatrixLowerPass::MatFrExpToVec(CallInst *CI) {
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- FunctionType *FT = CI->getCalledFunction()->getFunctionType();
|
|
|
- Type *RetTy = LowerMatrixType(FT->getReturnType());
|
|
|
- SmallVector<Type *, 4> params;
|
|
|
- for (Type *param : FT->params()) {
|
|
|
- if (!param->isPointerTy()) {
|
|
|
- params.emplace_back(LowerMatrixType(param));
|
|
|
- } else {
|
|
|
- // Lower pointer type for frexp.
|
|
|
- Type *EltTy = LowerMatrixType(param->getPointerElementType());
|
|
|
- params.emplace_back(
|
|
|
- PointerType::get(EltTy, param->getPointerAddressSpace()));
|
|
|
- }
|
|
|
- }
|
|
|
+// Replaces all uses of a matrix value by its lowered vector form,
|
|
|
+// inserting translation stubs for users which still expect a matrix value.
|
|
|
+void HLMatrixLowerPass::replaceAllUsesByLoweredValue(Instruction* MatInst, Value* VecVal) {
|
|
|
+ if (VecVal == nullptr || VecVal == MatInst) return;
|
|
|
|
|
|
- Type *VecFT = FunctionType::get(RetTy, params, false);
|
|
|
+ DXASSERT(HLMatrixType::getLoweredType(MatInst->getType()) == VecVal->getType(),
|
|
|
+ "Unexpected lowered value type.");
|
|
|
|
|
|
- HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- Function *vecF =
|
|
|
- GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(VecFT), group,
|
|
|
- static_cast<unsigned>(IntrinsicOp::IOP_frexp));
|
|
|
+ Instruction *VecToMatStub = nullptr;
|
|
|
|
|
|
- SmallVector<Value *, 4> argList;
|
|
|
- auto paramTyIt = params.begin();
|
|
|
- for (Value *arg : CI->arg_operands()) {
|
|
|
- Type *Ty = arg->getType();
|
|
|
- Type *ParamTy = *(paramTyIt++);
|
|
|
+ while (!MatInst->use_empty()) {
|
|
|
+ Use &ValUse = *MatInst->use_begin();
|
|
|
|
|
|
- if (Ty != ParamTy)
|
|
|
- argList.emplace_back(UndefValue::get(ParamTy));
|
|
|
- else
|
|
|
- argList.emplace_back(arg);
|
|
|
- }
|
|
|
+ // Handle non-matrix cases, just point to the new value.
|
|
|
+ if (MatInst->getType() == VecVal->getType()) {
|
|
|
+ ValUse.set(VecVal);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- return Builder.CreateCall(vecF, argList);
|
|
|
-}
|
|
|
+ // If the user is already a matrix-to-vector translation stub,
|
|
|
+ // we can now replace it by the proper vector value.
|
|
|
+ if (CallInst *Call = dyn_cast<CallInst>(ValUse.getUser())) {
|
|
|
+ if (m_matToVecStubs->contains(Call->getCalledFunction())) {
|
|
|
+ Call->replaceAllUsesWith(VecVal);
|
|
|
+ ValUse.set(UndefValue::get(MatInst->getType()));
|
|
|
+ addToDeadInsts(Call);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
-Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- unsigned opcode = GetHLOpcode(CI);
|
|
|
+ // Otherwise, the user should point to a vector-to-matrix translation
|
|
|
+ // stub of the new vector value.
|
|
|
+ if (VecToMatStub == nullptr) {
|
|
|
+ FunctionType *TranslationStubTy = FunctionType::get(
|
|
|
+ MatInst->getType(), { VecVal->getType() }, /* isVarArg */ false);
|
|
|
+ Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
|
|
|
|
|
|
- if (opcode == static_cast<unsigned>(IntrinsicOp::IOP_frexp))
|
|
|
- return MatFrExpToVec(CI);
|
|
|
+ Instruction *PrevInst = dyn_cast<Instruction>(VecVal);
|
|
|
+ if (PrevInst == nullptr) PrevInst = MatInst;
|
|
|
|
|
|
- Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
|
|
|
+ IRBuilder<> Builder(dxilutil::SkipAllocas(PrevInst->getNextNode()));
|
|
|
+ VecToMatStub = Builder.CreateCall(TranslationStub, { VecVal });
|
|
|
+ }
|
|
|
|
|
|
- HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
+ ValUse.set(VecToMatStub);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Replaces all uses of a matrix or matrix array alloca or global variable by its lowered equivalent.
|
|
|
+// This doesn't lower the users, but will insert a translation stub from the lowered value pointer
|
|
|
+// back to the matrix value pointer, and recreate any GEPs around the new pointer.
|
|
|
+// Before: User(GEP(MatrixArrayAlloca))
|
|
|
+// After: User(VecToMatPtrStub(GEP'(VectorArrayAlloca)))
|
|
|
+void HLMatrixLowerPass::replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr) {
|
|
|
+ DXASSERT_NOMSG(HLMatrixType::isMatrixPtrOrArrayPtr(MatPtr->getType()));
|
|
|
+ DXASSERT_NOMSG(LoweredPtr->getType() == HLMatrixType::getLoweredType(MatPtr->getType()));
|
|
|
+
|
|
|
+ SmallVector<Value*, 4> GEPIdxStack;
|
|
|
+ GEPIdxStack.emplace_back(ConstantInt::get(Type::getInt32Ty(MatPtr->getContext()), 0));
|
|
|
+ replaceAllVariableUses(GEPIdxStack, MatPtr, LoweredPtr);
|
|
|
+}
|
|
|
+
|
|
|
+void HLMatrixLowerPass::replaceAllVariableUses(
|
|
|
+ SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr) {
|
|
|
+ while (!StackTopPtr->use_empty()) {
|
|
|
+ llvm::Use &Use = *StackTopPtr->use_begin();
|
|
|
+ if (GEPOperator *GEP = dyn_cast<GEPOperator>(Use.getUser())) {
|
|
|
+ DXASSERT(GEP->getNumIndices() >= 1, "Unexpected degenerate GEP.");
|
|
|
+ DXASSERT(cast<ConstantInt>(*GEP->idx_begin())->isZero(), "Unexpected non-zero first GEP index.");
|
|
|
+
|
|
|
+ // Recurse in GEP to find actual users
|
|
|
+ for (auto It = GEP->idx_begin() + 1; It != GEP->idx_end(); ++It)
|
|
|
+ GEPIdxStack.emplace_back(*It);
|
|
|
+ replaceAllVariableUses(GEPIdxStack, GEP, LoweredPtr);
|
|
|
+ GEPIdxStack.erase(GEPIdxStack.end() - (GEP->getNumIndices() - 1), GEPIdxStack.end());
|
|
|
+
|
|
|
+ // Discard the GEP
|
|
|
+ DXASSERT_NOMSG(GEP->use_empty());
|
|
|
+ if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP)) {
|
|
|
+ Use.set(UndefValue::get(Use->getType()));
|
|
|
+ addToDeadInsts(GEPInst);
|
|
|
+ } else {
|
|
|
+ // constant GEP
|
|
|
+ cast<Constant>(GEP)->destroyConstant();
|
|
|
+ }
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
|
|
|
+ if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Use.getUser())) {
|
|
|
+ DXASSERT(CE->getOpcode() == Instruction::AddrSpaceCast,
|
|
|
+ "Unexpected constant user");
|
|
|
+ replaceAllVariableUses(GEPIdxStack, CE, LoweredPtr);
|
|
|
+ DXASSERT_NOMSG(CE->use_empty());
|
|
|
+ CE->destroyConstant();
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- SmallVector<Value *, 4> argList;
|
|
|
- for (Value *arg : CI->arg_operands()) {
|
|
|
- Type *Ty = arg->getType();
|
|
|
- if (dxilutil::IsHLSLMatrixType(Ty)) {
|
|
|
- argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
|
|
|
- } else
|
|
|
- argList.emplace_back(arg);
|
|
|
- }
|
|
|
+ if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(Use.getUser())) {
|
|
|
+ replaceAllVariableUses(GEPIdxStack, CI, LoweredPtr);
|
|
|
+ Use.set(UndefValue::get(Use->getType()));
|
|
|
+ addToDeadInsts(CI);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- return Builder.CreateCall(vecF, argList);
|
|
|
-}
|
|
|
+ // Recreate the same GEP sequence, if any, on the lowered pointer
|
|
|
+ IRBuilder<> Builder(cast<Instruction>(Use.getUser()));
|
|
|
+ Value *LoweredStackTopPtr = GEPIdxStack.size() == 1
|
|
|
+ ? LoweredPtr : Builder.CreateGEP(LoweredPtr, GEPIdxStack);
|
|
|
|
|
|
-Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
|
|
|
- Type *ResultTy = LowerMatrixType(CI->getType());
|
|
|
- UndefValue *tmp = UndefValue::get(ResultTy);
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
|
|
|
- bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
|
|
|
-
|
|
|
- Constant *one = isFloat
|
|
|
- ? ConstantFP::get(ResultTy->getVectorElementType(), 1)
|
|
|
- : ConstantInt::get(ResultTy->getVectorElementType(), 1);
|
|
|
- Constant *oneVec = ConstantVector::getSplat(ResultTy->getVectorNumElements(), one);
|
|
|
-
|
|
|
- Instruction *Result = nullptr;
|
|
|
- switch (opcode) {
|
|
|
- case HLUnaryOpcode::Plus: {
|
|
|
- // This is actually a no-op, but the structure of the code here requires
|
|
|
- // that we create an instruction.
|
|
|
- Constant *zero = Constant::getNullValue(ResultTy);
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFAdd(tmp, zero);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateAdd(tmp, zero);
|
|
|
- } break;
|
|
|
- case HLUnaryOpcode::Minus: {
|
|
|
- Constant *zero = Constant::getNullValue(ResultTy);
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFSub(zero, tmp);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateSub(zero, tmp);
|
|
|
- } break;
|
|
|
- case HLUnaryOpcode::LNot: {
|
|
|
- Constant *zero = Constant::getNullValue(ResultTy);
|
|
|
- if (isFloat)
|
|
|
- Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UEQ, tmp, zero);
|
|
|
- else
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, zero);
|
|
|
- } break;
|
|
|
- case HLUnaryOpcode::Not: {
|
|
|
- Constant *allOneBits = Constant::getAllOnesValue(ResultTy);
|
|
|
- Result = BinaryOperator::CreateXor(tmp, allOneBits);
|
|
|
- } break;
|
|
|
- case HLUnaryOpcode::PostInc:
|
|
|
- case HLUnaryOpcode::PreInc:
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFAdd(tmp, oneVec);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateAdd(tmp, oneVec);
|
|
|
- break;
|
|
|
- case HLUnaryOpcode::PostDec:
|
|
|
- case HLUnaryOpcode::PreDec:
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFSub(tmp, oneVec);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateSub(tmp, oneVec);
|
|
|
- break;
|
|
|
- default:
|
|
|
- DXASSERT(0, "not implement");
|
|
|
- return nullptr;
|
|
|
+ // Generate a stub translating the vector pointer back to a matrix pointer,
|
|
|
+ // such that consuming instructions are unaffected.
|
|
|
+ FunctionType *TranslationStubTy = FunctionType::get(
|
|
|
+ StackTopPtr->getType(), { LoweredStackTopPtr->getType() }, /* isVarArg */ false);
|
|
|
+ Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
|
|
|
+ Use.set(Builder.CreateCall(TranslationStub, { LoweredStackTopPtr }));
|
|
|
}
|
|
|
- Builder.Insert(Result);
|
|
|
- return Result;
|
|
|
}
|
|
|
|
|
|
-Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
|
|
|
- Type *ResultTy = LowerMatrixType(CI->getType());
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
|
|
|
- Type *OpTy = LowerMatrixType(
|
|
|
- CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
|
|
|
- UndefValue *tmp = UndefValue::get(OpTy);
|
|
|
- bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
|
|
|
+void HLMatrixLowerPass::lowerGlobal(GlobalVariable *Global) {
|
|
|
+ if (Global->user_empty()) return;
|
|
|
|
|
|
- Instruction *Result = nullptr;
|
|
|
+ PointerType *LoweredPtrTy = cast<PointerType>(HLMatrixType::getLoweredType(Global->getType()));
|
|
|
+ DXASSERT_NOMSG(LoweredPtrTy != Global->getType());
|
|
|
|
|
|
- switch (opcode) {
|
|
|
- case HLBinaryOpcode::Add:
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFAdd(tmp, tmp);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateAdd(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::Sub:
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFSub(tmp, tmp);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateSub(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::Mul:
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFMul(tmp, tmp);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateMul(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::Div:
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFDiv(tmp, tmp);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateSDiv(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::Rem:
|
|
|
- if (isFloat)
|
|
|
- Result = BinaryOperator::CreateFRem(tmp, tmp);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateSRem(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::And:
|
|
|
- Result = BinaryOperator::CreateAnd(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::Or:
|
|
|
- Result = BinaryOperator::CreateOr(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::Xor:
|
|
|
- Result = BinaryOperator::CreateXor(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::Shl: {
|
|
|
- Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
- DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
|
|
|
- "must be matrix type here");
|
|
|
- Result = BinaryOperator::CreateShl(tmp, tmp);
|
|
|
- } break;
|
|
|
- case HLBinaryOpcode::Shr: {
|
|
|
- Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
- DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
|
|
|
- "must be matrix type here");
|
|
|
- Result = BinaryOperator::CreateAShr(tmp, tmp);
|
|
|
- } break;
|
|
|
- case HLBinaryOpcode::LT:
|
|
|
- if (isFloat)
|
|
|
- Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
|
|
|
- else
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::GT:
|
|
|
- if (isFloat)
|
|
|
- Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
|
|
|
- else
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::LE:
|
|
|
- if (isFloat)
|
|
|
- Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
|
|
|
- else
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::GE:
|
|
|
- if (isFloat)
|
|
|
- Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
|
|
|
- else
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::EQ:
|
|
|
- if (isFloat)
|
|
|
- Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
|
|
|
- else
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::NE:
|
|
|
- if (isFloat)
|
|
|
- Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
|
|
|
- else
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::UDiv:
|
|
|
- Result = BinaryOperator::CreateUDiv(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::URem:
|
|
|
- Result = BinaryOperator::CreateURem(tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::UShr: {
|
|
|
- Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
- DXASSERT_LOCALVAR(op1, dxilutil::IsHLSLMatrixType(op1->getType()),
|
|
|
- "must be matrix type here");
|
|
|
- Result = BinaryOperator::CreateLShr(tmp, tmp);
|
|
|
- } break;
|
|
|
- case HLBinaryOpcode::ULT:
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::UGT:
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::ULE:
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::UGE:
|
|
|
- Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
|
|
|
- break;
|
|
|
- case HLBinaryOpcode::LAnd:
|
|
|
- case HLBinaryOpcode::LOr: {
|
|
|
- Value *vecZero = Constant::getNullValue(ResultTy);
|
|
|
- Instruction *cmpL;
|
|
|
- if (isFloat)
|
|
|
- cmpL = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, vecZero);
|
|
|
- else
|
|
|
- cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, vecZero);
|
|
|
- Builder.Insert(cmpL);
|
|
|
-
|
|
|
- Instruction *cmpR;
|
|
|
- if (isFloat)
|
|
|
- cmpR =
|
|
|
- CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, vecZero);
|
|
|
- else
|
|
|
- cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, vecZero);
|
|
|
- Builder.Insert(cmpR);
|
|
|
-
|
|
|
- // How to map l, r back? Need check opcode
|
|
|
- if (opcode == HLBinaryOpcode::LOr)
|
|
|
- Result = BinaryOperator::CreateOr(cmpL, cmpR);
|
|
|
- else
|
|
|
- Result = BinaryOperator::CreateAnd(cmpL, cmpR);
|
|
|
- break;
|
|
|
- }
|
|
|
- default:
|
|
|
- DXASSERT(0, "not implement");
|
|
|
- return nullptr;
|
|
|
- }
|
|
|
- Builder.Insert(Result);
|
|
|
- return Result;
|
|
|
-}
|
|
|
+ Constant *LoweredInitVal = Global->hasInitializer()
|
|
|
+ ? lowerConstInitVal(Global->getInitializer()) : nullptr;
|
|
|
+ GlobalVariable *LoweredGlobal = new GlobalVariable(*m_pModule, LoweredPtrTy->getElementType(),
|
|
|
+ Global->isConstant(), Global->getLinkage(), LoweredInitVal,
|
|
|
+ Global->getName() + ".v", /*InsertBefore*/ nullptr, Global->getThreadLocalMode(),
|
|
|
+ Global->getType()->getAddressSpace());
|
|
|
|
|
|
-// Create BitCast if ptr, otherwise, create alloca of new type, write to bitcast of alloca, and return load from alloca
|
|
|
-// If bOrigAllocaTy is true: create alloca of old type instead, write to alloca, and return load from bitcast of alloca
|
|
|
-static Instruction *BitCastValueOrPtr(Value* V, Instruction *Insert, Type *Ty, bool bOrigAllocaTy = false, const Twine &Name = "") {
|
|
|
- IRBuilder<> Builder(Insert);
|
|
|
- if (Ty->isPointerTy()) {
|
|
|
- // If pointer, we can bitcast directly
|
|
|
- return cast<Instruction>(Builder.CreateBitCast(V, Ty, Name));
|
|
|
- } else {
|
|
|
- // If value, we have to alloca, store to bitcast ptr, and load
|
|
|
- IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Insert));
|
|
|
- Type *allocaTy = bOrigAllocaTy ? V->getType() : Ty;
|
|
|
- Type *otherTy = bOrigAllocaTy ? Ty : V->getType();
|
|
|
- Instruction *allocaInst = AllocaBuilder.CreateAlloca(allocaTy);
|
|
|
- Instruction *bitCast = cast<Instruction>(Builder.CreateBitCast(allocaInst, otherTy->getPointerTo()));
|
|
|
- Builder.CreateStore(V, bOrigAllocaTy ? allocaInst : bitCast);
|
|
|
- return Builder.CreateLoad(bOrigAllocaTy ? bitCast : allocaInst, Name);
|
|
|
+ // Add debug info.
|
|
|
+ if (m_HasDbgInfo) {
|
|
|
+ DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
|
|
|
+ HLModule::UpdateGlobalVariableDebugInfo(Global, Finder, LoweredGlobal);
|
|
|
}
|
|
|
+
|
|
|
+ replaceAllVariableUses(Global, LoweredGlobal);
|
|
|
+ Global->removeDeadConstantUsers();
|
|
|
+ Global->eraseFromParent();
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
|
|
|
- Value *vecVal = nullptr;
|
|
|
-
|
|
|
- if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
|
|
|
- hlsl::HLOpcodeGroup group =
|
|
|
- hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- switch (group) {
|
|
|
- case HLOpcodeGroup::HLIntrinsic: {
|
|
|
- vecVal = MatIntrinsicToVec(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLSelect: {
|
|
|
- vecVal = MatIntrinsicToVec(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLBinOp: {
|
|
|
- vecVal = TrivialMatBinOpToVec(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLUnOp: {
|
|
|
- vecVal = TrivialMatUnOpToVec(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLCast: {
|
|
|
- vecVal = MatCastToVec(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLInit: {
|
|
|
- vecVal = MatIntrinsicToVec(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
- vecVal = MatLdStToVec(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLSubscript: {
|
|
|
- vecVal = MatSubscriptToVec(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::NotHL: {
|
|
|
- // Translate user function return
|
|
|
- vecVal = BitCastValueOrPtr( matInst,
|
|
|
- matInst->getNextNode(),
|
|
|
- HLMatrixLower::LowerMatrixType(matInst->getType()),
|
|
|
- /*bOrigAllocaTy*/ false,
|
|
|
- matInst->getName());
|
|
|
- // matrix equivalent of this new vector will be the original, retained user call
|
|
|
- vecToMatMap[vecVal] = matInst;
|
|
|
- } break;
|
|
|
- default:
|
|
|
- DXASSERT(0, "invalid inst");
|
|
|
- }
|
|
|
- } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
|
|
|
- Type *Ty = AI->getAllocatedType();
|
|
|
- Type *matTy = Ty;
|
|
|
-
|
|
|
- IRBuilder<> AllocaBuilder(AI);
|
|
|
- if (Ty->isArrayTy()) {
|
|
|
- Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType(), /*forMem*/ true);
|
|
|
- vecTy = vecTy->getPointerElementType();
|
|
|
- vecVal = AllocaBuilder.CreateAlloca(vecTy, nullptr, AI->getName());
|
|
|
- } else {
|
|
|
- Type *vecTy = HLMatrixLower::LowerMatrixType(matTy, /*forMem*/ true);
|
|
|
- vecVal = AllocaBuilder.CreateAlloca(vecTy, nullptr, AI->getName());
|
|
|
- }
|
|
|
- // Update debug info.
|
|
|
- DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
|
|
|
- if (DDI) {
|
|
|
- LLVMContext &Context = AI->getContext();
|
|
|
- Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
|
|
|
- Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
|
|
|
- Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecVal));
|
|
|
- IRBuilder<> debugBuilder(DDI);
|
|
|
- debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
|
|
|
+Constant *HLMatrixLowerPass::lowerConstInitVal(Constant *Val) {
|
|
|
+ Type *Ty = Val->getType();
|
|
|
+
|
|
|
+ // If it's an array of matrices, recurse for each element or nested array
|
|
|
+ if (ArrayType *ArrayTy = dyn_cast<ArrayType>(Ty)) {
|
|
|
+ SmallVector<Constant*, 4> LoweredElems;
|
|
|
+ unsigned NumElems = ArrayTy->getNumElements();
|
|
|
+ LoweredElems.reserve(NumElems);
|
|
|
+ for (unsigned ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
|
|
|
+ Constant *ArrayElem = Val->getAggregateElement(ElemIdx);
|
|
|
+ LoweredElems.emplace_back(lowerConstInitVal(ArrayElem));
|
|
|
}
|
|
|
|
|
|
- if (HLModule::HasPreciseAttributeWithMetadata(AI))
|
|
|
- HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
|
|
|
-
|
|
|
- } else if (GetIfMatrixGEPOfUDTAlloca(matInst) ||
|
|
|
- GetIfMatrixGEPOfUDTArg(matInst, *m_pHLModule)) {
|
|
|
- // If GEP from alloca of non-matrix UDT, bitcast
|
|
|
- IRBuilder<> Builder(matInst->getNextNode());
|
|
|
- vecVal = Builder.CreateBitCast(matInst,
|
|
|
- HLMatrixLower::LowerMatrixType(
|
|
|
- matInst->getType()->getPointerElementType() )->getPointerTo());
|
|
|
- // matrix equivalent of this new vector will be the original, retained GEP
|
|
|
- vecToMatMap[vecVal] = matInst;
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid inst");
|
|
|
+ Type *LoweredElemTy = HLMatrixType::getLoweredType(ArrayTy->getElementType());
|
|
|
+ ArrayType *LoweredArrayTy = ArrayType::get(LoweredElemTy, NumElems);
|
|
|
+ return ConstantArray::get(LoweredArrayTy, LoweredElems);
|
|
|
}
|
|
|
- if (vecVal) {
|
|
|
- matToVecMap[matInst] = vecVal;
|
|
|
- }
|
|
|
-}
|
|
|
|
|
|
-// Replace matInst with vecVal on matUseInst.
|
|
|
-void HLMatrixLowerPass::TrivialMatUnOpReplace(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *matUseInst) {
|
|
|
- (void)(matVal); // Unused
|
|
|
- HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
|
|
|
- switch (opcode) {
|
|
|
- case HLUnaryOpcode::Plus: // add(x, 0)
|
|
|
- // Ideally we'd get completely rid of the instruction for +mat,
|
|
|
- // but matToVecMap needs to point to some instruction.
|
|
|
- case HLUnaryOpcode::Not: // xor(x, -1)
|
|
|
- case HLUnaryOpcode::LNot: // cmpeq(x, 0)
|
|
|
- case HLUnaryOpcode::PostInc:
|
|
|
- case HLUnaryOpcode::PreInc:
|
|
|
- case HLUnaryOpcode::PostDec:
|
|
|
- case HLUnaryOpcode::PreDec:
|
|
|
- vecUseInst->setOperand(0, vecVal);
|
|
|
- break;
|
|
|
- case HLUnaryOpcode::Minus: // sub(0, x)
|
|
|
- vecUseInst->setOperand(1, vecVal);
|
|
|
- break;
|
|
|
- case HLUnaryOpcode::Invalid:
|
|
|
- case HLUnaryOpcode::NumOfUO:
|
|
|
- DXASSERT(false, "Unexpected HL unary opcode.");
|
|
|
- break;
|
|
|
- }
|
|
|
-}
|
|
|
+ // Otherwise it's a matrix, lower it to a vector
|
|
|
+ HLMatrixType MatTy = HLMatrixType::cast(Ty);
|
|
|
+ DXASSERT_NOMSG(isa<StructType>(Ty));
|
|
|
+ Constant *RowArrayVal = Val->getAggregateElement((unsigned)0);
|
|
|
|
|
|
-// Replace matInst with vecVal on matUseInst.
|
|
|
-void HLMatrixLowerPass::TrivialMatBinOpReplace(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *matUseInst) {
|
|
|
- HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
|
|
|
-
|
|
|
- if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
|
|
|
- if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matVal)
|
|
|
- vecUseInst->setOperand(0, vecVal);
|
|
|
- if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matVal)
|
|
|
- vecUseInst->setOperand(1, vecVal);
|
|
|
- } else {
|
|
|
- if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
|
|
|
- matVal) {
|
|
|
- Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
|
|
|
- vecCmp->setOperand(0, vecVal);
|
|
|
- }
|
|
|
- if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
|
|
|
- matVal) {
|
|
|
- Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
|
|
|
- vecCmp->setOperand(0, vecVal);
|
|
|
+ // Original initializer should have been produced in row/column-major order
|
|
|
+ // depending on the qualifiers of the target variable, so preserve the order.
|
|
|
+ SmallVector<Constant*, 16> MatElems;
|
|
|
+ for (unsigned RowIdx = 0; RowIdx < MatTy.getNumRows(); ++RowIdx) {
|
|
|
+ Constant *RowVal = RowArrayVal->getAggregateElement(RowIdx);
|
|
|
+ for (unsigned ColIdx = 0; ColIdx < MatTy.getNumColumns(); ++ColIdx) {
|
|
|
+ MatElems.emplace_back(RowVal->getAggregateElement(ColIdx));
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ Constant *Vec = ConstantVector::get(MatElems);
|
|
|
+
|
|
|
+ // Matrix elements are always in register representation,
|
|
|
+ // but the lowered global variable is of vector type in
|
|
|
+ // its memory representation, so we must convert here.
|
|
|
+
|
|
|
+ // This will produce a constant so we can use an IRBuilder without a valid insertion point.
|
|
|
+ IRBuilder<> DummyBuilder(Val->getContext());
|
|
|
+ return cast<Constant>(MatTy.emitLoweredRegToMem(Vec, DummyBuilder));
|
|
|
}
|
|
|
|
|
|
-static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
|
|
|
- llvm::FunctionType *MadFuncTy =
|
|
|
- llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
|
|
|
+AllocaInst *HLMatrixLowerPass::lowerAlloca(AllocaInst *MatAlloca) {
|
|
|
+ PointerType *LoweredAllocaTy = cast<PointerType>(HLMatrixType::getLoweredType(MatAlloca->getType()));
|
|
|
|
|
|
- Function *MAD =
|
|
|
- GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
|
|
|
- (unsigned)madOp);
|
|
|
- return MAD;
|
|
|
+ IRBuilder<> Builder(MatAlloca);
|
|
|
+ AllocaInst *LoweredAlloca = Builder.CreateAlloca(
|
|
|
+ LoweredAllocaTy->getElementType(), nullptr, MatAlloca->getName());
|
|
|
+
|
|
|
+ // Update debug info.
|
|
|
+ if (DbgDeclareInst *DbgDeclare = llvm::FindAllocaDbgDeclare(MatAlloca)) {
|
|
|
+ LLVMContext &Context = MatAlloca->getContext();
|
|
|
+ Value *DbgDeclareVar = MetadataAsValue::get(Context, DbgDeclare->getRawVariable());
|
|
|
+ Value *DbgDeclareExpr = MetadataAsValue::get(Context, DbgDeclare->getRawExpression());
|
|
|
+ Value *ValueMetadata = MetadataAsValue::get(Context, ValueAsMetadata::get(LoweredAlloca));
|
|
|
+ IRBuilder<> DebugBuilder(DbgDeclare);
|
|
|
+ DebugBuilder.CreateCall(DbgDeclare->getCalledFunction(), { ValueMetadata, DbgDeclareVar, DbgDeclareExpr });
|
|
|
+ }
|
|
|
+
|
|
|
+ if (HLModule::HasPreciseAttributeWithMetadata(MatAlloca))
|
|
|
+ HLModule::MarkPreciseAttributeWithMetadata(LoweredAlloca);
|
|
|
+
|
|
|
+ replaceAllVariableUses(MatAlloca, LoweredAlloca);
|
|
|
+
|
|
|
+ return LoweredAlloca;
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatMatMul(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *mulInst, bool isSigned) {
|
|
|
- (void)(matVal); // Unused; retrieved from matToVecMap directly
|
|
|
- DXASSERT(matToVecMap.count(mulInst), "must have vec version");
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
|
|
|
- // Already translated.
|
|
|
- if (!isa<CallInst>(vecUseInst))
|
|
|
- return;
|
|
|
- Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
|
- Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
-
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
|
|
|
- unsigned rCol, rRow;
|
|
|
- GetMatrixInfo(RVal->getType(), rCol, rRow);
|
|
|
- DXASSERT_NOMSG(col == rRow);
|
|
|
-
|
|
|
- bool isFloat = EltTy->isFloatingPointTy();
|
|
|
-
|
|
|
- Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
|
|
|
- IRBuilder<> Builder(vecUseInst);
|
|
|
-
|
|
|
- Value *lMat = matToVecMap[cast<Instruction>(LVal)];
|
|
|
- Value *rMat = matToVecMap[cast<Instruction>(RVal)];
|
|
|
-
|
|
|
- auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
|
|
|
- unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
|
|
|
- unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
|
|
|
- Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
|
|
|
- Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
|
|
|
- return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
|
|
|
- : Builder.CreateMul(lMatElt, rMatElt);
|
|
|
- };
|
|
|
-
|
|
|
- IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
|
|
|
- Type *opcodeTy = Builder.getInt32Ty();
|
|
|
- Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
|
|
|
- *m_pHLModule->GetModule());
|
|
|
- Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
-
|
|
|
- auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
|
|
|
- Value *acc) -> Value * {
|
|
|
- unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
|
|
|
- unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
|
|
|
- Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
|
|
|
- Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
|
|
|
- return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
|
|
|
- };
|
|
|
-
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
- for (unsigned c = 0; c < rCol; c++) {
|
|
|
- unsigned lc = 0;
|
|
|
- Value *tmpVal = CreateOneEltMul(r, lc, c);
|
|
|
-
|
|
|
- for (lc = 1; lc < col; lc++) {
|
|
|
- tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
|
|
|
- }
|
|
|
- unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
|
|
|
- retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
|
|
|
+void HLMatrixLowerPass::lowerInstruction(Instruction* Inst) {
|
|
|
+ if (CallInst *Call = dyn_cast<CallInst>(Inst)) {
|
|
|
+ Value *LoweredValue = lowerCall(Call);
|
|
|
+
|
|
|
+ // lowerCall returns the lowered value iff we should discard
|
|
|
+ // the original matrix instruction and replace all of its uses
|
|
|
+ // by the lowered value. It returns nullptr to opt-out of this.
|
|
|
+ if (LoweredValue != nullptr) {
|
|
|
+ replaceAllUsesByLoweredValue(Call, LoweredValue);
|
|
|
+ addToDeadInsts(Inst);
|
|
|
}
|
|
|
}
|
|
|
+ else if (ReturnInst *Return = dyn_cast<ReturnInst>(Inst)) {
|
|
|
+ lowerReturn(Return);
|
|
|
+ }
|
|
|
+ else
|
|
|
+ llvm_unreachable("Unexpected matrix instruction type.");
|
|
|
+}
|
|
|
+
|
|
|
+void HLMatrixLowerPass::lowerReturn(ReturnInst* Return) {
|
|
|
+ Value *RetVal = Return->getReturnValue();
|
|
|
+ Type *RetTy = RetVal->getType();
|
|
|
+ DXASSERT_LOCALVAR(RetTy, !RetTy->isPointerTy(), "Unexpected matrix returned by pointer.");
|
|
|
+
|
|
|
+ IRBuilder<> Builder(Return);
|
|
|
+ Value *LoweredRetVal = getLoweredByValOperand(RetVal, Builder, /* DiscardStub */ true);
|
|
|
+
|
|
|
+ // Since we're not lowering the signature, we can't return the lowered value directly,
|
|
|
+ // so insert a bitcast, which HLMatrixBitcastLower knows how to eliminate.
|
|
|
+ Value *BitCastedRetVal = bitCastValue(LoweredRetVal, RetVal->getType(), /* DstTyAlloca */ false, Builder);
|
|
|
+ Return->setOperand(0, BitCastedRetVal);
|
|
|
+}
|
|
|
+
|
|
|
+Value *HLMatrixLowerPass::lowerCall(CallInst *Call) {
|
|
|
+ HLOpcodeGroup OpcodeGroup = GetHLOpcodeGroupByName(Call->getCalledFunction());
|
|
|
+ return OpcodeGroup == HLOpcodeGroup::NotHL
|
|
|
+ ? lowerNonHLCall(Call) : lowerHLOperation(Call, OpcodeGroup);
|
|
|
+}
|
|
|
+
|
|
|
+Value *HLMatrixLowerPass::lowerNonHLCall(CallInst *Call) {
|
|
|
+ // First, handle any operand of matrix-derived type
|
|
|
+ // We don't lower the callee's signature in this pass,
|
|
|
+ // so, for any matrix-typed parameter, we create a bitcast from the
|
|
|
+ // lowered vector back to the matrix type, which the later HLMatrixBitcastLower
|
|
|
+ // pass knows how to eliminate.
|
|
|
+ IRBuilder<> PreCallBuilder(Call);
|
|
|
+ unsigned NumArgs = Call->getNumArgOperands();
|
|
|
+ for (unsigned ArgIdx = 0; ArgIdx < NumArgs; ++ArgIdx) {
|
|
|
+ Use &ArgUse = Call->getArgOperandUse(ArgIdx);
|
|
|
+ if (ArgUse->getType()->isPointerTy()) {
|
|
|
+ // Byref arg
|
|
|
+ Value *LoweredArg = tryGetLoweredPtrOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
|
|
|
+ if (LoweredArg != nullptr) {
|
|
|
+ // Pointer to a matrix we've lowered, insert a bitcast back to matrix pointer type.
|
|
|
+ Value *BitCastedArg = PreCallBuilder.CreateBitCast(LoweredArg, ArgUse->getType());
|
|
|
+ ArgUse.set(BitCastedArg);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ // Byvalue arg
|
|
|
+ Value *LoweredArg = getLoweredByValOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
|
|
|
+ if (LoweredArg == ArgUse.get()) continue;
|
|
|
|
|
|
- Instruction *matmatMul = cast<Instruction>(retVal);
|
|
|
- // Replace vec transpose function call with shuf.
|
|
|
- vecUseInst->replaceAllUsesWith(matmatMul);
|
|
|
- AddToDeadInsts(vecUseInst);
|
|
|
- matToVecMap[mulInst] = matmatMul;
|
|
|
-}
|
|
|
+ Value *BitCastedArg = bitCastValue(LoweredArg, ArgUse->getType(), /* DstTyAlloca */ false, PreCallBuilder);
|
|
|
+ ArgUse.set(BitCastedArg);
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatVecMul(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *mulInst, bool isSigned) {
|
|
|
- // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
|
- Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
+ // Now check the return type
|
|
|
+ HLMatrixType RetMatTy = HLMatrixType::dyn_cast(Call->getType());
|
|
|
+ if (!RetMatTy) {
|
|
|
+ DXASSERT(!HLMatrixType::isMatrixPtrOrArrayPtr(Call->getType()),
|
|
|
+ "Unexpected user call returning a matrix by pointer.");
|
|
|
+ // Nothing to replace, other instructions can consume a non-matrix return type.
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
|
|
|
- DXASSERT_NOMSG(RVal->getType()->getVectorNumElements() == col);
|
|
|
+ // The callee returns a matrix, and we don't lower signatures in this pass.
|
|
|
+ // We perform a sketchy bitcast to the lowered register-representation type,
|
|
|
+ // which the later HLMatrixBitcastLower pass knows how to eliminate.
|
|
|
+ IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Call));
|
|
|
+ Value *LoweredAlloca = AllocaBuilder.CreateAlloca(RetMatTy.getLoweredVectorTypeForReg());
|
|
|
+
|
|
|
+ IRBuilder<> PostCallBuilder(Call->getNextNode());
|
|
|
+ Value *BitCastedAlloca = PostCallBuilder.CreateBitCast(LoweredAlloca, Call->getType()->getPointerTo());
|
|
|
+
|
|
|
+ // This is slightly tricky
|
|
|
+ // We want to replace all uses of the matrix-returning call by the bitcasted value,
|
|
|
+ // but the store to the bitcasted pointer itself is a use of that matrix,
|
|
|
+ // so we need to create the load, replace the uses, and then insert the store.
|
|
|
+ LoadInst *LoweredVal = PostCallBuilder.CreateLoad(LoweredAlloca);
|
|
|
+ replaceAllUsesByLoweredValue(Call, LoweredVal);
|
|
|
+
|
|
|
+ // Now we can insert the store. Make sure to do so before the load.
|
|
|
+ PostCallBuilder.SetInsertPoint(LoweredVal);
|
|
|
+ PostCallBuilder.CreateStore(Call, BitCastedAlloca);
|
|
|
+
|
|
|
+ // Return nullptr since we did our own uses replacement and we don't want
|
|
|
+ // the matrix instruction to be marked as dead since we're still using it.
|
|
|
+ return nullptr;
|
|
|
+}
|
|
|
|
|
|
- bool isFloat = EltTy->isFloatingPointTy();
|
|
|
+Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup) {
|
|
|
+ IRBuilder<> Builder(Call);
|
|
|
+ switch (OpcodeGroup) {
|
|
|
+ case HLOpcodeGroup::HLIntrinsic:
|
|
|
+ return lowerHLIntrinsic(Call, static_cast<IntrinsicOp>(GetHLOpcode(Call)));
|
|
|
|
|
|
- Value *retVal = llvm::UndefValue::get(mulInst->getType());
|
|
|
- IRBuilder<> Builder(mulInst);
|
|
|
+ case HLOpcodeGroup::HLBinOp:
|
|
|
+ return lowerHLBinaryOperation(
|
|
|
+ Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
|
|
|
+ Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
|
|
|
+ static_cast<HLBinaryOpcode>(GetHLOpcode(Call)), Builder);
|
|
|
|
|
|
- Value *vec = RVal;
|
|
|
- Value *mat = vecVal; // vec version of matInst;
|
|
|
+ case HLOpcodeGroup::HLUnOp:
|
|
|
+ return lowerHLUnaryOperation(
|
|
|
+ Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx),
|
|
|
+ static_cast<HLUnaryOpcode>(GetHLOpcode(Call)), Builder);
|
|
|
|
|
|
- IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
|
|
|
- Type *opcodeTy = Builder.getInt32Ty();
|
|
|
- Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
|
|
|
- *m_pHLModule->GetModule());
|
|
|
- Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
+ case HLOpcodeGroup::HLMatLoadStore:
|
|
|
+ return lowerHLLoadStore(Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
|
|
|
|
|
|
- auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
|
- Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
|
- uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
|
- Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
- return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
|
|
|
- };
|
|
|
+ case HLOpcodeGroup::HLCast:
|
|
|
+ return lowerHLCast(
|
|
|
+ Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Call->getType(),
|
|
|
+ static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
|
|
|
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
- unsigned c = 0;
|
|
|
- Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
|
- uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
|
- Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
+ case HLOpcodeGroup::HLSubscript:
|
|
|
+ return lowerHLSubscript(Call, static_cast<HLSubscriptOpcode>(GetHLOpcode(Call)));
|
|
|
|
|
|
- Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
|
|
|
- : Builder.CreateMul(vecElt, matElt);
|
|
|
+ case HLOpcodeGroup::HLInit:
|
|
|
+ return lowerHLInit(Call);
|
|
|
|
|
|
- for (c = 1; c < col; c++) {
|
|
|
- tmpVal = CreateOneEltMad(r, c, tmpVal);
|
|
|
- }
|
|
|
+ case HLOpcodeGroup::HLSelect:
|
|
|
+ return lowerHLSelect(Call);
|
|
|
|
|
|
- retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
|
|
|
+ default:
|
|
|
+ llvm_unreachable("Unexpected matrix opcode");
|
|
|
}
|
|
|
-
|
|
|
- mulInst->replaceAllUsesWith(retVal);
|
|
|
- AddToDeadInsts(mulInst);
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateVecMatMul(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *mulInst, bool isSigned) {
|
|
|
- Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
|
- // matVal should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
- Value *RVal = vecVal;
|
|
|
-
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
|
|
|
- DXASSERT_NOMSG(LVal->getType()->getVectorNumElements() == row);
|
|
|
-
|
|
|
- bool isFloat = EltTy->isFloatingPointTy();
|
|
|
-
|
|
|
- Value *retVal = llvm::UndefValue::get(mulInst->getType());
|
|
|
- IRBuilder<> Builder(mulInst);
|
|
|
-
|
|
|
- Value *vec = LVal;
|
|
|
- Value *mat = RVal;
|
|
|
-
|
|
|
- IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
|
|
|
- Type *opcodeTy = Builder.getInt32Ty();
|
|
|
- Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
|
|
|
- *m_pHLModule->GetModule());
|
|
|
- Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
-
|
|
|
- auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
|
- Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
|
- uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
|
- Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
- return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
|
|
|
- };
|
|
|
-
|
|
|
- for (unsigned c = 0; c < col; c++) {
|
|
|
- unsigned r = 0;
|
|
|
- Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
|
- uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
|
- Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
-
|
|
|
- Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
|
|
|
- : Builder.CreateMul(vecElt, matElt);
|
|
|
-
|
|
|
- for (r = 1; r < row; r++) {
|
|
|
- tmpVal = CreateOneEltMad(r, c, tmpVal);
|
|
|
- }
|
|
|
+static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
|
|
|
+ Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
|
|
|
+ SmallVector<Type*, 4> ArgTys;
|
|
|
+ ArgTys.reserve(Args.size());
|
|
|
+ for (Value *Arg : Args)
|
|
|
+ ArgTys.emplace_back(Arg->getType());
|
|
|
|
|
|
- retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
|
|
|
- }
|
|
|
+ FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
|
|
|
+ Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
|
|
|
|
|
|
- mulInst->replaceAllUsesWith(retVal);
|
|
|
- AddToDeadInsts(mulInst);
|
|
|
+ return Builder.CreateCall(Func, Args);
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMul(Value *matVal, Value *vecVal,
|
|
|
- CallInst *mulInst, bool isSigned) {
|
|
|
- Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
|
- Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
-
|
|
|
- bool LMat = dxilutil::IsHLSLMatrixType(LVal->getType());
|
|
|
- bool RMat = dxilutil::IsHLSLMatrixType(RVal->getType());
|
|
|
- if (LMat && RMat) {
|
|
|
- TranslateMatMatMul(matVal, vecVal, mulInst, isSigned);
|
|
|
- } else if (LMat) {
|
|
|
- TranslateMatVecMul(matVal, vecVal, mulInst, isSigned);
|
|
|
- } else {
|
|
|
- TranslateVecMatMul(matVal, vecVal, mulInst, isSigned);
|
|
|
- }
|
|
|
-}
|
|
|
+Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
|
|
|
+ IRBuilder<> Builder(Call);
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatTranspose(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *transposeInst) {
|
|
|
- // Matrix value is row major, transpose is cast it to col major.
|
|
|
- TranslateMatMajorCast(matVal, vecVal, transposeInst,
|
|
|
- /*bRowToCol*/ true, /*bTranspose*/ true);
|
|
|
-}
|
|
|
+ // See if this is a matrix-specific intrinsic which we should expand here
|
|
|
+ switch (Opcode) {
|
|
|
+ case IntrinsicOp::IOP_umul:
|
|
|
+ case IntrinsicOp::IOP_mul:
|
|
|
+ return lowerHLMulIntrinsic(
|
|
|
+ Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
|
|
|
+ Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
|
|
|
+ /* Unsigned */ Opcode == IntrinsicOp::IOP_umul, Builder);
|
|
|
+ case IntrinsicOp::IOP_transpose:
|
|
|
+ return lowerHLTransposeIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
|
|
|
+ case IntrinsicOp::IOP_determinant:
|
|
|
+ return lowerHLDeterminantIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Delegate to a lowered intrinsic call
|
|
|
+ SmallVector<Value*, 4> LoweredArgs;
|
|
|
+ LoweredArgs.reserve(Call->getNumArgOperands());
|
|
|
+ for (Value *Arg : Call->arg_operands()) {
|
|
|
+ if (Arg->getType()->isPointerTy()) {
|
|
|
+ // ByRef parameter (for example, frexp's second parameter)
|
|
|
+ // If the argument points to a lowered matrix variable, replace it here,
|
|
|
+ // otherwise preserve the matrix type and let further passes handle the lowering.
|
|
|
+ Value *LoweredArg = tryGetLoweredPtrOperand(Arg, Builder);
|
|
|
+ if (LoweredArg == nullptr) LoweredArg = Arg;
|
|
|
+ LoweredArgs.emplace_back(LoweredArg);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ LoweredArgs.emplace_back(getLoweredByValOperand(Arg, Builder));
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
-static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
|
|
|
- IRBuilder<> &Builder) {
|
|
|
- Value *mul0 = Builder.CreateFMul(m00, m11);
|
|
|
- Value *mul1 = Builder.CreateFMul(m01, m10);
|
|
|
- return Builder.CreateFSub(mul0, mul1);
|
|
|
+ Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
|
|
|
+ return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
|
|
|
+ LoweredRetTy, LoweredArgs, Builder);
|
|
|
}
|
|
|
|
|
|
-static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
|
|
|
- Value *m10, Value *m11, Value *m12,
|
|
|
- Value *m20, Value *m21, Value *m22,
|
|
|
- IRBuilder<> &Builder) {
|
|
|
- Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
|
|
|
- Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
|
|
|
- Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
|
|
|
- deter00 = Builder.CreateFMul(m00, deter00);
|
|
|
- deter01 = Builder.CreateFMul(m01, deter01);
|
|
|
- deter02 = Builder.CreateFMul(m02, deter02);
|
|
|
- Value *result = Builder.CreateFSub(deter00, deter01);
|
|
|
- result = Builder.CreateFAdd(result, deter02);
|
|
|
- return result;
|
|
|
-}
|
|
|
+Value *HLMatrixLowerPass::lowerHLMulIntrinsic(Value* Lhs, Value *Rhs,
|
|
|
+ bool Unsigned, IRBuilder<> &Builder) {
|
|
|
+ HLMatrixType LhsMatTy = HLMatrixType::dyn_cast(Lhs->getType());
|
|
|
+ HLMatrixType RhsMatTy = HLMatrixType::dyn_cast(Rhs->getType());
|
|
|
+ Value* LoweredLhs = getLoweredByValOperand(Lhs, Builder);
|
|
|
+ Value* LoweredRhs = getLoweredByValOperand(Rhs, Builder);
|
|
|
|
|
|
-static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
|
|
|
- Value *m10, Value *m11, Value *m12, Value *m13,
|
|
|
- Value *m20, Value *m21, Value *m22, Value *m23,
|
|
|
- Value *m30, Value *m31, Value *m32, Value *m33,
|
|
|
- IRBuilder<> &Builder) {
|
|
|
- Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
|
|
|
- Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
|
|
|
- Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
|
|
|
- Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
|
|
|
- deter00 = Builder.CreateFMul(m00, deter00);
|
|
|
- deter01 = Builder.CreateFMul(m01, deter01);
|
|
|
- deter02 = Builder.CreateFMul(m02, deter02);
|
|
|
- deter03 = Builder.CreateFMul(m03, deter03);
|
|
|
- Value *result = Builder.CreateFSub(deter00, deter01);
|
|
|
- result = Builder.CreateFAdd(result, deter02);
|
|
|
- result = Builder.CreateFSub(result, deter03);
|
|
|
- return result;
|
|
|
-}
|
|
|
+ DXASSERT(LoweredLhs->getType()->getScalarType() == LoweredRhs->getType()->getScalarType(),
|
|
|
+ "Unexpected element type mismatch in mul intrinsic.");
|
|
|
+ DXASSERT(cast<VectorType>(LoweredLhs->getType()) && cast<VectorType>(LoweredLhs->getType()),
|
|
|
+ "Unexpected scalar in lowered matrix mul intrinsic operands.");
|
|
|
|
|
|
+ Type* ElemTy = LoweredLhs->getType()->getScalarType();
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatDeterminant(Value *matVal, Value *vecVal,
|
|
|
- CallInst *determinantInst) {
|
|
|
- unsigned row, col;
|
|
|
- GetMatrixInfo(matVal->getType(), col, row);
|
|
|
- IRBuilder<> Builder(determinantInst);
|
|
|
- // when row == 1, result is vecVal.
|
|
|
- Value *Result = vecVal;
|
|
|
- if (row == 2) {
|
|
|
- Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
|
|
|
- Value *m01 = Builder.CreateExtractElement(vecVal, 1);
|
|
|
- Value *m10 = Builder.CreateExtractElement(vecVal, 2);
|
|
|
- Value *m11 = Builder.CreateExtractElement(vecVal, 3);
|
|
|
- Result = Determinant2x2(m00, m01, m10, m11, Builder);
|
|
|
+ // Figure out the dimensions of each side
|
|
|
+ unsigned LhsNumRows, LhsNumCols, RhsNumRows, RhsNumCols;
|
|
|
+ if (LhsMatTy && RhsMatTy) {
|
|
|
+ LhsNumRows = LhsMatTy.getNumRows();
|
|
|
+ LhsNumCols = LhsMatTy.getNumColumns();
|
|
|
+ RhsNumRows = RhsMatTy.getNumRows();
|
|
|
+ RhsNumCols = RhsMatTy.getNumColumns();
|
|
|
}
|
|
|
- else if (row == 3) {
|
|
|
- Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
|
|
|
- Value *m01 = Builder.CreateExtractElement(vecVal, 1);
|
|
|
- Value *m02 = Builder.CreateExtractElement(vecVal, 2);
|
|
|
- Value *m10 = Builder.CreateExtractElement(vecVal, 3);
|
|
|
- Value *m11 = Builder.CreateExtractElement(vecVal, 4);
|
|
|
- Value *m12 = Builder.CreateExtractElement(vecVal, 5);
|
|
|
- Value *m20 = Builder.CreateExtractElement(vecVal, 6);
|
|
|
- Value *m21 = Builder.CreateExtractElement(vecVal, 7);
|
|
|
- Value *m22 = Builder.CreateExtractElement(vecVal, 8);
|
|
|
- Result = Determinant3x3(m00, m01, m02,
|
|
|
- m10, m11, m12,
|
|
|
- m20, m21, m22, Builder);
|
|
|
+ else if (LhsMatTy) {
|
|
|
+ LhsNumRows = LhsMatTy.getNumRows();
|
|
|
+ LhsNumCols = LhsMatTy.getNumColumns();
|
|
|
+ RhsNumRows = LoweredRhs->getType()->getVectorNumElements();
|
|
|
+ RhsNumCols = 1;
|
|
|
}
|
|
|
- else if (row == 4) {
|
|
|
- Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
|
|
|
- Value *m01 = Builder.CreateExtractElement(vecVal, 1);
|
|
|
- Value *m02 = Builder.CreateExtractElement(vecVal, 2);
|
|
|
- Value *m03 = Builder.CreateExtractElement(vecVal, 3);
|
|
|
-
|
|
|
- Value *m10 = Builder.CreateExtractElement(vecVal, 4);
|
|
|
- Value *m11 = Builder.CreateExtractElement(vecVal, 5);
|
|
|
- Value *m12 = Builder.CreateExtractElement(vecVal, 6);
|
|
|
- Value *m13 = Builder.CreateExtractElement(vecVal, 7);
|
|
|
-
|
|
|
- Value *m20 = Builder.CreateExtractElement(vecVal, 8);
|
|
|
- Value *m21 = Builder.CreateExtractElement(vecVal, 9);
|
|
|
- Value *m22 = Builder.CreateExtractElement(vecVal, 10);
|
|
|
- Value *m23 = Builder.CreateExtractElement(vecVal, 11);
|
|
|
-
|
|
|
- Value *m30 = Builder.CreateExtractElement(vecVal, 12);
|
|
|
- Value *m31 = Builder.CreateExtractElement(vecVal, 13);
|
|
|
- Value *m32 = Builder.CreateExtractElement(vecVal, 14);
|
|
|
- Value *m33 = Builder.CreateExtractElement(vecVal, 15);
|
|
|
-
|
|
|
- Result = Determinant4x4(m00, m01, m02, m03,
|
|
|
- m10, m11, m12, m13,
|
|
|
- m20, m21, m22, m23,
|
|
|
- m30, m31, m32, m33,
|
|
|
- Builder);
|
|
|
- } else {
|
|
|
- DXASSERT(row == 1, "invalid matrix type");
|
|
|
- Result = Builder.CreateExtractElement(Result, (uint64_t)0);
|
|
|
+ else if (RhsMatTy) {
|
|
|
+ LhsNumRows = 1;
|
|
|
+ LhsNumCols = LoweredLhs->getType()->getVectorNumElements();
|
|
|
+ RhsNumRows = RhsMatTy.getNumRows();
|
|
|
+ RhsNumCols = RhsMatTy.getNumColumns();
|
|
|
}
|
|
|
- determinantInst->replaceAllUsesWith(Result);
|
|
|
- AddToDeadInsts(determinantInst);
|
|
|
+ else {
|
|
|
+ llvm_unreachable("mul intrinsic was identified as a matrix operation but neither operand is a matrix.");
|
|
|
+ }
|
|
|
+
|
|
|
+ DXASSERT(LhsNumCols == RhsNumRows, "Matrix mul intrinsic operands dimensions mismatch.");
|
|
|
+ HLMatrixType ResultMatTy(ElemTy, LhsNumRows, RhsNumCols);
|
|
|
+ unsigned AccCount = LhsNumCols;
|
|
|
+
|
|
|
+ // Get the multiply-and-add intrinsic function, we'll need it
|
|
|
+ IntrinsicOp MadOpcode = Unsigned ? IntrinsicOp::IOP_umad : IntrinsicOp::IOP_mad;
|
|
|
+ FunctionType *MadFuncTy = FunctionType::get(ElemTy, { Builder.getInt32Ty(), ElemTy, ElemTy, ElemTy }, false);
|
|
|
+ Function *MadFunc = GetOrCreateHLFunction(*m_pModule, MadFuncTy, HLOpcodeGroup::HLIntrinsic, (unsigned)MadOpcode);
|
|
|
+ Constant *MadOpcodeVal = Builder.getInt32((unsigned)MadOpcode);
|
|
|
+
|
|
|
+ // Perform the multiplication!
|
|
|
+ Value *Result = UndefValue::get(VectorType::get(ElemTy, LhsNumRows * RhsNumCols));
|
|
|
+ for (unsigned ResultRowIdx = 0; ResultRowIdx < ResultMatTy.getNumRows(); ++ResultRowIdx) {
|
|
|
+ for (unsigned ResultColIdx = 0; ResultColIdx < ResultMatTy.getNumColumns(); ++ResultColIdx) {
|
|
|
+ unsigned ResultElemIdx = ResultMatTy.getRowMajorIndex(ResultRowIdx, ResultColIdx);
|
|
|
+ Value *ResultElem = nullptr;
|
|
|
+
|
|
|
+ for (unsigned AccIdx = 0; AccIdx < AccCount; ++AccIdx) {
|
|
|
+ unsigned LhsElemIdx = HLMatrixType::getRowMajorIndex(ResultRowIdx, AccIdx, LhsNumRows, LhsNumCols);
|
|
|
+ unsigned RhsElemIdx = HLMatrixType::getRowMajorIndex(AccIdx, ResultColIdx, RhsNumRows, RhsNumCols);
|
|
|
+ Value* LhsElem = Builder.CreateExtractElement(LoweredLhs, static_cast<uint64_t>(LhsElemIdx));
|
|
|
+ Value* RhsElem = Builder.CreateExtractElement(LoweredRhs, static_cast<uint64_t>(RhsElemIdx));
|
|
|
+ if (ResultElem == nullptr) {
|
|
|
+ ResultElem = ElemTy->isFloatingPointTy()
|
|
|
+ ? Builder.CreateFMul(LhsElem, RhsElem)
|
|
|
+ : Builder.CreateMul(LhsElem, RhsElem);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ ResultElem = Builder.CreateCall(MadFunc, { MadOpcodeVal, LhsElem, RhsElem, ResultElem });
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ Result = Builder.CreateInsertElement(Result, ResultElem, static_cast<uint64_t>(ResultElemIdx));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return Result;
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TrivialMatReplace(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *matUseInst) {
|
|
|
- CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
|
|
|
+Value *HLMatrixLowerPass::lowerHLTransposeIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
|
|
|
+ HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
|
|
|
+ Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
|
|
|
+ return MatTy.emitLoweredVectorRowToCol(LoweredVal, Builder);
|
|
|
+}
|
|
|
|
|
|
- for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
|
|
|
- if (matUseInst->getArgOperand(i) == matVal) {
|
|
|
- vecUseInst->setArgOperand(i, vecVal);
|
|
|
- }
|
|
|
+static Value *determinant2x2(Value *M00, Value *M01, Value *M10, Value *M11, IRBuilder<> &Builder) {
|
|
|
+ Value *Mul0 = Builder.CreateFMul(M00, M11);
|
|
|
+ Value *Mul1 = Builder.CreateFMul(M01, M10);
|
|
|
+ return Builder.CreateFSub(Mul0, Mul1);
|
|
|
}
|
|
|
|
|
|
-static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned toRows, unsigned toCols) {
|
|
|
- SmallVector<int, 16> castMask(toCols * toRows);
|
|
|
- unsigned idx = 0;
|
|
|
- for (unsigned r = 0; r < toRows; r++)
|
|
|
- for (unsigned c = 0; c < toCols; c++)
|
|
|
- castMask[idx++] = c * toRows + r;
|
|
|
- return cast<Instruction>(
|
|
|
- Builder.CreateShuffleVector(vecVal, vecVal, castMask));
|
|
|
+static Value *determinant3x3(Value *M00, Value *M01, Value *M02,
|
|
|
+ Value *M10, Value *M11, Value *M12,
|
|
|
+ Value *M20, Value *M21, Value *M22,
|
|
|
+ IRBuilder<> &Builder) {
|
|
|
+ Value *Det00 = determinant2x2(M11, M12, M21, M22, Builder);
|
|
|
+ Value *Det01 = determinant2x2(M10, M12, M20, M22, Builder);
|
|
|
+ Value *Det02 = determinant2x2(M10, M11, M20, M21, Builder);
|
|
|
+ Det00 = Builder.CreateFMul(M00, Det00);
|
|
|
+ Det01 = Builder.CreateFMul(M01, Det01);
|
|
|
+ Det02 = Builder.CreateFMul(M02, Det02);
|
|
|
+ Value *Result = Builder.CreateFSub(Det00, Det01);
|
|
|
+ Result = Builder.CreateFAdd(Result, Det02);
|
|
|
+ return Result;
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *castInst,
|
|
|
- bool bRowToCol,
|
|
|
- bool bTranspose) {
|
|
|
- unsigned col, row;
|
|
|
- if (!bTranspose) {
|
|
|
- GetMatrixInfo(castInst->getType(), col, row);
|
|
|
- DXASSERT(castInst->getType() == matVal->getType(), "type must match");
|
|
|
- } else {
|
|
|
- unsigned castCol, castRow;
|
|
|
- Type *castTy = GetMatrixInfo(castInst->getType(), castCol, castRow);
|
|
|
- unsigned srcCol, srcRow;
|
|
|
- Type *srcTy = GetMatrixInfo(matVal->getType(), srcCol, srcRow);
|
|
|
- DXASSERT_LOCALVAR((castTy == srcTy), srcTy == castTy, "type must match");
|
|
|
- DXASSERT(castCol == srcRow && castRow == srcCol, "col row must match");
|
|
|
- col = srcCol;
|
|
|
- row = srcRow;
|
|
|
- }
|
|
|
+static Value *determinant4x4(Value *M00, Value *M01, Value *M02, Value *M03,
|
|
|
+ Value *M10, Value *M11, Value *M12, Value *M13,
|
|
|
+ Value *M20, Value *M21, Value *M22, Value *M23,
|
|
|
+ Value *M30, Value *M31, Value *M32, Value *M33,
|
|
|
+ IRBuilder<> &Builder) {
|
|
|
+ Value *Det00 = determinant3x3(M11, M12, M13, M21, M22, M23, M31, M32, M33, Builder);
|
|
|
+ Value *Det01 = determinant3x3(M10, M12, M13, M20, M22, M23, M30, M32, M33, Builder);
|
|
|
+ Value *Det02 = determinant3x3(M10, M11, M13, M20, M21, M23, M30, M31, M33, Builder);
|
|
|
+ Value *Det03 = determinant3x3(M10, M11, M12, M20, M21, M22, M30, M31, M32, Builder);
|
|
|
+ Det00 = Builder.CreateFMul(M00, Det00);
|
|
|
+ Det01 = Builder.CreateFMul(M01, Det01);
|
|
|
+ Det02 = Builder.CreateFMul(M02, Det02);
|
|
|
+ Det03 = Builder.CreateFMul(M03, Det03);
|
|
|
+ Value *Result = Builder.CreateFSub(Det00, Det01);
|
|
|
+ Result = Builder.CreateFAdd(Result, Det02);
|
|
|
+ Result = Builder.CreateFSub(Result, Det03);
|
|
|
+ return Result;
|
|
|
+}
|
|
|
|
|
|
- DXASSERT(matToVecMap.count(castInst), "must have vec version");
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
|
|
|
- // Create before vecUseInst to prevent instructions being inserted after uses.
|
|
|
- IRBuilder<> Builder(vecUseInst);
|
|
|
+Value *HLMatrixLowerPass::lowerHLDeterminantIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
|
|
|
+ HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
|
|
|
+ DXASSERT_NOMSG(MatTy.getNumColumns() == MatTy.getNumRows());
|
|
|
|
|
|
- if (bRowToCol)
|
|
|
- std::swap(row, col);
|
|
|
- Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
+ Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
|
|
|
|
|
|
- // Replace vec cast function call with vecCast.
|
|
|
- vecUseInst->replaceAllUsesWith(vecCast);
|
|
|
- AddToDeadInsts(vecUseInst);
|
|
|
- matToVecMap[castInst] = vecCast;
|
|
|
-}
|
|
|
+ // Extract all matrix elements
|
|
|
+ SmallVector<Value*, 16> Elems;
|
|
|
+ for (unsigned ElemIdx = 0; ElemIdx < MatTy.getNumElements(); ++ElemIdx)
|
|
|
+ Elems.emplace_back(Builder.CreateExtractElement(LoweredVal, static_cast<uint64_t>(ElemIdx)));
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatMatCast(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *castInst) {
|
|
|
- unsigned toCol, toRow;
|
|
|
- Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
|
|
|
- unsigned fromCol, fromRow;
|
|
|
- Type *FromEltTy = GetMatrixInfo(matVal->getType(), fromCol, fromRow);
|
|
|
- unsigned fromSize = fromCol * fromRow;
|
|
|
- unsigned toSize = toCol * toRow;
|
|
|
- DXASSERT(fromSize >= toSize, "cannot extend matrix");
|
|
|
- DXASSERT(matToVecMap.count(castInst), "must have vec version");
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
|
|
|
-
|
|
|
- IRBuilder<> Builder(vecUseInst);
|
|
|
- Instruction *vecCast = nullptr;
|
|
|
-
|
|
|
- HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
|
|
|
-
|
|
|
- if (fromSize == toSize) {
|
|
|
- vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecVal,
|
|
|
- Builder);
|
|
|
- } else {
|
|
|
- // shuf first
|
|
|
- std::vector<int> castMask(toCol * toRow);
|
|
|
- unsigned idx = 0;
|
|
|
- for (unsigned r = 0; r < toRow; r++)
|
|
|
- for (unsigned c = 0; c < toCol; c++) {
|
|
|
- unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
|
|
|
- castMask[idx++] = matIdx;
|
|
|
- }
|
|
|
+ // Delegate to appropriate determinant function
|
|
|
+ switch (MatTy.getNumColumns()) {
|
|
|
+ case 1:
|
|
|
+ return Elems[0];
|
|
|
|
|
|
- Instruction *shuf = cast<Instruction>(
|
|
|
- Builder.CreateShuffleVector(vecVal, vecVal, castMask));
|
|
|
+ case 2:
|
|
|
+ return determinant2x2(
|
|
|
+ Elems[0], Elems[1],
|
|
|
+ Elems[2], Elems[3],
|
|
|
+ Builder);
|
|
|
|
|
|
- if (ToEltTy != FromEltTy)
|
|
|
- vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
|
|
|
- Builder);
|
|
|
- else
|
|
|
- vecCast = shuf;
|
|
|
- }
|
|
|
- // Replace vec cast function call with vecCast.
|
|
|
- vecUseInst->replaceAllUsesWith(vecCast);
|
|
|
- AddToDeadInsts(vecUseInst);
|
|
|
- matToVecMap[castInst] = vecCast;
|
|
|
-}
|
|
|
+ case 3:
|
|
|
+ return determinant3x3(
|
|
|
+ Elems[0], Elems[1], Elems[2],
|
|
|
+ Elems[3], Elems[4], Elems[5],
|
|
|
+ Elems[6], Elems[7], Elems[8],
|
|
|
+ Builder);
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatToOtherCast(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *castInst) {
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
|
|
|
- unsigned fromSize = col * row;
|
|
|
-
|
|
|
- IRBuilder<> Builder(castInst);
|
|
|
- Value *sizeCast = nullptr;
|
|
|
-
|
|
|
- HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
|
|
|
-
|
|
|
- Type *ToTy = castInst->getType();
|
|
|
- if (ToTy->isVectorTy()) {
|
|
|
- unsigned toSize = ToTy->getVectorNumElements();
|
|
|
- if (fromSize != toSize) {
|
|
|
- std::vector<int> castMask(fromSize);
|
|
|
- for (unsigned c = 0; c < toSize; c++)
|
|
|
- castMask[c] = c;
|
|
|
-
|
|
|
- sizeCast = Builder.CreateShuffleVector(vecVal, vecVal, castMask);
|
|
|
- } else
|
|
|
- sizeCast = vecVal;
|
|
|
- } else {
|
|
|
- DXASSERT(ToTy->isSingleValueType(), "must scalar here");
|
|
|
- sizeCast = Builder.CreateExtractElement(vecVal, (uint64_t)0);
|
|
|
- }
|
|
|
+ case 4:
|
|
|
+ return determinant4x4(
|
|
|
+ Elems[0], Elems[1], Elems[2], Elems[3],
|
|
|
+ Elems[4], Elems[5], Elems[6], Elems[7],
|
|
|
+ Elems[8], Elems[9], Elems[10], Elems[11],
|
|
|
+ Elems[12], Elems[13], Elems[14], Elems[15],
|
|
|
+ Builder);
|
|
|
|
|
|
- Value *typeCast = sizeCast;
|
|
|
- if (EltTy != ToTy->getScalarType()) {
|
|
|
- typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
|
|
|
+ default:
|
|
|
+ llvm_unreachable("Unexpected matrix dimensions.");
|
|
|
}
|
|
|
- // Replace cast function call with typeCast.
|
|
|
- castInst->replaceAllUsesWith(typeCast);
|
|
|
- AddToDeadInsts(castInst);
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatCast(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *castInst) {
|
|
|
- HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
|
|
|
- if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
|
|
|
- opcode == HLCastOpcode::RowMatrixToColMatrix) {
|
|
|
- TranslateMatMajorCast(matVal, vecVal, castInst,
|
|
|
- opcode == HLCastOpcode::RowMatrixToColMatrix,
|
|
|
- /*bTranspose*/false);
|
|
|
- } else {
|
|
|
- bool ToMat = dxilutil::IsHLSLMatrixType(castInst->getType());
|
|
|
- bool FromMat = dxilutil::IsHLSLMatrixType(matVal->getType());
|
|
|
- if (ToMat && FromMat) {
|
|
|
- TranslateMatMatCast(matVal, vecVal, castInst);
|
|
|
- } else if (FromMat)
|
|
|
- TranslateMatToOtherCast(matVal, vecVal, castInst);
|
|
|
+Value *HLMatrixLowerPass::lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder) {
|
|
|
+ Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
|
|
|
+ VectorType *VecTy = cast<VectorType>(LoweredVal->getType());
|
|
|
+ bool IsFloat = VecTy->getElementType()->isFloatingPointTy();
|
|
|
+
|
|
|
+ switch (Opcode) {
|
|
|
+ case HLUnaryOpcode::Plus: return LoweredVal; // No-op
|
|
|
+
|
|
|
+ case HLUnaryOpcode::Minus:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFSub(Constant::getNullValue(VecTy), LoweredVal)
|
|
|
+ : Builder.CreateSub(Constant::getNullValue(VecTy), LoweredVal);
|
|
|
+
|
|
|
+ case HLUnaryOpcode::LNot:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_UEQ, LoweredVal, Constant::getNullValue(VecTy))
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredVal, Constant::getNullValue(VecTy));
|
|
|
+
|
|
|
+ case HLUnaryOpcode::Not:
|
|
|
+ return Builder.CreateXor(LoweredVal, Constant::getAllOnesValue(VecTy));
|
|
|
+
|
|
|
+ case HLUnaryOpcode::PostInc:
|
|
|
+ case HLUnaryOpcode::PreInc:
|
|
|
+ case HLUnaryOpcode::PostDec:
|
|
|
+ case HLUnaryOpcode::PreDec: {
|
|
|
+ Constant *ScalarOne = IsFloat
|
|
|
+ ? ConstantFP::get(VecTy->getElementType(), 1)
|
|
|
+ : ConstantInt::get(VecTy->getElementType(), 1);
|
|
|
+ Constant *VecOne = ConstantVector::getSplat(VecTy->getNumElements(), ScalarOne);
|
|
|
+ // BUGBUG: This implementation has incorrect semantics (GitHub #1780)
|
|
|
+ if (Opcode == HLUnaryOpcode::PostInc || Opcode == HLUnaryOpcode::PreInc) {
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFAdd(LoweredVal, VecOne)
|
|
|
+ : Builder.CreateAdd(LoweredVal, VecOne);
|
|
|
+ }
|
|
|
else {
|
|
|
- DXASSERT(0, "Not translate as user of matInst");
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFSub(LoweredVal, VecOne)
|
|
|
+ : Builder.CreateSub(LoweredVal, VecOne);
|
|
|
}
|
|
|
}
|
|
|
-}
|
|
|
-
|
|
|
-void HLMatrixLowerPass::MatIntrinsicReplace(Value *matVal,
|
|
|
- Value *vecVal,
|
|
|
- CallInst *matUseInst) {
|
|
|
- IRBuilder<> Builder(matUseInst);
|
|
|
- IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
|
|
|
- switch (opcode) {
|
|
|
- case IntrinsicOp::IOP_umul:
|
|
|
- TranslateMul(matVal, vecVal, matUseInst, /*isSigned*/false);
|
|
|
- break;
|
|
|
- case IntrinsicOp::IOP_mul:
|
|
|
- TranslateMul(matVal, vecVal, matUseInst, /*isSigned*/true);
|
|
|
- break;
|
|
|
- case IntrinsicOp::IOP_transpose:
|
|
|
- TranslateMatTranspose(matVal, vecVal, matUseInst);
|
|
|
- break;
|
|
|
- case IntrinsicOp::IOP_determinant:
|
|
|
- TranslateMatDeterminant(matVal, vecVal, matUseInst);
|
|
|
- break;
|
|
|
default:
|
|
|
- CallInst *useInst = matUseInst;
|
|
|
- if (matToVecMap.count(matUseInst))
|
|
|
- useInst = cast<CallInst>(matToVecMap[matUseInst]);
|
|
|
- for (unsigned i = 0; i < useInst->getNumArgOperands(); i++) {
|
|
|
- if (matUseInst->getArgOperand(i) == matVal)
|
|
|
- useInst->setArgOperand(i, vecVal);
|
|
|
- }
|
|
|
- break;
|
|
|
+ llvm_unreachable("Unsupported unary matrix operator");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatSubscript(Value *matVal, Value *vecVal,
|
|
|
- CallInst *matSubInst) {
|
|
|
- unsigned opcode = GetHLOpcode(matSubInst);
|
|
|
- HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
|
|
|
- assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
|
|
|
- "matrix don't use default subscript");
|
|
|
-
|
|
|
- Type *matType = matVal->getType()->getPointerElementType();
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
|
|
|
-
|
|
|
- bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
|
|
|
- (matOpcode == HLSubscriptOpcode::RowMatElement);
|
|
|
- Value *mask =
|
|
|
- matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
|
|
|
-
|
|
|
- if (isElement) {
|
|
|
- Type *resultType = matSubInst->getType()->getPointerElementType();
|
|
|
- unsigned resultSize = 1;
|
|
|
- if (resultType->isVectorTy())
|
|
|
- resultSize = resultType->getVectorNumElements();
|
|
|
-
|
|
|
- std::vector<int> shufMask(resultSize);
|
|
|
- Constant *EltIdxs = cast<Constant>(mask);
|
|
|
- for (unsigned i = 0; i < resultSize; i++) {
|
|
|
- shufMask[i] =
|
|
|
- cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
|
|
|
- }
|
|
|
+Value *HLMatrixLowerPass::lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder) {
|
|
|
+ Value *LoweredLhs = getLoweredByValOperand(Lhs, Builder);
|
|
|
+ Value *LoweredRhs = getLoweredByValOperand(Rhs, Builder);
|
|
|
|
|
|
- for (Value::use_iterator CallUI = matSubInst->use_begin(),
|
|
|
- CallE = matSubInst->use_end();
|
|
|
- CallUI != CallE;) {
|
|
|
- Use &CallUse = *CallUI++;
|
|
|
- Instruction *CallUser = cast<Instruction>(CallUse.getUser());
|
|
|
- IRBuilder<> Builder(CallUser);
|
|
|
- Value *vecLd = Builder.CreateLoad(vecVal);
|
|
|
- if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
|
|
|
- if (resultSize > 1) {
|
|
|
- Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
|
|
|
- ld->replaceAllUsesWith(shuf);
|
|
|
- } else {
|
|
|
- Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
|
|
|
- ld->replaceAllUsesWith(elt);
|
|
|
- }
|
|
|
- } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
|
|
|
- Value *val = st->getValueOperand();
|
|
|
- if (resultSize > 1) {
|
|
|
- for (unsigned i = 0; i < shufMask.size(); i++) {
|
|
|
- unsigned idx = shufMask[i];
|
|
|
- Value *valElt = Builder.CreateExtractElement(val, i);
|
|
|
- vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
|
|
|
- }
|
|
|
- Builder.CreateStore(vecLd, vecVal);
|
|
|
- } else {
|
|
|
- vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
|
|
|
- Builder.CreateStore(vecLd, vecVal);
|
|
|
- }
|
|
|
- } else {
|
|
|
- DXASSERT(0, "matrix element should only used by load/store.");
|
|
|
- }
|
|
|
- AddToDeadInsts(CallUser);
|
|
|
- }
|
|
|
- } else {
|
|
|
- // Subscript.
|
|
|
- // Return a row.
|
|
|
- // Use insertElement and extractElement.
|
|
|
- ArrayType *AT = ArrayType::get(EltTy, col*row);
|
|
|
-
|
|
|
- IRBuilder<> AllocaBuilder(
|
|
|
- matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
|
|
|
- Value *tempArray = AllocaBuilder.CreateAlloca(AT);
|
|
|
- Value *zero = AllocaBuilder.getInt32(0);
|
|
|
- bool isDynamicIndexing = !isa<ConstantInt>(mask);
|
|
|
- SmallVector<Value *, 4> idxList;
|
|
|
- for (unsigned i = 0; i < col; i++) {
|
|
|
- idxList.emplace_back(
|
|
|
- matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
|
|
|
- }
|
|
|
+ DXASSERT(LoweredLhs->getType()->isVectorTy() && LoweredRhs->getType()->isVectorTy(),
|
|
|
+ "Expected lowered binary operation operands to be vectors");
|
|
|
+ DXASSERT(LoweredLhs->getType() == LoweredRhs->getType(),
|
|
|
+ "Expected lowered binary operation operands to have matching types.");
|
|
|
|
|
|
- for (Value::use_iterator CallUI = matSubInst->use_begin(),
|
|
|
- CallE = matSubInst->use_end();
|
|
|
- CallUI != CallE;) {
|
|
|
- Use &CallUse = *CallUI++;
|
|
|
- Instruction *CallUser = cast<Instruction>(CallUse.getUser());
|
|
|
- IRBuilder<> Builder(CallUser);
|
|
|
- Value *vecLd = Builder.CreateLoad(vecVal);
|
|
|
- if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
|
|
|
- Value *sub = UndefValue::get(ld->getType());
|
|
|
- if (!isDynamicIndexing) {
|
|
|
- for (unsigned i = 0; i < col; i++) {
|
|
|
- Value *matIdx = idxList[i];
|
|
|
- Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
|
|
|
- sub = Builder.CreateInsertElement(sub, valElt, i);
|
|
|
- }
|
|
|
- } else {
|
|
|
- // Copy vec to array.
|
|
|
- for (unsigned int i = 0; i < row*col; i++) {
|
|
|
- Value *Elt =
|
|
|
- Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
|
|
|
- Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
|
|
|
- {zero, Builder.getInt32(i)});
|
|
|
- Builder.CreateStore(Elt, Ptr);
|
|
|
- }
|
|
|
- for (unsigned i = 0; i < col; i++) {
|
|
|
- Value *matIdx = idxList[i];
|
|
|
- Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
|
|
|
- Value *valElt = Builder.CreateLoad(Ptr);
|
|
|
- sub = Builder.CreateInsertElement(sub, valElt, i);
|
|
|
- }
|
|
|
- }
|
|
|
- ld->replaceAllUsesWith(sub);
|
|
|
- } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
|
|
|
- Value *val = st->getValueOperand();
|
|
|
- if (!isDynamicIndexing) {
|
|
|
- for (unsigned i = 0; i < col; i++) {
|
|
|
- Value *matIdx = idxList[i];
|
|
|
- Value *valElt = Builder.CreateExtractElement(val, i);
|
|
|
- vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
|
|
|
- }
|
|
|
- } else {
|
|
|
- // Copy vec to array.
|
|
|
- for (unsigned int i = 0; i < row * col; i++) {
|
|
|
- Value *Elt =
|
|
|
- Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
|
|
|
- Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
|
|
|
- {zero, Builder.getInt32(i)});
|
|
|
- Builder.CreateStore(Elt, Ptr);
|
|
|
- }
|
|
|
- // Update array.
|
|
|
- for (unsigned i = 0; i < col; i++) {
|
|
|
- Value *matIdx = idxList[i];
|
|
|
- Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
|
|
|
- Value *valElt = Builder.CreateExtractElement(val, i);
|
|
|
- Builder.CreateStore(valElt, Ptr);
|
|
|
- }
|
|
|
- // Copy array to vec.
|
|
|
- for (unsigned int i = 0; i < row * col; i++) {
|
|
|
- Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
|
|
|
- {zero, Builder.getInt32(i)});
|
|
|
- Value *Elt = Builder.CreateLoad(Ptr);
|
|
|
- vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
|
|
|
- }
|
|
|
- }
|
|
|
- Builder.CreateStore(vecLd, vecVal);
|
|
|
- } else if (GetElementPtrInst *GEP =
|
|
|
- dyn_cast<GetElementPtrInst>(CallUser)) {
|
|
|
- Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
|
|
|
- Value *NewGEP = Builder.CreateGEP(vecVal, {zero, GEPOffset});
|
|
|
- GEP->replaceAllUsesWith(NewGEP);
|
|
|
- } else {
|
|
|
- DXASSERT(0, "matrix subscript should only used by load/store.");
|
|
|
- }
|
|
|
- AddToDeadInsts(CallUser);
|
|
|
- }
|
|
|
- }
|
|
|
- // Check vec version.
|
|
|
- DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
|
|
|
- // All the user should have been removed.
|
|
|
- matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
|
|
|
- AddToDeadInsts(matSubInst);
|
|
|
-}
|
|
|
+ bool IsFloat = LoweredLhs->getType()->getVectorElementType()->isFloatingPointTy();
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
|
|
|
- Value *matGlobal, ArrayRef<Value *> vecGlobals,
|
|
|
- CallInst *matLdStInst) {
|
|
|
- // No dynamic indexing on matrix, flatten matrix to scalars.
|
|
|
- // vecGlobals already in correct major.
|
|
|
- Type *matType = matGlobal->getType()->getPointerElementType();
|
|
|
- unsigned col, row;
|
|
|
- HLMatrixLower::GetMatrixInfo(matType, col, row);
|
|
|
- Type *vecType = HLMatrixLower::LowerMatrixType(matType);
|
|
|
+ switch (Opcode) {
|
|
|
+ case HLBinaryOpcode::Add:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFAdd(LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateAdd(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- IRBuilder<> Builder(matLdStInst);
|
|
|
+ case HLBinaryOpcode::Sub:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFSub(LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateSub(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
|
|
|
- switch (opcode) {
|
|
|
- case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
- case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
- Value *Result = UndefValue::get(vecType);
|
|
|
- for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
|
|
|
- Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
|
|
|
- Result = Builder.CreateInsertElement(Result, Elt, matIdx);
|
|
|
- }
|
|
|
- matLdStInst->replaceAllUsesWith(Result);
|
|
|
- } break;
|
|
|
- case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
- case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
- Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
- for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
|
|
|
- Value *Elt = Builder.CreateExtractElement(Val, matIdx);
|
|
|
- Builder.CreateStore(Elt, vecGlobals[matIdx]);
|
|
|
- }
|
|
|
- } break;
|
|
|
- }
|
|
|
-}
|
|
|
+ case HLBinaryOpcode::Mul:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFMul(LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateMul(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
|
|
|
- GlobalVariable *scalarArrayGlobal,
|
|
|
- CallInst *matLdStInst) {
|
|
|
- // vecGlobals already in correct major.
|
|
|
- HLMatLoadStoreOpcode opcode =
|
|
|
- static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
|
|
|
- switch (opcode) {
|
|
|
- case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
- case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
- IRBuilder<> Builder(matLdStInst);
|
|
|
- Type *matTy = matGlobal->getType()->getPointerElementType();
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
|
|
|
- Value *zeroIdx = Builder.getInt32(0);
|
|
|
-
|
|
|
- std::vector<Value *> matElts(col * row);
|
|
|
-
|
|
|
- for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
|
|
|
- Value *GEP = Builder.CreateInBoundsGEP(
|
|
|
- scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
|
|
|
- matElts[matIdx] = Builder.CreateLoad(GEP);
|
|
|
- }
|
|
|
+ case HLBinaryOpcode::Div:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFDiv(LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateSDiv(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- Value *newVec =
|
|
|
- HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
|
|
|
- matLdStInst->replaceAllUsesWith(newVec);
|
|
|
- matLdStInst->eraseFromParent();
|
|
|
- } break;
|
|
|
- case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
- case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
- Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
-
|
|
|
- IRBuilder<> Builder(matLdStInst);
|
|
|
- Type *matTy = matGlobal->getType()->getPointerElementType();
|
|
|
- unsigned col, row;
|
|
|
- HLMatrixLower::GetMatrixInfo(matTy, col, row);
|
|
|
- Value *zeroIdx = Builder.getInt32(0);
|
|
|
-
|
|
|
- std::vector<Value *> matElts(col * row);
|
|
|
-
|
|
|
- for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
|
|
|
- Value *GEP = Builder.CreateInBoundsGEP(
|
|
|
- scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
|
|
|
- Value *Elt = Builder.CreateExtractElement(Val, matIdx);
|
|
|
- Builder.CreateStore(Elt, GEP);
|
|
|
- }
|
|
|
+ case HLBinaryOpcode::Rem:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFRem(LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateSRem(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- matLdStInst->eraseFromParent();
|
|
|
- } break;
|
|
|
- }
|
|
|
-}
|
|
|
-void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
|
|
|
- CallInst *matSubInst, Value *vecPtr) {
|
|
|
- Value *basePtr =
|
|
|
- matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
|
|
|
- Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
|
|
|
- IRBuilder<> subBuilder(matSubInst);
|
|
|
- Value *zeroIdx = subBuilder.getInt32(0);
|
|
|
-
|
|
|
- HLSubscriptOpcode opcode =
|
|
|
- static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
|
|
|
-
|
|
|
- Type *matTy = basePtr->getType()->getPointerElementType();
|
|
|
- unsigned col, row;
|
|
|
- HLMatrixLower::GetMatrixInfo(matTy, col, row);
|
|
|
-
|
|
|
- std::vector<Value *> idxList;
|
|
|
- switch (opcode) {
|
|
|
- case HLSubscriptOpcode::ColMatSubscript:
|
|
|
- case HLSubscriptOpcode::RowMatSubscript: {
|
|
|
- // Just use index created in EmitHLSLMatrixSubscript.
|
|
|
- for (unsigned c = 0; c < col; c++) {
|
|
|
- Value *matIdx =
|
|
|
- matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
|
|
|
- idxList.emplace_back(matIdx);
|
|
|
- }
|
|
|
- } break;
|
|
|
- case HLSubscriptOpcode::RowMatElement:
|
|
|
- case HLSubscriptOpcode::ColMatElement: {
|
|
|
- Type *resultType = matSubInst->getType()->getPointerElementType();
|
|
|
- unsigned resultSize = 1;
|
|
|
- if (resultType->isVectorTy())
|
|
|
- resultSize = resultType->getVectorNumElements();
|
|
|
- // Just use index created in EmitHLSLMatrixElement.
|
|
|
- Constant *EltIdxs = cast<Constant>(idx);
|
|
|
- for (unsigned i = 0; i < resultSize; i++) {
|
|
|
- Value *matIdx = EltIdxs->getAggregateElement(i);
|
|
|
- idxList.emplace_back(matIdx);
|
|
|
- }
|
|
|
- } break;
|
|
|
- default:
|
|
|
- DXASSERT(0, "invalid operation");
|
|
|
- break;
|
|
|
- }
|
|
|
+ case HLBinaryOpcode::And:
|
|
|
+ return Builder.CreateAnd(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- // Cannot generate vector pointer
|
|
|
- // Replace all uses with scalar pointers.
|
|
|
- if (!matSubInst->getType()->getPointerElementType()->isVectorTy()) {
|
|
|
- DXASSERT(idxList.size() == 1, "Expected a single matrix element index if the result is not a vector");
|
|
|
- Value *Ptr =
|
|
|
- subBuilder.CreateInBoundsGEP(vecPtr, { zeroIdx, idxList[0] });
|
|
|
- matSubInst->replaceAllUsesWith(Ptr);
|
|
|
- } else {
|
|
|
- // Split the use of CI with Ptrs.
|
|
|
- for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
|
|
|
- Instruction *subsUser = cast<Instruction>(*(U++));
|
|
|
- IRBuilder<> userBuilder(subsUser);
|
|
|
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
|
|
|
- Value *IndexPtr =
|
|
|
- HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
|
|
|
- Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
|
|
|
- {zeroIdx, IndexPtr});
|
|
|
- for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
|
|
|
- Instruction *gepUser = cast<Instruction>(*(gepU++));
|
|
|
- IRBuilder<> gepUserBuilder(gepUser);
|
|
|
- if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
|
|
|
- Value *subData = stUser->getValueOperand();
|
|
|
- gepUserBuilder.CreateStore(subData, Ptr);
|
|
|
- stUser->eraseFromParent();
|
|
|
- } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
|
|
|
- Value *subData = gepUserBuilder.CreateLoad(Ptr);
|
|
|
- ldUser->replaceAllUsesWith(subData);
|
|
|
- ldUser->eraseFromParent();
|
|
|
- } else {
|
|
|
- AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
|
|
|
- Cast->setOperand(0, Ptr);
|
|
|
- }
|
|
|
- }
|
|
|
- GEP->eraseFromParent();
|
|
|
- } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
|
|
|
- Value *val = stUser->getValueOperand();
|
|
|
- for (unsigned i = 0; i < idxList.size(); i++) {
|
|
|
- Value *Elt = userBuilder.CreateExtractElement(val, i);
|
|
|
- Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
|
|
|
- {zeroIdx, idxList[i]});
|
|
|
- userBuilder.CreateStore(Elt, Ptr);
|
|
|
- }
|
|
|
- stUser->eraseFromParent();
|
|
|
- } else {
|
|
|
+ case HLBinaryOpcode::Or:
|
|
|
+ return Builder.CreateOr(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- Value *ldVal =
|
|
|
- UndefValue::get(matSubInst->getType()->getPointerElementType());
|
|
|
- for (unsigned i = 0; i < idxList.size(); i++) {
|
|
|
- Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
|
|
|
- {zeroIdx, idxList[i]});
|
|
|
- Value *Elt = userBuilder.CreateLoad(Ptr);
|
|
|
- ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
|
|
|
- }
|
|
|
- // Must be load here.
|
|
|
- LoadInst *ldUser = cast<LoadInst>(subsUser);
|
|
|
- ldUser->replaceAllUsesWith(ldVal);
|
|
|
- ldUser->eraseFromParent();
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- matSubInst->eraseFromParent();
|
|
|
-}
|
|
|
+ case HLBinaryOpcode::Xor:
|
|
|
+ return Builder.CreateXor(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
|
|
|
- CallInst *matLdStInst, Value *vecPtr) {
|
|
|
- // Just translate into vector here.
|
|
|
- // DynamicIndexingVectorToArray will change it to scalar array.
|
|
|
- IRBuilder<> Builder(matLdStInst);
|
|
|
- unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
|
|
|
- HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
|
|
|
- switch (matLdStOp) {
|
|
|
- case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
- case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
- // Load as vector.
|
|
|
- Value *newLoad = Builder.CreateLoad(vecPtr);
|
|
|
+ case HLBinaryOpcode::Shl:
|
|
|
+ return Builder.CreateShl(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- matLdStInst->replaceAllUsesWith(newLoad);
|
|
|
- matLdStInst->eraseFromParent();
|
|
|
- } break;
|
|
|
- case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
- case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
- // Change value to vector array, then store.
|
|
|
- Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
-
|
|
|
- Value *vecArrayGep = vecPtr;
|
|
|
- Builder.CreateStore(Val, vecArrayGep);
|
|
|
- matLdStInst->eraseFromParent();
|
|
|
- } break;
|
|
|
- default:
|
|
|
- DXASSERT(0, "invalid operation");
|
|
|
- break;
|
|
|
- }
|
|
|
-}
|
|
|
+ case HLBinaryOpcode::Shr:
|
|
|
+ return Builder.CreateAShr(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
-// Flatten values inside init list to scalar elements.
|
|
|
-static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
|
|
|
- Value *val,
|
|
|
- DenseMap<Value *, Value *> &matToVecMap,
|
|
|
- IRBuilder<> &Builder) {
|
|
|
- Type *valTy = val->getType();
|
|
|
-
|
|
|
- if (valTy->isPointerTy()) {
|
|
|
- if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
|
|
|
- if (matToVecMap.count(cast<Instruction>(val))) {
|
|
|
- val = matToVecMap[cast<Instruction>(val)];
|
|
|
- } else {
|
|
|
- // Convert to vec array with bitcast.
|
|
|
- Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
|
|
|
- val = Builder.CreateBitCast(val, vecArrayPtrTy);
|
|
|
- }
|
|
|
- }
|
|
|
- Type *valEltTy = val->getType()->getPointerElementType();
|
|
|
- if (valEltTy->isVectorTy() || dxilutil::IsHLSLMatrixType(valEltTy) ||
|
|
|
- valEltTy->isSingleValueType()) {
|
|
|
- Value *ldVal = Builder.CreateLoad(val);
|
|
|
- IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
|
|
|
- } else {
|
|
|
- Type *i32Ty = Type::getInt32Ty(valTy->getContext());
|
|
|
- Value *zero = ConstantInt::get(i32Ty, 0);
|
|
|
- if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
|
|
|
- for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
|
|
|
- Value *gepIdx = ConstantInt::get(i32Ty, i);
|
|
|
- Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
|
|
|
- IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
|
|
|
- }
|
|
|
- } else {
|
|
|
- // Struct.
|
|
|
- StructType *ST = cast<StructType>(valEltTy);
|
|
|
- for (unsigned i = 0; i < ST->getNumElements(); i++) {
|
|
|
- Value *gepIdx = ConstantInt::get(i32Ty, i);
|
|
|
- Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
|
|
|
- IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- } else if (dxilutil::IsHLSLMatrixType(valTy)) {
|
|
|
- unsigned col, row;
|
|
|
- HLMatrixLower::GetMatrixInfo(valTy, col, row);
|
|
|
- unsigned matSize = col * row;
|
|
|
- val = matToVecMap[cast<Instruction>(val)];
|
|
|
- // temp matrix all row major
|
|
|
- for (unsigned i = 0; i < matSize; i++) {
|
|
|
- Value *Elt = Builder.CreateExtractElement(val, i);
|
|
|
- elts[idx + i] = Elt;
|
|
|
- }
|
|
|
- idx += matSize;
|
|
|
- } else {
|
|
|
- if (valTy->isVectorTy()) {
|
|
|
- unsigned vecSize = valTy->getVectorNumElements();
|
|
|
- for (unsigned i = 0; i < vecSize; i++) {
|
|
|
- Value *Elt = Builder.CreateExtractElement(val, i);
|
|
|
- elts[idx + i] = Elt;
|
|
|
- }
|
|
|
- idx += vecSize;
|
|
|
- } else {
|
|
|
- DXASSERT(valTy->isSingleValueType(), "must be single value type here");
|
|
|
- elts[idx++] = val;
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
+ case HLBinaryOpcode::LT:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_OLT, LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_SLT, LoweredLhs, LoweredRhs);
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
|
|
|
- // Array matrix init will be translated in TranslateMatArrayInitReplace.
|
|
|
- if (matInitInst->getType()->isVoidTy())
|
|
|
- return;
|
|
|
+ case HLBinaryOpcode::GT:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_OGT, LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_SGT, LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- DXASSERT(matToVecMap.count(matInitInst), "must have vec version");
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
|
|
|
- IRBuilder<> Builder(vecUseInst);
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
|
|
|
-
|
|
|
- Type *vecTy = VectorType::get(EltTy, col * row);
|
|
|
- unsigned vecSize = vecTy->getVectorNumElements();
|
|
|
- unsigned idx = 0;
|
|
|
- std::vector<Value *> elts(vecSize);
|
|
|
- // Skip opcode arg.
|
|
|
- for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
|
|
|
- Value *val = matInitInst->getArgOperand(i);
|
|
|
-
|
|
|
- IterateInitList(elts, idx, val, matToVecMap, Builder);
|
|
|
- }
|
|
|
+ case HLBinaryOpcode::LE:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_OLE, LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_SLE, LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- Value *newInit = UndefValue::get(vecTy);
|
|
|
- // InitList is row major, the result is row major too.
|
|
|
- for (unsigned i=0;i< col * row;i++) {
|
|
|
- Constant *vecIdx = Builder.getInt32(i);
|
|
|
- newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
|
|
|
- Builder.Insert(cast<Instruction>(newInit));
|
|
|
- }
|
|
|
+ case HLBinaryOpcode::GE:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_OGE, LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_SGE, LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- // Replace matInit function call with matInitInst.
|
|
|
- vecUseInst->replaceAllUsesWith(newInit);
|
|
|
- AddToDeadInsts(vecUseInst);
|
|
|
- matToVecMap[matInitInst] = newInit;
|
|
|
-}
|
|
|
+ case HLBinaryOpcode::EQ:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_OEQ, LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredLhs, LoweredRhs);
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
|
|
|
- unsigned col, row;
|
|
|
- Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
|
|
|
+ case HLBinaryOpcode::NE:
|
|
|
+ return IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, LoweredRhs)
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- Type *vecTy = VectorType::get(EltTy, col * row);
|
|
|
- unsigned vecSize = vecTy->getVectorNumElements();
|
|
|
+ case HLBinaryOpcode::UDiv:
|
|
|
+ return Builder.CreateUDiv(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
|
|
|
- Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
|
|
|
- Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
|
|
|
+ case HLBinaryOpcode::URem:
|
|
|
+ return Builder.CreateURem(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- IRBuilder<> Builder(vecUseInst);
|
|
|
+ case HLBinaryOpcode::UShr:
|
|
|
+ return Builder.CreateLShr(LoweredLhs, LoweredRhs);
|
|
|
|
|
|
- Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
|
|
|
- bool isVecCond = Cond->getType()->isVectorTy();
|
|
|
- if (isVecCond) {
|
|
|
- Instruction *MatCond = cast<Instruction>(
|
|
|
- matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
|
|
|
- DXASSERT_NOMSG(matToVecMap.count(MatCond));
|
|
|
- Cond = matToVecMap[MatCond];
|
|
|
+ case HLBinaryOpcode::ULT:
|
|
|
+ return Builder.CreateICmp(CmpInst::ICMP_ULT, LoweredLhs, LoweredRhs);
|
|
|
+
|
|
|
+ case HLBinaryOpcode::UGT:
|
|
|
+ return Builder.CreateICmp(CmpInst::ICMP_UGT, LoweredLhs, LoweredRhs);
|
|
|
+
|
|
|
+ case HLBinaryOpcode::ULE:
|
|
|
+ return Builder.CreateICmp(CmpInst::ICMP_ULE, LoweredLhs, LoweredRhs);
|
|
|
+
|
|
|
+ case HLBinaryOpcode::UGE:
|
|
|
+ return Builder.CreateICmp(CmpInst::ICMP_UGE, LoweredLhs, LoweredRhs);
|
|
|
+
|
|
|
+ case HLBinaryOpcode::LAnd:
|
|
|
+ case HLBinaryOpcode::LOr: {
|
|
|
+ Value *Zero = Constant::getNullValue(LoweredLhs->getType());
|
|
|
+ Value *LhsCmp = IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, Zero)
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, Zero);
|
|
|
+ Value *RhsCmp = IsFloat
|
|
|
+ ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredRhs, Zero)
|
|
|
+ : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredRhs, Zero);
|
|
|
+ return Opcode == HLBinaryOpcode::LOr
|
|
|
+ ? Builder.CreateOr(LhsCmp, RhsCmp)
|
|
|
+ : Builder.CreateAnd(LhsCmp, RhsCmp);
|
|
|
}
|
|
|
- DXASSERT_NOMSG(matToVecMap.count(LHS));
|
|
|
- Value *VLHS = matToVecMap[LHS];
|
|
|
- DXASSERT_NOMSG(matToVecMap.count(RHS));
|
|
|
- Value *VRHS = matToVecMap[RHS];
|
|
|
-
|
|
|
- Value *VecSelect = UndefValue::get(vecTy);
|
|
|
- for (unsigned i = 0; i < vecSize; i++) {
|
|
|
- llvm::Value *EltCond = Cond;
|
|
|
- if (isVecCond)
|
|
|
- EltCond = Builder.CreateExtractElement(Cond, i);
|
|
|
- llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
|
|
|
- llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
|
|
|
- llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
|
|
|
- VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
|
|
|
+ default:
|
|
|
+ llvm_unreachable("Unsupported binary matrix operator");
|
|
|
}
|
|
|
- AddToDeadInsts(vecUseInst);
|
|
|
- vecUseInst->replaceAllUsesWith(VecSelect);
|
|
|
- matToVecMap[matSelectInst] = VecSelect;
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
|
|
|
- Value *vecVal,
|
|
|
- GetElementPtrInst *matGEP) {
|
|
|
- SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
|
|
|
-
|
|
|
- IRBuilder<> GEPBuilder(matGEP);
|
|
|
- Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecVal, idxList);
|
|
|
- // Only used by mat subscript and mat ld/st.
|
|
|
- for (Value::user_iterator user = matGEP->user_begin();
|
|
|
- user != matGEP->user_end();) {
|
|
|
- Instruction *useInst = cast<Instruction>(*(user++));
|
|
|
- IRBuilder<> Builder(useInst);
|
|
|
- // Skip return here.
|
|
|
- if (isa<ReturnInst>(useInst))
|
|
|
- continue;
|
|
|
- if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
|
|
|
- // Function call.
|
|
|
- hlsl::HLOpcodeGroup group =
|
|
|
- hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
|
|
|
- switch (group) {
|
|
|
- case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
- unsigned opcode = GetHLOpcode(useCall);
|
|
|
- HLMatLoadStoreOpcode matOpcode =
|
|
|
- static_cast<HLMatLoadStoreOpcode>(opcode);
|
|
|
- switch (matOpcode) {
|
|
|
- case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
- case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
- // Skip the vector version.
|
|
|
- if (useCall->getType()->isVectorTy())
|
|
|
- continue;
|
|
|
- Type *matTy = useCall->getType();
|
|
|
- Value *newLd = CreateVecMatrixLoad(newGEP, matTy, Builder);
|
|
|
- DXASSERT(matToVecMap.count(useCall), "must have vec version");
|
|
|
- Value *oldLd = matToVecMap[useCall];
|
|
|
- // Delete the oldLd.
|
|
|
- AddToDeadInsts(cast<Instruction>(oldLd));
|
|
|
- oldLd->replaceAllUsesWith(newLd);
|
|
|
- matToVecMap[useCall] = newLd;
|
|
|
- } break;
|
|
|
- case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
- case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
- Value *vecPtr = newGEP;
|
|
|
-
|
|
|
- Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
- // Skip the vector version.
|
|
|
- if (matVal->getType()->isVectorTy()) {
|
|
|
- AddToDeadInsts(useCall);
|
|
|
- continue;
|
|
|
- }
|
|
|
+Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode) {
|
|
|
+ IRBuilder<> Builder(Call);
|
|
|
+ switch (Opcode) {
|
|
|
+ case HLMatLoadStoreOpcode::RowMatLoad:
|
|
|
+ case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
+ return lowerHLLoad(Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
|
|
|
+ /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
|
|
|
|
|
|
- Instruction *matInst = cast<Instruction>(matVal);
|
|
|
+ case HLMatLoadStoreOpcode::RowMatStore:
|
|
|
+ case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
+ return lowerHLStore(
|
|
|
+ Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
|
|
|
+ Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
|
|
|
+ /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
|
|
|
+ /* Return */ !Call->getType()->isVoidTy(), Builder);
|
|
|
|
|
|
- DXASSERT(matToVecMap.count(matInst), "must have vec version");
|
|
|
- Value *vecVal = matToVecMap[matInst];
|
|
|
- CreateVecMatrixStore(vecVal, vecPtr, matVal->getType(), Builder);
|
|
|
- } break;
|
|
|
- }
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLSubscript: {
|
|
|
- TranslateMatSubscript(matGEP, newGEP, useCall);
|
|
|
- } break;
|
|
|
- default:
|
|
|
- DXASSERT(0, "invalid operation");
|
|
|
- break;
|
|
|
- }
|
|
|
- } else if (dyn_cast<BitCastInst>(useInst)) {
|
|
|
- // Just replace the src with vec version.
|
|
|
- useInst->setOperand(0, newGEP);
|
|
|
- } else {
|
|
|
- // Must be GEP.
|
|
|
- GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
|
|
|
- TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
|
|
|
- }
|
|
|
+ default:
|
|
|
+ llvm_unreachable("Unsupported matrix load/store operation");
|
|
|
}
|
|
|
- AddToDeadInsts(matGEP);
|
|
|
}
|
|
|
|
|
|
-Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
|
|
|
- Value *newMatVal = nullptr;
|
|
|
- if (vecToMatMap.count(vecVal)) {
|
|
|
- newMatVal = vecToMatMap[vecVal];
|
|
|
- } else {
|
|
|
- // create conversion instructions if necessary, caching result for subsequent replacements.
|
|
|
- // do so right after the vecVal def so it's available to all potential uses.
|
|
|
- newMatVal = BitCastValueOrPtr(vecVal,
|
|
|
- cast<Instruction>(vecVal)->getNextNode(), // vecVal must be instruction
|
|
|
- matTy,
|
|
|
- /*bOrigAllocaTy*/true,
|
|
|
- vecVal->getName());
|
|
|
- vecToMatMap[vecVal] = newMatVal;
|
|
|
- }
|
|
|
- return newMatVal;
|
|
|
-}
|
|
|
+Value *HLMatrixLowerPass::lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
|
|
|
+ HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
|
|
|
|
|
|
-void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
- Value *vecVal) {
|
|
|
- for (Value::user_iterator user = matVal->user_begin();
|
|
|
- user != matVal->user_end();) {
|
|
|
- Instruction *useInst = cast<Instruction>(*(user++));
|
|
|
- // User must be function call.
|
|
|
- if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
|
|
|
- hlsl::HLOpcodeGroup group =
|
|
|
- hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
|
|
|
- switch (group) {
|
|
|
- case HLOpcodeGroup::HLIntrinsic: {
|
|
|
- if (CallInst *matCI = dyn_cast<CallInst>(matVal)) {
|
|
|
- MatIntrinsicReplace(matCI, vecVal, useCall);
|
|
|
- } else {
|
|
|
- IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(useCall));
|
|
|
- if (opcode == IntrinsicOp::MOP_Append) {
|
|
|
- // Replace matrix with vector representation and update intrinsic signature
|
|
|
- // We don't care about matrix orientation here, since that will need to be
|
|
|
- // taken into account anyways when generating the store output calls.
|
|
|
- SmallVector<Value *, 4> flatArgs;
|
|
|
- SmallVector<Type *, 4> flatParamTys;
|
|
|
- for (Value *arg : useCall->arg_operands()) {
|
|
|
- Value *flagArg = arg == matVal ? vecVal : arg;
|
|
|
- flatArgs.emplace_back(arg == matVal ? vecVal : arg);
|
|
|
- flatParamTys.emplace_back(flagArg->getType());
|
|
|
- }
|
|
|
-
|
|
|
- // Don't need flat return type for Append.
|
|
|
- FunctionType *flatFuncTy =
|
|
|
- FunctionType::get(useInst->getType(), flatParamTys, false);
|
|
|
- Function *flatF = GetOrCreateHLFunction(*m_pModule, flatFuncTy, group, static_cast<unsigned int>(opcode));
|
|
|
-
|
|
|
- // Append returns void, so the old call should have no users
|
|
|
- DXASSERT(useInst->getType()->isVoidTy(), "Unexpected MOP_Append intrinsic return type");
|
|
|
- DXASSERT(useInst->use_empty(), "Unexpected users of MOP_Append intrinsic return value");
|
|
|
- IRBuilder<> Builder(useCall);
|
|
|
- Builder.CreateCall(flatF, flatArgs);
|
|
|
- AddToDeadInsts(useCall);
|
|
|
- }
|
|
|
- else if (opcode == IntrinsicOp::IOP_frexp) {
|
|
|
- // NOTE: because out param use copy out semantic, so the operand of
|
|
|
- // out must be temp alloca.
|
|
|
- DXASSERT(isa<AllocaInst>(matVal), "else invalid mat ptr for frexp");
|
|
|
- auto it = matToVecMap.find(useCall);
|
|
|
- DXASSERT(it != matToVecMap.end(),
|
|
|
- "else fail to create vec version of useCall");
|
|
|
- CallInst *vecUseInst = cast<CallInst>(it->second);
|
|
|
-
|
|
|
- for (unsigned i = 0; i < vecUseInst->getNumArgOperands(); i++) {
|
|
|
- if (useCall->getArgOperand(i) == matVal) {
|
|
|
- vecUseInst->setArgOperand(i, vecVal);
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- DXASSERT(false, "Unexpected matrix user intrinsic.");
|
|
|
- }
|
|
|
- }
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLSelect: {
|
|
|
- MatIntrinsicReplace(matVal, vecVal, useCall);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLBinOp: {
|
|
|
- TrivialMatBinOpReplace(matVal, vecVal, useCall);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLUnOp: {
|
|
|
- TrivialMatUnOpReplace(matVal, vecVal, useCall);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLCast: {
|
|
|
- TranslateMatCast(matVal, vecVal, useCall);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
- DXASSERT(matToVecMap.count(useCall), "must have vec version");
|
|
|
- Value *vecUser = matToVecMap[useCall];
|
|
|
- if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal) ||
|
|
|
- GetIfMatrixGEPOfUDTArg(matVal, *m_pHLModule)) {
|
|
|
- // Load Already translated in lowerToVec.
|
|
|
- // Store val operand will be set by the val use.
|
|
|
- // Do nothing here.
|
|
|
- } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser)) {
|
|
|
- DXASSERT(vecVal->getType() == stInst->getValueOperand()->getType(),
|
|
|
- "Mismatched vector matrix store value types.");
|
|
|
- stInst->setOperand(0, vecVal);
|
|
|
- } else if (ZExtInst *zextInst = dyn_cast<ZExtInst>(vecUser)) {
|
|
|
- // This happens when storing bool matrices,
|
|
|
- // which must first undergo conversion from i1's to i32's.
|
|
|
- DXASSERT(vecVal->getType() == zextInst->getOperand(0)->getType(),
|
|
|
- "Mismatched vector matrix store value types.");
|
|
|
- zextInst->setOperand(0, vecVal);
|
|
|
- } else
|
|
|
- TrivialMatReplace(matVal, vecVal, useCall);
|
|
|
-
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLSubscript: {
|
|
|
- if (AllocaInst *AI = dyn_cast<AllocaInst>(vecVal))
|
|
|
- TranslateMatSubscript(matVal, vecVal, useCall);
|
|
|
- else if (BitCastInst *BCI = dyn_cast<BitCastInst>(vecVal))
|
|
|
- TranslateMatSubscript(matVal, vecVal, useCall);
|
|
|
- else
|
|
|
- TrivialMatReplace(matVal, vecVal, useCall);
|
|
|
-
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLInit: {
|
|
|
- DXASSERT(!isa<AllocaInst>(matVal), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
|
|
|
- TranslateMatInit(useCall);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::NotHL: {
|
|
|
- castMatrixArgs(matVal, vecVal, useCall);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLExtIntrinsic:
|
|
|
- case HLOpcodeGroup::HLCreateHandle:
|
|
|
- case HLOpcodeGroup::NumOfHLOps:
|
|
|
- // No vector equivalents for these ops.
|
|
|
- break;
|
|
|
- }
|
|
|
- } else if (dyn_cast<BitCastInst>(useInst)) {
|
|
|
- // Just replace the src with vec version.
|
|
|
- if (useInst != vecVal)
|
|
|
- useInst->setOperand(0, vecVal);
|
|
|
- } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
|
|
|
- Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
|
|
|
- RI->setOperand(0, newMatVal);
|
|
|
- } else if (isa<StoreInst>(useInst)) {
|
|
|
- DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
|
|
|
- } else {
|
|
|
- // Must be GEP on mat array alloca.
|
|
|
- GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
|
|
|
- AllocaInst *AI = cast<AllocaInst>(matVal);
|
|
|
- TranslateMatArrayGEP(AI, vecVal, GEP);
|
|
|
- }
|
|
|
+ Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
|
|
|
+ if (LoweredPtr == nullptr) {
|
|
|
+ // Can't lower this here, defer to HL signature lower
|
|
|
+ HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
|
|
|
+ return callHLFunction(
|
|
|
+ *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
|
|
|
+ MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr }, Builder);
|
|
|
}
|
|
|
-}
|
|
|
|
|
|
-void HLMatrixLowerPass::castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI) {
|
|
|
- // translate user function parameters as necessary
|
|
|
- Type *Ty = matVal->getType();
|
|
|
- if (Ty->isPointerTy()) {
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- Value *newMatVal = Builder.CreateBitCast(vecVal, Ty);
|
|
|
- CI->replaceUsesOfWith(matVal, newMatVal);
|
|
|
- } else {
|
|
|
- Value *newMatVal = GetMatrixForVec(vecVal, Ty);
|
|
|
- CI->replaceUsesOfWith(matVal, newMatVal);
|
|
|
- }
|
|
|
+ return MatTy.emitLoweredLoad(LoweredPtr, Builder);
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::finalMatTranslation(Value *matVal) {
|
|
|
- // Translate matInit.
|
|
|
- if (CallInst *CI = dyn_cast<CallInst>(matVal)) {
|
|
|
- hlsl::HLOpcodeGroup group =
|
|
|
- hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- switch (group) {
|
|
|
- case HLOpcodeGroup::HLInit: {
|
|
|
- TranslateMatInit(CI);
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLSelect: {
|
|
|
- TranslateMatSelect(CI);
|
|
|
- } break;
|
|
|
- default:
|
|
|
- // Skip group already translated.
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
+Value *HLMatrixLowerPass::lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder) {
|
|
|
+ DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
|
|
|
+ "Matrix store value/pointer type mismatch.");
|
|
|
|
|
|
-void HLMatrixLowerPass::DeleteDeadInsts() {
|
|
|
- // Delete the matrix version insts.
|
|
|
- for (Instruction *deadInst : m_deadInsts) {
|
|
|
- // Replace with undef and remove it.
|
|
|
- deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
|
|
|
- deadInst->eraseFromParent();
|
|
|
+ Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
|
|
|
+ Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
|
|
|
+ if (LoweredPtr == nullptr) {
|
|
|
+ // Can't lower the pointer here, defer to HL signature lower
|
|
|
+ HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatStore : HLMatLoadStoreOpcode::ColMatStore;
|
|
|
+ return callHLFunction(
|
|
|
+ *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
|
|
|
+ Return ? LoweredVal->getType() : Builder.getVoidTy(),
|
|
|
+ { Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal }, Builder);
|
|
|
}
|
|
|
- m_deadInsts.clear();
|
|
|
- m_inDeadInstsSet.clear();
|
|
|
+
|
|
|
+ HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
|
|
|
+ StoreInst *LoweredStore = MatTy.emitLoweredStore(LoweredVal, LoweredPtr, Builder);
|
|
|
+
|
|
|
+ // If the intrinsic returned a value, return the stored lowered value
|
|
|
+ return Return ? LoweredVal : LoweredStore;
|
|
|
}
|
|
|
|
|
|
-static bool OnlyUsedByMatrixLdSt(Value *V) {
|
|
|
- bool onlyLdSt = true;
|
|
|
- for (User *user : V->users()) {
|
|
|
- if (isa<Constant>(user) && user->use_empty())
|
|
|
- continue;
|
|
|
+static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> Builder) {
|
|
|
+ DXASSERT(SrcVal->getType()->isVectorTy() == DstTy->isVectorTy(),
|
|
|
+ "Scalar/vector type mismatch in numerical conversion.");
|
|
|
+ Type *SrcTy = SrcVal->getType();
|
|
|
|
|
|
- CallInst *CI = cast<CallInst>(user);
|
|
|
- if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
|
|
|
- HLOpcodeGroup::HLMatLoadStore)
|
|
|
- continue;
|
|
|
+ // Conversions between equivalent types are no-ops,
|
|
|
+ // even between signed/unsigned variants.
|
|
|
+ if (SrcTy == DstTy) return SrcVal;
|
|
|
|
|
|
- onlyLdSt = false;
|
|
|
- break;
|
|
|
+ // Conversions to bools are comparisons
|
|
|
+ if (DstTy->getScalarSizeInBits() == 1) {
|
|
|
+ // fcmp une is what regular clang uses in C++ for (bool)f;
|
|
|
+ return cast<Instruction>(SrcTy->isIntOrIntVectorTy()
|
|
|
+ ? Builder.CreateICmpNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
|
|
|
+ : Builder.CreateFCmpUNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool"));
|
|
|
}
|
|
|
- return onlyLdSt;
|
|
|
-}
|
|
|
|
|
|
-static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
|
|
|
- if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
|
|
|
- std::vector<Constant *> Elts;
|
|
|
- Type *EltResultTy = AT->getElementType();
|
|
|
- for (unsigned i = 0; i < AT->getNumElements(); i++) {
|
|
|
- Constant *Elt =
|
|
|
- LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
|
|
|
- Elts.emplace_back(Elt);
|
|
|
- }
|
|
|
- return ConstantArray::get(AT, Elts);
|
|
|
- } else {
|
|
|
- // Cast float[row][col] -> float< row * col>.
|
|
|
- // Get float[row][col] from the struct.
|
|
|
- Constant *rows = MA->getAggregateElement((unsigned)0);
|
|
|
- ArrayType *RowAT = cast<ArrayType>(rows->getType());
|
|
|
- std::vector<Constant *> Elts;
|
|
|
- for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
|
|
|
- Constant *row = rows->getAggregateElement(r);
|
|
|
- VectorType *VT = cast<VectorType>(row->getType());
|
|
|
- for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
|
|
|
- Elts.emplace_back(row->getAggregateElement(c));
|
|
|
- }
|
|
|
+ // Cast necessary
|
|
|
+ bool SrcIsUnsigned = Opcode == HLCastOpcode::FromUnsignedCast ||
|
|
|
+ Opcode == HLCastOpcode::UnsignedUnsignedCast;
|
|
|
+ bool DstIsUnsigned = Opcode == HLCastOpcode::ToUnsignedCast ||
|
|
|
+ Opcode == HLCastOpcode::UnsignedUnsignedCast;
|
|
|
+ auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
|
|
|
+ SrcTy, SrcIsUnsigned, DstTy, DstIsUnsigned));
|
|
|
+ return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
|
|
|
+}
|
|
|
+
|
|
|
+Value *HLMatrixLowerPass::lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder) {
|
|
|
+ // The opcode really doesn't mean much here, the types involved are what drive most of the casting.
|
|
|
+ DXASSERT(Opcode != HLCastOpcode::HandleToResCast, "Unexpected matrix cast opcode.");
|
|
|
+
|
|
|
+ if (dxilutil::IsIntegerOrFloatingPointType(Src->getType())) {
|
|
|
+ // Scalar to matrix splat
|
|
|
+ HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
|
|
|
+
|
|
|
+ // Apply element conversion
|
|
|
+ Value *Result = convertScalarOrVector(Src,
|
|
|
+ MatDstTy.getElementTypeForReg(), Opcode, Builder);
|
|
|
+
|
|
|
+ // Splat to a vector
|
|
|
+ Result = Builder.CreateInsertElement(
|
|
|
+ UndefValue::get(VectorType::get(Result->getType(), 1)),
|
|
|
+ Result, static_cast<uint64_t>(0));
|
|
|
+ return Builder.CreateShuffleVector(Result, Result,
|
|
|
+ ConstantVector::getSplat(MatDstTy.getNumElements(), Builder.getInt32(0)));
|
|
|
+ }
|
|
|
+ else if (VectorType *SrcVecTy = dyn_cast<VectorType>(Src->getType())) {
|
|
|
+ // Vector to matrix
|
|
|
+ HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
|
|
|
+ Value *Result = Src;
|
|
|
+
|
|
|
+ // We might need to truncate
|
|
|
+ if (MatDstTy.getNumElements() < SrcVecTy->getNumElements()) {
|
|
|
+ SmallVector<int, 4> ShuffleIndices;
|
|
|
+ for (unsigned Idx = 0; Idx < MatDstTy.getNumElements(); ++Idx)
|
|
|
+ ShuffleIndices.emplace_back(static_cast<int>(Idx));
|
|
|
+ Result = Builder.CreateShuffleVector(Src, Src, ShuffleIndices);
|
|
|
}
|
|
|
- return ConstantVector::get(Elts);
|
|
|
- }
|
|
|
-}
|
|
|
|
|
|
-void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
|
|
|
- // Lower to array of vector array like float[row * col].
|
|
|
- // It's follow the major of decl.
|
|
|
- // DynamicIndexingVectorToArray will change it to scalar array.
|
|
|
- Type *Ty = GV->getType()->getPointerElementType();
|
|
|
- std::vector<unsigned> arraySizeList;
|
|
|
- while (Ty->isArrayTy()) {
|
|
|
- arraySizeList.push_back(Ty->getArrayNumElements());
|
|
|
- Ty = Ty->getArrayElementType();
|
|
|
- }
|
|
|
- unsigned row, col;
|
|
|
- Type *EltTy = GetMatrixInfo(Ty, col, row);
|
|
|
- Ty = VectorType::get(EltTy, col * row);
|
|
|
-
|
|
|
- for (auto arraySize = arraySizeList.rbegin();
|
|
|
- arraySize != arraySizeList.rend(); arraySize++)
|
|
|
- Ty = ArrayType::get(Ty, *arraySize);
|
|
|
-
|
|
|
- Type *VecArrayTy = Ty;
|
|
|
- Constant *InitVal = nullptr;
|
|
|
- if (GV->hasInitializer()) {
|
|
|
- Constant *OldInitVal = GV->getInitializer();
|
|
|
- InitVal = isa<UndefValue>(OldInitVal)
|
|
|
- ? UndefValue::get(VecArrayTy)
|
|
|
- : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
|
|
|
+ // Apply element conversion
|
|
|
+ return convertScalarOrVector(Result,
|
|
|
+ MatDstTy.getLoweredVectorTypeForReg(), Opcode, Builder);
|
|
|
}
|
|
|
|
|
|
- bool isConst = GV->isConstant();
|
|
|
- GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
|
|
|
- unsigned AddressSpace = GV->getType()->getAddressSpace();
|
|
|
- GlobalValue::LinkageTypes linkage = GV->getLinkage();
|
|
|
+ // Source must now be a matrix
|
|
|
+ HLMatrixType MatSrcTy = HLMatrixType::cast(Src->getType());
|
|
|
+ VectorType* LoweredSrcTy = MatSrcTy.getLoweredVectorTypeForReg();
|
|
|
|
|
|
- Module *M = GV->getParent();
|
|
|
- GlobalVariable *VecGV =
|
|
|
- new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
|
|
|
- /*InitVal*/ InitVal, GV->getName() + ".v",
|
|
|
- /*InsertBefore*/ nullptr, TLMode, AddressSpace);
|
|
|
- // Add debug info.
|
|
|
- if (m_HasDbgInfo) {
|
|
|
- DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
|
|
|
- HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
|
|
|
+ Value *LoweredSrc;
|
|
|
+ if (isa<Argument>(Src)) {
|
|
|
+ // Function arguments are lowered in HLSignatureLower.
|
|
|
+ // Initial codegen first generates those cast intrinsics to tell us how to lower them into vectors.
|
|
|
+ // Preserve them, but change the return type to vector.
|
|
|
+ DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast || Opcode == HLCastOpcode::RowMatrixToVecCast,
|
|
|
+ "Unexpected cast of matrix argument.");
|
|
|
+ LoweredSrc = callHLFunction(*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
|
|
|
+ LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src }, Builder);
|
|
|
}
|
|
|
-
|
|
|
- for (User *U : GV->users()) {
|
|
|
- Value *VecGEP = nullptr;
|
|
|
- // Must be GEP or GEPOperator.
|
|
|
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
|
|
|
- IRBuilder<> Builder(GEP);
|
|
|
- SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
|
|
|
- VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
|
|
|
- AddToDeadInsts(GEP);
|
|
|
- } else {
|
|
|
- GEPOperator *GEPOP = cast<GEPOperator>(U);
|
|
|
- IRBuilder<> Builder(GV->getContext());
|
|
|
- SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
|
|
|
- VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
|
|
|
+ else {
|
|
|
+ LoweredSrc = getLoweredByValOperand(Src, Builder);
|
|
|
+ }
|
|
|
+ DXASSERT_NOMSG(LoweredSrc->getType() == LoweredSrcTy);
|
|
|
+
|
|
|
+ Value* Result = LoweredSrc;
|
|
|
+ Type* LoweredDstTy = DstTy;
|
|
|
+ if (dxilutil::IsIntegerOrFloatingPointType(DstTy)) {
|
|
|
+ // Matrix to scalar
|
|
|
+ Result = Builder.CreateExtractElement(LoweredSrc, static_cast<uint64_t>(0));
|
|
|
+ }
|
|
|
+ else if (DstTy->isVectorTy()) {
|
|
|
+ // Matrix to vector
|
|
|
+ VectorType *DstVecTy = cast<VectorType>(DstTy);
|
|
|
+ DXASSERT(DstVecTy->getNumElements() <= LoweredSrcTy->getNumElements(),
|
|
|
+ "Cannot cast matrix to a larger vector.");
|
|
|
+
|
|
|
+ // We might have to truncate
|
|
|
+ if (DstTy->getVectorNumElements() < LoweredSrcTy->getNumElements()) {
|
|
|
+ SmallVector<int, 3> ShuffleIndices;
|
|
|
+ for (unsigned Idx = 0; Idx < DstVecTy->getNumElements(); ++Idx)
|
|
|
+ ShuffleIndices.emplace_back(static_cast<int>(Idx));
|
|
|
+ Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
|
|
|
}
|
|
|
-
|
|
|
- for (auto user = U->user_begin(); user != U->user_end();) {
|
|
|
- CallInst *CI = cast<CallInst>(*(user++));
|
|
|
- HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- if (group == HLOpcodeGroup::HLMatLoadStore) {
|
|
|
- TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
|
|
|
- } else if (group == HLOpcodeGroup::HLSubscript) {
|
|
|
- TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid operation");
|
|
|
- }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ // Destination must now be a matrix too
|
|
|
+ HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
|
|
|
+
|
|
|
+ // Apply any changes at the matrix level: orientation changes and truncation
|
|
|
+ if (Opcode == HLCastOpcode::ColMatrixToRowMatrix)
|
|
|
+ Result = MatSrcTy.emitLoweredVectorColToRow(Result, Builder);
|
|
|
+ else if (Opcode == HLCastOpcode::RowMatrixToColMatrix)
|
|
|
+ Result = MatSrcTy.emitLoweredVectorRowToCol(Result, Builder);
|
|
|
+ else if (MatDstTy.getNumRows() != MatSrcTy.getNumRows()
|
|
|
+ || MatDstTy.getNumColumns() != MatSrcTy.getNumColumns()) {
|
|
|
+ // Apply truncation
|
|
|
+ DXASSERT(MatDstTy.getNumRows() <= MatSrcTy.getNumRows()
|
|
|
+ && MatDstTy.getNumColumns() <= MatSrcTy.getNumColumns(),
|
|
|
+ "Unexpected matrix cast between incompatible dimensions.");
|
|
|
+ SmallVector<int, 16> ShuffleIndices;
|
|
|
+ for (unsigned RowIdx = 0; RowIdx < MatDstTy.getNumRows(); ++RowIdx)
|
|
|
+ for (unsigned ColIdx = 0; ColIdx < MatDstTy.getNumColumns(); ++ColIdx)
|
|
|
+ ShuffleIndices.emplace_back(static_cast<int>(MatSrcTy.getRowMajorIndex(RowIdx, ColIdx)));
|
|
|
+ Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
|
|
|
}
|
|
|
+
|
|
|
+ LoweredDstTy = MatDstTy.getLoweredVectorTypeForReg();
|
|
|
+ DXASSERT(Result->getType()->getVectorNumElements() == LoweredDstTy->getVectorNumElements(),
|
|
|
+ "Unexpected matrix src/dst lowered element count mismatch after truncation.");
|
|
|
}
|
|
|
|
|
|
- DeleteDeadInsts();
|
|
|
- GV->removeDeadConstantUsers();
|
|
|
- GV->eraseFromParent();
|
|
|
+ // Apply element conversion
|
|
|
+ return convertScalarOrVector(Result, LoweredDstTy, Opcode, Builder);
|
|
|
}
|
|
|
|
|
|
-static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
|
|
|
- unsigned row, col;
|
|
|
- Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
|
|
|
- if (isa<UndefValue>(M)) {
|
|
|
- Constant *Elt = UndefValue::get(EltTy);
|
|
|
- for (unsigned i=0;i<col*row;i++)
|
|
|
- Elts.emplace_back(Elt);
|
|
|
- } else {
|
|
|
- M = M->getAggregateElement((unsigned)0);
|
|
|
- // Initializer is already in correct major.
|
|
|
- // Just read it here.
|
|
|
- // The type is vector<element, col>[row].
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
- Constant *C = M->getAggregateElement(r);
|
|
|
- for (unsigned c = 0; c < col; c++) {
|
|
|
- Elts.emplace_back(C->getAggregateElement(c));
|
|
|
- }
|
|
|
- }
|
|
|
+Value *HLMatrixLowerPass::lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
|
|
|
+ switch (Opcode) {
|
|
|
+ case HLSubscriptOpcode::RowMatElement:
|
|
|
+ case HLSubscriptOpcode::ColMatElement:
|
|
|
+ return lowerHLMatElementSubscript(Call,
|
|
|
+ /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatElement);
|
|
|
+
|
|
|
+ case HLSubscriptOpcode::RowMatSubscript:
|
|
|
+ case HLSubscriptOpcode::ColMatSubscript:
|
|
|
+ return lowerHLMatSubscript(Call,
|
|
|
+ /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatSubscript);
|
|
|
+
|
|
|
+ case HLSubscriptOpcode::DefaultSubscript:
|
|
|
+ case HLSubscriptOpcode::CBufferSubscript:
|
|
|
+ // Those get lowered during HLOperationLower,
|
|
|
+ // and the return type must stay unchanged (as a matrix)
|
|
|
+ // to provide the metadata to properly emit the loads.
|
|
|
+ return nullptr;
|
|
|
+
|
|
|
+ default:
|
|
|
+ llvm_unreachable("Unexpected matrix subscript opcode.");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
|
|
|
- if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
|
|
|
- runOnGlobalMatrixArray(GV);
|
|
|
- return;
|
|
|
- }
|
|
|
+Value *HLMatrixLowerPass::lowerHLMatElementSubscript(CallInst *Call, bool RowMajor) {
|
|
|
+ (void)RowMajor; // It doesn't look like we actually need this?
|
|
|
|
|
|
- Type *Ty = GV->getType()->getPointerElementType();
|
|
|
- if (!dxilutil::IsHLSLMatrixType(Ty))
|
|
|
- return;
|
|
|
+ Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
|
|
|
+ Constant *IdxVec = cast<Constant>(Call->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
|
|
|
+ VectorType *IdxVecTy = cast<VectorType>(IdxVec->getType());
|
|
|
|
|
|
- bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
|
|
|
-
|
|
|
- bool isConst = GV->isConstant();
|
|
|
-
|
|
|
- Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
|
|
|
- Module *M = GV->getParent();
|
|
|
- const DataLayout &DL = M->getDataLayout();
|
|
|
-
|
|
|
- std::vector<Constant *> Elts;
|
|
|
- // Lower to vector or array for scalar matrix.
|
|
|
- // Make it col major so don't need shuffle when load/store.
|
|
|
- FlattenMatConst(GV->getInitializer(), Elts);
|
|
|
-
|
|
|
- if (onlyLdSt) {
|
|
|
- Type *EltTy = vecTy->getVectorElementType();
|
|
|
- unsigned vecSize = vecTy->getVectorNumElements();
|
|
|
- std::vector<Value *> vecGlobals(vecSize);
|
|
|
-
|
|
|
- GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
|
|
|
- unsigned AddressSpace = GV->getType()->getAddressSpace();
|
|
|
- GlobalValue::LinkageTypes linkage = GV->getLinkage();
|
|
|
- unsigned debugOffset = 0;
|
|
|
- unsigned size = DL.getTypeAllocSizeInBits(EltTy);
|
|
|
- unsigned align = DL.getPrefTypeAlignment(EltTy);
|
|
|
- for (int i = 0, e = vecSize; i != e; ++i) {
|
|
|
- Constant *InitVal = Elts[i];
|
|
|
- GlobalVariable *EltGV = new llvm::GlobalVariable(
|
|
|
- *M, EltTy, /*IsConstant*/ isConst, linkage,
|
|
|
- /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
|
|
|
- /*InsertBefore*/nullptr,
|
|
|
- TLMode, AddressSpace);
|
|
|
- // Add debug info.
|
|
|
- if (m_HasDbgInfo) {
|
|
|
- DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
|
|
|
- HLModule::CreateElementGlobalVariableDebugInfo(
|
|
|
- GV, Finder, EltGV, size, align, debugOffset,
|
|
|
- EltGV->getName().ltrim(GV->getName()));
|
|
|
- debugOffset += size;
|
|
|
- }
|
|
|
- vecGlobals[i] = EltGV;
|
|
|
- }
|
|
|
- for (User *user : GV->users()) {
|
|
|
- if (isa<Constant>(user) && user->use_empty())
|
|
|
- continue;
|
|
|
- CallInst *CI = cast<CallInst>(user);
|
|
|
- TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
|
|
|
- AddToDeadInsts(CI);
|
|
|
- }
|
|
|
- DeleteDeadInsts();
|
|
|
- GV->eraseFromParent();
|
|
|
+ // Get the loaded lowered vector element indices
|
|
|
+ SmallVector<Value*, 4> ElemIndices;
|
|
|
+ ElemIndices.reserve(IdxVecTy->getNumElements());
|
|
|
+ for (unsigned VecIdx = 0; VecIdx < IdxVecTy->getNumElements(); ++VecIdx) {
|
|
|
+ ElemIndices.emplace_back(IdxVec->getAggregateElement(VecIdx));
|
|
|
}
|
|
|
- else {
|
|
|
- // lower to array of scalar here.
|
|
|
- ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
|
|
|
- Constant *InitVal = ConstantArray::get(AT, Elts);
|
|
|
- GlobalVariable *arrayMat = new llvm::GlobalVariable(
|
|
|
- *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
|
|
|
- /*InitVal*/ InitVal, GV->getName());
|
|
|
- // Add debug info.
|
|
|
- if (m_HasDbgInfo) {
|
|
|
- DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
|
|
|
- HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
|
|
|
- arrayMat);
|
|
|
- }
|
|
|
|
|
|
- for (auto U = GV->user_begin(); U != GV->user_end();) {
|
|
|
- Value *user = *(U++);
|
|
|
- CallInst *CI = cast<CallInst>(user);
|
|
|
- HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- if (group == HLOpcodeGroup::HLMatLoadStore) {
|
|
|
- TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
|
|
|
- }
|
|
|
- else {
|
|
|
- DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
|
|
|
- TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
|
|
|
- }
|
|
|
- }
|
|
|
- GV->removeDeadConstantUsers();
|
|
|
- GV->eraseFromParent();
|
|
|
- }
|
|
|
+ lowerHLMatSubscript(Call, MatPtr, ElemIndices);
|
|
|
+
|
|
|
+ // We did our own replacement of uses, opt-out of having the caller does it for us.
|
|
|
+ return nullptr;
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
- // Skip hl function definition (like createhandle)
|
|
|
- if (hlsl::GetHLOpcodeGroupByName(&F) != HLOpcodeGroup::NotHL)
|
|
|
- return;
|
|
|
+Value *HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, bool RowMajor) {
|
|
|
+ (void)RowMajor; // It doesn't look like we actually need this?
|
|
|
|
|
|
- // Create vector version of matrix instructions first.
|
|
|
- // The matrix operands will be undefval for these instructions.
|
|
|
- for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
|
|
|
- BasicBlock *BB = BBI;
|
|
|
- for (auto II = BB->begin(); II != BB->end(); ) {
|
|
|
- Instruction &I = *(II++);
|
|
|
- if (dxilutil::IsHLSLMatrixType(I.getType())) {
|
|
|
- lowerToVec(&I);
|
|
|
- } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
|
|
|
- Type *Ty = AI->getAllocatedType();
|
|
|
- if (dxilutil::IsHLSLMatrixType(Ty)) {
|
|
|
- lowerToVec(&I);
|
|
|
- } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
|
|
|
- lowerToVec(&I);
|
|
|
- }
|
|
|
- } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
|
|
|
- HLOpcodeGroup group =
|
|
|
- hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- if (group == HLOpcodeGroup::HLMatLoadStore) {
|
|
|
- HLMatLoadStoreOpcode opcode =
|
|
|
- static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
- DXASSERT_LOCALVAR(opcode,
|
|
|
- opcode == HLMatLoadStoreOpcode::ColMatStore ||
|
|
|
- opcode == HLMatLoadStoreOpcode::RowMatStore,
|
|
|
- "Must MatStore here, load will go IsMatrixType path");
|
|
|
- // Lower it here to make sure it is ready before replace.
|
|
|
- lowerToVec(&I);
|
|
|
- }
|
|
|
- } else if (GetIfMatrixGEPOfUDTAlloca(&I) ||
|
|
|
- GetIfMatrixGEPOfUDTArg(&I, *m_pHLModule)) {
|
|
|
- lowerToVec(&I);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+ Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
|
|
|
|
|
|
- // Update the use of matrix inst with the vector version.
|
|
|
- for (auto matToVecIter = matToVecMap.begin();
|
|
|
- matToVecIter != matToVecMap.end();) {
|
|
|
- auto matToVec = matToVecIter++;
|
|
|
- replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
|
|
|
+ // Gather the indices, checking if they are all constant
|
|
|
+ SmallVector<Value*, 4> ElemIndices;
|
|
|
+ for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < Call->getNumArgOperands(); ++Idx) {
|
|
|
+ ElemIndices.emplace_back(Call->getArgOperand(Idx));
|
|
|
}
|
|
|
|
|
|
- // Translate mat inst which require all operands ready.
|
|
|
- for (auto matToVecIter = matToVecMap.begin();
|
|
|
- matToVecIter != matToVecMap.end();) {
|
|
|
- auto matToVec = matToVecIter++;
|
|
|
- if (isa<Instruction>(matToVec->first))
|
|
|
- finalMatTranslation(matToVec->first);
|
|
|
- }
|
|
|
+ lowerHLMatSubscript(Call, MatPtr, ElemIndices);
|
|
|
|
|
|
- // Remove matrix targets of vecToMatMap from matToVecMap before adding the rest to dead insts.
|
|
|
- for (auto &it : vecToMatMap) {
|
|
|
- matToVecMap.erase(it.second);
|
|
|
- }
|
|
|
+ // We did our own replacement of uses, opt-out of having the caller does it for us.
|
|
|
+ return nullptr;
|
|
|
+}
|
|
|
|
|
|
- // Delete the matrix version insts.
|
|
|
- for (auto matToVecIter = matToVecMap.begin();
|
|
|
- matToVecIter != matToVecMap.end();) {
|
|
|
- auto matToVec = matToVecIter++;
|
|
|
- // Add to m_deadInsts.
|
|
|
- if (Instruction *matInst = dyn_cast<Instruction>(matToVec->first))
|
|
|
- AddToDeadInsts(matInst);
|
|
|
- }
|
|
|
+void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices) {
|
|
|
+ DXASSERT_NOMSG(HLMatrixType::isMatrixPtr(MatPtr->getType()));
|
|
|
|
|
|
- DeleteDeadInsts();
|
|
|
+ IRBuilder<> CallBuilder(Call);
|
|
|
+ Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
|
|
|
+ if (LoweredPtr == nullptr) return;
|
|
|
+
|
|
|
+ // For global variables, we can GEP directly into the lowered vector pointer.
|
|
|
+ // This is necessary to support group shared memory atomics and the likes.
|
|
|
+ Value *RootPtr = LoweredPtr;
|
|
|
+ while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
|
|
|
+ RootPtr = GEP->getPointerOperand();
|
|
|
+ bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
|
|
|
|
|
|
- matToVecMap.clear();
|
|
|
- vecToMatMap.clear();
|
|
|
+ // Just constructing this does all the work
|
|
|
+ HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
|
|
|
|
|
|
- return;
|
|
|
+ DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
|
|
|
+ addToDeadInsts(Call);
|
|
|
}
|
|
|
|
|
|
-// Matrix Bitcast lower.
|
|
|
-// After linking Lower matrix bitcast patterns like:
|
|
|
-// %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
|
|
|
-// %conv.i = fptoui float %164 to i32
|
|
|
-// %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
|
|
|
-// %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
|
|
|
+// Lowers StructuredBuffer<matrix>[index] or similar with constant buffers
|
|
|
+Value *HLMatrixLowerPass::lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
|
|
|
+ // Just replace the intrinsic by its equivalent with a lowered return type
|
|
|
+ IRBuilder<> Builder(Call);
|
|
|
|
|
|
-namespace {
|
|
|
+ SmallVector<Value*, 4> Args;
|
|
|
+ Args.reserve(Call->getNumArgOperands());
|
|
|
+ for (Value *Arg : Call->arg_operands())
|
|
|
+ Args.emplace_back(Arg);
|
|
|
|
|
|
-Type *TryLowerMatTy(Type *Ty) {
|
|
|
- Type *VecTy = nullptr;
|
|
|
- if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
|
|
|
- VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
|
|
|
- } else if (isa<PointerType>(Ty) &&
|
|
|
- dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
|
|
|
- VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
|
|
|
- Ty->getPointerElementType());
|
|
|
- VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
|
|
|
- }
|
|
|
- return VecTy;
|
|
|
+ Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
|
|
|
+ return callHLFunction(*m_pModule, HLOpcodeGroup::HLSubscript, static_cast<unsigned>(Opcode),
|
|
|
+ LoweredRetTy, Args, Builder);
|
|
|
}
|
|
|
|
|
|
-class MatrixBitcastLowerPass : public FunctionPass {
|
|
|
+Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
|
|
|
+ DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
|
|
|
|
|
|
-public:
|
|
|
- static char ID; // Pass identification, replacement for typeid
|
|
|
- explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
|
|
|
-
|
|
|
- const char *getPassName() const override { return "Matrix Bitcast lower"; }
|
|
|
- bool runOnFunction(Function &F) override {
|
|
|
- bool bUpdated = false;
|
|
|
- std::unordered_set<BitCastInst*> matCastSet;
|
|
|
- for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
|
|
|
- BasicBlock *BB = blkIt;
|
|
|
- for (auto iIt = BB->begin(); iIt != BB->end(); ) {
|
|
|
- Instruction *I = (iIt++);
|
|
|
- if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
|
|
|
- // Mutate mat to vec.
|
|
|
- Type *ToTy = BCI->getType();
|
|
|
- if (Type *ToVecTy = TryLowerMatTy(ToTy)) {
|
|
|
- matCastSet.insert(BCI);
|
|
|
- bUpdated = true;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+ // Figure out the result type
|
|
|
+ HLMatrixType MatTy = HLMatrixType::cast(Call->getType());
|
|
|
+ VectorType *LoweredTy = MatTy.getLoweredVectorTypeForReg();
|
|
|
|
|
|
- DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
|
|
|
- // Remove bitcast which has CallInst user.
|
|
|
- if (DM.GetShaderModel()->IsLib()) {
|
|
|
- for (auto it = matCastSet.begin(); it != matCastSet.end();) {
|
|
|
- BitCastInst *BCI = *(it++);
|
|
|
- if (hasCallUser(BCI)) {
|
|
|
- matCastSet.erase(BCI);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Lower matrix first.
|
|
|
- for (BitCastInst *BCI : matCastSet) {
|
|
|
- lowerMatrix(BCI, BCI->getOperand(0));
|
|
|
- }
|
|
|
- return bUpdated;
|
|
|
+ // Handle case where produced by EmitHLSLFlatConversion where there's one
|
|
|
+ // vector argument, instead of scalar arguments.
|
|
|
+ if (1 == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx &&
|
|
|
+ Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx)->
|
|
|
+ getType()->isVectorTy()) {
|
|
|
+ Value *LoweredVec = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
|
|
|
+ DXASSERT(LoweredTy->getNumElements() ==
|
|
|
+ LoweredVec->getType()->getVectorNumElements(),
|
|
|
+ "Invalid matrix init argument vector element count.");
|
|
|
+ return LoweredVec;
|
|
|
}
|
|
|
-private:
|
|
|
- void lowerMatrix(Instruction *M, Value *A);
|
|
|
- bool hasCallUser(Instruction *M);
|
|
|
-};
|
|
|
|
|
|
-}
|
|
|
+ DXASSERT(LoweredTy->getNumElements() == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx,
|
|
|
+ "Invalid matrix init argument count.");
|
|
|
|
|
|
-bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
|
|
|
- for (auto it = M->user_begin(); it != M->user_end();) {
|
|
|
- User *U = *(it++);
|
|
|
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
|
|
|
- Type *EltTy = GEP->getType()->getPointerElementType();
|
|
|
- if (dxilutil::IsHLSLMatrixType(EltTy)) {
|
|
|
- if (hasCallUser(GEP))
|
|
|
- return true;
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid GEP for matrix");
|
|
|
- }
|
|
|
- } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
|
|
|
- if (hasCallUser(BCI))
|
|
|
- return true;
|
|
|
- } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
|
|
|
- if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid load for matrix");
|
|
|
- }
|
|
|
- } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
|
|
|
- Value *V = ST->getValueOperand();
|
|
|
- if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid load for matrix");
|
|
|
- }
|
|
|
- } else if (isa<CallInst>(U)) {
|
|
|
- return true;
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid use of matrix");
|
|
|
- }
|
|
|
+ // Build the result vector from the init args.
|
|
|
+ // Both the args and the result vector are in row-major order, so no shuffling is necessary.
|
|
|
+ IRBuilder<> Builder(Call);
|
|
|
+ Value *LoweredVec = UndefValue::get(LoweredTy);
|
|
|
+ for (unsigned VecElemIdx = 0; VecElemIdx < LoweredTy->getNumElements(); ++VecElemIdx) {
|
|
|
+ Value *ArgVal = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx + VecElemIdx);
|
|
|
+ DXASSERT(dxilutil::IsIntegerOrFloatingPointType(ArgVal->getType()),
|
|
|
+ "Expected only scalars in matrix initialization.");
|
|
|
+ LoweredVec = Builder.CreateInsertElement(LoweredVec, ArgVal, static_cast<uint64_t>(VecElemIdx));
|
|
|
}
|
|
|
- return false;
|
|
|
-}
|
|
|
|
|
|
-namespace {
|
|
|
-Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
|
|
|
- IRBuilder<> &Builder) {
|
|
|
- Value *GEP = nullptr;
|
|
|
- if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
|
|
|
- // A should be gep oneDimArray, 0, index * matSize
|
|
|
- // Here add eltIdx to index * matSize foreach elt.
|
|
|
- Instruction *EltGEP = GEPA->clone();
|
|
|
- unsigned eltIdx = EltGEP->getNumOperands() - 1;
|
|
|
- Value *NewIdx =
|
|
|
- Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
|
|
|
- EltGEP->setOperand(eltIdx, NewIdx);
|
|
|
- Builder.Insert(EltGEP);
|
|
|
- GEP = EltGEP;
|
|
|
- } else {
|
|
|
- GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
|
|
|
- }
|
|
|
- return GEP;
|
|
|
+ return LoweredVec;
|
|
|
}
|
|
|
-} // namespace
|
|
|
-
|
|
|
-void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
|
|
|
- for (auto it = M->user_begin(); it != M->user_end();) {
|
|
|
- User *U = *(it++);
|
|
|
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
|
|
|
- Type *EltTy = GEP->getType()->getPointerElementType();
|
|
|
- if (dxilutil::IsHLSLMatrixType(EltTy)) {
|
|
|
- // Change gep matrixArray, 0, index
|
|
|
- // into
|
|
|
- // gep oneDimArray, 0, index * matSize
|
|
|
- IRBuilder<> Builder(GEP);
|
|
|
- SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
|
|
|
- DXASSERT(idxList.size() == 2,
|
|
|
- "else not one dim matrix array index to matrix");
|
|
|
- unsigned col = 0;
|
|
|
- unsigned row = 0;
|
|
|
- HLMatrixLower::GetMatrixInfo(EltTy, col, row);
|
|
|
- Value *matSize = Builder.getInt32(col * row);
|
|
|
- idxList.back() = Builder.CreateMul(idxList.back(), matSize);
|
|
|
- Value *NewGEP = Builder.CreateGEP(A, idxList);
|
|
|
- lowerMatrix(GEP, NewGEP);
|
|
|
- DXASSERT(GEP->user_empty(), "else lower matrix fail");
|
|
|
- GEP->eraseFromParent();
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid GEP for matrix");
|
|
|
- }
|
|
|
- } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
|
|
|
- lowerMatrix(BCI, A);
|
|
|
- DXASSERT(BCI->user_empty(), "else lower matrix fail");
|
|
|
- BCI->eraseFromParent();
|
|
|
- } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
|
|
|
- if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
|
|
|
- IRBuilder<> Builder(LI);
|
|
|
- Value *zeroIdx = Builder.getInt32(0);
|
|
|
- unsigned vecSize = Ty->getNumElements();
|
|
|
- Value *NewVec = UndefValue::get(LI->getType());
|
|
|
- for (unsigned i = 0; i < vecSize; i++) {
|
|
|
- Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
|
|
|
- Value *Elt = Builder.CreateLoad(GEP);
|
|
|
- NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
|
|
|
- }
|
|
|
- LI->replaceAllUsesWith(NewVec);
|
|
|
- LI->eraseFromParent();
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid load for matrix");
|
|
|
- }
|
|
|
- } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
|
|
|
- Value *V = ST->getValueOperand();
|
|
|
- if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
|
|
|
- IRBuilder<> Builder(LI);
|
|
|
- Value *zeroIdx = Builder.getInt32(0);
|
|
|
- unsigned vecSize = Ty->getNumElements();
|
|
|
- for (unsigned i = 0; i < vecSize; i++) {
|
|
|
- Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
|
|
|
- Value *Elt = Builder.CreateExtractElement(V, i);
|
|
|
- Builder.CreateStore(Elt, GEP);
|
|
|
- }
|
|
|
- ST->eraseFromParent();
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid load for matrix");
|
|
|
- }
|
|
|
- } else {
|
|
|
- DXASSERT(0, "invalid use of matrix");
|
|
|
- }
|
|
|
+
|
|
|
+Value *HLMatrixLowerPass::lowerHLSelect(CallInst *Call) {
|
|
|
+ DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
|
|
|
+
|
|
|
+ Value *Cond = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
|
|
|
+ Value *TrueMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx);
|
|
|
+ Value *FalseMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx);
|
|
|
+
|
|
|
+ DXASSERT(TrueMat->getType() == FalseMat->getType(),
|
|
|
+ "Unexpected type mismatch between matrix ternary operator values.");
|
|
|
+
|
|
|
+#ifndef NDEBUG
|
|
|
+ // Assert that if the condition is a matrix, it matches the dimensions of the values
|
|
|
+ if (HLMatrixType MatCondTy = HLMatrixType::dyn_cast(Cond->getType())) {
|
|
|
+ HLMatrixType ValMatTy = HLMatrixType::cast(TrueMat->getType());
|
|
|
+ DXASSERT(MatCondTy.getNumRows() == ValMatTy.getNumRows()
|
|
|
+ && MatCondTy.getNumColumns() == ValMatTy.getNumColumns(),
|
|
|
+ "Unexpected mismatch between ternary operator condition and value matrix dimensions.");
|
|
|
}
|
|
|
-}
|
|
|
+#endif
|
|
|
|
|
|
-#include "dxc/HLSL/DxilGenerationPass.h"
|
|
|
-char MatrixBitcastLowerPass::ID = 0;
|
|
|
-FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
|
|
|
+ IRBuilder<> Builder(Call);
|
|
|
+ Value *LoweredCond = getLoweredByValOperand(Cond, Builder);
|
|
|
+ Value *LoweredTrueVec = getLoweredByValOperand(TrueMat, Builder);
|
|
|
+ Value *LoweredFalseVec = getLoweredByValOperand(FalseMat, Builder);
|
|
|
+ Value *Result = UndefValue::get(LoweredTrueVec->getType());
|
|
|
|
|
|
-INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)
|
|
|
+ bool IsScalarCond = !LoweredCond->getType()->isVectorTy();
|
|
|
+
|
|
|
+ unsigned NumElems = Result->getType()->getVectorNumElements();
|
|
|
+ for (uint64_t ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
|
|
|
+ Value *ElemCond = IsScalarCond ? LoweredCond
|
|
|
+ : Builder.CreateExtractElement(LoweredCond, ElemIdx);
|
|
|
+ Value *ElemTrueVal = Builder.CreateExtractElement(LoweredTrueVec, ElemIdx);
|
|
|
+ Value *ElemFalseVal = Builder.CreateExtractElement(LoweredFalseVec, ElemIdx);
|
|
|
+ Value *ResultElem = Builder.CreateSelect(ElemCond, ElemTrueVal, ElemFalseVal);
|
|
|
+ Result = Builder.CreateInsertElement(Result, ResultElem, ElemIdx);
|
|
|
+ }
|
|
|
+
|
|
|
+ return Result;
|
|
|
+}
|