DxilConstantFolding.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. //===-- DxilConstantFolding.cpp - Fold dxil intrinsics into constants -----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. // Copyright (C) Microsoft Corporation. All rights reserved.
  9. //
  10. //===----------------------------------------------------------------------===//
  11. //
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Analysis/DxilConstantFolding.h"
  15. #include "llvm/Analysis/ConstantFolding.h"
  16. #include "llvm/ADT/SmallPtrSet.h"
  17. #include "llvm/ADT/SmallVector.h"
  18. #include "llvm/ADT/StringMap.h"
  19. #include "llvm/Analysis/TargetLibraryInfo.h"
  20. #include "llvm/Analysis/ValueTracking.h"
  21. #include "llvm/Config/config.h"
  22. #include "llvm/IR/Constants.h"
  23. #include "llvm/IR/DataLayout.h"
  24. #include "llvm/IR/DerivedTypes.h"
  25. #include "llvm/IR/Function.h"
  26. #include "llvm/IR/GetElementPtrTypeIterator.h"
  27. #include "llvm/IR/GlobalVariable.h"
  28. #include "llvm/IR/Instructions.h"
  29. #include "llvm/IR/Intrinsics.h"
  30. #include "llvm/IR/Operator.h"
  31. #include "llvm/Support/ErrorHandling.h"
  32. #include "llvm/Support/MathExtras.h"
  33. #include <cerrno>
  34. #include <cmath>
  35. #include <algorithm>
  36. #include <functional>
  37. #include "dxc/DXIL/DXIL.h"
  38. #include "dxc/HLSL/DxilConvergentName.h"
  39. using namespace llvm;
  40. using namespace hlsl;
  41. namespace {
  42. bool IsConvergentMarker(const Function *F) {
  43. return F->getName().startswith(kConvergentFunctionPrefix);
  44. }
  45. bool IsConvergentMarker(const char *Name) {
  46. StringRef RName = Name;
  47. return RName.startswith(kConvergentFunctionPrefix);
  48. }
  49. } // namespace
  50. // Check if the given function is a dxil intrinsic and if so extract the
  51. // opcode for the instrinsic being called.
  52. static bool GetDxilOpcode(StringRef Name, ArrayRef<Constant *> Operands, OP::OpCode &out) {
  53. if (!OP::IsDxilOpFuncName(Name))
  54. return false;
  55. if (!Operands.size())
  56. return false;
  57. if (ConstantInt *ci = dyn_cast<ConstantInt>(Operands[0])) {
  58. uint64_t opcode = ci->getLimitedValue();
  59. if (opcode < static_cast<uint64_t>(OP::OpCode::NumOpCodes)) {
  60. out = static_cast<OP::OpCode>(opcode);
  61. return true;
  62. }
  63. }
  64. return false;
  65. }
  66. // Typedefs for passing function pointers to evaluate float constants.
  67. typedef double(__cdecl *NativeFPUnaryOp)(double);
  68. typedef std::function<APFloat::opStatus(APFloat&)> APFloatUnaryOp;
  69. /// Currently APFloat versions of these functions do not exist, so we use
  70. /// the host native double versions. Float versions are not called
  71. /// directly but for all these it is true (float)(f((double)arg)) ==
  72. /// f(arg). Long double not supported yet.
  73. ///
  74. /// Calls out to the llvm constant folding function to do the real work.
  75. static Constant *DxilConstantFoldFP(NativeFPUnaryOp NativeFP, ConstantFP *C, Type *Ty) {
  76. double V = llvm::getValueAsDouble(C);
  77. return llvm::ConstantFoldFP(NativeFP, V, Ty);
  78. }
  79. // Constant fold using the provided function on APFloats.
  80. static Constant *HLSLConstantFoldAPFloat(APFloatUnaryOp NativeFP, ConstantFP *C, Type *Ty) {
  81. APFloat APF = C->getValueAPF();
  82. if (NativeFP(APF) != APFloat::opStatus::opOK)
  83. return nullptr;
  84. return ConstantFP::get(Ty->getContext(), APF);
  85. }
  86. // Constant fold a round dxil intrinsic.
  87. static Constant *HLSLConstantFoldRound(APFloat::roundingMode roundingMode, ConstantFP *C, Type *Ty) {
  88. APFloatUnaryOp f = [roundingMode](APFloat &x) { return x.roundToIntegral(roundingMode); };
  89. return HLSLConstantFoldAPFloat(f, C, Ty);
  90. }
  91. namespace {
  92. // Wrapper for call operands that "shifts past" the hlsl intrinsic opcode.
  93. // Also provides accessors that dyn_cast the operand to a constant type.
  94. class DxilIntrinsicOperands {
  95. public:
  96. DxilIntrinsicOperands(ArrayRef<Constant *> RawCallOperands) : m_RawCallOperands(RawCallOperands) {}
  97. Constant * const &operator[](size_t index) const {
  98. return m_RawCallOperands[index + 1];
  99. }
  100. ConstantInt *GetConstantInt(size_t index) const {
  101. return dyn_cast<ConstantInt>(this->operator[](index));
  102. }
  103. ConstantFP *GetConstantFloat(size_t index) const {
  104. return dyn_cast<ConstantFP>(this->operator[](index));
  105. }
  106. size_t Size() const {
  107. return m_RawCallOperands.size() - 1;
  108. }
  109. private:
  110. ArrayRef<Constant *> m_RawCallOperands;
  111. };
  112. }
  113. /// We only fold functions with finite arguments. Folding NaN and inf is
  114. /// likely to be aborted with an exception anyway, and some host libms
  115. /// have known errors raising exceptions.
  116. static bool IsFinite(ConstantFP *C) {
  117. if (C->getValueAPF().isNaN() || C->getValueAPF().isInfinity())
  118. return false;
  119. return true;
  120. }
  121. // Check that the op is non-null and finite.
  122. static bool IsValidOp(ConstantFP *C) {
  123. if (!C || !IsFinite(C))
  124. return false;
  125. return true;
  126. }
  127. // Check that all ops are valid.
  128. static bool AllValidOps(ArrayRef<ConstantFP *> Ops) {
  129. return std::all_of(Ops.begin(), Ops.end(), IsValidOp);
  130. }
  131. // Constant fold unary floating point intrinsics.
  132. static Constant *ConstantFoldUnaryFPIntrinsic(OP::OpCode opcode, Type *Ty, ConstantFP *Op) {
  133. switch (opcode) {
  134. default: break;
  135. case OP::OpCode::FAbs: return DxilConstantFoldFP(fabs, Op, Ty);
  136. case OP::OpCode::Saturate: {
  137. NativeFPUnaryOp f = [](double x) { return std::max(std::min(x, 1.0), 0.0); };
  138. return DxilConstantFoldFP(f, Op, Ty);
  139. }
  140. case OP::OpCode::Cos: return DxilConstantFoldFP(cos, Op, Ty);
  141. case OP::OpCode::Sin: return DxilConstantFoldFP(sin, Op, Ty);
  142. case OP::OpCode::Tan: return DxilConstantFoldFP(tan, Op, Ty);
  143. case OP::OpCode::Acos: return DxilConstantFoldFP(acos, Op, Ty);
  144. case OP::OpCode::Asin: return DxilConstantFoldFP(asin, Op, Ty);
  145. case OP::OpCode::Atan: return DxilConstantFoldFP(atan, Op, Ty);
  146. case OP::OpCode::Hcos: return DxilConstantFoldFP(cosh, Op, Ty);
  147. case OP::OpCode::Hsin: return DxilConstantFoldFP(sinh, Op, Ty);
  148. case OP::OpCode::Htan: return DxilConstantFoldFP(tanh, Op, Ty);
  149. case OP::OpCode::Exp: return DxilConstantFoldFP(exp2, Op, Ty);
  150. case OP::OpCode::Frc: {
  151. NativeFPUnaryOp f = [](double x) { double unused; return fabs(modf(x, &unused)); };
  152. return DxilConstantFoldFP(f, Op, Ty);
  153. }
  154. case OP::OpCode::Log: return DxilConstantFoldFP(log2, Op, Ty);
  155. case OP::OpCode::Sqrt: return DxilConstantFoldFP(sqrt, Op, Ty);
  156. case OP::OpCode::Rsqrt: {
  157. NativeFPUnaryOp f = [](double x) { return 1.0 / sqrt(x); };
  158. return DxilConstantFoldFP(f, Op, Ty);
  159. }
  160. case OP::OpCode::Round_ne: return HLSLConstantFoldRound(APFloat::roundingMode::rmNearestTiesToEven, Op, Ty);
  161. case OP::OpCode::Round_ni: return HLSLConstantFoldRound(APFloat::roundingMode::rmTowardNegative, Op, Ty);
  162. case OP::OpCode::Round_pi: return HLSLConstantFoldRound(APFloat::roundingMode::rmTowardPositive, Op, Ty);
  163. case OP::OpCode::Round_z: return HLSLConstantFoldRound(APFloat::roundingMode::rmTowardZero, Op, Ty);
  164. }
  165. return nullptr;
  166. }
  167. // Constant fold binary floating point intrinsics.
  168. static Constant *ConstantFoldBinaryFPIntrinsic(OP::OpCode opcode, Type *Ty, ConstantFP *Op1, ConstantFP *Op2) {
  169. const APFloat &C1 = Op1->getValueAPF();
  170. const APFloat &C2 = Op2->getValueAPF();
  171. switch (opcode) {
  172. default: break;
  173. case OP::OpCode::FMax: return ConstantFP::get(Ty->getContext(), maxnum(C1, C2));
  174. case OP::OpCode::FMin: return ConstantFP::get(Ty->getContext(), minnum(C1, C2));
  175. }
  176. return nullptr;
  177. }
  178. // Constant fold ternary floating point intrinsics.
  179. static Constant *ConstantFoldTernaryFPIntrinsic(OP::OpCode opcode, Type *Ty, ConstantFP *Op1, ConstantFP *Op2, ConstantFP *Op3) {
  180. const APFloat &C1 = Op1->getValueAPF();
  181. const APFloat &C2 = Op2->getValueAPF();
  182. const APFloat &C3 = Op3->getValueAPF();
  183. APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven;
  184. switch (opcode) {
  185. default: break;
  186. case OP::OpCode::FMad: {
  187. APFloat result(C1);
  188. result.multiply(C2, roundingMode);
  189. result.add(C3, roundingMode);
  190. return ConstantFP::get(Ty->getContext(), result);
  191. }
  192. case OP::OpCode::Fma: {
  193. APFloat result(C1);
  194. result.fusedMultiplyAdd(C2, C3, roundingMode);
  195. return ConstantFP::get(Ty->getContext(), result);
  196. }
  197. }
  198. return nullptr;
  199. }
  200. // Compute dot product for arbitrary sized vectors.
  201. static Constant *ComputeDot(Type *Ty, ArrayRef<ConstantFP *> A, ArrayRef<ConstantFP *> B) {
  202. if (A.size() != B.size() || !A.size()) {
  203. assert(false && "invalid call to compute dot");
  204. return nullptr;
  205. }
  206. if (!AllValidOps(A) || !AllValidOps(B))
  207. return nullptr;
  208. APFloat::roundingMode roundingMode = APFloat::roundingMode::rmNearestTiesToEven;
  209. APFloat sum = APFloat::getZero(A[0]->getValueAPF().getSemantics());
  210. for (int i = 0, e = A.size(); i != e; ++i) {
  211. APFloat val(A[i]->getValueAPF());
  212. val.multiply(B[i]->getValueAPF(), roundingMode);
  213. sum.add(val, roundingMode);
  214. }
  215. return ConstantFP::get(Ty->getContext(), sum);
  216. }
  217. // Constant folding for dot2, dot3, and dot4.
  218. static Constant *ConstantFoldDot(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &operands) {
  219. switch (opcode) {
  220. default: break;
  221. case OP::OpCode::Dot2: {
  222. ConstantFP *Ax = operands.GetConstantFloat(0);
  223. ConstantFP *Ay = operands.GetConstantFloat(1);
  224. ConstantFP *Bx = operands.GetConstantFloat(2);
  225. ConstantFP *By = operands.GetConstantFloat(3);
  226. return ComputeDot(Ty, { Ax, Ay }, { Bx, By });
  227. }
  228. case OP::OpCode::Dot3: {
  229. ConstantFP *Ax = operands.GetConstantFloat(0);
  230. ConstantFP *Ay = operands.GetConstantFloat(1);
  231. ConstantFP *Az = operands.GetConstantFloat(2);
  232. ConstantFP *Bx = operands.GetConstantFloat(3);
  233. ConstantFP *By = operands.GetConstantFloat(4);
  234. ConstantFP *Bz = operands.GetConstantFloat(5);
  235. return ComputeDot(Ty, { Ax, Ay, Az }, { Bx, By, Bz });
  236. }
  237. case OP::OpCode::Dot4: {
  238. ConstantFP *Ax = operands.GetConstantFloat(0);
  239. ConstantFP *Ay = operands.GetConstantFloat(1);
  240. ConstantFP *Az = operands.GetConstantFloat(2);
  241. ConstantFP *Aw = operands.GetConstantFloat(3);
  242. ConstantFP *Bx = operands.GetConstantFloat(4);
  243. ConstantFP *By = operands.GetConstantFloat(5);
  244. ConstantFP *Bz = operands.GetConstantFloat(6);
  245. ConstantFP *Bw = operands.GetConstantFloat(7);
  246. return ComputeDot(Ty, { Ax, Ay, Az, Aw }, { Bx, By, Bz, Bw });
  247. }
  248. }
  249. return nullptr;
  250. }
  251. // Constant fold a Bfrev dxil intrinsic.
  252. static Constant *HLSLConstantFoldBfrev(ConstantInt *C, Type *Ty) {
  253. APInt API = C->getValue();
  254. uint64_t result = 0;
  255. if (Ty == Type::getInt32Ty(Ty->getContext())) {
  256. uint32_t val = static_cast<uint32_t>(API.getLimitedValue());
  257. result = llvm::reverseBits(val);
  258. }
  259. else if (Ty == Type::getInt16Ty(Ty->getContext())) {
  260. uint16_t val = static_cast<uint16_t>(API.getLimitedValue());
  261. result = llvm::reverseBits(val);
  262. }
  263. else if (Ty == Type::getInt64Ty(Ty->getContext())) {
  264. uint64_t val = static_cast<uint64_t>(API.getLimitedValue());
  265. result = llvm::reverseBits(val);
  266. }
  267. else {
  268. return nullptr;
  269. }
  270. return ConstantInt::get(Ty, result);
  271. }
  272. // Handle special case for findfirst* bit functions.
  273. // When the position is equal to the bitwidth the value was not found
  274. // and we need to return a result of -1.
  275. static Constant *HLSLConstantFoldFindBit(Type *Ty, unsigned position, unsigned bitwidth) {
  276. if (position == bitwidth)
  277. return ConstantInt::get(Ty, APInt::getAllOnesValue(Ty->getScalarSizeInBits()));
  278. return ConstantInt::get(Ty, position);
  279. }
  280. // Constant fold unary integer intrinsics.
  281. static Constant *ConstantFoldUnaryIntIntrinsic(OP::OpCode opcode, Type *Ty, ConstantInt *Op) {
  282. APInt API = Op->getValue();
  283. switch (opcode) {
  284. default: break;
  285. case OP::OpCode::Bfrev: return HLSLConstantFoldBfrev(Op, Ty);
  286. case OP::OpCode::Countbits: return ConstantInt::get(Ty, API.countPopulation());
  287. case OP::OpCode::FirstbitLo: return HLSLConstantFoldFindBit(Ty, API.countTrailingZeros(), API.getBitWidth());
  288. case OP::OpCode::FirstbitHi: return HLSLConstantFoldFindBit(Ty, API.countLeadingZeros(), API.getBitWidth());
  289. case OP::OpCode::FirstbitSHi: {
  290. if (API.isNegative())
  291. return HLSLConstantFoldFindBit(Ty, API.countLeadingOnes(), API.getBitWidth());
  292. else
  293. return HLSLConstantFoldFindBit(Ty, API.countLeadingZeros(), API.getBitWidth());
  294. }
  295. }
  296. return nullptr;
  297. }
  298. // Constant fold binary integer intrinsics.
  299. static Constant *ConstantFoldBinaryIntIntrinsic(OP::OpCode opcode, Type *Ty, ConstantInt *Op1, ConstantInt *Op2) {
  300. APInt C1 = Op1->getValue();
  301. APInt C2 = Op2->getValue();
  302. switch (opcode) {
  303. default: break;
  304. case OP::OpCode::IMin: {
  305. APInt minVal = C1.slt(C2) ? C1 : C2;
  306. return ConstantInt::get(Ty, minVal);
  307. }
  308. case OP::OpCode::IMax: {
  309. APInt maxVal = C1.sgt(C2) ? C1 : C2;
  310. return ConstantInt::get(Ty, maxVal);
  311. }
  312. case OP::OpCode::UMin: {
  313. APInt minVal = C1.ult(C2) ? C1 : C2;
  314. return ConstantInt::get(Ty, minVal);
  315. }
  316. case OP::OpCode::UMax: {
  317. APInt maxVal = C1.ugt(C2) ? C1 : C2;
  318. return ConstantInt::get(Ty, maxVal);
  319. }
  320. }
  321. return nullptr;
  322. }
  323. // Compute bit field extract for ibfe and ubfe.
  324. // The comptuation for ibfe and ubfe is the same except for the right shift,
  325. // which is an arithemetic shift for ibfe and logical shift for ubfe.
  326. // ubfe: https://msdn.microsoft.com/en-us/library/windows/desktop/hh447243(v=vs.85).aspx
  327. // ibfe: https://msdn.microsoft.com/en-us/library/windows/desktop/hh447243(v=vs.85).aspx
  328. static Constant *ComputeBFE(Type *Ty, APInt width, APInt offset, APInt val, std::function<APInt(APInt, APInt)> shr) {
  329. const APInt bitwidth(width.getBitWidth(), width.getBitWidth());
  330. // Limit width and offset to the bitwidth of the value.
  331. width = width.And(bitwidth-1);
  332. offset = offset.And(bitwidth-1);
  333. if (width == 0) {
  334. return ConstantInt::get(Ty, 0);
  335. }
  336. else if ((width + offset).ult(bitwidth)) {
  337. APInt dest = val.shl(bitwidth - (width + offset));
  338. dest = shr(dest, bitwidth - width);
  339. return ConstantInt::get(Ty, dest);
  340. }
  341. else {
  342. APInt dest = shr(val, offset);
  343. return ConstantInt::get(Ty, dest);
  344. }
  345. }
  346. // Constant fold ternary integer intrinsic.
  347. static Constant *ConstantFoldTernaryIntIntrinsic(OP::OpCode opcode, Type *Ty, ConstantInt *Op1, ConstantInt *Op2, ConstantInt *Op3) {
  348. APInt C1 = Op1->getValue();
  349. APInt C2 = Op2->getValue();
  350. APInt C3 = Op3->getValue();
  351. switch (opcode) {
  352. default: break;
  353. case OP::OpCode::IMad:
  354. case OP::OpCode::UMad: {
  355. // Result is same for signed/unsigned since this is twos complement and we only
  356. // keep the lower half of the multiply.
  357. APInt result = C1 * C2 + C3;
  358. return ConstantInt::get(Ty, result);
  359. }
  360. case OP::OpCode::Ubfe: return ComputeBFE(Ty, C1, C2, C3, [](APInt val, APInt amt) {return val.lshr(amt); });
  361. case OP::OpCode::Ibfe: return ComputeBFE(Ty, C1, C2, C3, [](APInt val, APInt amt) {return val.ashr(amt); });
  362. }
  363. return nullptr;
  364. }
  365. // Constant fold quaternary integer intrinsic.
  366. //
  367. // Currently we only have one quaternary intrinsic: Bfi.
  368. // The Bfi computaion is described here:
  369. // https://msdn.microsoft.com/en-us/library/windows/desktop/hh446837(v=vs.85).aspx
  370. static Constant *ConstantFoldQuaternaryIntInstrinsic(OP::OpCode opcode, Type *Ty, ConstantInt *Op1, ConstantInt *Op2, ConstantInt *Op3, ConstantInt *Op4) {
  371. if (opcode != OP::OpCode::Bfi)
  372. return nullptr;
  373. APInt bitwidth(Op1->getValue().getBitWidth(), Op1->getValue().getBitWidth());
  374. APInt width = Op1->getValue().And(bitwidth-1);
  375. APInt offset = Op2->getValue().And(bitwidth-1);
  376. APInt src = Op3->getValue();
  377. APInt dst = Op4->getValue();
  378. APInt one(bitwidth.getBitWidth(), 1);
  379. APInt allOnes = APInt::getAllOnesValue(bitwidth.getBitWidth());
  380. // bitmask = (((1 << width)-1) << offset) & 0xffffffff
  381. // dest = ((src2 << offset) & bitmask) | (src3 & ~bitmask)
  382. APInt bitmask = (one.shl(width) - 1).shl(offset).And(allOnes);
  383. APInt result = (src.shl(offset).And(bitmask)).Or(dst.And(~bitmask));
  384. return ConstantInt::get(Ty, result);
  385. }
  386. // Top level function to constant fold floating point intrinsics.
  387. static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &IntrinsicOperands) {
  388. if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
  389. return nullptr;
  390. OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
  391. switch (opClass) {
  392. default: break;
  393. case OP::OpCodeClass::Unary: {
  394. assert(IntrinsicOperands.Size() == 1);
  395. ConstantFP *Op = IntrinsicOperands.GetConstantFloat(0);
  396. if (!IsValidOp(Op))
  397. return nullptr;
  398. return ConstantFoldUnaryFPIntrinsic(opcode, Ty, Op);
  399. }
  400. case OP::OpCodeClass::Binary: {
  401. assert(IntrinsicOperands.Size() == 2);
  402. ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
  403. ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
  404. if (!IsValidOp(Op1) || !IsValidOp(Op2))
  405. return nullptr;
  406. return ConstantFoldBinaryFPIntrinsic(opcode, Ty, Op1, Op2);
  407. }
  408. case OP::OpCodeClass::Tertiary: {
  409. assert(IntrinsicOperands.Size() == 3);
  410. ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
  411. ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
  412. ConstantFP *Op3 = IntrinsicOperands.GetConstantFloat(2);
  413. if (!IsValidOp(Op1) || !IsValidOp(Op2) || !IsValidOp(Op3))
  414. return nullptr;
  415. return ConstantFoldTernaryFPIntrinsic(opcode, Ty, Op1, Op2, Op3);
  416. }
  417. case OP::OpCodeClass::Dot2:
  418. case OP::OpCodeClass::Dot3:
  419. case OP::OpCodeClass::Dot4:
  420. return ConstantFoldDot(opcode, Ty, IntrinsicOperands);
  421. }
  422. return nullptr;
  423. }
  424. // Top level function to constant fold integer intrinsics.
  425. static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &IntrinsicOperands) {
  426. if (Ty->getScalarSizeInBits() > (sizeof(int64_t) * CHAR_BIT))
  427. return nullptr;
  428. OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
  429. switch (opClass) {
  430. default: break;
  431. case OP::OpCodeClass::Unary:
  432. case OP::OpCodeClass::UnaryBits: {
  433. assert(IntrinsicOperands.Size() == 1);
  434. ConstantInt *Op = IntrinsicOperands.GetConstantInt(0);
  435. if (!Op)
  436. return nullptr;
  437. return ConstantFoldUnaryIntIntrinsic(opcode, Ty, Op);
  438. }
  439. case OP::OpCodeClass::Binary: {
  440. assert(IntrinsicOperands.Size() == 2);
  441. ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
  442. ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
  443. if (!Op1 || !Op2)
  444. return nullptr;
  445. return ConstantFoldBinaryIntIntrinsic(opcode, Ty, Op1, Op2);
  446. }
  447. case OP::OpCodeClass::Tertiary: {
  448. assert(IntrinsicOperands.Size() == 3);
  449. ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
  450. ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
  451. ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
  452. if (!Op1 || !Op2 || !Op3)
  453. return nullptr;
  454. return ConstantFoldTernaryIntIntrinsic(opcode, Ty, Op1, Op2, Op3);
  455. }
  456. case OP::OpCodeClass::Quaternary: {
  457. assert(IntrinsicOperands.Size() == 4);
  458. ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
  459. ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
  460. ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
  461. ConstantInt *Op4 = IntrinsicOperands.GetConstantInt(3);
  462. if (!Op1 || !Op2 || !Op3 || !Op4)
  463. return nullptr;
  464. return ConstantFoldQuaternaryIntInstrinsic(opcode, Ty, Op1, Op2, Op3, Op4);
  465. }
  466. case OP::OpCodeClass::IsHelperLane:
  467. return ConstantInt::get(Ty, (uint64_t)0);
  468. }
  469. return nullptr;
  470. }
  471. // External entry point to constant fold dxil intrinsics.
  472. // Called from the llvm constant folding routine.
  473. Constant *hlsl::ConstantFoldScalarCall(StringRef Name, Type *Ty, ArrayRef<Constant *> RawOperands) {
  474. OP::OpCode opcode;
  475. if (GetDxilOpcode(Name, RawOperands, opcode)) {
  476. DxilIntrinsicOperands IntrinsicOperands(RawOperands);
  477. if (Ty->isFloatingPointTy()) {
  478. return ConstantFoldFPIntrinsic(opcode, Ty, IntrinsicOperands);
  479. }
  480. else if (Ty->isIntegerTy()) {
  481. return ConstantFoldIntIntrinsic(opcode, Ty, IntrinsicOperands);
  482. }
  483. } else if (IsConvergentMarker(Name.data())) {
  484. assert(RawOperands.size() == 1);
  485. if (ConstantInt *C = dyn_cast<ConstantInt>(RawOperands[0]))
  486. return C;
  487. if (ConstantFP *C = dyn_cast<ConstantFP>(RawOperands[0]))
  488. return C;
  489. }
  490. return hlsl::ConstantFoldScalarCallExt(Name, Ty, RawOperands);
  491. }
  492. // External entry point to determine if we can constant fold calls to
  493. // the given function. We have to overestimate the set of functions because
  494. // we only have the function value here instead of the call. We need the
  495. // actual call to get the opcode for the intrinsic.
  496. bool hlsl::CanConstantFoldCallTo(const Function *F) {
  497. // Only constant fold dxil functions when we have a valid dxil module.
  498. if (!F->getParent()->HasDxilModule()) {
  499. assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
  500. return false;
  501. }
  502. if (IsConvergentMarker(F))
  503. return true;
  504. // Lookup opcode class in dxil module. Set default value to invalid class.
  505. OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
  506. const bool found = F->getParent()->GetDxilModule().GetOP()->GetOpCodeClass(F, opClass);
  507. // Return true for those dxil operation classes we can constant fold.
  508. if (found) {
  509. switch (opClass) {
  510. default: break;
  511. case OP::OpCodeClass::Unary:
  512. case OP::OpCodeClass::UnaryBits:
  513. case OP::OpCodeClass::Binary:
  514. case OP::OpCodeClass::Tertiary:
  515. case OP::OpCodeClass::Quaternary:
  516. case OP::OpCodeClass::Dot2:
  517. case OP::OpCodeClass::Dot3:
  518. case OP::OpCodeClass::Dot4:
  519. return true;
  520. case OP::OpCodeClass::IsHelperLane: {
  521. const hlsl::ShaderModel *pSM =
  522. F->getParent()->GetDxilModule().GetShaderModel();
  523. return !pSM->IsPS() && !pSM->IsLib();
  524. }
  525. }
  526. }
  527. return hlsl::CanConstantFoldCallToExt(F);
  528. }