HLOperations.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLOperations.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Implementation of DXIL operations. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLOperations.h"
  12. #include "dxc/HlslIntrinsicOp.h"
  13. #include "llvm/IR/Function.h"
  14. #include "llvm/IR/Instructions.h"
  15. #include "llvm/IR/Module.h"
  16. #include "llvm/Support/raw_ostream.h"
  17. using namespace hlsl;
  18. using namespace llvm;
  19. namespace hlsl {
  20. const char HLPrefixStr [] = "dx.hl";
  21. const char * const HLPrefix = HLPrefixStr;
  22. static const char HLLowerStrategyStr[] = "dx.hlls";
  23. static const char * const HLLowerStrategy = HLLowerStrategyStr;
  24. static StringRef HLOpcodeGroupNames[]{
  25. "notHLDXIL", // NotHL,
  26. "<ext>", // HLExtIntrinsic - should always refer through extension
  27. "op", // HLIntrinsic,
  28. "cast", // HLCast,
  29. "init", // HLInit,
  30. "binop", // HLBinOp,
  31. "unop", // HLUnOp,
  32. "subscript", // HLSubscript,
  33. "matldst", // HLMatLoadStore,
  34. "select", // HLSelect,
  35. "createhandle",// HLCreateHandle,
  36. "numOfHLDXIL", // NumOfHLOps
  37. };
  38. static StringRef HLOpcodeGroupFullNames[]{
  39. "notHLDXIL", // NotHL,
  40. "<ext>", // HLExtIntrinsic - should aways refer through extension
  41. "dx.hl.op", // HLIntrinsic,
  42. "dx.hl.cast", // HLCast,
  43. "dx.hl.init", // HLInit,
  44. "dx.hl.binop", // HLBinOp,
  45. "dx.hl.unop", // HLUnOp,
  46. "dx.hl.subscript", // HLSubscript,
  47. "dx.hl.matldst", // HLMatLoadStore,
  48. "dx.hl.select", // HLSelect,
  49. "dx.hl.createhandle", // HLCreateHandle,
  50. "numOfHLDXIL", // NumOfHLOps
  51. };
  52. static HLOpcodeGroup GetHLOpcodeGroupInternal(StringRef group) {
  53. if (!group.empty()) {
  54. switch (group[0]) {
  55. case 'o': // op
  56. return HLOpcodeGroup::HLIntrinsic;
  57. case 'c': // cast
  58. switch (group[1]) {
  59. case 'a': // cast
  60. return HLOpcodeGroup::HLCast;
  61. case 'r': // createhandle
  62. return HLOpcodeGroup::HLCreateHandle;
  63. }
  64. case 'i': // init
  65. return HLOpcodeGroup::HLInit;
  66. case 'b': // binaryOp
  67. return HLOpcodeGroup::HLBinOp;
  68. case 'u': // unaryOp
  69. return HLOpcodeGroup::HLUnOp;
  70. case 's': // subscript
  71. switch (group[1]) {
  72. case 'u':
  73. return HLOpcodeGroup::HLSubscript;
  74. case 'e':
  75. return HLOpcodeGroup::HLSelect;
  76. }
  77. case 'm': // matldst
  78. return HLOpcodeGroup::HLMatLoadStore;
  79. }
  80. }
  81. return HLOpcodeGroup::NotHL;
  82. }
  83. // GetHLOpGroup by function name.
  84. HLOpcodeGroup GetHLOpcodeGroupByName(const Function *F) {
  85. StringRef name = F->getName();
  86. if (!name.startswith(HLPrefix)) {
  87. // This could be an external intrinsic, but this function
  88. // won't recognize those as such. Use GetHLOpcodeGroupByName
  89. // to make that distinction.
  90. return HLOpcodeGroup::NotHL;
  91. }
  92. const unsigned prefixSize = sizeof(HLPrefixStr);
  93. StringRef group = name.substr(prefixSize);
  94. return GetHLOpcodeGroupInternal(group);
  95. }
  96. HLOpcodeGroup GetHLOpcodeGroup(llvm::Function *F) {
  97. llvm::StringRef name = GetHLOpcodeGroupNameByAttr(F);
  98. HLOpcodeGroup result = GetHLOpcodeGroupInternal(name);
  99. if (result == HLOpcodeGroup::NotHL) {
  100. result = name.empty() ? result : HLOpcodeGroup::HLExtIntrinsic;
  101. }
  102. if (result == HLOpcodeGroup::NotHL) {
  103. result = GetHLOpcodeGroupByName(F);
  104. }
  105. return result;
  106. }
  107. llvm::StringRef GetHLOpcodeGroupNameByAttr(llvm::Function *F) {
  108. Attribute groupAttr = F->getFnAttribute(hlsl::HLPrefix);
  109. StringRef group = groupAttr.getValueAsString();
  110. return group;
  111. }
  112. StringRef GetHLOpcodeGroupName(HLOpcodeGroup op) {
  113. switch (op) {
  114. case HLOpcodeGroup::HLCast:
  115. case HLOpcodeGroup::HLInit:
  116. case HLOpcodeGroup::HLBinOp:
  117. case HLOpcodeGroup::HLUnOp:
  118. case HLOpcodeGroup::HLIntrinsic:
  119. case HLOpcodeGroup::HLSubscript:
  120. case HLOpcodeGroup::HLMatLoadStore:
  121. case HLOpcodeGroup::HLSelect:
  122. case HLOpcodeGroup::HLCreateHandle:
  123. return HLOpcodeGroupNames[static_cast<unsigned>(op)];
  124. default:
  125. llvm_unreachable("invalid op");
  126. return "";
  127. }
  128. }
  129. StringRef GetHLOpcodeGroupFullName(HLOpcodeGroup op) {
  130. switch (op) {
  131. case HLOpcodeGroup::HLCast:
  132. case HLOpcodeGroup::HLInit:
  133. case HLOpcodeGroup::HLBinOp:
  134. case HLOpcodeGroup::HLUnOp:
  135. case HLOpcodeGroup::HLIntrinsic:
  136. case HLOpcodeGroup::HLSubscript:
  137. case HLOpcodeGroup::HLMatLoadStore:
  138. case HLOpcodeGroup::HLSelect:
  139. case HLOpcodeGroup::HLCreateHandle:
  140. return HLOpcodeGroupFullNames[static_cast<unsigned>(op)];
  141. default:
  142. llvm_unreachable("invalid op");
  143. return "";
  144. }
  145. }
  146. llvm::StringRef GetHLOpcodeName(HLUnaryOpcode Op) {
  147. switch (Op) {
  148. case HLUnaryOpcode::PostInc: return "++";
  149. case HLUnaryOpcode::PostDec: return "--";
  150. case HLUnaryOpcode::PreInc: return "++";
  151. case HLUnaryOpcode::PreDec: return "--";
  152. case HLUnaryOpcode::Plus: return "+";
  153. case HLUnaryOpcode::Minus: return "-";
  154. case HLUnaryOpcode::Not: return "~";
  155. case HLUnaryOpcode::LNot: return "!";
  156. case HLUnaryOpcode::Invalid:
  157. case HLUnaryOpcode::NumOfUO:
  158. // Invalid Unary Ops
  159. break;
  160. }
  161. llvm_unreachable("Unknown unary operator");
  162. }
  163. llvm::StringRef GetHLOpcodeName(HLBinaryOpcode Op) {
  164. switch (Op) {
  165. case HLBinaryOpcode::Mul: return "*";
  166. case HLBinaryOpcode::UDiv:
  167. case HLBinaryOpcode::Div: return "/";
  168. case HLBinaryOpcode::URem:
  169. case HLBinaryOpcode::Rem: return "%";
  170. case HLBinaryOpcode::Add: return "+";
  171. case HLBinaryOpcode::Sub: return "-";
  172. case HLBinaryOpcode::Shl: return "<<";
  173. case HLBinaryOpcode::UShr:
  174. case HLBinaryOpcode::Shr: return ">>";
  175. case HLBinaryOpcode::ULT:
  176. case HLBinaryOpcode::LT: return "<";
  177. case HLBinaryOpcode::UGT:
  178. case HLBinaryOpcode::GT: return ">";
  179. case HLBinaryOpcode::ULE:
  180. case HLBinaryOpcode::LE: return "<=";
  181. case HLBinaryOpcode::UGE:
  182. case HLBinaryOpcode::GE: return ">=";
  183. case HLBinaryOpcode::EQ: return "==";
  184. case HLBinaryOpcode::NE: return "!=";
  185. case HLBinaryOpcode::And: return "&";
  186. case HLBinaryOpcode::Xor: return "^";
  187. case HLBinaryOpcode::Or: return "|";
  188. case HLBinaryOpcode::LAnd: return "&&";
  189. case HLBinaryOpcode::LOr: return "||";
  190. case HLBinaryOpcode::Invalid:
  191. case HLBinaryOpcode::NumOfBO:
  192. // Invalid Binary Ops
  193. break;
  194. }
  195. llvm_unreachable("Invalid OpCode!");
  196. }
  197. llvm::StringRef GetHLOpcodeName(HLSubscriptOpcode Op) {
  198. switch (Op) {
  199. case HLSubscriptOpcode::DefaultSubscript:
  200. return "[]";
  201. case HLSubscriptOpcode::ColMatSubscript:
  202. return "colMajor[]";
  203. case HLSubscriptOpcode::RowMatSubscript:
  204. return "rowMajor[]";
  205. case HLSubscriptOpcode::ColMatElement:
  206. return "colMajor_m";
  207. case HLSubscriptOpcode::RowMatElement:
  208. return "rowMajor_m";
  209. case HLSubscriptOpcode::DoubleSubscript:
  210. return "[][]";
  211. case HLSubscriptOpcode::CBufferSubscript:
  212. return "cb";
  213. case HLSubscriptOpcode::VectorSubscript:
  214. return "vector[]";
  215. }
  216. return "";
  217. }
  218. llvm::StringRef GetHLOpcodeName(HLCastOpcode Op) {
  219. switch (Op) {
  220. case HLCastOpcode::DefaultCast:
  221. return "default";
  222. case HLCastOpcode::ToUnsignedCast:
  223. return "toUnsigned";
  224. case HLCastOpcode::FromUnsignedCast:
  225. return "fromUnsigned";
  226. case HLCastOpcode::UnsignedUnsignedCast:
  227. return "unsignedUnsigned";
  228. case HLCastOpcode::ColMatrixToVecCast:
  229. return "colMatToVec";
  230. case HLCastOpcode::RowMatrixToVecCast:
  231. return "rowMatToVec";
  232. case HLCastOpcode::ColMatrixToRowMatrix:
  233. return "colMatToRowMat";
  234. case HLCastOpcode::RowMatrixToColMatrix:
  235. return "rowMatToColMat";
  236. case HLCastOpcode::HandleToResCast:
  237. return "handleToRes";
  238. }
  239. return "";
  240. }
  241. llvm::StringRef GetHLOpcodeName(HLMatLoadStoreOpcode Op) {
  242. switch (Op) {
  243. case HLMatLoadStoreOpcode::ColMatLoad:
  244. return "colLoad";
  245. case HLMatLoadStoreOpcode::ColMatStore:
  246. return "colStore";
  247. case HLMatLoadStoreOpcode::RowMatLoad:
  248. return "rowLoad";
  249. case HLMatLoadStoreOpcode::RowMatStore:
  250. return "rowStore";
  251. }
  252. llvm_unreachable("invalid matrix load store operator");
  253. }
  254. StringRef GetHLLowerStrategy(Function *F) {
  255. llvm::Attribute A = F->getFnAttribute(HLLowerStrategy);
  256. llvm::StringRef LowerStrategy = A.getValueAsString();
  257. return LowerStrategy;
  258. }
  259. void SetHLLowerStrategy(Function *F, StringRef S) {
  260. F->addFnAttr(HLLowerStrategy, S);
  261. }
  262. std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
  263. assert(op != HLOpcodeGroup::HLExtIntrinsic && "else table name should be used");
  264. std::string opName = GetHLOpcodeGroupFullName(op).str() + ".";
  265. switch (op) {
  266. case HLOpcodeGroup::HLBinOp: {
  267. HLBinaryOpcode binOp = static_cast<HLBinaryOpcode>(opcode);
  268. return opName + GetHLOpcodeName(binOp).str();
  269. }
  270. case HLOpcodeGroup::HLUnOp: {
  271. HLUnaryOpcode unOp = static_cast<HLUnaryOpcode>(opcode);
  272. return opName + GetHLOpcodeName(unOp).str();
  273. }
  274. case HLOpcodeGroup::HLIntrinsic: {
  275. // intrinsic with same signature will share the funciton now
  276. // The opcode is in arg0.
  277. return opName;
  278. }
  279. case HLOpcodeGroup::HLMatLoadStore: {
  280. HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  281. return opName + GetHLOpcodeName(matOp).str();
  282. }
  283. case HLOpcodeGroup::HLSubscript: {
  284. HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
  285. return opName + GetHLOpcodeName(subOp).str();
  286. }
  287. case HLOpcodeGroup::HLCast: {
  288. HLCastOpcode castOp = static_cast<HLCastOpcode>(opcode);
  289. return opName + GetHLOpcodeName(castOp).str();
  290. }
  291. default:
  292. return opName;
  293. }
  294. }
  295. // Get opcode from arg0 of function call.
  296. unsigned GetHLOpcode(const CallInst *CI) {
  297. Value *idArg = CI->getArgOperand(HLOperandIndex::kOpcodeIdx);
  298. Constant *idConst = cast<Constant>(idArg);
  299. return idConst->getUniqueInteger().getLimitedValue();
  300. }
  301. unsigned GetRowMajorOpcode(HLOpcodeGroup group, unsigned opcode) {
  302. switch (group) {
  303. case HLOpcodeGroup::HLMatLoadStore: {
  304. HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  305. switch (matOp) {
  306. case HLMatLoadStoreOpcode::ColMatLoad:
  307. return static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad);
  308. case HLMatLoadStoreOpcode::ColMatStore:
  309. return static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore);
  310. default:
  311. return opcode;
  312. }
  313. } break;
  314. case HLOpcodeGroup::HLSubscript: {
  315. HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
  316. switch (subOp) {
  317. case HLSubscriptOpcode::ColMatElement:
  318. return static_cast<unsigned>(HLSubscriptOpcode::RowMatElement);
  319. case HLSubscriptOpcode::ColMatSubscript:
  320. return static_cast<unsigned>(HLSubscriptOpcode::RowMatSubscript);
  321. default:
  322. return opcode;
  323. }
  324. } break;
  325. default:
  326. return opcode;
  327. }
  328. }
  329. bool HasUnsignedOpcode(unsigned opcode) {
  330. return HasUnsignedIntrinsicOpcode(static_cast<IntrinsicOp>(opcode));
  331. }
  332. unsigned GetUnsignedOpcode(unsigned opcode) {
  333. return GetUnsignedIntrinsicOpcode(static_cast<IntrinsicOp>(opcode));
  334. }
  335. // For HLBinaryOpcode
  336. bool HasUnsignedOpcode(HLBinaryOpcode opcode) {
  337. switch (opcode) {
  338. case HLBinaryOpcode::Div:
  339. case HLBinaryOpcode::Rem:
  340. case HLBinaryOpcode::Shr:
  341. case HLBinaryOpcode::LT:
  342. case HLBinaryOpcode::GT:
  343. case HLBinaryOpcode::LE:
  344. case HLBinaryOpcode::GE:
  345. return true;
  346. default:
  347. return false;
  348. }
  349. }
  350. HLBinaryOpcode GetUnsignedOpcode(HLBinaryOpcode opcode) {
  351. switch (opcode) {
  352. case HLBinaryOpcode::Div:
  353. return HLBinaryOpcode::UDiv;
  354. case HLBinaryOpcode::Rem:
  355. return HLBinaryOpcode::URem;
  356. case HLBinaryOpcode::Shr:
  357. return HLBinaryOpcode::UShr;
  358. case HLBinaryOpcode::LT:
  359. return HLBinaryOpcode::ULT;
  360. case HLBinaryOpcode::GT:
  361. return HLBinaryOpcode::UGT;
  362. case HLBinaryOpcode::LE:
  363. return HLBinaryOpcode::ULE;
  364. case HLBinaryOpcode::GE:
  365. return HLBinaryOpcode::UGE;
  366. default:
  367. return opcode;
  368. }
  369. }
  370. static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
  371. unsigned opcode) {
  372. switch (group) {
  373. case HLOpcodeGroup::HLUnOp:
  374. case HLOpcodeGroup::HLBinOp:
  375. case HLOpcodeGroup::HLCast:
  376. case HLOpcodeGroup::HLSubscript:
  377. if (!F->hasFnAttribute(Attribute::ReadNone)) {
  378. F->addFnAttr(Attribute::ReadNone);
  379. F->addFnAttr(Attribute::NoUnwind);
  380. }
  381. break;
  382. case HLOpcodeGroup::HLInit:
  383. if (!F->hasFnAttribute(Attribute::ReadNone))
  384. if (!F->getReturnType()->isVoidTy()) {
  385. F->addFnAttr(Attribute::ReadNone);
  386. F->addFnAttr(Attribute::NoUnwind);
  387. }
  388. break;
  389. case HLOpcodeGroup::HLMatLoadStore: {
  390. HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  391. if (matOp == HLMatLoadStoreOpcode::ColMatLoad ||
  392. matOp == HLMatLoadStoreOpcode::RowMatLoad)
  393. if (!F->hasFnAttribute(Attribute::ReadOnly)) {
  394. F->addFnAttr(Attribute::ReadOnly);
  395. F->addFnAttr(Attribute::NoUnwind);
  396. }
  397. } break;
  398. case HLOpcodeGroup::HLCreateHandle: {
  399. F->addFnAttr(Attribute::ReadNone);
  400. F->addFnAttr(Attribute::NoUnwind);
  401. F->addFnAttr(Attribute::NoInline);
  402. F->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
  403. } break;
  404. case HLOpcodeGroup::NotHL:
  405. case HLOpcodeGroup::HLExtIntrinsic:
  406. case HLOpcodeGroup::HLIntrinsic:
  407. case HLOpcodeGroup::HLSelect:
  408. case HLOpcodeGroup::NumOfHLOps:
  409. // No default attributes for these opcodes.
  410. break;
  411. }
  412. }
  413. Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
  414. HLOpcodeGroup group, unsigned opcode) {
  415. return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode);
  416. }
  417. Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
  418. HLOpcodeGroup group, llvm::StringRef *groupName,
  419. llvm::StringRef *fnName, unsigned opcode) {
  420. std::string mangledName;
  421. raw_string_ostream mangledNameStr(mangledName);
  422. if (group == HLOpcodeGroup::HLExtIntrinsic) {
  423. assert(groupName && "else intrinsic should have been rejected");
  424. assert(fnName && "else intrinsic should have been rejected");
  425. mangledNameStr << *groupName;
  426. mangledNameStr << '.';
  427. mangledNameStr << *fnName;
  428. }
  429. else {
  430. mangledNameStr << GetHLFullName(group, opcode);
  431. mangledNameStr << '.';
  432. funcTy->print(mangledNameStr);
  433. }
  434. mangledNameStr.flush();
  435. Function *F = cast<Function>(M.getOrInsertFunction(mangledName, funcTy));
  436. if (group == HLOpcodeGroup::HLExtIntrinsic) {
  437. F->addFnAttr(hlsl::HLPrefix, *groupName);
  438. }
  439. SetHLFunctionAttribute(F, group, opcode);
  440. return F;
  441. }
  442. // HLFunction with body cannot share with HLFunction without body.
  443. // So need add name.
  444. Function *GetOrCreateHLFunctionWithBody(Module &M, FunctionType *funcTy,
  445. HLOpcodeGroup group, unsigned opcode,
  446. StringRef name) {
  447. std::string operatorName = GetHLFullName(group, opcode);
  448. std::string mangledName = operatorName + "." + name.str();
  449. raw_string_ostream mangledNameStr(mangledName);
  450. funcTy->print(mangledNameStr);
  451. mangledNameStr.flush();
  452. Function *F = cast<Function>(M.getOrInsertFunction(mangledName, funcTy));
  453. SetHLFunctionAttribute(F, group, opcode);
  454. F->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
  455. return F;
  456. }
  457. } // namespace hlsl