DxilTargetTransformInfo.cpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. //===-- DxilTargetTransformInfo.cpp - DXIL specific TTI pass ----------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // \file
  11. // This file implements a TargetTransformInfo analysis pass specific to the
  12. // DXIL. Only implemented isSourceOfDivergence for DivergenceAnalysis.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "DxilTargetTransformInfo.h"
  16. #include "dxc/DXIL/DxilModule.h"
  17. #include "dxc/DXIL/DxilOperations.h"
  18. #include "llvm/CodeGen/BasicTTIImpl.h"
  19. using namespace llvm;
  20. using namespace hlsl;
  21. #define DEBUG_TYPE "DXILtti"
  22. // For BasicTTImpl
  23. cl::opt<unsigned>
  24. llvm::PartialUnrollingThreshold("partial-unrolling-threshold", cl::init(0),
  25. cl::desc("Threshold for partial unrolling"),
  26. cl::Hidden);
  27. DxilTTIImpl::DxilTTIImpl(const TargetMachine *TM, const Function &F,
  28. hlsl::DxilModule &DM, bool ThreadGroup)
  29. : BaseT(TM, F.getParent()->getDataLayout()), m_pHlslOP(DM.GetOP()),
  30. m_isThreadGroup(ThreadGroup) {}
  31. namespace {
  32. bool IsDxilOpSourceOfDivergence(const CallInst *CI, OP *hlslOP,
  33. bool ThreadGroup) {
  34. DXIL::OpCode opcode = hlslOP->GetDxilOpFuncCallInst(CI);
  35. switch (opcode) {
  36. case DXIL::OpCode::AtomicBinOp:
  37. case DXIL::OpCode::AtomicCompareExchange:
  38. case DXIL::OpCode::LoadInput:
  39. case DXIL::OpCode::BufferUpdateCounter:
  40. case DXIL::OpCode::CycleCounterLegacy:
  41. case DXIL::OpCode::DomainLocation:
  42. case DXIL::OpCode::Coverage:
  43. case DXIL::OpCode::EvalCentroid:
  44. case DXIL::OpCode::EvalSampleIndex:
  45. case DXIL::OpCode::EvalSnapped:
  46. case DXIL::OpCode::FlattenedThreadIdInGroup:
  47. case DXIL::OpCode::GSInstanceID:
  48. case DXIL::OpCode::InnerCoverage:
  49. case DXIL::OpCode::LoadOutputControlPoint:
  50. case DXIL::OpCode::LoadPatchConstant:
  51. case DXIL::OpCode::OutputControlPointID:
  52. case DXIL::OpCode::PrimitiveID:
  53. case DXIL::OpCode::RenderTargetGetSampleCount:
  54. case DXIL::OpCode::RenderTargetGetSamplePosition:
  55. case DXIL::OpCode::ThreadId:
  56. case DXIL::OpCode::ThreadIdInGroup:
  57. return true;
  58. case DXIL::OpCode::GroupId:
  59. return !ThreadGroup;
  60. default:
  61. return false;
  62. }
  63. }
  64. }
  65. ///
  66. /// \returns true if the result of the value could potentially be
  67. /// different across dispatch or thread group.
  68. bool DxilTTIImpl::isSourceOfDivergence(const Value *V) const {
  69. if (dyn_cast<Argument>(V))
  70. return true;
  71. // Atomics are divergent because they are executed sequentially: when an
  72. // atomic operation refers to the same address in each thread, then each
  73. // thread after the first sees the value written by the previous thread as
  74. // original value.
  75. if (isa<AtomicRMWInst>(V) || isa<AtomicCmpXchgInst>(V))
  76. return true;
  77. if (const CallInst *CI = dyn_cast<CallInst>(V)) {
  78. // Assume none dxil instrincis function calls are a source of divergence.
  79. if (!m_pHlslOP->IsDxilOpFuncCallInst(CI))
  80. return true;
  81. return IsDxilOpSourceOfDivergence(CI, m_pHlslOP, m_isThreadGroup);
  82. }
  83. return false;
  84. }