HLExpandStoreIntrinsics.cpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLExpandStoreIntrinsics.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/Support/Global.h"
  10. #include "dxc/HLSL/HLOperations.h"
  11. #include "dxc/HLSL/HLMatrixType.h"
  12. #include "dxc/HLSL/HLModule.h"
  13. #include "dxc/DXIL/DxilTypeSystem.h"
  14. #include "dxc/HlslIntrinsicOp.h"
  15. #include "llvm/IR/InstIterator.h"
  16. #include "llvm/IR/Instruction.h"
  17. #include "llvm/IR/Instructions.h"
  18. #include "llvm/IR/Function.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/IR/Intrinsics.h"
  22. #include "llvm/Transforms/Scalar.h"
  23. using namespace hlsl;
  24. using namespace llvm;
  25. namespace {
  26. // Expands buffer stores of aggregate value types
  27. // into stores of its individual elements,
  28. // before SROA happens and we lose the layout information.
  29. class HLExpandStoreIntrinsics : public FunctionPass {
  30. public:
  31. static char ID;
  32. explicit HLExpandStoreIntrinsics() : FunctionPass(ID) {}
  33. const char *getPassName() const override {
  34. return "Expand HLSL store intrinsics";
  35. }
  36. bool runOnFunction(Function& Func) override;
  37. private:
  38. DxilTypeSystem *m_typeSys;
  39. bool expand(CallInst *StoreCall);
  40. void emitElementStores(CallInst &OriginalCall,
  41. SmallVectorImpl<Value*>& GEPIndicesStack, Type *StackTopTy,
  42. unsigned OffsetFromBase, DxilFieldAnnotation *fieldAnnotation);
  43. };
  44. char HLExpandStoreIntrinsics::ID = 0;
  45. bool HLExpandStoreIntrinsics::runOnFunction(Function& Func) {
  46. bool changed = false;
  47. m_typeSys = &(Func.getParent()->GetHLModule().GetTypeSystem());
  48. for (auto InstIt = inst_begin(Func), InstEnd = inst_end(Func); InstIt != InstEnd;) {
  49. CallInst *Call = dyn_cast<CallInst>(&*(InstIt++));
  50. if (Call == nullptr
  51. || GetHLOpcodeGroup(Call->getCalledFunction()) != HLOpcodeGroup::HLIntrinsic
  52. || static_cast<IntrinsicOp>(GetHLOpcode(Call)) != IntrinsicOp::MOP_Store) {
  53. continue;
  54. }
  55. changed |= expand(Call);
  56. }
  57. return changed;
  58. }
  59. bool HLExpandStoreIntrinsics::expand(CallInst* StoreCall) {
  60. Value *OldStoreValueArg = StoreCall->getArgOperand(HLOperandIndex::kStoreValOpIdx);
  61. Type *OldStoreValueArgTy = OldStoreValueArg->getType();
  62. // Only expand if the value argument is by pointer, which means it's an aggregate.
  63. if (!OldStoreValueArgTy->isPointerTy()) return false;
  64. IRBuilder<> Builder(StoreCall);
  65. SmallVector<Value*, 4> GEPIndicesStack;
  66. GEPIndicesStack.emplace_back(Builder.getInt32(0));
  67. emitElementStores(*StoreCall, GEPIndicesStack, OldStoreValueArgTy->getPointerElementType(), /* OffsetFromBase */ 0, nullptr);
  68. DXASSERT(StoreCall->getType()->isVoidTy() && StoreCall->use_empty(),
  69. "Buffer store intrinsic is expected to return void and hence not have uses.");
  70. StoreCall->eraseFromParent();
  71. return true;
  72. }
  73. void HLExpandStoreIntrinsics::emitElementStores(CallInst &OriginalCall,
  74. SmallVectorImpl<Value*>& GEPIndicesStack, Type *StackTopTy,
  75. unsigned OffsetFromBase, DxilFieldAnnotation* fieldAnnotation) {
  76. llvm::Module &Module = *OriginalCall.getModule();
  77. IRBuilder<> Builder(&OriginalCall);
  78. StructType* StructTy = dyn_cast<StructType>(StackTopTy);
  79. if (StructTy != nullptr && !HLMatrixType::isa(StructTy)) {
  80. const StructLayout* Layout = Module.getDataLayout().getStructLayout(StructTy);
  81. DxilStructAnnotation *SA = m_typeSys->GetStructAnnotation(StructTy);
  82. for (unsigned i = 0; i < StructTy->getNumElements(); ++i) {
  83. Type *ElemTy = StructTy->getElementType(i);
  84. unsigned ElemOffsetFromBase = OffsetFromBase + Layout->getElementOffset(i);
  85. GEPIndicesStack.emplace_back(Builder.getInt32(i));
  86. DxilFieldAnnotation* FA = SA != nullptr ? &(SA->GetFieldAnnotation(i)) : nullptr;
  87. emitElementStores(OriginalCall, GEPIndicesStack, ElemTy, ElemOffsetFromBase, FA);
  88. GEPIndicesStack.pop_back();
  89. }
  90. }
  91. else if (ArrayType *ArrayTy = dyn_cast<ArrayType>(StackTopTy)) {
  92. unsigned ElemSize = (unsigned)Module.getDataLayout().getTypeAllocSize(ArrayTy->getElementType());
  93. for (int i = 0; i < (int)ArrayTy->getNumElements(); ++i) {
  94. unsigned ElemOffsetFromBase = OffsetFromBase + ElemSize * i;
  95. GEPIndicesStack.emplace_back(Builder.getInt32(i));
  96. emitElementStores(OriginalCall, GEPIndicesStack, ArrayTy->getElementType(), ElemOffsetFromBase, fieldAnnotation);
  97. GEPIndicesStack.pop_back();
  98. }
  99. }
  100. else {
  101. // Scalar or vector
  102. Value* OpcodeVal = OriginalCall.getArgOperand(HLOperandIndex::kOpcodeIdx);
  103. Value* BufHandle = OriginalCall.getArgOperand(HLOperandIndex::kHandleOpIdx);
  104. Value* OffsetVal = OriginalCall.getArgOperand(HLOperandIndex::kStoreOffsetOpIdx);
  105. if (OffsetFromBase > 0)
  106. OffsetVal = Builder.CreateAdd(OffsetVal, Builder.getInt32(OffsetFromBase));
  107. Value* AggPtr = OriginalCall.getArgOperand(HLOperandIndex::kStoreValOpIdx);
  108. Value *ElemPtr = Builder.CreateGEP(AggPtr, GEPIndicesStack);
  109. Value* ElemVal = nullptr;
  110. if (HLMatrixType::isa(StackTopTy) && fieldAnnotation &&
  111. fieldAnnotation->HasMatrixAnnotation()) {
  112. // For matrix load, we generate HL intrinsic matldst.colLoad/matldst.rowLoad
  113. // instead of LLVM LoadInst to ensure that it gets lowered properly later
  114. // in HLMatrixLowerPass
  115. bool isRowMajor = fieldAnnotation->GetMatrixAnnotation().Orientation ==
  116. hlsl::MatrixOrientation::RowMajor;
  117. unsigned matLdOpcode =
  118. isRowMajor ? static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad)
  119. : static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad);
  120. // Generate matrix load
  121. FunctionType *MatLdFnType = FunctionType::get(
  122. StackTopTy, {Builder.getInt32Ty(), ElemPtr->getType()},
  123. /* isVarArg */ false);
  124. Function *MatLdFn = GetOrCreateHLFunction(
  125. Module, MatLdFnType, HLOpcodeGroup::HLMatLoadStore, matLdOpcode);
  126. Value *MatLdOpCode = ConstantInt::get(Builder.getInt32Ty(), matLdOpcode);
  127. ElemVal = Builder.CreateCall(MatLdFn, {MatLdOpCode, ElemPtr});
  128. } else {
  129. ElemVal = Builder.CreateLoad(ElemPtr); // We go from memory to memory so no special bool handling needed
  130. }
  131. FunctionType *NewCalleeType = FunctionType::get(Builder.getVoidTy(),
  132. { OpcodeVal->getType(), BufHandle->getType(), OffsetVal->getType(), ElemVal->getType() },
  133. /* isVarArg */ false);
  134. Function *NewCallee = GetOrCreateHLFunction(Module, NewCalleeType,
  135. HLOpcodeGroup::HLIntrinsic, (unsigned)IntrinsicOp::MOP_Store);
  136. Builder.CreateCall(NewCallee, { OpcodeVal, BufHandle, OffsetVal, ElemVal });
  137. }
  138. }
  139. } // namespace
  140. FunctionPass *llvm::createHLExpandStoreIntrinsicsPass() { return new HLExpandStoreIntrinsics(); }
  141. INITIALIZE_PASS(HLExpandStoreIntrinsics, "hl-expand-store-intrinsics",
  142. "Expand HLSL store intrinsics", false, false)