HLOperationLowerExtension.cpp 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLOperationLowerExtension.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/HLOperationLowerExtension.h"
  10. #include "dxc/DXIL/DxilModule.h"
  11. #include "dxc/DXIL/DxilOperations.h"
  12. #include "dxc/HLSL/HLModule.h"
  13. #include "dxc/HLSL/HLOperationLower.h"
  14. #include "dxc/HLSL/HLOperations.h"
  15. #include "dxc/HlslIntrinsicOp.h"
  16. #include "llvm/ADT/StringRef.h"
  17. #include "llvm/IR/IRBuilder.h"
  18. #include "llvm/IR/Instructions.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/Support/raw_os_ostream.h"
  21. #include "llvm/Support/YAMLParser.h"
  22. #include "llvm/Support/SourceMgr.h"
  23. #include "llvm/ADT/SmallString.h"
  24. using namespace llvm;
  25. using namespace hlsl;
  26. LLVM_ATTRIBUTE_NORETURN static void ThrowExtensionError(StringRef Details)
  27. {
  28. std::string Msg = (Twine("Error in dxc extension api: ") + Details).str();
  29. throw hlsl::Exception(DXC_E_EXTENSION_ERROR, Msg);
  30. }
  31. // The lowering strategy format is a string that matches the following regex:
  32. //
  33. // [a-z](:(?P<ExtraStrategyInfo>.+))?$
  34. //
  35. // The first character indicates the strategy with an optional : followed by
  36. // additional lowering information specific to that strategy.
  37. //
  38. ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
  39. if (strategy.size() < 1)
  40. return Strategy::Unknown;
  41. switch (strategy[0]) {
  42. case 'n': return Strategy::NoTranslation;
  43. case 'r': return Strategy::Replicate;
  44. case 'p': return Strategy::Pack;
  45. case 'm': return Strategy::Resource;
  46. case 'd': return Strategy::Dxil;
  47. default: break;
  48. }
  49. return Strategy::Unknown;
  50. }
  51. llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
  52. switch (strategy) {
  53. case Strategy::NoTranslation: return "n";
  54. case Strategy::Replicate: return "r";
  55. case Strategy::Pack: return "p";
  56. case Strategy::Resource: return "m"; // m for resource method
  57. case Strategy::Dxil: return "d";
  58. default: break;
  59. }
  60. return "?";
  61. }
  62. static std::string ParseExtraStrategyInfo(StringRef strategy)
  63. {
  64. std::pair<StringRef, StringRef> SplitInfo = strategy.split(":");
  65. return SplitInfo.second;
  66. }
  67. ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp, HLResourceLookup &hlResourceLookup)
  68. : m_strategy(strategy), m_helper(helper), m_hlslOp(hlslOp), m_hlResourceLookup(hlResourceLookup)
  69. {}
  70. ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp, HLResourceLookup &hlResourceLookup)
  71. : ExtensionLowering(GetStrategy(strategy), helper, hlslOp, hlResourceLookup)
  72. {
  73. m_extraStrategyInfo = ParseExtraStrategyInfo(strategy);
  74. }
  75. llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
  76. switch (m_strategy) {
  77. case Strategy::NoTranslation: return NoTranslation(CI);
  78. case Strategy::Replicate: return Replicate(CI);
  79. case Strategy::Pack: return Pack(CI);
  80. case Strategy::Resource: return Resource(CI);
  81. case Strategy::Dxil: return Dxil(CI);
  82. default: break;
  83. }
  84. return Unknown(CI);
  85. }
  86. llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
  87. assert(false && "unknown translation strategy");
  88. return nullptr;
  89. }
  90. // Interface to describe how to translate types from HL-dxil to dxil.
  91. class FunctionTypeTranslator {
  92. public:
  93. // Arguments can be exploded into multiple copies of the same type.
  94. // For example a <2 x i32> could become { i32, 2 } if the vector
  95. // is expanded in place or { i32, 1 } if the call is replicated.
  96. struct ArgumentType {
  97. Type *type;
  98. int count;
  99. ArgumentType(Type *ty, int cnt = 1) : type(ty), count(cnt) {}
  100. };
  101. virtual ~FunctionTypeTranslator() {}
  102. virtual Type *TranslateReturnType(CallInst *CI) = 0;
  103. virtual ArgumentType TranslateArgumentType(Value *OrigArg) = 0;
  104. };
  105. // Class to create the new function with the translated types for low-level dxil.
  106. class FunctionTranslator {
  107. public:
  108. template <typename TypeTranslator>
  109. static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
  110. TypeTranslator typeTranslator;
  111. return GetLoweredFunction(typeTranslator, CI, lower);
  112. }
  113. static Function *GetLoweredFunction(FunctionTypeTranslator &typeTranslator, CallInst *CI, ExtensionLowering &lower) {
  114. FunctionTranslator translator(typeTranslator, lower);
  115. return translator.GetLoweredFunction(CI);
  116. }
  117. virtual ~FunctionTranslator() {}
  118. protected:
  119. FunctionTypeTranslator &m_typeTranslator;
  120. ExtensionLowering &m_lower;
  121. FunctionTranslator(FunctionTypeTranslator &typeTranslator, ExtensionLowering &lower)
  122. : m_typeTranslator(typeTranslator)
  123. , m_lower(lower)
  124. {}
  125. Function *GetLoweredFunction(CallInst *CI) {
  126. // Ge the return type of replicated function.
  127. Type *RetTy = m_typeTranslator.TranslateReturnType(CI);
  128. if (!RetTy)
  129. return nullptr;
  130. // Get the Function type for replicated function.
  131. FunctionType *FTy = GetFunctionType(CI, RetTy);
  132. if (!FTy)
  133. return nullptr;
  134. // Create a new function that will be the replicated call.
  135. AttributeSet attributes = GetAttributeSet(CI);
  136. std::string name = m_lower.GetExtensionName(CI);
  137. return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
  138. }
  139. virtual FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
  140. // Create a new function type with the translated argument.
  141. SmallVector<Type *, 10> ParamTypes;
  142. ParamTypes.reserve(CI->getNumArgOperands());
  143. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  144. Value *OrigArg = CI->getArgOperand(i);
  145. FunctionTypeTranslator::ArgumentType newArgType = m_typeTranslator.TranslateArgumentType(OrigArg);
  146. for (int i = 0; i < newArgType.count; ++i) {
  147. ParamTypes.push_back(newArgType.type);
  148. }
  149. }
  150. const bool IsVarArg = false;
  151. return FunctionType::get(RetTy, ParamTypes, IsVarArg);
  152. }
  153. AttributeSet GetAttributeSet(CallInst *CI) {
  154. Function *F = CI->getCalledFunction();
  155. AttributeSet attributes;
  156. auto copyAttribute = [=, &attributes](Attribute::AttrKind a) {
  157. if (F->hasFnAttribute(a)) {
  158. attributes = attributes.addAttribute(CI->getContext(), AttributeSet::FunctionIndex, a);
  159. }
  160. };
  161. copyAttribute(Attribute::AttrKind::ReadOnly);
  162. copyAttribute(Attribute::AttrKind::ReadNone);
  163. copyAttribute(Attribute::AttrKind::ArgMemOnly);
  164. return attributes;
  165. }
  166. };
  167. ///////////////////////////////////////////////////////////////////////////////
  168. // NoTranslation Lowering.
  169. class NoTranslationTypeTranslator : public FunctionTypeTranslator {
  170. virtual Type *TranslateReturnType(CallInst *CI) override {
  171. return CI->getType();
  172. }
  173. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  174. return ArgumentType(OrigArg->getType());
  175. }
  176. };
  177. llvm::Value *ExtensionLowering::NoTranslation(CallInst *CI) {
  178. Function *NoTranslationFunction = FunctionTranslator::GetLoweredFunction<NoTranslationTypeTranslator>(CI, *this);
  179. if (!NoTranslationFunction)
  180. return nullptr;
  181. IRBuilder<> builder(CI);
  182. SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
  183. return builder.CreateCall(NoTranslationFunction, args);
  184. }
  185. ///////////////////////////////////////////////////////////////////////////////
  186. // Replicated Lowering.
  187. enum {
  188. NO_COMMON_VECTOR_SIZE = 0x0,
  189. };
  190. // Find the vector size that will be used for replication.
  191. // The function call will be replicated once for each element of the vector
  192. // size.
  193. static unsigned GetReplicatedVectorSize(llvm::CallInst *CI) {
  194. unsigned commonVectorSize = NO_COMMON_VECTOR_SIZE;
  195. Type *RetTy = CI->getType();
  196. if (RetTy->isVectorTy())
  197. commonVectorSize = RetTy->getVectorNumElements();
  198. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  199. Type *Ty = CI->getArgOperand(i)->getType();
  200. if (Ty->isVectorTy()) {
  201. unsigned vectorSize = Ty->getVectorNumElements();
  202. if (commonVectorSize != NO_COMMON_VECTOR_SIZE && commonVectorSize != vectorSize) {
  203. // Inconsistent vector sizes; need a different strategy.
  204. return NO_COMMON_VECTOR_SIZE;
  205. }
  206. commonVectorSize = vectorSize;
  207. }
  208. }
  209. return commonVectorSize;
  210. }
  211. class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
  212. virtual Type *TranslateReturnType(CallInst *CI) override {
  213. unsigned commonVectorSize = GetReplicatedVectorSize(CI);
  214. if (commonVectorSize == NO_COMMON_VECTOR_SIZE)
  215. return nullptr;
  216. // Result should be vector or void.
  217. Type *RetTy = CI->getType();
  218. if (!RetTy->isVoidTy() && !RetTy->isVectorTy())
  219. return nullptr;
  220. if (RetTy->isVectorTy()) {
  221. RetTy = RetTy->getVectorElementType();
  222. }
  223. return RetTy;
  224. }
  225. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  226. Type *Ty = OrigArg->getType();
  227. if (Ty->isVectorTy()) {
  228. Ty = Ty->getVectorElementType();
  229. }
  230. return ArgumentType(Ty);
  231. }
  232. };
  233. class ReplicateCall {
  234. public:
  235. ReplicateCall(CallInst *CI, Function &ReplicatedFunction)
  236. : m_CI(CI)
  237. , m_ReplicatedFunction(ReplicatedFunction)
  238. , m_numReplicatedCalls(GetReplicatedVectorSize(CI))
  239. , m_ScalarizeArgIdx()
  240. , m_Args(CI->getNumArgOperands())
  241. , m_ReplicatedCalls(m_numReplicatedCalls)
  242. , m_Builder(CI)
  243. {
  244. assert(m_numReplicatedCalls != NO_COMMON_VECTOR_SIZE);
  245. }
  246. Value *Generate() {
  247. CollectReplicatedArguments();
  248. CreateReplicatedCalls();
  249. Value *retVal = GetReturnValue();
  250. return retVal;
  251. }
  252. private:
  253. CallInst *m_CI;
  254. Function &m_ReplicatedFunction;
  255. unsigned m_numReplicatedCalls;
  256. SmallVector<unsigned, 10> m_ScalarizeArgIdx;
  257. SmallVector<Value *, 10> m_Args;
  258. SmallVector<Value *, 10> m_ReplicatedCalls;
  259. IRBuilder<> m_Builder;
  260. // Collect replicated arguments.
  261. // For non-vector arguments we can add them to the args list directly.
  262. // These args will be shared by each replicated call. For the vector
  263. // arguments we remember the position it will go in the argument list.
  264. // We will fill in the vector args below when we replicate the call
  265. // (once for each vector lane).
  266. void CollectReplicatedArguments() {
  267. for (unsigned i = 0; i < m_CI->getNumArgOperands(); ++i) {
  268. Type *Ty = m_CI->getArgOperand(i)->getType();
  269. if (Ty->isVectorTy()) {
  270. m_ScalarizeArgIdx.push_back(i);
  271. }
  272. else {
  273. m_Args[i] = m_CI->getArgOperand(i);
  274. }
  275. }
  276. }
  277. // Create replicated calls.
  278. // Replicate the call once for each element of the replicated vector size.
  279. void CreateReplicatedCalls() {
  280. for (unsigned vecIdx = 0; vecIdx < m_numReplicatedCalls; vecIdx++) {
  281. for (unsigned i = 0, e = m_ScalarizeArgIdx.size(); i < e; ++i) {
  282. unsigned argIdx = m_ScalarizeArgIdx[i];
  283. Value *arg = m_CI->getArgOperand(argIdx);
  284. m_Args[argIdx] = m_Builder.CreateExtractElement(arg, vecIdx);
  285. }
  286. Value *EltOP = m_Builder.CreateCall(&m_ReplicatedFunction, m_Args);
  287. m_ReplicatedCalls[vecIdx] = EltOP;
  288. }
  289. }
  290. // Get the final replicated value.
  291. // If the function is a void type then return (arbitrarily) the first call.
  292. // We do not return nullptr because that indicates a failure to replicate.
  293. // If the function is a vector type then aggregate all of the replicated
  294. // call values into a new vector.
  295. Value *GetReturnValue() {
  296. if (m_CI->getType()->isVoidTy())
  297. return m_ReplicatedCalls.back();
  298. Value *retVal = llvm::UndefValue::get(m_CI->getType());
  299. for (unsigned i = 0; i < m_ReplicatedCalls.size(); ++i)
  300. retVal = m_Builder.CreateInsertElement(retVal, m_ReplicatedCalls[i], i);
  301. return retVal;
  302. }
  303. };
  304. // Translate the HL call by replicating the call for each vector element.
  305. //
  306. // For example,
  307. //
  308. // <2xi32> %r = call @ext.foo(i32 %op, <2xi32> %v)
  309. // ==>
  310. // %r.1 = call @ext.foo.s(i32 %op, i32 %v.1)
  311. // %r.2 = call @ext.foo.s(i32 %op, i32 %v.2)
  312. // <2xi32> %r.v.1 = insertelement %r.1, 0, <2xi32> undef
  313. // <2xi32> %r.v.2 = insertelement %r.2, 1, %r.v.1
  314. //
  315. // You can then RAWU %r with %r.v.2. The RAWU is not done by the translate function.
  316. Value *ExtensionLowering::Replicate(CallInst *CI) {
  317. Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
  318. if (!ReplicatedFunction)
  319. return NoTranslation(CI);
  320. ReplicateCall replicate(CI, *ReplicatedFunction);
  321. return replicate.Generate();
  322. }
  323. ///////////////////////////////////////////////////////////////////////////////
  324. // Packed Lowering.
  325. class PackCall {
  326. public:
  327. PackCall(CallInst *CI, Function &PackedFunction)
  328. : m_CI(CI)
  329. , m_packedFunction(PackedFunction)
  330. , m_builder(CI)
  331. {}
  332. Value *Generate() {
  333. SmallVector<Value *, 10> args;
  334. PackArgs(args);
  335. Value *result = CreateCall(args);
  336. return UnpackResult(result);
  337. }
  338. static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
  339. assert(vecTy->isVectorTy());
  340. Type *elementTy = vecTy->getVectorElementType();
  341. unsigned numElements = vecTy->getVectorNumElements();
  342. SmallVector<Type *, 4> elements;
  343. for (unsigned i = 0; i < numElements; ++i)
  344. elements.push_back(elementTy);
  345. return StructType::get(vecTy->getContext(), elements);
  346. }
  347. private:
  348. CallInst *m_CI;
  349. Function &m_packedFunction;
  350. IRBuilder<> m_builder;
  351. void PackArgs(SmallVectorImpl<Value*> &args) {
  352. args.clear();
  353. for (Value *arg : m_CI->arg_operands()) {
  354. if (arg->getType()->isVectorTy())
  355. arg = PackVectorIntoStruct(m_builder, arg);
  356. args.push_back(arg);
  357. }
  358. }
  359. Value *CreateCall(const SmallVectorImpl<Value*> &args) {
  360. return m_builder.CreateCall(&m_packedFunction, args);
  361. }
  362. Value *UnpackResult(Value *result) {
  363. if (result->getType()->isStructTy()) {
  364. result = PackStructIntoVector(m_builder, result);
  365. }
  366. return result;
  367. }
  368. static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
  369. assert(structTy->isStructTy());
  370. return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
  371. }
  372. static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
  373. StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
  374. Value *packed = UndefValue::get(structTy);
  375. unsigned numElements = structTy->getStructNumElements();
  376. for (unsigned i = 0; i < numElements; ++i) {
  377. Value *element = builder.CreateExtractElement(vec, i);
  378. packed = builder.CreateInsertValue(packed, element, { i });
  379. }
  380. return packed;
  381. }
  382. static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
  383. Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
  384. Value *packed = UndefValue::get(vecTy);
  385. unsigned numElements = vecTy->getVectorNumElements();
  386. for (unsigned i = 0; i < numElements; ++i) {
  387. Value *element = builder.CreateExtractValue(strukt, i);
  388. packed = builder.CreateInsertElement(packed, element, i);
  389. }
  390. return packed;
  391. }
  392. };
  393. class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
  394. virtual Type *TranslateReturnType(CallInst *CI) override {
  395. return TranslateIfVector(CI->getType());
  396. }
  397. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  398. return ArgumentType(TranslateIfVector(OrigArg->getType()));
  399. }
  400. Type *TranslateIfVector(Type *ty) {
  401. if (ty->isVectorTy())
  402. ty = PackCall::ConvertVectorTypeToStructType(ty);
  403. return ty;
  404. }
  405. };
  406. Value *ExtensionLowering::Pack(CallInst *CI) {
  407. Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
  408. if (!PackedFunction)
  409. return NoTranslation(CI);
  410. PackCall pack(CI, *PackedFunction);
  411. Value *result = pack.Generate();
  412. return result;
  413. }
  414. ///////////////////////////////////////////////////////////////////////////////
  415. // Resource Lowering.
  416. // Modify a call to a resouce method. Makes the following transformation:
  417. //
  418. // 1. Convert non-void return value to dx.types.ResRet.
  419. // 2. Expand vectors in place as separate arguments.
  420. //
  421. // Example
  422. // -----------------------------------------------------------------------------
  423. //
  424. // %0 = call <2 x float> MyBufferOp(i32 138, %class.Buffer %3, <2 x i32> <1 , 2> )
  425. // %r = call %dx.types.ResRet.f32 MyBufferOp(i32 138, %dx.types.Handle %buf, i32 1, i32 2 )
  426. // %x = extractvalue %r, 0
  427. // %y = extractvalue %r, 1
  428. // %v = <2 x float> undef
  429. // %v.1 = insertelement %v, %x, 0
  430. // %v.2 = insertelement %v.1, %y, 1
  431. class ResourceMethodCall {
  432. public:
  433. ResourceMethodCall(CallInst *CI)
  434. : m_CI(CI)
  435. , m_builder(CI)
  436. { }
  437. virtual ~ResourceMethodCall() {}
  438. virtual Value *Generate(Function *explodedFunction) {
  439. SmallVector<Value *, 16> args;
  440. ExplodeArgs(args);
  441. Value *result = CreateCall(explodedFunction, args);
  442. result = ConvertResult(result);
  443. return result;
  444. }
  445. protected:
  446. CallInst *m_CI;
  447. IRBuilder<> m_builder;
  448. void ExplodeArgs(SmallVectorImpl<Value*> &args) {
  449. for (Value *arg : m_CI->arg_operands()) {
  450. // vector arg: <N x ty> -> ty, ty, ..., ty (N times)
  451. if (arg->getType()->isVectorTy()) {
  452. for (unsigned i = 0; i < arg->getType()->getVectorNumElements(); i++) {
  453. Value *xarg = m_builder.CreateExtractElement(arg, i);
  454. args.push_back(xarg);
  455. }
  456. }
  457. // any other value: arg -> arg
  458. else {
  459. args.push_back(arg);
  460. }
  461. }
  462. }
  463. Value *CreateCall(Function *explodedFunction, ArrayRef<Value*> args) {
  464. return m_builder.CreateCall(explodedFunction, args);
  465. }
  466. Value *ConvertResult(Value *result) {
  467. Type *origRetTy = m_CI->getType();
  468. if (origRetTy->isVoidTy())
  469. return ConvertVoidResult(result);
  470. else if (origRetTy->isVectorTy())
  471. return ConvertVectorResult(origRetTy, result);
  472. else
  473. return ConvertScalarResult(origRetTy, result);
  474. }
  475. // Void result does not need any conversion.
  476. Value *ConvertVoidResult(Value *result) {
  477. return result;
  478. }
  479. // Vector result will be populated with the elements from the resource return.
  480. Value *ConvertVectorResult(Type *origRetTy, Value *result) {
  481. Type *resourceRetTy = result->getType();
  482. assert(origRetTy->isVectorTy());
  483. assert(resourceRetTy->isStructTy() && "expected resource return type to be a struct");
  484. const unsigned vectorSize = origRetTy->getVectorNumElements();
  485. const unsigned structSize = resourceRetTy->getStructNumElements();
  486. const unsigned size = std::min(vectorSize, structSize);
  487. assert(vectorSize < structSize);
  488. // Copy resource struct elements to vector.
  489. Value *vector = UndefValue::get(origRetTy);
  490. for (unsigned i = 0; i < size; ++i) {
  491. Value *element = m_builder.CreateExtractValue(result, { i });
  492. vector = m_builder.CreateInsertElement(vector, element, i);
  493. }
  494. return vector;
  495. }
  496. // Scalar result will be populated with the first element of the resource return.
  497. Value *ConvertScalarResult(Type *origRetTy, Value *result) {
  498. assert(origRetTy->isSingleValueType());
  499. return m_builder.CreateExtractValue(result, { 0 });
  500. }
  501. };
  502. // Translate function return and argument types for resource method lowering.
  503. class ResourceFunctionTypeTranslator : public FunctionTypeTranslator {
  504. public:
  505. ResourceFunctionTypeTranslator(OP &hlslOp) : m_hlslOp(hlslOp) {}
  506. // Translate return type as follows:
  507. //
  508. // void -> void
  509. // <N x ty> -> dx.types.ResRet.ty
  510. // ty -> dx.types.ResRet.ty
  511. virtual Type *TranslateReturnType(CallInst *CI) override {
  512. Type *RetTy = CI->getType();
  513. if (RetTy->isVoidTy())
  514. return RetTy;
  515. else if (RetTy->isVectorTy())
  516. RetTy = RetTy->getVectorElementType();
  517. return m_hlslOp.GetResRetType(RetTy);
  518. }
  519. // Translate argument type as follows:
  520. //
  521. // resource -> dx.types.Handle
  522. // <N x ty> -> { ty, N }
  523. // ty -> { ty, 1 }
  524. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  525. int count = 1;
  526. Type *ty = OrigArg->getType();
  527. if (ty->isVectorTy()) {
  528. count = ty->getVectorNumElements();
  529. ty = ty->getVectorElementType();
  530. }
  531. return ArgumentType(ty, count);
  532. }
  533. private:
  534. OP& m_hlslOp;
  535. };
  536. Value *ExtensionLowering::Resource(CallInst *CI) {
  537. // Extra strategy info overrides the default lowering for resource methods.
  538. if (!m_extraStrategyInfo.empty())
  539. {
  540. return CustomResource(CI);
  541. }
  542. ResourceFunctionTypeTranslator resourceTypeTranslator(m_hlslOp);
  543. Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
  544. if (!resourceFunction)
  545. return NoTranslation(CI);
  546. ResourceMethodCall explode(CI);
  547. Value *result = explode.Generate(resourceFunction);
  548. return result;
  549. }
  550. // This class handles the core logic for custom lowering of resource
  551. // method intrinsics. The goal is to allow resource extension intrinsics
  552. // to be handled the same way as the core hlsl resource intrinsics.
  553. //
  554. // Specifically, we want to support:
  555. //
  556. // 1. Multiple hlsl overloads map to a single dxil intrinsic
  557. // 2. The hlsl overloads can take different parameters for a given resource type
  558. // 3. The hlsl overloads are not consistent across different resource types
  559. //
  560. // To achieve these goals we need a more complex mechanism for describing how
  561. // to translate the high-level arguments to arguments for a dxil function.
  562. // The custom lowering info describes this lowering using the following format.
  563. //
  564. // [Custom Lowering Info Format]
  565. // A json string encoding a map where each key is either a specific resource type or
  566. // the keyword "default" to be used for any other resource. The value is a
  567. // a custom-format string encoding how high-level arguments are mapped to
  568. // dxil intrinsic arguments.
  569. //
  570. // [Argument Translation Format]
  571. // A comma separated string where the number of fields is exactly equal to the number
  572. // of parameters in the target dxil intrinsic. Each field describes how to generate
  573. // the argument for that dxil intrinsic parameter. It has the following format where
  574. // the hl_arg_index is mandatory, but the other two parts are optional.
  575. //
  576. // <hl_arg_index>.<vector_index>:<optional_type_info>
  577. //
  578. // The format is precisely described by the following regular expression:
  579. //
  580. // (?P<hl_arg_index>[-0-9]+)(.(?P<vector_index>[-0-9]+))?(:(?P<optional_type_info>\?i32|\?i16|\?i8|\?float|\?half))?$
  581. //
  582. // Example
  583. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  584. // Say we want to define the MyTextureOp extension with the following overloads:
  585. //
  586. // Texture1D
  587. // MyTextureOp(uint addr, uint offset)
  588. // MyTextureOp(uint addr, uint offset, uint val)
  589. //
  590. // Texture2D
  591. // MyTextureOp(uint2 addr, uint2 val)
  592. //
  593. // And a dxil intrinsic defined as follows
  594. // @MyTextureOp(i32 opcode, %dx.types.Handle handle, i32 addr0, i32 addr1, i32 offset, i32 val0, i32 val1)
  595. //
  596. // Then we would define the lowering info json as follows
  597. //
  598. // {
  599. // "default" : "0, 1, 2.0, 2.1, 3 , 4.0:?i32, 4.1:?i32"
  600. // "Texture2D" : "0, 1, 2.0, 2.1, -1:?i32, 3.0 , 3.1\"
  601. // }
  602. //
  603. //
  604. // This would produce the following lowerings (assuming the MyTextureOp opcode is 17)
  605. //
  606. // hlsl: Texture1D.MyTextureOp(a, b)
  607. // hl: @MyTextureOp(17, handle, a, b)
  608. // dxil: @MyTextureOp(17, handle, a, undef, b, undef, undef)
  609. //
  610. // hlsl: Texture1D.MyTextureOp(a, b, c)
  611. // hl: @MyTextureOp(17, handle, a, b, c)
  612. // dxil: @MyTextureOp(17, handle, a, undef, b, c, undef)
  613. //
  614. // hlsl: Texture2D.MyTextureOp(a, c)
  615. // hl: @MyTextureOp(17, handle, a, c)
  616. // dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
  617. //
  618. //
  619. class CustomResourceLowering
  620. {
  621. public:
  622. CustomResourceLowering(StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
  623. {
  624. // Parse lowering info json format.
  625. std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
  626. ParseLoweringInfo(LoweringInfo, CI->getContext());
  627. // Lookup resource kind based on handle (first arg after hl opcode)
  628. enum {RESOURCE_HANDLE_ARG=1};
  629. const char *pName = nullptr;
  630. if (!ResourceLookup.GetResourceKindName(CI->getArgOperand(RESOURCE_HANDLE_ARG), &pName))
  631. {
  632. ThrowExtensionError("Failed to find resource from handle");
  633. }
  634. std::string Name(pName);
  635. // Select lowering info to use based on resource kind.
  636. const char *DefaultInfoName = "default";
  637. std::vector<DxilArgInfo> *pArgInfo = nullptr;
  638. if (LoweringInfoMap.count(Name))
  639. {
  640. pArgInfo = &LoweringInfoMap.at(Name);
  641. }
  642. else if (LoweringInfoMap.count(DefaultInfoName))
  643. {
  644. pArgInfo = &LoweringInfoMap.at(DefaultInfoName);
  645. }
  646. else
  647. {
  648. ThrowExtensionError("Unable to find lowering info for resource");
  649. }
  650. GenerateLoweredArgs(CI, *pArgInfo);
  651. }
  652. const std::vector<Value *> &GetLoweredArgs() const
  653. {
  654. return m_LoweredArgs;
  655. }
  656. private:
  657. struct OptionalTypeSpec
  658. {
  659. const char* TypeName;
  660. Type *LLVMType;
  661. };
  662. // These are the supported optional types for generating dxil parameters
  663. // that have no matching argument in the high-level intrinsic overload.
  664. // See [Argument Translation Format] for details.
  665. void InitOptionalTypes(LLVMContext &Ctx)
  666. {
  667. // Table of supported optional types.
  668. // Keep in sync with m_OptionalTypes small vector size to avoid
  669. // dynamic allocation.
  670. OptionalTypeSpec OptionalTypes[] = {
  671. {"?i32", Type::getInt32Ty(Ctx)},
  672. {"?float", Type::getFloatTy(Ctx)},
  673. {"?half", Type::getHalfTy(Ctx)},
  674. {"?i8", Type::getInt8Ty(Ctx)},
  675. {"?i16", Type::getInt16Ty(Ctx)},
  676. };
  677. DXASSERT(m_OptionalTypes.empty(), "Init should only be called once");
  678. m_OptionalTypes.clear();
  679. m_OptionalTypes.reserve(_countof(OptionalTypes));
  680. for (const OptionalTypeSpec &T : OptionalTypes)
  681. {
  682. m_OptionalTypes.push_back(T);
  683. }
  684. }
  685. Type *ParseOptionalType(StringRef OptionalTypeInfo)
  686. {
  687. if (OptionalTypeInfo.empty())
  688. {
  689. return nullptr;
  690. }
  691. for (OptionalTypeSpec &O : m_OptionalTypes)
  692. {
  693. if (OptionalTypeInfo == O.TypeName)
  694. {
  695. return O.LLVMType;
  696. }
  697. }
  698. ThrowExtensionError("Failed to parse optional type");
  699. }
  700. // Mapping from high level function arg to dxil function arg.
  701. //
  702. // The `HighLevelArgIndex` is the index of the function argument to
  703. // which this dxil argument maps.
  704. //
  705. // If `HasVectorIndex` is true then the `VectorIndex` contains the
  706. // index of the element in the vector pointed to by HighLevelArgIndex.
  707. //
  708. // The `OptionalType` is used to specify types for arguments that are not
  709. // present in all overloads of the high level function. This lets us
  710. // map multiple high level functions to a single dxil extension intrinsic.
  711. //
  712. struct DxilArgInfo
  713. {
  714. unsigned HighLevelArgIndex = 0;
  715. unsigned VectorIndex = 0;
  716. bool HasVectorIndex = false;
  717. Type *OptionalType = nullptr;
  718. };
  719. typedef std::string ResourceKindName;
  720. // Convert the lowering info to a machine-friendly format.
  721. // Note that we use the YAML parser to parse the JSON since JSON
  722. // is a subset of YAML (and this llvm has no JSON parser).
  723. //
  724. // See [Custom Lowering Info Format] for details.
  725. std::map<ResourceKindName, std::vector<DxilArgInfo>> ParseLoweringInfo(StringRef LoweringInfo, LLVMContext &Ctx)
  726. {
  727. InitOptionalTypes(Ctx);
  728. std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap;
  729. SourceMgr SM;
  730. yaml::Stream YAMLStream(LoweringInfo, SM);
  731. // Make sure we have a valid json input.
  732. llvm::yaml::document_iterator I = YAMLStream.begin();
  733. if (I == YAMLStream.end()) {
  734. ThrowExtensionError("Found empty resource lowering JSON.");
  735. }
  736. llvm::yaml::Node *Root = I->getRoot();
  737. if (!Root) {
  738. ThrowExtensionError("Error parsing resource lowering JSON.");
  739. }
  740. // Parse the top level map object.
  741. llvm::yaml::MappingNode *Object = dyn_cast<llvm::yaml::MappingNode>(Root);
  742. if (!Object) {
  743. ThrowExtensionError("Expected map in top level of resource lowering JSON.");
  744. }
  745. // Parse all key/value pairs from the map.
  746. for (llvm::yaml::MappingNode::iterator KVI = Object->begin(),
  747. KVE = Object->end();
  748. KVI != KVE; ++KVI)
  749. {
  750. // Parse key.
  751. llvm::yaml::ScalarNode *KeyString =
  752. dyn_cast_or_null<llvm::yaml::ScalarNode>((*KVI).getKey());
  753. if (!KeyString) {
  754. ThrowExtensionError("Expected string as key in resource lowering info JSON map.");
  755. }
  756. SmallString<32> KeyStorage;
  757. StringRef Key = KeyString->getValue(KeyStorage);
  758. // Parse value.
  759. llvm::yaml::ScalarNode *ValueString =
  760. dyn_cast_or_null<llvm::yaml::ScalarNode>((*KVI).getValue());
  761. if (!ValueString) {
  762. ThrowExtensionError("Expected string as value in resource lowering info JSON map.");
  763. }
  764. SmallString<128> ValueStorage;
  765. StringRef Value = ValueString->getValue(ValueStorage);
  766. // Parse dxil arg info from value.
  767. LoweringInfoMap[Key] = ParseDxilArgInfo(Value, Ctx);
  768. }
  769. return LoweringInfoMap;
  770. }
  771. // Parse the dxail argument translation info.
  772. // See [Argument Translation Format] for details.
  773. std::vector<DxilArgInfo> ParseDxilArgInfo(StringRef ArgSpec, LLVMContext &Ctx)
  774. {
  775. std::vector<DxilArgInfo> Args;
  776. SmallVector<StringRef, 14> Splits;
  777. ArgSpec.split(Splits, ",");
  778. for (const StringRef Split : Splits)
  779. {
  780. StringRef Field = Split.trim();
  781. StringRef HighLevelArgInfo;
  782. StringRef OptionalTypeInfo;
  783. std::tie(HighLevelArgInfo, OptionalTypeInfo) = Field.split(":");
  784. Type *OptionalType = ParseOptionalType(OptionalTypeInfo);
  785. StringRef HighLevelArgIndex;
  786. StringRef VectorIndex;
  787. std::tie(HighLevelArgIndex, VectorIndex) = HighLevelArgInfo.split(".");
  788. // Parse the arg and vector index.
  789. // Parse the values as signed integers, but store them as unsigned values to
  790. // allows using -1 as a shorthand for the max value.
  791. DxilArgInfo ArgInfo;
  792. ArgInfo.HighLevelArgIndex = static_cast<unsigned>(std::stoi(HighLevelArgIndex));
  793. if (!VectorIndex.empty())
  794. {
  795. ArgInfo.HasVectorIndex = true;
  796. ArgInfo.VectorIndex = static_cast<unsigned>(std::stoi(VectorIndex));
  797. }
  798. ArgInfo.OptionalType = OptionalType;
  799. Args.push_back(ArgInfo);
  800. }
  801. return Args;
  802. }
  803. // Create the dxil args based on custom lowering info.
  804. void GenerateLoweredArgs(CallInst *CI, const std::vector<DxilArgInfo> &ArgInfoRecords)
  805. {
  806. IRBuilder<> builder(CI);
  807. for (const DxilArgInfo &ArgInfo : ArgInfoRecords)
  808. {
  809. // Check to see if we have the corresponding high-level arg in the overload for this call.
  810. if (ArgInfo.HighLevelArgIndex < CI->getNumArgOperands())
  811. {
  812. Value *Arg = CI->getArgOperand(ArgInfo.HighLevelArgIndex);
  813. if (ArgInfo.HasVectorIndex)
  814. {
  815. // We expect a vector type here, but we handle one special case if not.
  816. if (Arg->getType()->isVectorTy())
  817. {
  818. // We allow multiple high-level overloads to map to a single dxil extension function.
  819. // If the vector index is invalid for this specific overload then use an undef
  820. // value as a replacement.
  821. if (ArgInfo.VectorIndex < Arg->getType()->getVectorNumElements())
  822. {
  823. Arg = builder.CreateExtractElement(Arg, ArgInfo.VectorIndex);
  824. }
  825. else
  826. {
  827. Arg = UndefValue::get(Arg->getType()->getVectorElementType());
  828. }
  829. }
  830. else
  831. {
  832. // If it is a non-vector type then we replace non-zero vector index with
  833. // undef. This is to handle hlsl intrinsic overloading rules that allow
  834. // scalars in place of single-element vectors. We assume here that a non-vector
  835. // means that a single element vector was already scalarized.
  836. //
  837. if (ArgInfo.VectorIndex > 0)
  838. {
  839. Arg = UndefValue::get(Arg->getType());
  840. }
  841. }
  842. }
  843. m_LoweredArgs.push_back(Arg);
  844. }
  845. else if (ArgInfo.OptionalType)
  846. {
  847. // If there was no matching high-level arg then we look for the optional
  848. // arg type specified by the lowering info.
  849. m_LoweredArgs.push_back(UndefValue::get(ArgInfo.OptionalType));
  850. }
  851. else
  852. {
  853. // No way to know how to generate the correc type for this dxil arg.
  854. ThrowExtensionError("Unable to map high-level arg to dxil arg");
  855. }
  856. }
  857. }
  858. std::vector<Value *> m_LoweredArgs;
  859. SmallVector<OptionalTypeSpec, 5> m_OptionalTypes;
  860. };
  861. // Boilerplate to reuse exising logic as much as possible.
  862. // We just want to overload GetFunctionType here.
  863. class CustomResourceFunctionTranslator : public FunctionTranslator {
  864. public:
  865. static Function *GetLoweredFunction(
  866. const CustomResourceLowering &CustomLowering,
  867. ResourceFunctionTypeTranslator &typeTranslator,
  868. CallInst *CI,
  869. ExtensionLowering &lower
  870. )
  871. {
  872. CustomResourceFunctionTranslator T(CustomLowering, typeTranslator, lower);
  873. return T.FunctionTranslator::GetLoweredFunction(CI);
  874. }
  875. private:
  876. CustomResourceFunctionTranslator(
  877. const CustomResourceLowering &CustomLowering,
  878. ResourceFunctionTypeTranslator &typeTranslator,
  879. ExtensionLowering &lower
  880. )
  881. : FunctionTranslator(typeTranslator, lower)
  882. , m_CustomLowering(CustomLowering)
  883. {
  884. }
  885. virtual FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) override {
  886. SmallVector<Type *, 16> ParamTypes;
  887. for (Value *V : m_CustomLowering.GetLoweredArgs())
  888. {
  889. ParamTypes.push_back(V->getType());
  890. }
  891. const bool IsVarArg = false;
  892. return FunctionType::get(RetTy, ParamTypes, IsVarArg);
  893. }
  894. private:
  895. const CustomResourceLowering &m_CustomLowering;
  896. };
  897. // Boilerplate to reuse exising logic as much as possible.
  898. // We just want to overload Generate here.
  899. class CustomResourceMethodCall : public ResourceMethodCall
  900. {
  901. public:
  902. CustomResourceMethodCall(CallInst *CI, const CustomResourceLowering &CustomLowering)
  903. : ResourceMethodCall(CI)
  904. , m_CustomLowering(CustomLowering)
  905. {}
  906. virtual Value *Generate(Function *loweredFunction) override {
  907. Value *result = CreateCall(loweredFunction, m_CustomLowering.GetLoweredArgs());
  908. result = ConvertResult(result);
  909. return result;
  910. }
  911. private:
  912. const CustomResourceLowering &m_CustomLowering;
  913. };
  914. // Support custom lowering logic for resource functions.
  915. Value *ExtensionLowering::CustomResource(CallInst *CI) {
  916. CustomResourceLowering CustomLowering(m_extraStrategyInfo, CI, m_hlResourceLookup);
  917. ResourceFunctionTypeTranslator ResourceTypeTranslator(m_hlslOp);
  918. Function *ResourceFunction = CustomResourceFunctionTranslator::GetLoweredFunction(
  919. CustomLowering,
  920. ResourceTypeTranslator,
  921. CI,
  922. *this
  923. );
  924. if (!ResourceFunction)
  925. return NoTranslation(CI);
  926. CustomResourceMethodCall custom(CI, CustomLowering);
  927. Value *Result = custom.Generate(ResourceFunction);
  928. return Result;
  929. }
  930. ///////////////////////////////////////////////////////////////////////////////
  931. // Dxil Lowering.
  932. Value *ExtensionLowering::Dxil(CallInst *CI) {
  933. // Map the extension opcode to the corresponding dxil opcode.
  934. unsigned extOpcode = GetHLOpcode(CI);
  935. OP::OpCode dxilOpcode;
  936. if (!m_helper->GetDxilOpcode(extOpcode, dxilOpcode))
  937. return nullptr;
  938. // Find the dxil function based on the overload type.
  939. Type *overloadTy = OP::GetOverloadType(dxilOpcode, CI->getCalledFunction());
  940. Function *F = m_hlslOp.GetOpFunc(dxilOpcode, overloadTy->getScalarType());
  941. // Update the opcode in the original call so we can just copy it below.
  942. // We are about to delete this call anyway.
  943. CI->setOperand(0, m_hlslOp.GetI32Const(static_cast<unsigned>(dxilOpcode)));
  944. // Create the new call.
  945. Value *result = nullptr;
  946. if (overloadTy->isVectorTy()) {
  947. ReplicateCall replicate(CI, *F);
  948. result = replicate.Generate();
  949. }
  950. else {
  951. IRBuilder<> builder(CI);
  952. SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
  953. result = builder.CreateCall(F, args);
  954. }
  955. return result;
  956. }
  957. ///////////////////////////////////////////////////////////////////////////////
  958. // Computing Extension Names.
  959. // Compute the name to use for the intrinsic function call once it is lowered to dxil.
  960. // First checks to see if we have a custom name from the codegen helper and if not
  961. // chooses a default name based on the lowergin strategy.
  962. class ExtensionName {
  963. public:
  964. ExtensionName(CallInst *CI, ExtensionLowering::Strategy strategy, HLSLExtensionsCodegenHelper *helper)
  965. : m_CI(CI)
  966. , m_strategy(strategy)
  967. , m_helper(helper)
  968. {}
  969. std::string Get() {
  970. std::string name;
  971. if (m_helper)
  972. name = GetCustomExtensionName(m_CI, *m_helper);
  973. if (!HasCustomExtensionName(name))
  974. name = GetDefaultCustomExtensionName(m_CI, ExtensionLowering::GetStrategyName(m_strategy));
  975. return name;
  976. }
  977. private:
  978. CallInst *m_CI;
  979. ExtensionLowering::Strategy m_strategy;
  980. HLSLExtensionsCodegenHelper *m_helper;
  981. static std::string GetCustomExtensionName(CallInst *CI, HLSLExtensionsCodegenHelper &helper) {
  982. unsigned opcode = GetHLOpcode(CI);
  983. std::string name = helper.GetIntrinsicName(opcode);
  984. ReplaceOverloadMarkerWithTypeName(name, CI);
  985. return name;
  986. }
  987. static std::string GetDefaultCustomExtensionName(CallInst *CI, StringRef strategyName) {
  988. return (Twine(CI->getCalledFunction()->getName()) + "." + Twine(strategyName)).str();
  989. }
  990. static bool HasCustomExtensionName(const std::string name) {
  991. return name.size() > 0;
  992. }
  993. typedef unsigned OverloadArgIndex;
  994. static constexpr OverloadArgIndex DefaultOverloadIndex = std::numeric_limits<OverloadArgIndex>::max();
  995. // Choose the (return value or argument) type that determines the overload type
  996. // for the intrinsic call.
  997. // If the overload arg index was explicitly specified (see ParseOverloadArgIndex)
  998. // then we use that arg to pick the overload name. Otherwise we pick a default
  999. // where we take the return type as the overload. If the return is void we
  1000. // take the first (non-opcode) argument as the overload type.
  1001. static Type *SelectOverloadSlot(CallInst *CI, OverloadArgIndex ArgIndex) {
  1002. if (ArgIndex != DefaultOverloadIndex)
  1003. {
  1004. return CI->getArgOperand(ArgIndex)->getType();
  1005. }
  1006. Type *ty = CI->getType();
  1007. if (ty->isVoidTy()) {
  1008. if (CI->getNumArgOperands() > 1)
  1009. ty = CI->getArgOperand(1)->getType(); // First non-opcode argument.
  1010. }
  1011. return ty;
  1012. }
  1013. static Type *GetOverloadType(CallInst *CI, OverloadArgIndex ArgIndex) {
  1014. Type *ty = SelectOverloadSlot(CI, ArgIndex);
  1015. if (ty->isVectorTy())
  1016. ty = ty->getVectorElementType();
  1017. return ty;
  1018. }
  1019. static std::string GetTypeName(Type *ty) {
  1020. std::string typeName;
  1021. llvm::raw_string_ostream os(typeName);
  1022. ty->print(os);
  1023. os.flush();
  1024. return typeName;
  1025. }
  1026. static std::string GetOverloadTypeName(CallInst *CI, OverloadArgIndex ArgIndex) {
  1027. Type *ty = GetOverloadType(CI, ArgIndex);
  1028. return GetTypeName(ty);
  1029. }
  1030. // Parse the arg index out of the overload marker (if any).
  1031. //
  1032. // The function names use a $o to indicate that the function is overloaded
  1033. // and we should replace $o with the overload type. The extension name can
  1034. // explicitly set which arg to use for the overload type by adding a colon
  1035. // and a number after the $o (e.g. $o:3 would say the overload type is
  1036. // determined by parameter 3).
  1037. //
  1038. // If we find an arg index after the overload marker we update the size
  1039. // of the marker to include the full parsed string size so that it can
  1040. // be replaced with the selected overload type.
  1041. //
  1042. static OverloadArgIndex ParseOverloadArgIndex(
  1043. const std::string& functionName,
  1044. size_t OverloadMarkerStartIndex,
  1045. size_t *pOverloadMarkerSize)
  1046. {
  1047. assert(OverloadMarkerStartIndex != std::string::npos);
  1048. size_t StartIndex = OverloadMarkerStartIndex + *pOverloadMarkerSize;
  1049. // Check if we have anything after the overload marker to parse.
  1050. if (StartIndex >= functionName.size())
  1051. {
  1052. return DefaultOverloadIndex;
  1053. }
  1054. // Does it start with a ':' ?
  1055. if (functionName[StartIndex] != ':')
  1056. {
  1057. return DefaultOverloadIndex;
  1058. }
  1059. // Skip past the :
  1060. ++StartIndex;
  1061. // Collect all the digits.
  1062. std::string Digits;
  1063. Digits.reserve(functionName.size() - StartIndex);
  1064. for (size_t i = StartIndex; i < functionName.size(); ++i)
  1065. {
  1066. char c = functionName[i];
  1067. if (!isdigit(c))
  1068. {
  1069. break;
  1070. }
  1071. Digits.push_back(c);
  1072. }
  1073. if (Digits.empty())
  1074. {
  1075. return DefaultOverloadIndex;
  1076. }
  1077. *pOverloadMarkerSize = *pOverloadMarkerSize + std::strlen(":") + Digits.size();
  1078. return std::stoi(Digits);
  1079. }
  1080. // Find the occurence of the overload marker $o and replace it the the overload type name.
  1081. static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
  1082. const char *OverloadMarker = "$o";
  1083. size_t OverloadMarkerLength = 2;
  1084. size_t pos = functionName.find(OverloadMarker);
  1085. if (pos != std::string::npos) {
  1086. OverloadArgIndex ArgIndex = ParseOverloadArgIndex(functionName, pos, &OverloadMarkerLength);
  1087. std::string typeName = GetOverloadTypeName(CI, ArgIndex);
  1088. functionName.replace(pos, OverloadMarkerLength, typeName);
  1089. }
  1090. }
  1091. };
  1092. std::string ExtensionLowering::GetExtensionName(llvm::CallInst *CI) {
  1093. ExtensionName name(CI, m_strategy, m_helper);
  1094. return name.Get();
  1095. }