DxilSimplify.cpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. //===-- DxilSimplify.cpp - Fold dxil intrinsics into constants -----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. // Copyright (C) Microsoft Corporation. All rights reserved.
  9. //
  10. //===----------------------------------------------------------------------===//
  11. //
  12. //
  13. //===----------------------------------------------------------------------===//
  14. // simplify dxil op like mad 0, a, b->b.
  15. #include "llvm/Analysis/InstructionSimplify.h"
  16. #include "llvm/IR/Constants.h"
  17. #include "llvm/IR/Function.h"
  18. #include "llvm/IR/Instruction.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "dxc/HLSL/DxilModule.h"
  22. #include "dxc/HLSL/DxilOperations.h"
  23. #include "llvm/Analysis/DxilConstantFolding.h"
  24. #include "llvm/Analysis/DxilSimplify.h"
  25. using namespace llvm;
  26. using namespace hlsl;
  27. namespace {
  28. DXIL::OpCode GetOpcode(Value *opArg) {
  29. if (ConstantInt *ci = dyn_cast<ConstantInt>(opArg)) {
  30. uint64_t opcode = ci->getLimitedValue();
  31. if (opcode < static_cast<uint64_t>(OP::OpCode::NumOpCodes)) {
  32. return static_cast<OP::OpCode>(opcode);
  33. }
  34. }
  35. return DXIL::OpCode::NumOpCodes;
  36. }
  37. } // namespace
  38. namespace hlsl {
  39. bool CanSimplify(const llvm::Function *F) {
  40. // Only simplify dxil functions when we have a valid dxil module.
  41. if (!F->getParent()->HasDxilModule()) {
  42. assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
  43. return false;
  44. }
  45. // Lookup opcode class in dxil module. Set default value to invalid class.
  46. OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
  47. const bool found =
  48. F->getParent()->GetDxilModule().GetOP()->GetOpCodeClass(F, opClass);
  49. // Return true for those dxil operation classes we can simplify.
  50. if (found) {
  51. switch (opClass) {
  52. default:
  53. break;
  54. case OP::OpCodeClass::Tertiary:
  55. return true;
  56. }
  57. }
  58. return false;
  59. }
  60. /// \brief Given a function and set of arguments, see if we can fold the
  61. /// result as dxil operation.
  62. ///
  63. /// If this call could not be simplified returns null.
  64. Value *SimplifyDxilCall(llvm::Function *F, ArrayRef<Value *> Args,
  65. llvm::Instruction *I) {
  66. if (!F->getParent()->HasDxilModule()) {
  67. assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
  68. return nullptr;
  69. }
  70. DxilModule &DM = F->getParent()->GetDxilModule();
  71. // Skip precise.
  72. if (DM.IsPrecise(I))
  73. return nullptr;
  74. // Lookup opcode class in dxil module. Set default value to invalid class.
  75. OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
  76. const bool found = DM.GetOP()->GetOpCodeClass(F, opClass);
  77. if (!found)
  78. return nullptr;
  79. DXIL::OpCode opcode = GetOpcode(Args[0]);
  80. if (opcode == DXIL::OpCode::NumOpCodes)
  81. return nullptr;
  82. if (CanConstantFoldCallTo(F)) {
  83. bool bAllConstant = true;
  84. SmallVector<Constant *, 4> ConstantArgs;
  85. ConstantArgs.reserve(Args.size());
  86. for (Value *V : Args) {
  87. Constant *C = dyn_cast<Constant>(V);
  88. if (!C) {
  89. bAllConstant = false;
  90. break;
  91. }
  92. ConstantArgs.push_back(C);
  93. }
  94. if (bAllConstant)
  95. return hlsl::ConstantFoldScalarCall(F->getName(), F->getReturnType(),
  96. ConstantArgs);
  97. }
  98. switch (opcode) {
  99. default:
  100. return nullptr;
  101. case DXIL::OpCode::FMad: {
  102. Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
  103. Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
  104. Constant *zero = ConstantFP::get(op0->getType(), 0);
  105. if (op0 == zero)
  106. return op2;
  107. Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
  108. if (op1 == zero)
  109. return op2;
  110. Constant *one = ConstantFP::get(op0->getType(), 1);
  111. if (op0 == one) {
  112. IRBuilder<> Builder(I);
  113. llvm::FastMathFlags FMF;
  114. FMF.setUnsafeAlgebraHLSL();
  115. Builder.SetFastMathFlags(FMF);
  116. return Builder.CreateFAdd(op1, op2);
  117. }
  118. if (op1 == one) {
  119. IRBuilder<> Builder(I);
  120. llvm::FastMathFlags FMF;
  121. FMF.setUnsafeAlgebraHLSL();
  122. Builder.SetFastMathFlags(FMF);
  123. return Builder.CreateFAdd(op0, op2);
  124. }
  125. return nullptr;
  126. } break;
  127. case DXIL::OpCode::IMad:
  128. case DXIL::OpCode::UMad: {
  129. Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
  130. Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
  131. Constant *zero = ConstantInt::get(op0->getType(), 0);
  132. if (op0 == zero)
  133. return op2;
  134. Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
  135. if (op1 == zero)
  136. return op2;
  137. Constant *one = ConstantInt::get(op0->getType(), 1);
  138. if (op0 == one) {
  139. IRBuilder<> Builder(I);
  140. return Builder.CreateAdd(op1, op2);
  141. }
  142. if (op1 == one) {
  143. IRBuilder<> Builder(I);
  144. return Builder.CreateAdd(op0, op2);
  145. }
  146. return nullptr;
  147. } break;
  148. }
  149. }
  150. } // namespace hlsl