2
0

DxilSimplify.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. //===-- DxilSimplify.cpp - Fold dxil intrinsics into constants -----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. // Copyright (C) Microsoft Corporation. All rights reserved.
  9. //
  10. //===----------------------------------------------------------------------===//
  11. //
  12. //
  13. //===----------------------------------------------------------------------===//
  14. // simplify dxil op like mad 0, a, b->b.
  15. #include "llvm/Analysis/InstructionSimplify.h"
  16. #include "llvm/IR/Constants.h"
  17. #include "llvm/IR/Function.h"
  18. #include "llvm/IR/Instruction.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "dxc/DXIL/DxilModule.h"
  22. #include "dxc/DXIL/DxilOperations.h"
  23. #include "llvm/Analysis/DxilConstantFolding.h"
  24. #include "llvm/Analysis/DxilSimplify.h"
  25. using namespace llvm;
  26. using namespace hlsl;
  27. namespace {
  28. DXIL::OpCode GetOpcode(Value *opArg) {
  29. if (ConstantInt *ci = dyn_cast<ConstantInt>(opArg)) {
  30. uint64_t opcode = ci->getLimitedValue();
  31. if (opcode < static_cast<uint64_t>(OP::OpCode::NumOpCodes)) {
  32. return static_cast<OP::OpCode>(opcode);
  33. }
  34. }
  35. return DXIL::OpCode::NumOpCodes;
  36. }
  37. } // namespace
  38. namespace hlsl {
  39. bool CanSimplify(const llvm::Function *F) {
  40. // Only simplify dxil functions when we have a valid dxil module.
  41. if (!F->getParent()->HasDxilModule()) {
  42. assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
  43. return false;
  44. }
  45. if (CanConstantFoldCallTo(F))
  46. return true;
  47. // Lookup opcode class in dxil module. Set default value to invalid class.
  48. OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
  49. const bool found =
  50. F->getParent()->GetDxilModule().GetOP()->GetOpCodeClass(F, opClass);
  51. // Return true for those dxil operation classes we can simplify.
  52. if (found) {
  53. switch (opClass) {
  54. default:
  55. break;
  56. case OP::OpCodeClass::Tertiary:
  57. return true;
  58. }
  59. }
  60. return false;
  61. }
  62. /// \brief Given a function and set of arguments, see if we can fold the
  63. /// result as dxil operation.
  64. ///
  65. /// If this call could not be simplified returns null.
  66. Value *SimplifyDxilCall(llvm::Function *F, ArrayRef<Value *> Args,
  67. llvm::Instruction *I,
  68. bool MayInsert)
  69. {
  70. if (!F->getParent()->HasDxilModule()) {
  71. assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
  72. return nullptr;
  73. }
  74. DxilModule &DM = F->getParent()->GetDxilModule();
  75. // Skip precise.
  76. if (DM.IsPrecise(I))
  77. return nullptr;
  78. // Lookup opcode class in dxil module. Set default value to invalid class.
  79. OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
  80. const bool found = DM.GetOP()->GetOpCodeClass(F, opClass);
  81. if (!found)
  82. return nullptr;
  83. DXIL::OpCode opcode = GetOpcode(Args[0]);
  84. if (opcode == DXIL::OpCode::NumOpCodes)
  85. return nullptr;
  86. if (CanConstantFoldCallTo(F)) {
  87. bool bAllConstant = true;
  88. SmallVector<Constant *, 4> ConstantArgs;
  89. ConstantArgs.reserve(Args.size());
  90. for (Value *V : Args) {
  91. Constant *C = dyn_cast<Constant>(V);
  92. if (!C) {
  93. bAllConstant = false;
  94. break;
  95. }
  96. ConstantArgs.push_back(C);
  97. }
  98. if (bAllConstant)
  99. return hlsl::ConstantFoldScalarCall(F->getName(), F->getReturnType(),
  100. ConstantArgs);
  101. }
  102. switch (opcode) {
  103. default:
  104. return nullptr;
  105. case DXIL::OpCode::FMad: {
  106. Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
  107. Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
  108. Constant *zero = ConstantFP::get(op0->getType(), 0);
  109. if (op0 == zero)
  110. return op2;
  111. Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
  112. if (op1 == zero)
  113. return op2;
  114. if (MayInsert) {
  115. Constant *one = ConstantFP::get(op0->getType(), 1);
  116. if (op0 == one) {
  117. IRBuilder<> Builder(I);
  118. llvm::FastMathFlags FMF;
  119. FMF.setUnsafeAlgebraHLSL();
  120. Builder.SetFastMathFlags(FMF);
  121. return Builder.CreateFAdd(op1, op2);
  122. }
  123. if (op1 == one) {
  124. IRBuilder<> Builder(I);
  125. llvm::FastMathFlags FMF;
  126. FMF.setUnsafeAlgebraHLSL();
  127. Builder.SetFastMathFlags(FMF);
  128. return Builder.CreateFAdd(op0, op2);
  129. }
  130. }
  131. return nullptr;
  132. } break;
  133. case DXIL::OpCode::IMad:
  134. case DXIL::OpCode::UMad: {
  135. Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
  136. Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
  137. Constant *zero = ConstantInt::get(op0->getType(), 0);
  138. if (op0 == zero)
  139. return op2;
  140. Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
  141. if (op1 == zero)
  142. return op2;
  143. if (MayInsert) {
  144. Constant *one = ConstantInt::get(op0->getType(), 1);
  145. if (op0 == one) {
  146. IRBuilder<> Builder(I);
  147. return Builder.CreateAdd(op1, op2);
  148. }
  149. if (op1 == one) {
  150. IRBuilder<> Builder(I);
  151. return Builder.CreateAdd(op0, op2);
  152. }
  153. }
  154. return nullptr;
  155. } break;
  156. }
  157. }
  158. } // namespace hlsl