HLOperationLowerExtension.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLOperationLowerExtension.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/HLOperationLowerExtension.h"
  10. #include "dxc/HLSL/DxilModule.h"
  11. #include "dxc/HLSL/DxilOperations.h"
  12. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  13. #include "dxc/HLSL/HLModule.h"
  14. #include "dxc/HLSL/HLOperationLower.h"
  15. #include "dxc/HLSL/HLOperations.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "llvm/ADT/StringRef.h"
  18. #include "llvm/IR/IRBuilder.h"
  19. #include "llvm/IR/Instructions.h"
  20. #include "llvm/IR/Module.h"
  21. #include "llvm/Support/raw_os_ostream.h"
  22. using namespace llvm;
  23. using namespace hlsl;
  24. ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
  25. if (strategy.size() < 1)
  26. return Strategy::Unknown;
  27. switch (strategy[0]) {
  28. case 'n': return Strategy::NoTranslation;
  29. case 'r': return Strategy::Replicate;
  30. case 'p': return Strategy::Pack;
  31. default: break;
  32. }
  33. return Strategy::Unknown;
  34. }
  35. llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
  36. switch (strategy) {
  37. case Strategy::NoTranslation: return "n";
  38. case Strategy::Replicate: return "r";
  39. case Strategy::Pack: return "p";
  40. default: break;
  41. }
  42. return "?";
  43. }
  44. ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper)
  45. : m_strategy(strategy), m_helper(helper)
  46. {}
  47. ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper)
  48. : ExtensionLowering(GetStrategy(strategy), helper)
  49. {}
  50. llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
  51. switch (m_strategy) {
  52. case Strategy::NoTranslation: return NoTranslation(CI);
  53. case Strategy::Replicate: return Replicate(CI);
  54. case Strategy::Pack: return Pack(CI);
  55. default: break;
  56. }
  57. return Unknown(CI);
  58. }
  59. llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
  60. assert(false && "unknown translation strategy");
  61. return nullptr;
  62. }
  63. // Interface to describe how to translate types from HL-dxil to dxil.
  64. class FunctionTypeTranslator {
  65. public:
  66. virtual Type *TranslateReturnType(CallInst *CI) = 0;
  67. virtual Type *TranslateArgumentType(Type *OrigArgType) = 0;
  68. };
  69. // Class to create the new function with the translated types for low-level dxil.
  70. class FunctionTranslator {
  71. public:
  72. template <typename TypeTranslator>
  73. static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
  74. TypeTranslator typeTranslator;
  75. FunctionTranslator translator(typeTranslator, lower);
  76. return translator.GetLoweredFunction(CI);
  77. }
  78. private:
  79. FunctionTypeTranslator &m_typeTranslator;
  80. ExtensionLowering &m_lower;
  81. FunctionTranslator(FunctionTypeTranslator &typeTranslator, ExtensionLowering &lower)
  82. : m_typeTranslator(typeTranslator)
  83. , m_lower(lower)
  84. {}
  85. Function *GetLoweredFunction(CallInst *CI) {
  86. // Ge the return type of replicated function.
  87. Type *RetTy = m_typeTranslator.TranslateReturnType(CI);
  88. if (!RetTy)
  89. return nullptr;
  90. // Get the Function type for replicated function.
  91. FunctionType *FTy = GetFunctionType(CI, RetTy);
  92. if (!FTy)
  93. return nullptr;
  94. // Create a new function that will be the replicated call.
  95. AttributeSet attributes = GetAttributeSet(CI);
  96. std::string name = m_lower.GetExtensionName(CI);
  97. return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
  98. }
  99. FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
  100. // Create a new function type with the translated argument.
  101. SmallVector<Type *, 10> ParamTypes;
  102. ParamTypes.reserve(CI->getNumArgOperands());
  103. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  104. Type *OrigTy = CI->getArgOperand(i)->getType();
  105. Type *TranslatedTy = m_typeTranslator.TranslateArgumentType(OrigTy);
  106. ParamTypes.push_back(TranslatedTy);
  107. }
  108. const bool IsVarArg = false;
  109. return FunctionType::get(RetTy, ParamTypes, IsVarArg);
  110. }
  111. AttributeSet GetAttributeSet(CallInst *CI) {
  112. Function *F = CI->getCalledFunction();
  113. AttributeSet attributes;
  114. auto copyAttribute = [=, &attributes](Attribute::AttrKind a) {
  115. if (F->hasFnAttribute(a)) {
  116. attributes = attributes.addAttribute(CI->getContext(), AttributeSet::FunctionIndex, a);
  117. }
  118. };
  119. copyAttribute(Attribute::AttrKind::ReadOnly);
  120. copyAttribute(Attribute::AttrKind::ReadNone);
  121. copyAttribute(Attribute::AttrKind::ArgMemOnly);
  122. return attributes;
  123. }
  124. };
  125. ///////////////////////////////////////////////////////////////////////////////
  126. // NoTranslation Lowering.
  127. class NoTranslationTypeTranslator : public FunctionTypeTranslator {
  128. virtual Type *TranslateReturnType(CallInst *CI) override {
  129. return CI->getType();
  130. }
  131. virtual Type *TranslateArgumentType(Type *OrigArgType) override {
  132. return OrigArgType;
  133. }
  134. };
  135. llvm::Value *ExtensionLowering::NoTranslation(CallInst *CI) {
  136. Function *NoTranslationFunction = FunctionTranslator::GetLoweredFunction<NoTranslationTypeTranslator>(CI, *this);
  137. if (!NoTranslationFunction)
  138. return nullptr;
  139. IRBuilder<> builder(CI);
  140. SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
  141. return builder.CreateCall(NoTranslationFunction, args);
  142. };
  143. ///////////////////////////////////////////////////////////////////////////////
  144. // Replicated Lowering.
  145. enum {
  146. NO_COMMON_VECTOR_SIZE = 0xFFFFFFFF,
  147. };
  148. // Find the vector size that will be used for replication.
  149. // The function call will be replicated once for each element of the vector
  150. // size.
  151. static unsigned GetReplicatedVectorSize(llvm::CallInst *CI) {
  152. unsigned commonVectorSize = NO_COMMON_VECTOR_SIZE;
  153. Type *RetTy = CI->getType();
  154. if (RetTy->isVectorTy())
  155. commonVectorSize = RetTy->getVectorNumElements();
  156. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  157. Type *Ty = CI->getArgOperand(i)->getType();
  158. if (Ty->isVectorTy()) {
  159. unsigned vectorSize = Ty->getVectorNumElements();
  160. if (commonVectorSize != NO_COMMON_VECTOR_SIZE && commonVectorSize != vectorSize) {
  161. // Inconsistent vector sizes; need a different strategy.
  162. return NO_COMMON_VECTOR_SIZE;
  163. }
  164. commonVectorSize = vectorSize;
  165. }
  166. }
  167. return commonVectorSize;
  168. }
  169. class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
  170. virtual Type *TranslateReturnType(CallInst *CI) override {
  171. unsigned commonVectorSize = GetReplicatedVectorSize(CI);
  172. if (commonVectorSize == NO_COMMON_VECTOR_SIZE)
  173. return nullptr;
  174. // Result should be vector or void.
  175. Type *RetTy = CI->getType();
  176. if (!RetTy->isVoidTy() && !RetTy->isVectorTy())
  177. return nullptr;
  178. if (RetTy->isVectorTy()) {
  179. RetTy = RetTy->getVectorElementType();
  180. }
  181. return RetTy;
  182. }
  183. virtual Type *TranslateArgumentType(Type *OrigArgType) override {
  184. Type *Ty = OrigArgType;
  185. if (Ty->isVectorTy()) {
  186. Ty = Ty->getVectorElementType();
  187. }
  188. return Ty;
  189. }
  190. };
  191. class ReplicateCall {
  192. public:
  193. ReplicateCall(CallInst *CI, Function &ReplicatedFunction)
  194. : m_CI(CI)
  195. , m_ReplicatedFunction(ReplicatedFunction)
  196. , m_numReplicatedCalls(GetReplicatedVectorSize(CI))
  197. , m_ScalarizeArgIdx()
  198. , m_Args(CI->getNumArgOperands())
  199. , m_ReplicatedCalls(m_numReplicatedCalls)
  200. , m_Builder(CI)
  201. {
  202. assert(m_numReplicatedCalls != NO_COMMON_VECTOR_SIZE);
  203. }
  204. Value *Generate() {
  205. CollectReplicatedArguments();
  206. CreateReplicatedCalls();
  207. Value *retVal = GetReturnValue();
  208. return retVal;
  209. }
  210. private:
  211. CallInst *m_CI;
  212. Function &m_ReplicatedFunction;
  213. unsigned m_numReplicatedCalls;
  214. SmallVector<unsigned, 10> m_ScalarizeArgIdx;
  215. SmallVector<Value *, 10> m_Args;
  216. SmallVector<Value *, 10> m_ReplicatedCalls;
  217. IRBuilder<> m_Builder;
  218. // Collect replicated arguments.
  219. // For non-vector arguments we can add them to the args list directly.
  220. // These args will be shared by each replicated call. For the vector
  221. // arguments we remember the position it will go in the argument list.
  222. // We will fill in the vector args below when we replicate the call
  223. // (once for each vector lane).
  224. void CollectReplicatedArguments() {
  225. for (unsigned i = 0; i < m_CI->getNumArgOperands(); ++i) {
  226. Type *Ty = m_CI->getArgOperand(i)->getType();
  227. if (Ty->isVectorTy()) {
  228. m_ScalarizeArgIdx.push_back(i);
  229. }
  230. else {
  231. m_Args[i] = m_CI->getArgOperand(i);
  232. }
  233. }
  234. }
  235. // Create replicated calls.
  236. // Replicate the call once for each element of the replicated vector size.
  237. void CreateReplicatedCalls() {
  238. for (unsigned vecIdx = 0; vecIdx < m_numReplicatedCalls; vecIdx++) {
  239. for (unsigned i = 0, e = m_ScalarizeArgIdx.size(); i < e; ++i) {
  240. unsigned argIdx = m_ScalarizeArgIdx[i];
  241. Value *arg = m_CI->getArgOperand(argIdx);
  242. m_Args[argIdx] = m_Builder.CreateExtractElement(arg, vecIdx);
  243. }
  244. Value *EltOP = m_Builder.CreateCall(&m_ReplicatedFunction, m_Args);
  245. m_ReplicatedCalls[vecIdx] = EltOP;
  246. }
  247. }
  248. // Get the final replicated value.
  249. // If the function is a void type then return (arbitrarily) the first call.
  250. // We do not return nullptr because that indicates a failure to replicate.
  251. // If the function is a vector type then aggregate all of the replicated
  252. // call values into a new vector.
  253. Value *GetReturnValue() {
  254. if (m_CI->getType()->isVoidTy())
  255. return m_ReplicatedCalls.back();
  256. Value *retVal = llvm::UndefValue::get(m_CI->getType());
  257. for (unsigned i = 0; i < m_ReplicatedCalls.size(); ++i)
  258. retVal = m_Builder.CreateInsertElement(retVal, m_ReplicatedCalls[i], i);
  259. return retVal;
  260. }
  261. };
  262. Value *ExtensionLowering::TranslateReplicating(CallInst *CI, Function *ReplicatedFunction) {
  263. if (!ReplicatedFunction)
  264. return nullptr;
  265. ReplicateCall replicate(CI, *ReplicatedFunction);
  266. return replicate.Generate();
  267. }
  268. Value *ExtensionLowering::Replicate(CallInst *CI) {
  269. Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
  270. return TranslateReplicating(CI, ReplicatedFunction);
  271. }
  272. ///////////////////////////////////////////////////////////////////////////////
  273. // Packed Lowering.
  274. class PackCall {
  275. public:
  276. PackCall(CallInst *CI, Function &PackedFunction)
  277. : m_CI(CI)
  278. , m_packedFunction(PackedFunction)
  279. , m_builder(CI)
  280. {}
  281. Value *Generate() {
  282. SmallVector<Value *, 10> args;
  283. PackArgs(args);
  284. Value *result = CreateCall(args);
  285. return UnpackResult(result);
  286. }
  287. static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
  288. assert(vecTy->isVectorTy());
  289. Type *elementTy = vecTy->getVectorElementType();
  290. unsigned numElements = vecTy->getVectorNumElements();
  291. SmallVector<Type *, 4> elements;
  292. for (unsigned i = 0; i < numElements; ++i)
  293. elements.push_back(elementTy);
  294. return StructType::get(vecTy->getContext(), elements);
  295. }
  296. private:
  297. CallInst *m_CI;
  298. Function &m_packedFunction;
  299. IRBuilder<> m_builder;
  300. void PackArgs(SmallVectorImpl<Value*> &args) {
  301. args.clear();
  302. for (Value *arg : m_CI->arg_operands()) {
  303. if (arg->getType()->isVectorTy())
  304. arg = PackVectorIntoStruct(m_builder, arg);
  305. args.push_back(arg);
  306. }
  307. }
  308. Value *CreateCall(const SmallVectorImpl<Value*> &args) {
  309. return m_builder.CreateCall(&m_packedFunction, args);
  310. }
  311. Value *UnpackResult(Value *result) {
  312. if (result->getType()->isStructTy()) {
  313. result = PackStructIntoVector(m_builder, result);
  314. }
  315. return result;
  316. }
  317. static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
  318. assert(structTy->isStructTy());
  319. return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
  320. }
  321. static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
  322. StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
  323. Value *packed = UndefValue::get(structTy);
  324. unsigned numElements = structTy->getStructNumElements();
  325. for (unsigned i = 0; i < numElements; ++i) {
  326. Value *element = builder.CreateExtractElement(vec, i);
  327. packed = builder.CreateInsertValue(packed, element, { i });
  328. }
  329. return packed;
  330. }
  331. static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
  332. Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
  333. Value *packed = UndefValue::get(vecTy);
  334. unsigned numElements = vecTy->getVectorNumElements();
  335. for (unsigned i = 0; i < numElements; ++i) {
  336. Value *element = builder.CreateExtractValue(strukt, i);
  337. packed = builder.CreateInsertElement(packed, element, { i });
  338. }
  339. return packed;
  340. }
  341. };
  342. class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
  343. virtual Type *TranslateReturnType(CallInst *CI) override {
  344. return TranslateIfVector(CI->getType());
  345. }
  346. virtual Type *TranslateArgumentType(Type *OrigArgType) override {
  347. return TranslateIfVector(OrigArgType);
  348. }
  349. Type *TranslateIfVector(Type *ty) {
  350. if (ty->isVectorTy())
  351. ty = PackCall::ConvertVectorTypeToStructType(ty);
  352. return ty;
  353. }
  354. };
  355. Value *ExtensionLowering::Pack(CallInst *CI) {
  356. Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
  357. if (!PackedFunction)
  358. return nullptr;
  359. PackCall pack(CI, *PackedFunction);
  360. Value *result = pack.Generate();
  361. return result;
  362. }
  363. ///////////////////////////////////////////////////////////////////////////////
  364. // Computing Extension Names.
  365. // Compute the name to use for the intrinsic function call once it is lowered to dxil.
  366. // First checks to see if we have a custom name from the codegen helper and if not
  367. // chooses a default name based on the lowergin strategy.
  368. class ExtensionName {
  369. public:
  370. ExtensionName(CallInst *CI, ExtensionLowering::Strategy strategy, HLSLExtensionsCodegenHelper *helper)
  371. : m_CI(CI)
  372. , m_strategy(strategy)
  373. , m_helper(helper)
  374. {}
  375. std::string Get() {
  376. std::string name;
  377. if (m_helper)
  378. name = GetCustomExtensionName(m_CI, *m_helper);
  379. if (!HasCustomExtensionName(name))
  380. name = GetDefaultCustomExtensionName(m_CI, ExtensionLowering::GetStrategyName(m_strategy));
  381. return name;
  382. }
  383. private:
  384. CallInst *m_CI;
  385. ExtensionLowering::Strategy m_strategy;
  386. HLSLExtensionsCodegenHelper *m_helper;
  387. static std::string GetCustomExtensionName(CallInst *CI, HLSLExtensionsCodegenHelper &helper) {
  388. unsigned opcode = GetHLOpcode(CI);
  389. std::string name = helper.GetIntrinsicName(opcode);
  390. ReplaceOverloadMarkerWithTypeName(name, CI);
  391. return name;
  392. }
  393. static std::string GetDefaultCustomExtensionName(CallInst *CI, StringRef strategyName) {
  394. return (Twine(CI->getCalledFunction()->getName()) + "." + Twine(strategyName)).str();
  395. }
  396. static bool HasCustomExtensionName(const std::string name) {
  397. return name.size() > 0;
  398. }
  399. // Choose the (return value or argument) type that determines the overload type
  400. // for the intrinsic call.
  401. // For now we take the return type as the overload. If the return is void we
  402. // take the first (non-opcode) argument as the overload type. We could extend the
  403. // $o sytnax in the extension name to explicitly specify the overload slot (e.g.
  404. // $o:3 would say the overload type is determined by parameter 3.
  405. static Type *SelectOverloadSlot(CallInst *CI) {
  406. Type *ty = CI->getType();
  407. if (ty->isVoidTy()) {
  408. if (CI->getNumArgOperands() > 1)
  409. ty = CI->getArgOperand(1)->getType(); // First non-opcode argument.
  410. }
  411. return ty;
  412. }
  413. static Type *GetOverloadType(CallInst *CI) {
  414. Type *ty = SelectOverloadSlot(CI);
  415. if (ty->isVectorTy())
  416. ty = ty->getVectorElementType();
  417. return ty;
  418. }
  419. static std::string GetTypeName(Type *ty) {
  420. std::string typeName;
  421. llvm::raw_string_ostream os(typeName);
  422. ty->print(os);
  423. os.flush();
  424. return typeName;
  425. }
  426. static std::string GetOverloadTypeName(CallInst *CI) {
  427. Type *ty = GetOverloadType(CI);
  428. return GetTypeName(ty);
  429. }
  430. // Find the occurence of the overload marker $o and replace it the the overload type name.
  431. static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
  432. const char *OverloadMarker = "$o";
  433. const size_t OverloadMarkerLength = 2;
  434. size_t pos = functionName.find(OverloadMarker);
  435. if (pos != std::string::npos) {
  436. std::string typeName = GetOverloadTypeName(CI);
  437. functionName.replace(pos, OverloadMarkerLength, typeName);
  438. }
  439. }
  440. };
  441. std::string ExtensionLowering::GetExtensionName(llvm::CallInst *CI) {
  442. ExtensionName name(CI, m_strategy, m_helper);
  443. return name.Get();
  444. }