DxilConstantFolding.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. //===-- DxilConstantFolding.cpp - Fold dxil intrinsics into constants -----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. // Copyright (C) Microsoft Corporation. All rights reserved.
  9. //
  10. //===----------------------------------------------------------------------===//
  11. //
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Analysis/DxilConstantFolding.h"
  15. #include "llvm/Analysis/ConstantFolding.h"
  16. #include "llvm/ADT/SmallPtrSet.h"
  17. #include "llvm/ADT/SmallVector.h"
  18. #include "llvm/ADT/StringMap.h"
  19. #include "llvm/Analysis/TargetLibraryInfo.h"
  20. #include "llvm/Analysis/ValueTracking.h"
  21. #include "llvm/Config/config.h"
  22. #include "llvm/IR/Constants.h"
  23. #include "llvm/IR/DataLayout.h"
  24. #include "llvm/IR/DerivedTypes.h"
  25. #include "llvm/IR/Function.h"
  26. #include "llvm/IR/GetElementPtrTypeIterator.h"
  27. #include "llvm/IR/GlobalVariable.h"
  28. #include "llvm/IR/Instructions.h"
  29. #include "llvm/IR/Intrinsics.h"
  30. #include "llvm/IR/Operator.h"
  31. #include "llvm/Support/ErrorHandling.h"
  32. #include "llvm/Support/MathExtras.h"
  33. #include <cerrno>
  34. #include <cmath>
  35. #include <algorithm>
  36. #include <functional>
  37. #include "dxc/HLSL/DXIL.h"
  38. using namespace llvm;
  39. using namespace hlsl;
  40. // Check if the given function is a dxil intrinsic and if so extract the
  41. // opcode for the instrinsic being called.
  42. static bool GetDxilOpcode(StringRef Name, ArrayRef<Constant *> Operands, OP::OpCode &out) {
  43. if (!OP::IsDxilOpFuncName(Name))
  44. return false;
  45. if (!Operands.size())
  46. return false;
  47. if (ConstantInt *ci = dyn_cast<ConstantInt>(Operands[0])) {
  48. uint64_t opcode = ci->getLimitedValue();
  49. if (opcode < static_cast<uint64_t>(OP::OpCode::NumOpCodes)) {
  50. out = static_cast<OP::OpCode>(opcode);
  51. return true;
  52. }
  53. }
  54. return false;
  55. }
  56. // Typedefs for passing function pointers to evaluate float constants.
  57. typedef double(__cdecl *NativeFPUnaryOp)(double);
  58. typedef std::function<APFloat::opStatus(APFloat&)> APFloatUnaryOp;
  59. /// Currently APFloat versions of these functions do not exist, so we use
  60. /// the host native double versions. Float versions are not called
  61. /// directly but for all these it is true (float)(f((double)arg)) ==
  62. /// f(arg). Long double not supported yet.
  63. ///
  64. /// Calls out to the llvm constant folding function to do the real work.
  65. static Constant *DxilConstantFoldFP(NativeFPUnaryOp NativeFP, ConstantFP *C, Type *Ty) {
  66. double V = llvm::getValueAsDouble(C);
  67. return llvm::ConstantFoldFP(NativeFP, V, Ty);
  68. }
  69. // Constant fold using the provided function on APFloats.
  70. static Constant *HLSLConstantFoldAPFloat(APFloatUnaryOp NativeFP, ConstantFP *C, Type *Ty) {
  71. APFloat APF = C->getValueAPF();
  72. if (NativeFP(APF) != APFloat::opStatus::opOK)
  73. return nullptr;
  74. return ConstantFP::get(Ty->getContext(), APF);
  75. }
  76. // Constant fold a round dxil intrinsic.
  77. static Constant *HLSLConstantFoldRound(APFloat::roundingMode roundingMode, ConstantFP *C, Type *Ty) {
  78. APFloatUnaryOp f = [roundingMode](APFloat &x) { return x.roundToIntegral(roundingMode); };
  79. return HLSLConstantFoldAPFloat(f, C, Ty);
  80. }
  81. namespace {
  82. // Wrapper for call operands that "shifts past" the hlsl intrinsic opcode.
  83. // Also provides accessors that dyn_cast the operand to a constant type.
  84. class DxilIntrinsicOperands {
  85. public:
  86. DxilIntrinsicOperands(ArrayRef<Constant *> RawCallOperands) : m_RawCallOperands(RawCallOperands) {}
  87. Constant * const &operator[](size_t index) const {
  88. return m_RawCallOperands[index + 1];
  89. }
  90. ConstantInt *GetConstantInt(size_t index) const {
  91. return dyn_cast<ConstantInt>(this->operator[](index));
  92. }
  93. ConstantFP *GetConstantFloat(size_t index) const {
  94. return dyn_cast<ConstantFP>(this->operator[](index));
  95. }
  96. size_t Size() const {
  97. return m_RawCallOperands.size() - 1;
  98. }
  99. private:
  100. ArrayRef<Constant *> m_RawCallOperands;
  101. };
  102. }
  103. /// We only fold functions with finite arguments. Folding NaN and inf is
  104. /// likely to be aborted with an exception anyway, and some host libms
  105. /// have known errors raising exceptions.
  106. static bool IsFinite(ConstantFP *C) {
  107. if (C->getValueAPF().isNaN() || C->getValueAPF().isInfinity())
  108. return false;
  109. return true;
  110. }
  111. // Check that the op is non-null and finite.
  112. static bool IsValidOp(ConstantFP *C) {
  113. if (!C || !IsFinite(C))
  114. return false;
  115. return true;
  116. }
  117. // Check that all ops are valid.
  118. static bool AllValidOps(ArrayRef<ConstantFP *> Ops) {
  119. return std::all_of(Ops.begin(), Ops.end(), IsValidOp);
  120. }
  121. // Constant fold unary floating point intrinsics.
  122. static Constant *ConstantFoldUnaryFPIntrinsic(OP::OpCode opcode, Type *Ty, ConstantFP *Op) {
  123. switch (opcode) {
  124. default: break;
  125. case OP::OpCode::FAbs: return DxilConstantFoldFP(fabs, Op, Ty);
  126. case OP::OpCode::Saturate: {
  127. NativeFPUnaryOp f = [](double x) { return std::max(std::min(x, 1.0), 0.0); };
  128. return DxilConstantFoldFP(f, Op, Ty);
  129. }
  130. case OP::OpCode::Cos: return DxilConstantFoldFP(cos, Op, Ty);
  131. case OP::OpCode::Sin: return DxilConstantFoldFP(sin, Op, Ty);
  132. case OP::OpCode::Tan: return DxilConstantFoldFP(tan, Op, Ty);
  133. case OP::OpCode::Acos: return DxilConstantFoldFP(acos, Op, Ty);
  134. case OP::OpCode::Asin: return DxilConstantFoldFP(asin, Op, Ty);
  135. case OP::OpCode::Atan: return DxilConstantFoldFP(atan, Op, Ty);
  136. case OP::OpCode::Hcos: return DxilConstantFoldFP(cosh, Op, Ty);
  137. case OP::OpCode::Hsin: return DxilConstantFoldFP(sinh, Op, Ty);
  138. case OP::OpCode::Htan: return DxilConstantFoldFP(tanh, Op, Ty);
  139. case OP::OpCode::Exp: return DxilConstantFoldFP(exp2, Op, Ty);
  140. case OP::OpCode::Frc: {
  141. NativeFPUnaryOp f = [](double x) { double unused; return fabs(modf(x, &unused)); };
  142. return DxilConstantFoldFP(f, Op, Ty);
  143. }
  144. case OP::OpCode::Log: return DxilConstantFoldFP(log2, Op, Ty);
  145. case OP::OpCode::Sqrt: return DxilConstantFoldFP(sqrt, Op, Ty);
  146. case OP::OpCode::Rsqrt: {
  147. NativeFPUnaryOp f = [](double x) { return 1.0 / sqrt(x); };
  148. return DxilConstantFoldFP(f, Op, Ty);
  149. }
  150. case OP::OpCode::Round_ne: return HLSLConstantFoldRound(APFloat::roundingMode::rmNearestTiesToEven, Op, Ty);
  151. case OP::OpCode::Round_ni: return HLSLConstantFoldRound(APFloat::roundingMode::rmTowardNegative, Op, Ty);
  152. case OP::OpCode::Round_pi: return HLSLConstantFoldRound(APFloat::roundingMode::rmTowardPositive, Op, Ty);
  153. case OP::OpCode::Round_z: return HLSLConstantFoldRound(APFloat::roundingMode::rmTowardZero, Op, Ty);
  154. }
  155. return nullptr;
  156. }
  157. // Constant fold binary floating point intrinsics.
  158. static Constant *ConstantFoldBinaryFPIntrinsic(OP::OpCode opcode, Type *Ty, ConstantFP *Op1, ConstantFP *Op2) {
  159. const APFloat &C1 = Op1->getValueAPF();
  160. const APFloat &C2 = Op2->getValueAPF();
  161. switch (opcode) {
  162. default: break;
  163. case OP::OpCode::FMax: return ConstantFP::get(Ty->getContext(), maxnum(C1, C2));
  164. case OP::OpCode::FMin: return ConstantFP::get(Ty->getContext(), minnum(C1, C2));
  165. }
  166. return nullptr;
  167. }
  168. // Constant fold ternary floating point intrinsics.
  169. static Constant *ConstantFoldTernaryFPIntrinsic(OP::OpCode opcode, Type *Ty, ConstantFP *Op1, ConstantFP *Op2, ConstantFP *Op3) {
  170. const APFloat &C1 = Op1->getValueAPF();
  171. const APFloat &C2 = Op2->getValueAPF();
  172. const APFloat &C3 = Op3->getValueAPF();
  173. APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven;
  174. switch (opcode) {
  175. default: break;
  176. case OP::OpCode::FMad: {
  177. APFloat result(C1);
  178. result.multiply(C2, roundingMode);
  179. result.add(C3, roundingMode);
  180. return ConstantFP::get(Ty->getContext(), result);
  181. }
  182. case OP::OpCode::Fma: {
  183. APFloat result(C1);
  184. result.fusedMultiplyAdd(C2, C3, roundingMode);
  185. return ConstantFP::get(Ty->getContext(), result);
  186. }
  187. }
  188. return nullptr;
  189. }
  190. // Compute dot product for arbitrary sized vectors.
  191. static Constant *ComputeDot(Type *Ty, ArrayRef<ConstantFP *> A, ArrayRef<ConstantFP *> B) {
  192. if (A.size() != B.size() || !A.size()) {
  193. assert(false && "invalid call to compute dot");
  194. return nullptr;
  195. }
  196. if (!AllValidOps(A) || !AllValidOps(B))
  197. return nullptr;
  198. APFloat::roundingMode roundingMode = APFloat::roundingMode::rmNearestTiesToEven;
  199. APFloat sum = APFloat::getZero(A[0]->getValueAPF().getSemantics());
  200. for (int i = 0, e = A.size(); i != e; ++i) {
  201. APFloat val(A[i]->getValueAPF());
  202. val.multiply(B[i]->getValueAPF(), roundingMode);
  203. sum.add(val, roundingMode);
  204. }
  205. return ConstantFP::get(Ty->getContext(), sum);
  206. }
  207. // Constant folding for dot2, dot3, and dot4.
  208. static Constant *ConstantFoldDot(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &operands) {
  209. switch (opcode) {
  210. default: break;
  211. case OP::OpCode::Dot2: {
  212. ConstantFP *Ax = operands.GetConstantFloat(0);
  213. ConstantFP *Ay = operands.GetConstantFloat(1);
  214. ConstantFP *Bx = operands.GetConstantFloat(2);
  215. ConstantFP *By = operands.GetConstantFloat(3);
  216. return ComputeDot(Ty, { Ax, Ay }, { Bx, By });
  217. }
  218. case OP::OpCode::Dot3: {
  219. ConstantFP *Ax = operands.GetConstantFloat(0);
  220. ConstantFP *Ay = operands.GetConstantFloat(1);
  221. ConstantFP *Az = operands.GetConstantFloat(2);
  222. ConstantFP *Bx = operands.GetConstantFloat(3);
  223. ConstantFP *By = operands.GetConstantFloat(4);
  224. ConstantFP *Bz = operands.GetConstantFloat(5);
  225. return ComputeDot(Ty, { Ax, Ay, Az }, { Bx, By, Bz });
  226. }
  227. case OP::OpCode::Dot4: {
  228. ConstantFP *Ax = operands.GetConstantFloat(0);
  229. ConstantFP *Ay = operands.GetConstantFloat(1);
  230. ConstantFP *Az = operands.GetConstantFloat(2);
  231. ConstantFP *Aw = operands.GetConstantFloat(3);
  232. ConstantFP *Bx = operands.GetConstantFloat(4);
  233. ConstantFP *By = operands.GetConstantFloat(5);
  234. ConstantFP *Bz = operands.GetConstantFloat(6);
  235. ConstantFP *Bw = operands.GetConstantFloat(7);
  236. return ComputeDot(Ty, { Ax, Ay, Az, Aw }, { Bx, By, Bz, Bw });
  237. }
  238. }
  239. return nullptr;
  240. }
  241. // Constant fold a Bfrev dxil intrinsic.
  242. static Constant *HLSLConstantFoldBfrev(ConstantInt *C, Type *Ty) {
  243. APInt API = C->getValue();
  244. uint64_t result = 0;
  245. if (Ty == Type::getInt32Ty(Ty->getContext())) {
  246. uint32_t val = static_cast<uint32_t>(API.getLimitedValue());
  247. result = llvm::reverseBits(val);
  248. }
  249. else if (Ty == Type::getInt16Ty(Ty->getContext())) {
  250. uint16_t val = static_cast<uint16_t>(API.getLimitedValue());
  251. result = llvm::reverseBits(val);
  252. }
  253. else if (Ty == Type::getInt64Ty(Ty->getContext())) {
  254. uint64_t val = static_cast<uint64_t>(API.getLimitedValue());
  255. result = llvm::reverseBits(val);
  256. }
  257. else {
  258. return nullptr;
  259. }
  260. return ConstantInt::get(Ty, result);
  261. }
  262. // Handle special case for findfirst* bit functions.
  263. // When the position is equal to the bitwidth the value was not found
  264. // and we need to return a result of -1.
  265. static Constant *HLSLConstantFoldFindBit(Type *Ty, unsigned position, unsigned bitwidth) {
  266. if (position == bitwidth)
  267. return ConstantInt::get(Ty, APInt::getAllOnesValue(Ty->getScalarSizeInBits()));
  268. return ConstantInt::get(Ty, position);
  269. }
  270. // Constant fold unary integer intrinsics.
  271. static Constant *ConstantFoldUnaryIntIntrinsic(OP::OpCode opcode, Type *Ty, ConstantInt *Op) {
  272. APInt API = Op->getValue();
  273. switch (opcode) {
  274. default: break;
  275. case OP::OpCode::Bfrev: return HLSLConstantFoldBfrev(Op, Ty);
  276. case OP::OpCode::Countbits: return ConstantInt::get(Ty, API.countPopulation());
  277. case OP::OpCode::FirstbitLo: return HLSLConstantFoldFindBit(Ty, API.countTrailingZeros(), API.getBitWidth());
  278. case OP::OpCode::FirstbitHi: return HLSLConstantFoldFindBit(Ty, API.countLeadingZeros(), API.getBitWidth());
  279. case OP::OpCode::FirstbitSHi: {
  280. if (API.isNegative())
  281. return HLSLConstantFoldFindBit(Ty, API.countLeadingOnes(), API.getBitWidth());
  282. else
  283. return HLSLConstantFoldFindBit(Ty, API.countLeadingZeros(), API.getBitWidth());
  284. }
  285. }
  286. return nullptr;
  287. }
  288. // Constant fold binary integer intrinsics.
  289. static Constant *ConstantFoldBinaryIntIntrinsic(OP::OpCode opcode, Type *Ty, ConstantInt *Op1, ConstantInt *Op2) {
  290. APInt C1 = Op1->getValue();
  291. APInt C2 = Op2->getValue();
  292. switch (opcode) {
  293. default: break;
  294. case OP::OpCode::IMin: {
  295. APInt minVal = C1.slt(C2) ? C1 : C2;
  296. return ConstantInt::get(Ty, minVal);
  297. }
  298. case OP::OpCode::IMax: {
  299. APInt maxVal = C1.sgt(C2) ? C1 : C2;
  300. return ConstantInt::get(Ty, maxVal);
  301. }
  302. case OP::OpCode::UMin: {
  303. APInt minVal = C1.ult(C2) ? C1 : C2;
  304. return ConstantInt::get(Ty, minVal);
  305. }
  306. case OP::OpCode::UMax: {
  307. APInt maxVal = C1.ugt(C2) ? C1 : C2;
  308. return ConstantInt::get(Ty, maxVal);
  309. }
  310. }
  311. return nullptr;
  312. }
  313. // Compute bit field extract for ibfe and ubfe.
  314. // The comptuation for ibfe and ubfe is the same except for the right shift,
  315. // which is an arithemetic shift for ibfe and logical shift for ubfe.
  316. // ubfe: https://msdn.microsoft.com/en-us/library/windows/desktop/hh447243(v=vs.85).aspx
  317. // ibfe: https://msdn.microsoft.com/en-us/library/windows/desktop/hh447243(v=vs.85).aspx
  318. static Constant *ComputeBFE(Type *Ty, APInt width, APInt offset, APInt val, std::function<APInt(APInt, APInt)> shr) {
  319. const APInt bitwidth(width.getBitWidth(), width.getBitWidth());
  320. // Limit width and offset to the bitwidth of the value.
  321. width = width.And(bitwidth-1);
  322. offset = offset.And(bitwidth-1);
  323. if (width == 0) {
  324. return ConstantInt::get(Ty, 0);
  325. }
  326. else if ((width + offset).ult(bitwidth)) {
  327. APInt dest = val.shl(bitwidth - (width + offset));
  328. dest = shr(dest, bitwidth - width);
  329. return ConstantInt::get(Ty, dest);
  330. }
  331. else {
  332. APInt dest = shr(val, offset);
  333. return ConstantInt::get(Ty, dest);
  334. }
  335. }
  336. // Constant fold ternary integer intrinsic.
  337. static Constant *ConstantFoldTernaryIntIntrinsic(OP::OpCode opcode, Type *Ty, ConstantInt *Op1, ConstantInt *Op2, ConstantInt *Op3) {
  338. APInt C1 = Op1->getValue();
  339. APInt C2 = Op2->getValue();
  340. APInt C3 = Op3->getValue();
  341. switch (opcode) {
  342. default: break;
  343. case OP::OpCode::IMad:
  344. case OP::OpCode::UMad: {
  345. // Result is same for signed/unsigned since this is twos complement and we only
  346. // keep the lower half of the multiply.
  347. APInt result = C1 * C2 + C3;
  348. return ConstantInt::get(Ty, result);
  349. }
  350. case OP::OpCode::Ubfe: return ComputeBFE(Ty, C1, C2, C3, [](APInt val, APInt amt) {return val.lshr(amt); });
  351. case OP::OpCode::Ibfe: return ComputeBFE(Ty, C1, C2, C3, [](APInt val, APInt amt) {return val.ashr(amt); });
  352. }
  353. return nullptr;
  354. }
  355. // Constant fold quaternary integer intrinsic.
  356. //
  357. // Currently we only have one quaternary intrinsic: Bfi.
  358. // The Bfi computaion is described here:
  359. // https://msdn.microsoft.com/en-us/library/windows/desktop/hh446837(v=vs.85).aspx
  360. static Constant *ConstantFoldQuaternaryIntInstrinsic(OP::OpCode opcode, Type *Ty, ConstantInt *Op1, ConstantInt *Op2, ConstantInt *Op3, ConstantInt *Op4) {
  361. if (opcode != OP::OpCode::Bfi)
  362. return nullptr;
  363. APInt bitwidth(Op1->getValue().getBitWidth(), Op1->getValue().getBitWidth());
  364. APInt width = Op1->getValue().And(bitwidth-1);
  365. APInt offset = Op2->getValue().And(bitwidth-1);
  366. APInt src = Op3->getValue();
  367. APInt dst = Op4->getValue();
  368. APInt one(bitwidth.getBitWidth(), 1);
  369. APInt allOnes = APInt::getAllOnesValue(bitwidth.getBitWidth());
  370. // bitmask = (((1 << width)-1) << offset) & 0xffffffff
  371. // dest = ((src2 << offset) & bitmask) | (src3 & ~bitmask)
  372. APInt bitmask = (one.shl(width) - 1).shl(offset).And(allOnes);
  373. APInt result = (src.shl(offset).And(bitmask)).Or(dst.And(~bitmask));
  374. return ConstantInt::get(Ty, result);
  375. }
  376. // Top level function to constant fold floating point intrinsics.
  377. static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &IntrinsicOperands) {
  378. if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
  379. return nullptr;
  380. OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
  381. switch (opClass) {
  382. default: break;
  383. case OP::OpCodeClass::Unary: {
  384. assert(IntrinsicOperands.Size() == 1);
  385. ConstantFP *Op = IntrinsicOperands.GetConstantFloat(0);
  386. if (!IsValidOp(Op))
  387. return nullptr;
  388. return ConstantFoldUnaryFPIntrinsic(opcode, Ty, Op);
  389. }
  390. case OP::OpCodeClass::Binary: {
  391. assert(IntrinsicOperands.Size() == 2);
  392. ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
  393. ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
  394. if (!IsValidOp(Op1) || !IsValidOp(Op2))
  395. return nullptr;
  396. return ConstantFoldBinaryFPIntrinsic(opcode, Ty, Op1, Op2);
  397. }
  398. case OP::OpCodeClass::Tertiary: {
  399. assert(IntrinsicOperands.Size() == 3);
  400. ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
  401. ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
  402. ConstantFP *Op3 = IntrinsicOperands.GetConstantFloat(2);
  403. if (!IsValidOp(Op1) || !IsValidOp(Op2) || !IsValidOp(Op3))
  404. return nullptr;
  405. return ConstantFoldTernaryFPIntrinsic(opcode, Ty, Op1, Op2, Op3);
  406. }
  407. case OP::OpCodeClass::Dot2:
  408. case OP::OpCodeClass::Dot3:
  409. case OP::OpCodeClass::Dot4:
  410. return ConstantFoldDot(opcode, Ty, IntrinsicOperands);
  411. }
  412. return nullptr;
  413. }
  414. // Top level function to constant fold integer intrinsics.
  415. static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &IntrinsicOperands) {
  416. if (Ty->getScalarSizeInBits() > (sizeof(int64_t) * CHAR_BIT))
  417. return nullptr;
  418. OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
  419. switch (opClass) {
  420. default: break;
  421. case OP::OpCodeClass::Unary:
  422. case OP::OpCodeClass::UnaryBits: {
  423. assert(IntrinsicOperands.Size() == 1);
  424. ConstantInt *Op = IntrinsicOperands.GetConstantInt(0);
  425. if (!Op)
  426. return nullptr;
  427. return ConstantFoldUnaryIntIntrinsic(opcode, Ty, Op);
  428. }
  429. case OP::OpCodeClass::Binary: {
  430. assert(IntrinsicOperands.Size() == 2);
  431. ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
  432. ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
  433. if (!Op1 || !Op2)
  434. return nullptr;
  435. return ConstantFoldBinaryIntIntrinsic(opcode, Ty, Op1, Op2);
  436. }
  437. case OP::OpCodeClass::Tertiary: {
  438. assert(IntrinsicOperands.Size() == 3);
  439. ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
  440. ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
  441. ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
  442. if (!Op1 || !Op2 || !Op3)
  443. return nullptr;
  444. return ConstantFoldTernaryIntIntrinsic(opcode, Ty, Op1, Op2, Op3);
  445. }
  446. case OP::OpCodeClass::Quaternary: {
  447. assert(IntrinsicOperands.Size() == 4);
  448. ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
  449. ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
  450. ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
  451. ConstantInt *Op4 = IntrinsicOperands.GetConstantInt(3);
  452. if (!Op1 || !Op2 || !Op3 || !Op4)
  453. return nullptr;
  454. return ConstantFoldQuaternaryIntInstrinsic(opcode, Ty, Op1, Op2, Op3, Op4);
  455. }
  456. }
  457. return nullptr;
  458. }
  459. // External entry point to constant fold dxil intrinsics.
  460. // Called from the llvm constant folding routine.
  461. Constant *hlsl::ConstantFoldScalarCall(StringRef Name, Type *Ty, ArrayRef<Constant *> RawOperands) {
  462. OP::OpCode opcode;
  463. if (GetDxilOpcode(Name, RawOperands, opcode)) {
  464. DxilIntrinsicOperands IntrinsicOperands(RawOperands);
  465. if (Ty->isFloatingPointTy()) {
  466. return ConstantFoldFPIntrinsic(opcode, Ty, IntrinsicOperands);
  467. }
  468. else if (Ty->isIntegerTy()) {
  469. return ConstantFoldIntIntrinsic(opcode, Ty, IntrinsicOperands);
  470. }
  471. }
  472. return hlsl::ConstantFoldScalarCallExt(Name, Ty, RawOperands);
  473. }
  474. // External entry point to determine if we can constant fold calls to
  475. // the given function. We have to overestimate the set of functions because
  476. // we only have the function value here instead of the call. We need the
  477. // actual call to get the opcode for the intrinsic.
  478. bool hlsl::CanConstantFoldCallTo(const Function *F) {
  479. // Only constant fold dxil functions when we have a valid dxil module.
  480. if (!F->getParent()->HasDxilModule()) {
  481. assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
  482. return false;
  483. }
  484. // Lookup opcode class in dxil module. Set default value to invalid class.
  485. OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
  486. const bool found = F->getParent()->GetDxilModule().GetOP()->GetOpCodeClass(F, opClass);
  487. // Return true for those dxil operation classes we can constant fold.
  488. if (found) {
  489. switch (opClass) {
  490. default: break;
  491. case OP::OpCodeClass::Unary:
  492. case OP::OpCodeClass::UnaryBits:
  493. case OP::OpCodeClass::Binary:
  494. case OP::OpCodeClass::Tertiary:
  495. case OP::OpCodeClass::Quaternary:
  496. case OP::OpCodeClass::Dot2:
  497. case OP::OpCodeClass::Dot3:
  498. case OP::OpCodeClass::Dot4:
  499. return true;
  500. }
  501. }
  502. return hlsl::CanConstantFoldCallToExt(F);
  503. }