HLOperationLowerExtension.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLOperationLowerExtension.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/HLOperationLowerExtension.h"
  10. #include "dxc/HLSL/DxilModule.h"
  11. #include "dxc/HLSL/DxilOperations.h"
  12. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  13. #include "dxc/HLSL/HLModule.h"
  14. #include "dxc/HLSL/HLOperationLower.h"
  15. #include "dxc/HLSL/HLOperations.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "llvm/ADT/StringRef.h"
  18. #include "llvm/IR/IRBuilder.h"
  19. #include "llvm/IR/Instructions.h"
  20. #include "llvm/IR/Module.h"
  21. #include "llvm/Support/raw_os_ostream.h"
  22. using namespace llvm;
  23. using namespace hlsl;
  24. ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
  25. if (strategy.size() < 1)
  26. return Strategy::Unknown;
  27. switch (strategy[0]) {
  28. case 'n': return Strategy::NoTranslation;
  29. case 'r': return Strategy::Replicate;
  30. case 'p': return Strategy::Pack;
  31. case 'm': return Strategy::Resource;
  32. case 'd': return Strategy::Dxil;
  33. default: break;
  34. }
  35. return Strategy::Unknown;
  36. }
  37. llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
  38. switch (strategy) {
  39. case Strategy::NoTranslation: return "n";
  40. case Strategy::Replicate: return "r";
  41. case Strategy::Pack: return "p";
  42. case Strategy::Resource: return "m"; // m for resource method
  43. case Strategy::Dxil: return "d";
  44. default: break;
  45. }
  46. return "?";
  47. }
  48. ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp)
  49. : m_strategy(strategy), m_helper(helper), m_hlslOp(hlslOp)
  50. {}
  51. ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp)
  52. : ExtensionLowering(GetStrategy(strategy), helper, hlslOp)
  53. {}
  54. llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
  55. switch (m_strategy) {
  56. case Strategy::NoTranslation: return NoTranslation(CI);
  57. case Strategy::Replicate: return Replicate(CI);
  58. case Strategy::Pack: return Pack(CI);
  59. case Strategy::Resource: return Resource(CI);
  60. case Strategy::Dxil: return Dxil(CI);
  61. default: break;
  62. }
  63. return Unknown(CI);
  64. }
  65. llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
  66. assert(false && "unknown translation strategy");
  67. return nullptr;
  68. }
  69. // Interface to describe how to translate types from HL-dxil to dxil.
  70. class FunctionTypeTranslator {
  71. public:
  72. // Arguments can be exploded into multiple copies of the same type.
  73. // For example a <2 x i32> could become { i32, 2 } if the vector
  74. // is expanded in place or { i32, 1 } if the call is replicated.
  75. struct ArgumentType {
  76. Type *type;
  77. int count;
  78. ArgumentType(Type *ty, int cnt = 1) : type(ty), count(cnt) {}
  79. };
  80. virtual ~FunctionTypeTranslator() {}
  81. virtual Type *TranslateReturnType(CallInst *CI) = 0;
  82. virtual ArgumentType TranslateArgumentType(Value *OrigArg) = 0;
  83. };
  84. // Class to create the new function with the translated types for low-level dxil.
  85. class FunctionTranslator {
  86. public:
  87. template <typename TypeTranslator>
  88. static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
  89. TypeTranslator typeTranslator;
  90. return GetLoweredFunction(typeTranslator, CI, lower);
  91. }
  92. static Function *GetLoweredFunction(FunctionTypeTranslator &typeTranslator, CallInst *CI, ExtensionLowering &lower) {
  93. FunctionTranslator translator(typeTranslator, lower);
  94. return translator.GetLoweredFunction(CI);
  95. }
  96. private:
  97. FunctionTypeTranslator &m_typeTranslator;
  98. ExtensionLowering &m_lower;
  99. FunctionTranslator(FunctionTypeTranslator &typeTranslator, ExtensionLowering &lower)
  100. : m_typeTranslator(typeTranslator)
  101. , m_lower(lower)
  102. {}
  103. Function *GetLoweredFunction(CallInst *CI) {
  104. // Ge the return type of replicated function.
  105. Type *RetTy = m_typeTranslator.TranslateReturnType(CI);
  106. if (!RetTy)
  107. return nullptr;
  108. // Get the Function type for replicated function.
  109. FunctionType *FTy = GetFunctionType(CI, RetTy);
  110. if (!FTy)
  111. return nullptr;
  112. // Create a new function that will be the replicated call.
  113. AttributeSet attributes = GetAttributeSet(CI);
  114. std::string name = m_lower.GetExtensionName(CI);
  115. return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
  116. }
  117. FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
  118. // Create a new function type with the translated argument.
  119. SmallVector<Type *, 10> ParamTypes;
  120. ParamTypes.reserve(CI->getNumArgOperands());
  121. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  122. Value *OrigArg = CI->getArgOperand(i);
  123. FunctionTypeTranslator::ArgumentType newArgType = m_typeTranslator.TranslateArgumentType(OrigArg);
  124. for (int i = 0; i < newArgType.count; ++i) {
  125. ParamTypes.push_back(newArgType.type);
  126. }
  127. }
  128. const bool IsVarArg = false;
  129. return FunctionType::get(RetTy, ParamTypes, IsVarArg);
  130. }
  131. AttributeSet GetAttributeSet(CallInst *CI) {
  132. Function *F = CI->getCalledFunction();
  133. AttributeSet attributes;
  134. auto copyAttribute = [=, &attributes](Attribute::AttrKind a) {
  135. if (F->hasFnAttribute(a)) {
  136. attributes = attributes.addAttribute(CI->getContext(), AttributeSet::FunctionIndex, a);
  137. }
  138. };
  139. copyAttribute(Attribute::AttrKind::ReadOnly);
  140. copyAttribute(Attribute::AttrKind::ReadNone);
  141. copyAttribute(Attribute::AttrKind::ArgMemOnly);
  142. return attributes;
  143. }
  144. };
  145. ///////////////////////////////////////////////////////////////////////////////
  146. // NoTranslation Lowering.
  147. class NoTranslationTypeTranslator : public FunctionTypeTranslator {
  148. virtual Type *TranslateReturnType(CallInst *CI) override {
  149. return CI->getType();
  150. }
  151. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  152. return ArgumentType(OrigArg->getType());
  153. }
  154. };
  155. llvm::Value *ExtensionLowering::NoTranslation(CallInst *CI) {
  156. Function *NoTranslationFunction = FunctionTranslator::GetLoweredFunction<NoTranslationTypeTranslator>(CI, *this);
  157. if (!NoTranslationFunction)
  158. return nullptr;
  159. IRBuilder<> builder(CI);
  160. SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
  161. return builder.CreateCall(NoTranslationFunction, args);
  162. }
  163. ///////////////////////////////////////////////////////////////////////////////
  164. // Replicated Lowering.
  165. enum {
  166. NO_COMMON_VECTOR_SIZE = 0x0,
  167. };
  168. // Find the vector size that will be used for replication.
  169. // The function call will be replicated once for each element of the vector
  170. // size.
  171. static unsigned GetReplicatedVectorSize(llvm::CallInst *CI) {
  172. unsigned commonVectorSize = NO_COMMON_VECTOR_SIZE;
  173. Type *RetTy = CI->getType();
  174. if (RetTy->isVectorTy())
  175. commonVectorSize = RetTy->getVectorNumElements();
  176. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  177. Type *Ty = CI->getArgOperand(i)->getType();
  178. if (Ty->isVectorTy()) {
  179. unsigned vectorSize = Ty->getVectorNumElements();
  180. if (commonVectorSize != NO_COMMON_VECTOR_SIZE && commonVectorSize != vectorSize) {
  181. // Inconsistent vector sizes; need a different strategy.
  182. return NO_COMMON_VECTOR_SIZE;
  183. }
  184. commonVectorSize = vectorSize;
  185. }
  186. }
  187. return commonVectorSize;
  188. }
  189. class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
  190. virtual Type *TranslateReturnType(CallInst *CI) override {
  191. unsigned commonVectorSize = GetReplicatedVectorSize(CI);
  192. if (commonVectorSize == NO_COMMON_VECTOR_SIZE)
  193. return nullptr;
  194. // Result should be vector or void.
  195. Type *RetTy = CI->getType();
  196. if (!RetTy->isVoidTy() && !RetTy->isVectorTy())
  197. return nullptr;
  198. if (RetTy->isVectorTy()) {
  199. RetTy = RetTy->getVectorElementType();
  200. }
  201. return RetTy;
  202. }
  203. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  204. Type *Ty = OrigArg->getType();
  205. if (Ty->isVectorTy()) {
  206. Ty = Ty->getVectorElementType();
  207. }
  208. return ArgumentType(Ty);
  209. }
  210. };
  211. class ReplicateCall {
  212. public:
  213. ReplicateCall(CallInst *CI, Function &ReplicatedFunction)
  214. : m_CI(CI)
  215. , m_ReplicatedFunction(ReplicatedFunction)
  216. , m_numReplicatedCalls(GetReplicatedVectorSize(CI))
  217. , m_ScalarizeArgIdx()
  218. , m_Args(CI->getNumArgOperands())
  219. , m_ReplicatedCalls(m_numReplicatedCalls)
  220. , m_Builder(CI)
  221. {
  222. assert(m_numReplicatedCalls != NO_COMMON_VECTOR_SIZE);
  223. }
  224. Value *Generate() {
  225. CollectReplicatedArguments();
  226. CreateReplicatedCalls();
  227. Value *retVal = GetReturnValue();
  228. return retVal;
  229. }
  230. private:
  231. CallInst *m_CI;
  232. Function &m_ReplicatedFunction;
  233. unsigned m_numReplicatedCalls;
  234. SmallVector<unsigned, 10> m_ScalarizeArgIdx;
  235. SmallVector<Value *, 10> m_Args;
  236. SmallVector<Value *, 10> m_ReplicatedCalls;
  237. IRBuilder<> m_Builder;
  238. // Collect replicated arguments.
  239. // For non-vector arguments we can add them to the args list directly.
  240. // These args will be shared by each replicated call. For the vector
  241. // arguments we remember the position it will go in the argument list.
  242. // We will fill in the vector args below when we replicate the call
  243. // (once for each vector lane).
  244. void CollectReplicatedArguments() {
  245. for (unsigned i = 0; i < m_CI->getNumArgOperands(); ++i) {
  246. Type *Ty = m_CI->getArgOperand(i)->getType();
  247. if (Ty->isVectorTy()) {
  248. m_ScalarizeArgIdx.push_back(i);
  249. }
  250. else {
  251. m_Args[i] = m_CI->getArgOperand(i);
  252. }
  253. }
  254. }
  255. // Create replicated calls.
  256. // Replicate the call once for each element of the replicated vector size.
  257. void CreateReplicatedCalls() {
  258. for (unsigned vecIdx = 0; vecIdx < m_numReplicatedCalls; vecIdx++) {
  259. for (unsigned i = 0, e = m_ScalarizeArgIdx.size(); i < e; ++i) {
  260. unsigned argIdx = m_ScalarizeArgIdx[i];
  261. Value *arg = m_CI->getArgOperand(argIdx);
  262. m_Args[argIdx] = m_Builder.CreateExtractElement(arg, vecIdx);
  263. }
  264. Value *EltOP = m_Builder.CreateCall(&m_ReplicatedFunction, m_Args);
  265. m_ReplicatedCalls[vecIdx] = EltOP;
  266. }
  267. }
  268. // Get the final replicated value.
  269. // If the function is a void type then return (arbitrarily) the first call.
  270. // We do not return nullptr because that indicates a failure to replicate.
  271. // If the function is a vector type then aggregate all of the replicated
  272. // call values into a new vector.
  273. Value *GetReturnValue() {
  274. if (m_CI->getType()->isVoidTy())
  275. return m_ReplicatedCalls.back();
  276. Value *retVal = llvm::UndefValue::get(m_CI->getType());
  277. for (unsigned i = 0; i < m_ReplicatedCalls.size(); ++i)
  278. retVal = m_Builder.CreateInsertElement(retVal, m_ReplicatedCalls[i], i);
  279. return retVal;
  280. }
  281. };
  282. // Translate the HL call by replicating the call for each vector element.
  283. //
  284. // For example,
  285. //
  286. // <2xi32> %r = call @ext.foo(i32 %op, <2xi32> %v)
  287. // ==>
  288. // %r.1 = call @ext.foo.s(i32 %op, i32 %v.1)
  289. // %r.2 = call @ext.foo.s(i32 %op, i32 %v.2)
  290. // <2xi32> %r.v.1 = insertelement %r.1, 0, <2xi32> undef
  291. // <2xi32> %r.v.2 = insertelement %r.2, 1, %r.v.1
  292. //
  293. // You can then RAWU %r with %r.v.2. The RAWU is not done by the translate function.
  294. Value *ExtensionLowering::Replicate(CallInst *CI) {
  295. Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
  296. if (!ReplicatedFunction)
  297. return NoTranslation(CI);
  298. ReplicateCall replicate(CI, *ReplicatedFunction);
  299. return replicate.Generate();
  300. }
  301. ///////////////////////////////////////////////////////////////////////////////
  302. // Packed Lowering.
  303. class PackCall {
  304. public:
  305. PackCall(CallInst *CI, Function &PackedFunction)
  306. : m_CI(CI)
  307. , m_packedFunction(PackedFunction)
  308. , m_builder(CI)
  309. {}
  310. Value *Generate() {
  311. SmallVector<Value *, 10> args;
  312. PackArgs(args);
  313. Value *result = CreateCall(args);
  314. return UnpackResult(result);
  315. }
  316. static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
  317. assert(vecTy->isVectorTy());
  318. Type *elementTy = vecTy->getVectorElementType();
  319. unsigned numElements = vecTy->getVectorNumElements();
  320. SmallVector<Type *, 4> elements;
  321. for (unsigned i = 0; i < numElements; ++i)
  322. elements.push_back(elementTy);
  323. return StructType::get(vecTy->getContext(), elements);
  324. }
  325. private:
  326. CallInst *m_CI;
  327. Function &m_packedFunction;
  328. IRBuilder<> m_builder;
  329. void PackArgs(SmallVectorImpl<Value*> &args) {
  330. args.clear();
  331. for (Value *arg : m_CI->arg_operands()) {
  332. if (arg->getType()->isVectorTy())
  333. arg = PackVectorIntoStruct(m_builder, arg);
  334. args.push_back(arg);
  335. }
  336. }
  337. Value *CreateCall(const SmallVectorImpl<Value*> &args) {
  338. return m_builder.CreateCall(&m_packedFunction, args);
  339. }
  340. Value *UnpackResult(Value *result) {
  341. if (result->getType()->isStructTy()) {
  342. result = PackStructIntoVector(m_builder, result);
  343. }
  344. return result;
  345. }
  346. static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
  347. assert(structTy->isStructTy());
  348. return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
  349. }
  350. static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
  351. StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
  352. Value *packed = UndefValue::get(structTy);
  353. unsigned numElements = structTy->getStructNumElements();
  354. for (unsigned i = 0; i < numElements; ++i) {
  355. Value *element = builder.CreateExtractElement(vec, i);
  356. packed = builder.CreateInsertValue(packed, element, { i });
  357. }
  358. return packed;
  359. }
  360. static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
  361. Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
  362. Value *packed = UndefValue::get(vecTy);
  363. unsigned numElements = vecTy->getVectorNumElements();
  364. for (unsigned i = 0; i < numElements; ++i) {
  365. Value *element = builder.CreateExtractValue(strukt, i);
  366. packed = builder.CreateInsertElement(packed, element, i);
  367. }
  368. return packed;
  369. }
  370. };
  371. class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
  372. virtual Type *TranslateReturnType(CallInst *CI) override {
  373. return TranslateIfVector(CI->getType());
  374. }
  375. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  376. return ArgumentType(TranslateIfVector(OrigArg->getType()));
  377. }
  378. Type *TranslateIfVector(Type *ty) {
  379. if (ty->isVectorTy())
  380. ty = PackCall::ConvertVectorTypeToStructType(ty);
  381. return ty;
  382. }
  383. };
  384. Value *ExtensionLowering::Pack(CallInst *CI) {
  385. Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
  386. if (!PackedFunction)
  387. return NoTranslation(CI);
  388. PackCall pack(CI, *PackedFunction);
  389. Value *result = pack.Generate();
  390. return result;
  391. }
  392. ///////////////////////////////////////////////////////////////////////////////
  393. // Resource Lowering.
  394. // Modify a call to a resouce method. Makes the following transformation:
  395. //
  396. // 1. Convert non-void return value to dx.types.ResRet.
  397. // 2. Expand vectors in place as separate arguments.
  398. //
  399. // Example
  400. // -----------------------------------------------------------------------------
  401. //
  402. // %0 = call <2 x float> MyBufferOp(i32 138, %class.Buffer %3, <2 x i32> <1 , 2> )
  403. // %r = call %dx.types.ResRet.f32 MyBufferOp(i32 138, %dx.types.Handle %buf, i32 1, i32 2 )
  404. // %x = extractvalue %r, 0
  405. // %y = extractvalue %r, 1
  406. // %v = <2 x float> undef
  407. // %v.1 = insertelement %v, %x, 0
  408. // %v.2 = insertelement %v.1, %y, 1
  409. class ResourceMethodCall {
  410. public:
  411. ResourceMethodCall(CallInst *CI, Function &explodedFunction)
  412. : m_CI(CI)
  413. , m_explodedFunction(explodedFunction)
  414. , m_builder(CI)
  415. { }
  416. Value *Generate() {
  417. SmallVector<Value *, 16> args;
  418. ExplodeArgs(args);
  419. Value *result = CreateCall(args);
  420. result = ConvertResult(result);
  421. return result;
  422. }
  423. private:
  424. CallInst *m_CI;
  425. Function &m_explodedFunction;
  426. IRBuilder<> m_builder;
  427. void ExplodeArgs(SmallVectorImpl<Value*> &args) {
  428. for (Value *arg : m_CI->arg_operands()) {
  429. // vector arg: <N x ty> -> ty, ty, ..., ty (N times)
  430. if (arg->getType()->isVectorTy()) {
  431. for (unsigned i = 0; i < arg->getType()->getVectorNumElements(); i++) {
  432. Value *xarg = m_builder.CreateExtractElement(arg, i);
  433. args.push_back(xarg);
  434. }
  435. }
  436. // any other value: arg -> arg
  437. else {
  438. args.push_back(arg);
  439. }
  440. }
  441. }
  442. Value *CreateCall(const SmallVectorImpl<Value*> &args) {
  443. return m_builder.CreateCall(&m_explodedFunction, args);
  444. }
  445. Value *ConvertResult(Value *result) {
  446. Type *origRetTy = m_CI->getType();
  447. if (origRetTy->isVoidTy())
  448. return ConvertVoidResult(result);
  449. else if (origRetTy->isVectorTy())
  450. return ConvertVectorResult(origRetTy, result);
  451. else
  452. return ConvertScalarResult(origRetTy, result);
  453. }
  454. // Void result does not need any conversion.
  455. Value *ConvertVoidResult(Value *result) {
  456. return result;
  457. }
  458. // Vector result will be populated with the elements from the resource return.
  459. Value *ConvertVectorResult(Type *origRetTy, Value *result) {
  460. Type *resourceRetTy = result->getType();
  461. assert(origRetTy->isVectorTy());
  462. assert(resourceRetTy->isStructTy() && "expected resource return type to be a struct");
  463. const unsigned vectorSize = origRetTy->getVectorNumElements();
  464. const unsigned structSize = resourceRetTy->getStructNumElements();
  465. const unsigned size = std::min(vectorSize, structSize);
  466. assert(vectorSize < structSize);
  467. // Copy resource struct elements to vector.
  468. Value *vector = UndefValue::get(origRetTy);
  469. for (unsigned i = 0; i < size; ++i) {
  470. Value *element = m_builder.CreateExtractValue(result, { i });
  471. vector = m_builder.CreateInsertElement(vector, element, i);
  472. }
  473. return vector;
  474. }
  475. // Scalar result will be populated with the first element of the resource return.
  476. Value *ConvertScalarResult(Type *origRetTy, Value *result) {
  477. assert(origRetTy->isSingleValueType());
  478. return m_builder.CreateExtractValue(result, { 0 });
  479. }
  480. };
  481. // Translate function return and argument types for resource method lowering.
  482. class ResourceFunctionTypeTranslator : public FunctionTypeTranslator {
  483. public:
  484. ResourceFunctionTypeTranslator(OP &hlslOp) : m_hlslOp(hlslOp) {}
  485. // Translate return type as follows:
  486. //
  487. // void -> void
  488. // <N x ty> -> dx.types.ResRet.ty
  489. // ty -> dx.types.ResRet.ty
  490. virtual Type *TranslateReturnType(CallInst *CI) override {
  491. Type *RetTy = CI->getType();
  492. if (RetTy->isVoidTy())
  493. return RetTy;
  494. else if (RetTy->isVectorTy())
  495. RetTy = RetTy->getVectorElementType();
  496. return m_hlslOp.GetResRetType(RetTy);
  497. }
  498. // Translate argument type as follows:
  499. //
  500. // resource -> dx.types.Handle
  501. // <N x ty> -> { ty, N }
  502. // ty -> { ty, 1 }
  503. virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
  504. int count = 1;
  505. Type *ty = OrigArg->getType();
  506. if (ty->isVectorTy()) {
  507. count = ty->getVectorNumElements();
  508. ty = ty->getVectorElementType();
  509. }
  510. return ArgumentType(ty, count);
  511. }
  512. private:
  513. OP& m_hlslOp;
  514. };
  515. Value *ExtensionLowering::Resource(CallInst *CI) {
  516. ResourceFunctionTypeTranslator resourceTypeTranslator(m_hlslOp);
  517. Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
  518. if (!resourceFunction)
  519. return NoTranslation(CI);
  520. ResourceMethodCall explode(CI, *resourceFunction);
  521. Value *result = explode.Generate();
  522. return result;
  523. }
  524. ///////////////////////////////////////////////////////////////////////////////
  525. // Dxil Lowering.
  526. Value *ExtensionLowering::Dxil(CallInst *CI) {
  527. // Map the extension opcode to the corresponding dxil opcode.
  528. unsigned extOpcode = GetHLOpcode(CI);
  529. OP::OpCode dxilOpcode;
  530. if (!m_helper->GetDxilOpcode(extOpcode, dxilOpcode))
  531. return nullptr;
  532. // Find the dxil function based on the overload type.
  533. Type *overloadTy = m_hlslOp.GetOverloadType(dxilOpcode, CI->getCalledFunction());
  534. Function *F = m_hlslOp.GetOpFunc(dxilOpcode, overloadTy->getScalarType());
  535. // Update the opcode in the original call so we can just copy it below.
  536. // We are about to delete this call anyway.
  537. CI->setOperand(0, m_hlslOp.GetI32Const(static_cast<unsigned>(dxilOpcode)));
  538. // Create the new call.
  539. Value *result = nullptr;
  540. if (overloadTy->isVectorTy()) {
  541. ReplicateCall replicate(CI, *F);
  542. result = replicate.Generate();
  543. }
  544. else {
  545. IRBuilder<> builder(CI);
  546. SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
  547. result = builder.CreateCall(F, args);
  548. }
  549. return result;
  550. }
  551. ///////////////////////////////////////////////////////////////////////////////
  552. // Computing Extension Names.
  553. // Compute the name to use for the intrinsic function call once it is lowered to dxil.
  554. // First checks to see if we have a custom name from the codegen helper and if not
  555. // chooses a default name based on the lowergin strategy.
  556. class ExtensionName {
  557. public:
  558. ExtensionName(CallInst *CI, ExtensionLowering::Strategy strategy, HLSLExtensionsCodegenHelper *helper)
  559. : m_CI(CI)
  560. , m_strategy(strategy)
  561. , m_helper(helper)
  562. {}
  563. std::string Get() {
  564. std::string name;
  565. if (m_helper)
  566. name = GetCustomExtensionName(m_CI, *m_helper);
  567. if (!HasCustomExtensionName(name))
  568. name = GetDefaultCustomExtensionName(m_CI, ExtensionLowering::GetStrategyName(m_strategy));
  569. return name;
  570. }
  571. private:
  572. CallInst *m_CI;
  573. ExtensionLowering::Strategy m_strategy;
  574. HLSLExtensionsCodegenHelper *m_helper;
  575. static std::string GetCustomExtensionName(CallInst *CI, HLSLExtensionsCodegenHelper &helper) {
  576. unsigned opcode = GetHLOpcode(CI);
  577. std::string name = helper.GetIntrinsicName(opcode);
  578. ReplaceOverloadMarkerWithTypeName(name, CI);
  579. return name;
  580. }
  581. static std::string GetDefaultCustomExtensionName(CallInst *CI, StringRef strategyName) {
  582. return (Twine(CI->getCalledFunction()->getName()) + "." + Twine(strategyName)).str();
  583. }
  584. static bool HasCustomExtensionName(const std::string name) {
  585. return name.size() > 0;
  586. }
  587. // Choose the (return value or argument) type that determines the overload type
  588. // for the intrinsic call.
  589. // For now we take the return type as the overload. If the return is void we
  590. // take the first (non-opcode) argument as the overload type. We could extend the
  591. // $o sytnax in the extension name to explicitly specify the overload slot (e.g.
  592. // $o:3 would say the overload type is determined by parameter 3.
  593. static Type *SelectOverloadSlot(CallInst *CI) {
  594. Type *ty = CI->getType();
  595. if (ty->isVoidTy()) {
  596. if (CI->getNumArgOperands() > 1)
  597. ty = CI->getArgOperand(1)->getType(); // First non-opcode argument.
  598. }
  599. return ty;
  600. }
  601. static Type *GetOverloadType(CallInst *CI) {
  602. Type *ty = SelectOverloadSlot(CI);
  603. if (ty->isVectorTy())
  604. ty = ty->getVectorElementType();
  605. return ty;
  606. }
  607. static std::string GetTypeName(Type *ty) {
  608. std::string typeName;
  609. llvm::raw_string_ostream os(typeName);
  610. ty->print(os);
  611. os.flush();
  612. return typeName;
  613. }
  614. static std::string GetOverloadTypeName(CallInst *CI) {
  615. Type *ty = GetOverloadType(CI);
  616. return GetTypeName(ty);
  617. }
  618. // Find the occurence of the overload marker $o and replace it the the overload type name.
  619. static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
  620. const char *OverloadMarker = "$o";
  621. const size_t OverloadMarkerLength = 2;
  622. size_t pos = functionName.find(OverloadMarker);
  623. if (pos != std::string::npos) {
  624. std::string typeName = GetOverloadTypeName(CI);
  625. functionName.replace(pos, OverloadMarkerLength, typeName);
  626. }
  627. }
  628. };
  629. std::string ExtensionLowering::GetExtensionName(llvm::CallInst *CI) {
  630. ExtensionName name(CI, m_strategy, m_helper);
  631. return name.Get();
  632. }