HLMatrixSubscriptUseReplacer.h 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixSubscriptUseReplacer.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #pragma once
  10. #include "llvm/ADT/SmallVector.h"
  11. #include "llvm/IR/IRBuilder.h"
  12. #include <vector>
  13. namespace llvm {
  14. class Value;
  15. class AllocaInst;
  16. class CallInst;
  17. class Instruction;
  18. class Function;
  19. } // namespace llvm
  20. namespace hlsl {
  21. // Implements recursive replacement of a matrix subscript's uses,
  22. // from a pointer to a matrix value to a pointer to its lowered vector version,
  23. // whether directly or through GEPs in the case of two-level indexing like mat[i][j].
  24. // This has to handle one or two levels of indices, each of which either
  25. // constant or dynamic: mat[0], mat[i], mat[0][0], mat[i][0], mat[0][j], mat[i][j],
  26. // plus the equivalent element accesses: mat._11, mat._11_12, mat._11_12[0], mat._11_12[i]
  27. class HLMatrixSubscriptUseReplacer {
  28. public:
  29. // The constructor does everything
  30. HLMatrixSubscriptUseReplacer(llvm::CallInst* Call, llvm::Value *LoweredPtr, llvm::Value *TempLoweredMatrix,
  31. llvm::SmallVectorImpl<llvm::Value*> &ElemIndices, bool AllowLoweredPtrGEPs,
  32. std::vector<llvm::Instruction*> &DeadInsts);
  33. private:
  34. void replaceUses(llvm::Instruction* PtrInst, llvm::Value* SubIdxVal);
  35. llvm::Value *tryGetScalarIndex(llvm::Value *SubIdxVal, llvm::IRBuilder<> &Builder);
  36. void cacheLoweredMatrix(bool ForDynamicIndexing, llvm::IRBuilder<> &Builder);
  37. llvm::Value *loadElem(llvm::Value *Idx, llvm::IRBuilder<> &Builder);
  38. void storeElem(llvm::Value *Idx, llvm::Value *Elem, llvm::IRBuilder<> &Builder);
  39. llvm::Value *loadVector(llvm::IRBuilder<> &Builder);
  40. void storeVector(llvm::Value *Vec, llvm::IRBuilder<> &Builder);
  41. void flushLoweredMatrix(llvm::IRBuilder<> &Builder);
  42. private:
  43. llvm::Value *LoweredPtr;
  44. llvm::SmallVectorImpl<llvm::Value*> &ElemIndices;
  45. std::vector<llvm::Instruction*> &DeadInsts;
  46. bool AllowLoweredPtrGEPs = false;
  47. bool HasScalarResult = false;
  48. bool HasDynamicElemIndex = false;
  49. llvm::Type *LoweredTy = nullptr;
  50. // The entire lowered matrix as loaded from LoweredPtr,
  51. // nullptr if we copied it to a temporary array.
  52. llvm::Value *TempLoweredMatrix = nullptr;
  53. // We allocate this if the level 1 indices are not all constants,
  54. // so we can dynamically index the lowered matrix vector.
  55. llvm::AllocaInst *LazyTempElemArrayAlloca = nullptr;
  56. // We'll allocate this lazily if we have a dynamic level 2 index (mat[0][i]),
  57. // so we can dynamically index the level 1 indices.
  58. llvm::AllocaInst *LazyTempElemIndicesArrayAlloca = nullptr;
  59. };
  60. } // namespace hlsl