HLOperationLowerExtension.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLOperationLowerExtension.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/HLOperationLowerExtension.h"
  10. #include "dxc/DXIL/DxilModule.h"
  11. #include "dxc/DXIL/DxilOperations.h"
  12. #include "dxc/HLSL/HLModule.h"
  13. #include "dxc/HLSL/HLOperationLower.h"
  14. #include "dxc/HLSL/HLOperations.h"
  15. #include "dxc/HlslIntrinsicOp.h"
  16. #include "llvm/ADT/StringRef.h"
  17. #include "llvm/IR/IRBuilder.h"
  18. #include "llvm/IR/Instructions.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/Support/raw_os_ostream.h"
  21. using namespace llvm;
  22. using namespace hlsl;
  23. ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
  24. if (strategy.size() < 1)
  25. return Strategy::Unknown;
  26. switch (strategy[0]) {
  27. case 'n': return Strategy::NoTranslation;
  28. case 'r': return Strategy::Replicate;
  29. case 'p': return Strategy::Pack;
  30. case 'm': return Strategy::Resource;
  31. case 'd': return Strategy::Dxil;
  32. default: break;
  33. }
  34. return Strategy::Unknown;
  35. }
  36. llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
  37. switch (strategy) {
  38. case Strategy::NoTranslation: return "n";
  39. case Strategy::Replicate: return "r";
  40. case Strategy::Pack: return "p";
  41. case Strategy::Resource: return "m"; // m for resource method
  42. case Strategy::Dxil: return "d";
  43. default: break;
  44. }
  45. return "?";
  46. }
  47. ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp)
  48. : m_strategy(strategy), m_helper(helper), m_hlslOp(hlslOp)
  49. {}
  50. ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp)
  51. : ExtensionLowering(GetStrategy(strategy), helper, hlslOp)
  52. {}
  53. llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
  54. switch (m_strategy) {
  55. case Strategy::NoTranslation: return NoTranslation(CI);
  56. case Strategy::Replicate: return Replicate(CI);
  57. case Strategy::Pack: return Pack(CI);
  58. case Strategy::Resource: return Resource(CI);
  59. case Strategy::Dxil: return Dxil(CI);
  60. default: break;
  61. }
  62. return Unknown(CI);
  63. }
  64. llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
  65. assert(false && "unknown translation strategy");
  66. return nullptr;
  67. }
  68. // Interface to describe how to translate types from HL-dxil to dxil.
  69. class FunctionTypeTranslator {
  70. public:
  71. // Arguments can be exploded into multiple copies of the same type.
  72. // For example a <2 x i32> could become { i32, 2 } if the vector
  73. // is expanded in place or { i32, 1 } if the call is replicated.
  74. struct ArgumentType {
  75. Type *type;
  76. int count;
  77. ArgumentType(Type *ty, int cnt = 1) : type(ty), count(cnt) {}
  78. };
  79. virtual ~FunctionTypeTranslator() {}
  80. virtual Type *TranslateReturnType(CallInst *CI) = 0;
  81. virtual ArgumentType TranslateArgumentType(Value *OrigArg) = 0;
  82. };
  83. // Class to create the new function with the translated types for low-level dxil.
  84. class FunctionTranslator {
  85. public:
  86. template <typename TypeTranslator>
  87. static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
  88. TypeTranslator typeTranslator;
  89. return GetLoweredFunction(typeTranslator, CI, lower);
  90. }
  91. static Function *GetLoweredFunction(FunctionTypeTranslator &typeTranslator, CallInst *CI, ExtensionLowering &lower) {
  92. FunctionTranslator translator(typeTranslator, lower);
  93. return translator.GetLoweredFunction(CI);
  94. }
  95. private:
  96. FunctionTypeTranslator &m_typeTranslator;
  97. ExtensionLowering &m_lower;
  98. FunctionTranslator(FunctionTypeTranslator &typeTranslator, ExtensionLowering &lower)
  99. : m_typeTranslator(typeTranslator)
  100. , m_lower(lower)
  101. {}
  102. Function *GetLoweredFunction(CallInst *CI) {
  103. // Ge the return type of replicated function.
  104. Type *RetTy = m_typeTranslator.TranslateReturnType(CI);
  105. if (!RetTy)
  106. return nullptr;
  107. // Get the Function type for replicated function.
  108. FunctionType *FTy = GetFunctionType(CI, RetTy);
  109. if (!FTy)
  110. return nullptr;
  111. // Create a new function that will be the replicated call.
  112. AttributeSet attributes = GetAttributeSet(CI);
  113. std::string name = m_lower.GetExtensionName(CI);
  114. return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
  115. }
  116. FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
  117. // Create a new function type with the translated argument.
  118. SmallVector<Type *, 10> ParamTypes;
  119. ParamTypes.reserve(CI->getNumArgOperands());
  120. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  121. Value *OrigArg = CI->getArgOperand(i);
  122. FunctionTypeTranslator::ArgumentType newArgType = m_typeTranslator.TranslateArgumentType(OrigArg);
  123. for (int i = 0; i < newArgType.count; ++i) {
  124. ParamTypes.push_back(newArgType.type);
  125. }
  126. }
  127. const bool IsVarArg = false;
  128. return FunctionType::get(RetTy, ParamTypes, IsVarArg);
  129. }
  130. AttributeSet GetAttributeSet(CallInst *CI) {
  131. Function *F = CI->getCalledFunction();
  132. AttributeSet attributes;
  133. auto copyAttribute = [=, &attributes](Attribute::AttrKind a) {
  134. if (F->hasFnAttribute(a)) {
  135. attributes = attributes.addAttribute(CI->getContext(), AttributeSet::FunctionIndex, a);
  136. }
  137. };
  138. copyAttribute(Attribute::AttrKind::ReadOnly);
  139. copyAttribute(Attribute::AttrKind::ReadNone);
  140. copyAttribute(Attribute::AttrKind::ArgMemOnly);
  141. return attributes;
  142. }
  143. };
  144. ///////////////////////////////////////////////////////////////////////////////
  145. // NoTranslation Lowering.
  146. class NoTranslationTypeTranslator : public FunctionTypeTranslator {
  147. virtual Type *TranslateReturnType(CallInst *CI) override {
  148. return CI->getType();
  149. }
  150. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  151. return ArgumentType(OrigArg->getType());
  152. }
  153. };
  154. llvm::Value *ExtensionLowering::NoTranslation(CallInst *CI) {
  155. Function *NoTranslationFunction = FunctionTranslator::GetLoweredFunction<NoTranslationTypeTranslator>(CI, *this);
  156. if (!NoTranslationFunction)
  157. return nullptr;
  158. IRBuilder<> builder(CI);
  159. SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
  160. return builder.CreateCall(NoTranslationFunction, args);
  161. }
  162. ///////////////////////////////////////////////////////////////////////////////
  163. // Replicated Lowering.
  164. enum {
  165. NO_COMMON_VECTOR_SIZE = 0x0,
  166. };
  167. // Find the vector size that will be used for replication.
  168. // The function call will be replicated once for each element of the vector
  169. // size.
  170. static unsigned GetReplicatedVectorSize(llvm::CallInst *CI) {
  171. unsigned commonVectorSize = NO_COMMON_VECTOR_SIZE;
  172. Type *RetTy = CI->getType();
  173. if (RetTy->isVectorTy())
  174. commonVectorSize = RetTy->getVectorNumElements();
  175. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  176. Type *Ty = CI->getArgOperand(i)->getType();
  177. if (Ty->isVectorTy()) {
  178. unsigned vectorSize = Ty->getVectorNumElements();
  179. if (commonVectorSize != NO_COMMON_VECTOR_SIZE && commonVectorSize != vectorSize) {
  180. // Inconsistent vector sizes; need a different strategy.
  181. return NO_COMMON_VECTOR_SIZE;
  182. }
  183. commonVectorSize = vectorSize;
  184. }
  185. }
  186. return commonVectorSize;
  187. }
  188. class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
  189. virtual Type *TranslateReturnType(CallInst *CI) override {
  190. unsigned commonVectorSize = GetReplicatedVectorSize(CI);
  191. if (commonVectorSize == NO_COMMON_VECTOR_SIZE)
  192. return nullptr;
  193. // Result should be vector or void.
  194. Type *RetTy = CI->getType();
  195. if (!RetTy->isVoidTy() && !RetTy->isVectorTy())
  196. return nullptr;
  197. if (RetTy->isVectorTy()) {
  198. RetTy = RetTy->getVectorElementType();
  199. }
  200. return RetTy;
  201. }
  202. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  203. Type *Ty = OrigArg->getType();
  204. if (Ty->isVectorTy()) {
  205. Ty = Ty->getVectorElementType();
  206. }
  207. return ArgumentType(Ty);
  208. }
  209. };
  210. class ReplicateCall {
  211. public:
  212. ReplicateCall(CallInst *CI, Function &ReplicatedFunction)
  213. : m_CI(CI)
  214. , m_ReplicatedFunction(ReplicatedFunction)
  215. , m_numReplicatedCalls(GetReplicatedVectorSize(CI))
  216. , m_ScalarizeArgIdx()
  217. , m_Args(CI->getNumArgOperands())
  218. , m_ReplicatedCalls(m_numReplicatedCalls)
  219. , m_Builder(CI)
  220. {
  221. assert(m_numReplicatedCalls != NO_COMMON_VECTOR_SIZE);
  222. }
  223. Value *Generate() {
  224. CollectReplicatedArguments();
  225. CreateReplicatedCalls();
  226. Value *retVal = GetReturnValue();
  227. return retVal;
  228. }
  229. private:
  230. CallInst *m_CI;
  231. Function &m_ReplicatedFunction;
  232. unsigned m_numReplicatedCalls;
  233. SmallVector<unsigned, 10> m_ScalarizeArgIdx;
  234. SmallVector<Value *, 10> m_Args;
  235. SmallVector<Value *, 10> m_ReplicatedCalls;
  236. IRBuilder<> m_Builder;
  237. // Collect replicated arguments.
  238. // For non-vector arguments we can add them to the args list directly.
  239. // These args will be shared by each replicated call. For the vector
  240. // arguments we remember the position it will go in the argument list.
  241. // We will fill in the vector args below when we replicate the call
  242. // (once for each vector lane).
  243. void CollectReplicatedArguments() {
  244. for (unsigned i = 0; i < m_CI->getNumArgOperands(); ++i) {
  245. Type *Ty = m_CI->getArgOperand(i)->getType();
  246. if (Ty->isVectorTy()) {
  247. m_ScalarizeArgIdx.push_back(i);
  248. }
  249. else {
  250. m_Args[i] = m_CI->getArgOperand(i);
  251. }
  252. }
  253. }
  254. // Create replicated calls.
  255. // Replicate the call once for each element of the replicated vector size.
  256. void CreateReplicatedCalls() {
  257. for (unsigned vecIdx = 0; vecIdx < m_numReplicatedCalls; vecIdx++) {
  258. for (unsigned i = 0, e = m_ScalarizeArgIdx.size(); i < e; ++i) {
  259. unsigned argIdx = m_ScalarizeArgIdx[i];
  260. Value *arg = m_CI->getArgOperand(argIdx);
  261. m_Args[argIdx] = m_Builder.CreateExtractElement(arg, vecIdx);
  262. }
  263. Value *EltOP = m_Builder.CreateCall(&m_ReplicatedFunction, m_Args);
  264. m_ReplicatedCalls[vecIdx] = EltOP;
  265. }
  266. }
  267. // Get the final replicated value.
  268. // If the function is a void type then return (arbitrarily) the first call.
  269. // We do not return nullptr because that indicates a failure to replicate.
  270. // If the function is a vector type then aggregate all of the replicated
  271. // call values into a new vector.
  272. Value *GetReturnValue() {
  273. if (m_CI->getType()->isVoidTy())
  274. return m_ReplicatedCalls.back();
  275. Value *retVal = llvm::UndefValue::get(m_CI->getType());
  276. for (unsigned i = 0; i < m_ReplicatedCalls.size(); ++i)
  277. retVal = m_Builder.CreateInsertElement(retVal, m_ReplicatedCalls[i], i);
  278. return retVal;
  279. }
  280. };
  281. // Translate the HL call by replicating the call for each vector element.
  282. //
  283. // For example,
  284. //
  285. // <2xi32> %r = call @ext.foo(i32 %op, <2xi32> %v)
  286. // ==>
  287. // %r.1 = call @ext.foo.s(i32 %op, i32 %v.1)
  288. // %r.2 = call @ext.foo.s(i32 %op, i32 %v.2)
  289. // <2xi32> %r.v.1 = insertelement %r.1, 0, <2xi32> undef
  290. // <2xi32> %r.v.2 = insertelement %r.2, 1, %r.v.1
  291. //
  292. // You can then RAWU %r with %r.v.2. The RAWU is not done by the translate function.
  293. Value *ExtensionLowering::Replicate(CallInst *CI) {
  294. Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
  295. if (!ReplicatedFunction)
  296. return NoTranslation(CI);
  297. ReplicateCall replicate(CI, *ReplicatedFunction);
  298. return replicate.Generate();
  299. }
  300. ///////////////////////////////////////////////////////////////////////////////
  301. // Packed Lowering.
  302. class PackCall {
  303. public:
  304. PackCall(CallInst *CI, Function &PackedFunction)
  305. : m_CI(CI)
  306. , m_packedFunction(PackedFunction)
  307. , m_builder(CI)
  308. {}
  309. Value *Generate() {
  310. SmallVector<Value *, 10> args;
  311. PackArgs(args);
  312. Value *result = CreateCall(args);
  313. return UnpackResult(result);
  314. }
  315. static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
  316. assert(vecTy->isVectorTy());
  317. Type *elementTy = vecTy->getVectorElementType();
  318. unsigned numElements = vecTy->getVectorNumElements();
  319. SmallVector<Type *, 4> elements;
  320. for (unsigned i = 0; i < numElements; ++i)
  321. elements.push_back(elementTy);
  322. return StructType::get(vecTy->getContext(), elements);
  323. }
  324. private:
  325. CallInst *m_CI;
  326. Function &m_packedFunction;
  327. IRBuilder<> m_builder;
  328. void PackArgs(SmallVectorImpl<Value*> &args) {
  329. args.clear();
  330. for (Value *arg : m_CI->arg_operands()) {
  331. if (arg->getType()->isVectorTy())
  332. arg = PackVectorIntoStruct(m_builder, arg);
  333. args.push_back(arg);
  334. }
  335. }
  336. Value *CreateCall(const SmallVectorImpl<Value*> &args) {
  337. return m_builder.CreateCall(&m_packedFunction, args);
  338. }
  339. Value *UnpackResult(Value *result) {
  340. if (result->getType()->isStructTy()) {
  341. result = PackStructIntoVector(m_builder, result);
  342. }
  343. return result;
  344. }
  345. static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
  346. assert(structTy->isStructTy());
  347. return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
  348. }
  349. static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
  350. StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
  351. Value *packed = UndefValue::get(structTy);
  352. unsigned numElements = structTy->getStructNumElements();
  353. for (unsigned i = 0; i < numElements; ++i) {
  354. Value *element = builder.CreateExtractElement(vec, i);
  355. packed = builder.CreateInsertValue(packed, element, { i });
  356. }
  357. return packed;
  358. }
  359. static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
  360. Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
  361. Value *packed = UndefValue::get(vecTy);
  362. unsigned numElements = vecTy->getVectorNumElements();
  363. for (unsigned i = 0; i < numElements; ++i) {
  364. Value *element = builder.CreateExtractValue(strukt, i);
  365. packed = builder.CreateInsertElement(packed, element, i);
  366. }
  367. return packed;
  368. }
  369. };
  370. class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
  371. virtual Type *TranslateReturnType(CallInst *CI) override {
  372. return TranslateIfVector(CI->getType());
  373. }
  374. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  375. return ArgumentType(TranslateIfVector(OrigArg->getType()));
  376. }
  377. Type *TranslateIfVector(Type *ty) {
  378. if (ty->isVectorTy())
  379. ty = PackCall::ConvertVectorTypeToStructType(ty);
  380. return ty;
  381. }
  382. };
  383. Value *ExtensionLowering::Pack(CallInst *CI) {
  384. Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
  385. if (!PackedFunction)
  386. return NoTranslation(CI);
  387. PackCall pack(CI, *PackedFunction);
  388. Value *result = pack.Generate();
  389. return result;
  390. }
  391. ///////////////////////////////////////////////////////////////////////////////
  392. // Resource Lowering.
  393. // Modify a call to a resouce method. Makes the following transformation:
  394. //
  395. // 1. Convert non-void return value to dx.types.ResRet.
  396. // 2. Expand vectors in place as separate arguments.
  397. //
  398. // Example
  399. // -----------------------------------------------------------------------------
  400. //
  401. // %0 = call <2 x float> MyBufferOp(i32 138, %class.Buffer %3, <2 x i32> <1 , 2> )
  402. // %r = call %dx.types.ResRet.f32 MyBufferOp(i32 138, %dx.types.Handle %buf, i32 1, i32 2 )
  403. // %x = extractvalue %r, 0
  404. // %y = extractvalue %r, 1
  405. // %v = <2 x float> undef
  406. // %v.1 = insertelement %v, %x, 0
  407. // %v.2 = insertelement %v.1, %y, 1
  408. class ResourceMethodCall {
  409. public:
  410. ResourceMethodCall(CallInst *CI, Function &explodedFunction)
  411. : m_CI(CI)
  412. , m_explodedFunction(explodedFunction)
  413. , m_builder(CI)
  414. { }
  415. Value *Generate() {
  416. SmallVector<Value *, 16> args;
  417. ExplodeArgs(args);
  418. Value *result = CreateCall(args);
  419. result = ConvertResult(result);
  420. return result;
  421. }
  422. private:
  423. CallInst *m_CI;
  424. Function &m_explodedFunction;
  425. IRBuilder<> m_builder;
  426. void ExplodeArgs(SmallVectorImpl<Value*> &args) {
  427. for (Value *arg : m_CI->arg_operands()) {
  428. // vector arg: <N x ty> -> ty, ty, ..., ty (N times)
  429. if (arg->getType()->isVectorTy()) {
  430. for (unsigned i = 0; i < arg->getType()->getVectorNumElements(); i++) {
  431. Value *xarg = m_builder.CreateExtractElement(arg, i);
  432. args.push_back(xarg);
  433. }
  434. }
  435. // any other value: arg -> arg
  436. else {
  437. args.push_back(arg);
  438. }
  439. }
  440. }
  441. Value *CreateCall(const SmallVectorImpl<Value*> &args) {
  442. return m_builder.CreateCall(&m_explodedFunction, args);
  443. }
  444. Value *ConvertResult(Value *result) {
  445. Type *origRetTy = m_CI->getType();
  446. if (origRetTy->isVoidTy())
  447. return ConvertVoidResult(result);
  448. else if (origRetTy->isVectorTy())
  449. return ConvertVectorResult(origRetTy, result);
  450. else
  451. return ConvertScalarResult(origRetTy, result);
  452. }
  453. // Void result does not need any conversion.
  454. Value *ConvertVoidResult(Value *result) {
  455. return result;
  456. }
  457. // Vector result will be populated with the elements from the resource return.
  458. Value *ConvertVectorResult(Type *origRetTy, Value *result) {
  459. Type *resourceRetTy = result->getType();
  460. assert(origRetTy->isVectorTy());
  461. assert(resourceRetTy->isStructTy() && "expected resource return type to be a struct");
  462. const unsigned vectorSize = origRetTy->getVectorNumElements();
  463. const unsigned structSize = resourceRetTy->getStructNumElements();
  464. const unsigned size = std::min(vectorSize, structSize);
  465. assert(vectorSize < structSize);
  466. // Copy resource struct elements to vector.
  467. Value *vector = UndefValue::get(origRetTy);
  468. for (unsigned i = 0; i < size; ++i) {
  469. Value *element = m_builder.CreateExtractValue(result, { i });
  470. vector = m_builder.CreateInsertElement(vector, element, i);
  471. }
  472. return vector;
  473. }
  474. // Scalar result will be populated with the first element of the resource return.
  475. Value *ConvertScalarResult(Type *origRetTy, Value *result) {
  476. assert(origRetTy->isSingleValueType());
  477. return m_builder.CreateExtractValue(result, { 0 });
  478. }
  479. };
  480. // Translate function return and argument types for resource method lowering.
  481. class ResourceFunctionTypeTranslator : public FunctionTypeTranslator {
  482. public:
  483. ResourceFunctionTypeTranslator(OP &hlslOp) : m_hlslOp(hlslOp) {}
  484. // Translate return type as follows:
  485. //
  486. // void -> void
  487. // <N x ty> -> dx.types.ResRet.ty
  488. // ty -> dx.types.ResRet.ty
  489. virtual Type *TranslateReturnType(CallInst *CI) override {
  490. Type *RetTy = CI->getType();
  491. if (RetTy->isVoidTy())
  492. return RetTy;
  493. else if (RetTy->isVectorTy())
  494. RetTy = RetTy->getVectorElementType();
  495. return m_hlslOp.GetResRetType(RetTy);
  496. }
  497. // Translate argument type as follows:
  498. //
  499. // resource -> dx.types.Handle
  500. // <N x ty> -> { ty, N }
  501. // ty -> { ty, 1 }
  502. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  503. int count = 1;
  504. Type *ty = OrigArg->getType();
  505. if (ty->isVectorTy()) {
  506. count = ty->getVectorNumElements();
  507. ty = ty->getVectorElementType();
  508. }
  509. return ArgumentType(ty, count);
  510. }
  511. private:
  512. OP& m_hlslOp;
  513. };
  514. Value *ExtensionLowering::Resource(CallInst *CI) {
  515. ResourceFunctionTypeTranslator resourceTypeTranslator(m_hlslOp);
  516. Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
  517. if (!resourceFunction)
  518. return NoTranslation(CI);
  519. ResourceMethodCall explode(CI, *resourceFunction);
  520. Value *result = explode.Generate();
  521. return result;
  522. }
  523. ///////////////////////////////////////////////////////////////////////////////
  524. // Dxil Lowering.
  525. Value *ExtensionLowering::Dxil(CallInst *CI) {
  526. // Map the extension opcode to the corresponding dxil opcode.
  527. unsigned extOpcode = GetHLOpcode(CI);
  528. OP::OpCode dxilOpcode;
  529. if (!m_helper->GetDxilOpcode(extOpcode, dxilOpcode))
  530. return nullptr;
  531. // Find the dxil function based on the overload type.
  532. Type *overloadTy = m_hlslOp.GetOverloadType(dxilOpcode, CI->getCalledFunction());
  533. Function *F = m_hlslOp.GetOpFunc(dxilOpcode, overloadTy->getScalarType());
  534. // Update the opcode in the original call so we can just copy it below.
  535. // We are about to delete this call anyway.
  536. CI->setOperand(0, m_hlslOp.GetI32Const(static_cast<unsigned>(dxilOpcode)));
  537. // Create the new call.
  538. Value *result = nullptr;
  539. if (overloadTy->isVectorTy()) {
  540. ReplicateCall replicate(CI, *F);
  541. result = replicate.Generate();
  542. }
  543. else {
  544. IRBuilder<> builder(CI);
  545. SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
  546. result = builder.CreateCall(F, args);
  547. }
  548. return result;
  549. }
  550. ///////////////////////////////////////////////////////////////////////////////
  551. // Computing Extension Names.
  552. // Compute the name to use for the intrinsic function call once it is lowered to dxil.
  553. // First checks to see if we have a custom name from the codegen helper and if not
  554. // chooses a default name based on the lowergin strategy.
  555. class ExtensionName {
  556. public:
  557. ExtensionName(CallInst *CI, ExtensionLowering::Strategy strategy, HLSLExtensionsCodegenHelper *helper)
  558. : m_CI(CI)
  559. , m_strategy(strategy)
  560. , m_helper(helper)
  561. {}
  562. std::string Get() {
  563. std::string name;
  564. if (m_helper)
  565. name = GetCustomExtensionName(m_CI, *m_helper);
  566. if (!HasCustomExtensionName(name))
  567. name = GetDefaultCustomExtensionName(m_CI, ExtensionLowering::GetStrategyName(m_strategy));
  568. return name;
  569. }
  570. private:
  571. CallInst *m_CI;
  572. ExtensionLowering::Strategy m_strategy;
  573. HLSLExtensionsCodegenHelper *m_helper;
  574. static std::string GetCustomExtensionName(CallInst *CI, HLSLExtensionsCodegenHelper &helper) {
  575. unsigned opcode = GetHLOpcode(CI);
  576. std::string name = helper.GetIntrinsicName(opcode);
  577. ReplaceOverloadMarkerWithTypeName(name, CI);
  578. return name;
  579. }
  580. static std::string GetDefaultCustomExtensionName(CallInst *CI, StringRef strategyName) {
  581. return (Twine(CI->getCalledFunction()->getName()) + "." + Twine(strategyName)).str();
  582. }
  583. static bool HasCustomExtensionName(const std::string name) {
  584. return name.size() > 0;
  585. }
  586. // Choose the (return value or argument) type that determines the overload type
  587. // for the intrinsic call.
  588. // For now we take the return type as the overload. If the return is void we
  589. // take the first (non-opcode) argument as the overload type. We could extend the
  590. // $o sytnax in the extension name to explicitly specify the overload slot (e.g.
  591. // $o:3 would say the overload type is determined by parameter 3.
  592. static Type *SelectOverloadSlot(CallInst *CI) {
  593. Type *ty = CI->getType();
  594. if (ty->isVoidTy()) {
  595. if (CI->getNumArgOperands() > 1)
  596. ty = CI->getArgOperand(1)->getType(); // First non-opcode argument.
  597. }
  598. return ty;
  599. }
  600. static Type *GetOverloadType(CallInst *CI) {
  601. Type *ty = SelectOverloadSlot(CI);
  602. if (ty->isVectorTy())
  603. ty = ty->getVectorElementType();
  604. return ty;
  605. }
  606. static std::string GetTypeName(Type *ty) {
  607. std::string typeName;
  608. llvm::raw_string_ostream os(typeName);
  609. ty->print(os);
  610. os.flush();
  611. return typeName;
  612. }
  613. static std::string GetOverloadTypeName(CallInst *CI) {
  614. Type *ty = GetOverloadType(CI);
  615. return GetTypeName(ty);
  616. }
  617. // Find the occurence of the overload marker $o and replace it the the overload type name.
  618. static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
  619. const char *OverloadMarker = "$o";
  620. const size_t OverloadMarkerLength = 2;
  621. size_t pos = functionName.find(OverloadMarker);
  622. if (pos != std::string::npos) {
  623. std::string typeName = GetOverloadTypeName(CI);
  624. functionName.replace(pos, OverloadMarkerLength, typeName);
  625. }
  626. }
  627. };
  628. std::string ExtensionLowering::GetExtensionName(llvm::CallInst *CI) {
  629. ExtensionName name(CI, m_strategy, m_helper);
  630. return name.Get();
  631. }