DxilTypeSystem.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilTypeSystem.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/DxilTypeSystem.h"
  10. #include "dxc/HLSL/DxilModule.h"
  11. #include "dxc/HLSL/HLModule.h"
  12. #include "dxc/Support/Global.h"
  13. #include "llvm/IR/Module.h"
  14. #include "llvm/IR/LLVMContext.h"
  15. #include "llvm/Support/raw_ostream.h"
  16. using namespace llvm;
  17. using std::unique_ptr;
  18. using std::string;
  19. using std::vector;
  20. using std::map;
  21. namespace hlsl {
  22. //------------------------------------------------------------------------------
  23. //
  24. // DxilMatrixAnnotation class methods.
  25. //
  26. DxilMatrixAnnotation::DxilMatrixAnnotation()
  27. : Rows(0)
  28. , Cols(0)
  29. , Orientation(MatrixOrientation::Undefined) {
  30. }
  31. //------------------------------------------------------------------------------
  32. //
  33. // DxilFieldAnnotation class methods.
  34. //
  35. DxilFieldAnnotation::DxilFieldAnnotation()
  36. : m_bPrecise(false)
  37. , m_ResourceAttribute(nullptr)
  38. , m_CBufferOffset(UINT_MAX) {
  39. }
  40. bool DxilFieldAnnotation::IsPrecise() const { return m_bPrecise; }
  41. void DxilFieldAnnotation::SetPrecise(bool b) { m_bPrecise = b; }
  42. bool DxilFieldAnnotation::HasMatrixAnnotation() const { return m_Matrix.Cols != 0; }
  43. const DxilMatrixAnnotation &DxilFieldAnnotation::GetMatrixAnnotation() const { return m_Matrix; }
  44. void DxilFieldAnnotation::SetMatrixAnnotation(const DxilMatrixAnnotation &MA) { m_Matrix = MA; }
  45. bool DxilFieldAnnotation::HasResourceAttribute() const {
  46. return m_ResourceAttribute;
  47. }
  48. llvm::MDNode *DxilFieldAnnotation::GetResourceAttribute() const {
  49. return m_ResourceAttribute;
  50. }
  51. void DxilFieldAnnotation::SetResourceAttribute(llvm::MDNode *MD) {
  52. m_ResourceAttribute = MD;
  53. }
  54. bool DxilFieldAnnotation::HasCBufferOffset() const { return m_CBufferOffset != UINT_MAX; }
  55. unsigned DxilFieldAnnotation::GetCBufferOffset() const { return m_CBufferOffset; }
  56. void DxilFieldAnnotation::SetCBufferOffset(unsigned Offset) { m_CBufferOffset = Offset; }
  57. bool DxilFieldAnnotation::HasCompType() const { return m_CompType.GetKind() != CompType::Kind::Invalid; }
  58. const CompType &DxilFieldAnnotation::GetCompType() const { return m_CompType; }
  59. void DxilFieldAnnotation::SetCompType(CompType::Kind kind) { m_CompType = CompType(kind); }
  60. bool DxilFieldAnnotation::HasSemanticString() const { return !m_Semantic.empty(); }
  61. const std::string &DxilFieldAnnotation::GetSemanticString() const { return m_Semantic; }
  62. llvm::StringRef DxilFieldAnnotation::GetSemanticStringRef() const { return llvm::StringRef(m_Semantic); }
  63. void DxilFieldAnnotation::SetSemanticString(const std::string &SemString) { m_Semantic = SemString; }
  64. bool DxilFieldAnnotation::HasInterpolationMode() const { return !m_InterpMode.IsUndefined(); }
  65. const InterpolationMode &DxilFieldAnnotation::GetInterpolationMode() const { return m_InterpMode; }
  66. void DxilFieldAnnotation::SetInterpolationMode(const InterpolationMode &IM) { m_InterpMode = IM; }
  67. bool DxilFieldAnnotation::HasFieldName() const { return !m_FieldName.empty(); }
  68. const std::string &DxilFieldAnnotation::GetFieldName() const { return m_FieldName; }
  69. void DxilFieldAnnotation::SetFieldName(const std::string &FieldName) { m_FieldName = FieldName; }
  70. //------------------------------------------------------------------------------
  71. //
  72. // DxilStructAnnotation class methods.
  73. //
  74. unsigned DxilStructAnnotation::GetNumFields() const {
  75. return (unsigned)m_FieldAnnotations.size();
  76. }
  77. DxilFieldAnnotation &DxilStructAnnotation::GetFieldAnnotation(unsigned FieldIdx) {
  78. return m_FieldAnnotations[FieldIdx];
  79. }
  80. const DxilFieldAnnotation &DxilStructAnnotation::GetFieldAnnotation(unsigned FieldIdx) const {
  81. return m_FieldAnnotations[FieldIdx];
  82. }
  83. const StructType *DxilStructAnnotation::GetStructType() const {
  84. return m_pStructType;
  85. }
  86. unsigned DxilStructAnnotation::GetCBufferSize() const { return m_CBufferSize; }
  87. void DxilStructAnnotation::SetCBufferSize(unsigned size) { m_CBufferSize = size; }
  88. void DxilStructAnnotation::MarkEmptyStruct() { m_FieldAnnotations.clear(); }
  89. bool DxilStructAnnotation::IsEmptyStruct() { return m_FieldAnnotations.empty(); }
  90. //------------------------------------------------------------------------------
  91. //
  92. // DxilParameterAnnotation class methods.
  93. //
  94. DxilParameterAnnotation::DxilParameterAnnotation()
  95. : m_inputQual(DxilParamInputQual::In), DxilFieldAnnotation() {
  96. }
  97. DxilParamInputQual DxilParameterAnnotation::GetParamInputQual() const {
  98. return m_inputQual;
  99. }
  100. void DxilParameterAnnotation::SetParamInputQual(DxilParamInputQual qual) {
  101. m_inputQual = qual;
  102. }
  103. const std::vector<unsigned> &DxilParameterAnnotation::GetSemanticIndexVec() const {
  104. return m_semanticIndex;
  105. }
  106. void DxilParameterAnnotation::SetSemanticIndexVec(const std::vector<unsigned> &Vec) {
  107. m_semanticIndex = Vec;
  108. }
  109. void DxilParameterAnnotation::AppendSemanticIndex(unsigned SemIdx) {
  110. m_semanticIndex.emplace_back(SemIdx);
  111. }
  112. //------------------------------------------------------------------------------
  113. //
  114. // DxilFunctionAnnotation class methods.
  115. //
  116. unsigned DxilFunctionAnnotation::GetNumParameters() const {
  117. return (unsigned)m_parameterAnnotations.size();
  118. }
  119. DxilParameterAnnotation &DxilFunctionAnnotation::GetParameterAnnotation(unsigned ParamIdx) {
  120. return m_parameterAnnotations[ParamIdx];
  121. }
  122. const DxilParameterAnnotation &DxilFunctionAnnotation::GetParameterAnnotation(unsigned ParamIdx) const {
  123. return m_parameterAnnotations[ParamIdx];
  124. }
  125. DxilParameterAnnotation &DxilFunctionAnnotation::GetRetTypeAnnotation() {
  126. return m_retTypeAnnotation;
  127. }
  128. const DxilParameterAnnotation &DxilFunctionAnnotation::GetRetTypeAnnotation() const {
  129. return m_retTypeAnnotation;
  130. }
  131. const Function *DxilFunctionAnnotation::GetFunction() const {
  132. return m_pFunction;
  133. }
  134. DxilFunctionFPFlag &DxilFunctionAnnotation::GetFlag() {
  135. return m_fpFlag;
  136. }
  137. const DxilFunctionFPFlag &DxilFunctionAnnotation::GetFlag() const {
  138. return m_fpFlag;
  139. }
  140. //------------------------------------------------------------------------------
  141. //
  142. // DxilFunctionFPFlag class methods.
  143. //
  144. void DxilFunctionFPFlag::SetFP32DenormMode(const DXIL::FPDenormMode mode) {
  145. m_flag |= ((uint32_t)mode & kFPDenormMask) << kFPDenormOffset;
  146. }
  147. DXIL::FPDenormMode DxilFunctionFPFlag::GetFP32DenormMode() {
  148. return (DXIL::FPDenormMode)((m_flag >> kFPDenormOffset) & kFPDenormMask);
  149. }
  150. uint32_t DxilFunctionFPFlag::GetFlagValue() {
  151. return m_flag;
  152. }
  153. const uint32_t DxilFunctionFPFlag::GetFlagValue() const {
  154. return m_flag;
  155. }
  156. void DxilFunctionFPFlag::SetFlagValue(const uint32_t flag) {
  157. m_flag = flag;
  158. }
  159. //------------------------------------------------------------------------------
  160. //
  161. // DxilStructAnnotationSystem class methods.
  162. //
  163. DxilTypeSystem::DxilTypeSystem(Module *pModule)
  164. : m_pModule(pModule),
  165. m_LowPrecisionMode(DXIL::LowPrecisionMode::Undefined) {}
  166. DxilStructAnnotation *DxilTypeSystem::AddStructAnnotation(const StructType *pStructType) {
  167. DXASSERT_NOMSG(m_StructAnnotations.find(pStructType) == m_StructAnnotations.end());
  168. DxilStructAnnotation *pA = new DxilStructAnnotation();
  169. m_StructAnnotations[pStructType] = unique_ptr<DxilStructAnnotation>(pA);
  170. pA->m_pStructType = pStructType;
  171. pA->m_FieldAnnotations.resize(pStructType->getNumElements());
  172. return pA;
  173. }
  174. DxilStructAnnotation *DxilTypeSystem::GetStructAnnotation(const StructType *pStructType) {
  175. auto it = m_StructAnnotations.find(pStructType);
  176. if (it != m_StructAnnotations.end()) {
  177. return it->second.get();
  178. } else {
  179. return nullptr;
  180. }
  181. }
  182. const DxilStructAnnotation *
  183. DxilTypeSystem::GetStructAnnotation(const StructType *pStructType) const {
  184. auto it = m_StructAnnotations.find(pStructType);
  185. if (it != m_StructAnnotations.end()) {
  186. return it->second.get();
  187. } else {
  188. return nullptr;
  189. }
  190. }
  191. void DxilTypeSystem::EraseStructAnnotation(const StructType *pStructType) {
  192. DXASSERT_NOMSG(m_StructAnnotations.count(pStructType));
  193. m_StructAnnotations.remove_if([pStructType](
  194. const std::pair<const StructType *, std::unique_ptr<DxilStructAnnotation>>
  195. &I) { return pStructType == I.first; });
  196. }
  197. DxilTypeSystem::StructAnnotationMap &DxilTypeSystem::GetStructAnnotationMap() {
  198. return m_StructAnnotations;
  199. }
  200. DxilFunctionAnnotation *DxilTypeSystem::AddFunctionAnnotation(const Function *pFunction) {
  201. DxilFunctionFPFlag flag;
  202. flag.SetFlagValue(0);
  203. DxilFunctionAnnotation *pA = AddFunctionAnnotationWithFPFlag(pFunction, &flag);
  204. return pA;
  205. }
  206. DxilFunctionAnnotation *DxilTypeSystem::AddFunctionAnnotationWithFPFlag(const Function *pFunction, const DxilFunctionFPFlag *pFlag) {
  207. DXASSERT_NOMSG(m_FunctionAnnotations.find(pFunction) == m_FunctionAnnotations.end());
  208. DxilFunctionAnnotation *pA = new DxilFunctionAnnotation();
  209. m_FunctionAnnotations[pFunction] = unique_ptr<DxilFunctionAnnotation>(pA);
  210. pA->m_pFunction = pFunction;
  211. pA->m_parameterAnnotations.resize(pFunction->getFunctionType()->getNumParams());
  212. pA->GetFlag().SetFlagValue(pFlag->GetFlagValue());
  213. return pA;
  214. }
  215. DxilFunctionAnnotation *DxilTypeSystem::GetFunctionAnnotation(const Function *pFunction) {
  216. auto it = m_FunctionAnnotations.find(pFunction);
  217. if (it != m_FunctionAnnotations.end()) {
  218. return it->second.get();
  219. } else {
  220. return nullptr;
  221. }
  222. }
  223. const DxilFunctionAnnotation *
  224. DxilTypeSystem::GetFunctionAnnotation(const Function *pFunction) const {
  225. auto it = m_FunctionAnnotations.find(pFunction);
  226. if (it != m_FunctionAnnotations.end()) {
  227. return it->second.get();
  228. } else {
  229. return nullptr;
  230. }
  231. }
  232. void DxilTypeSystem::EraseFunctionAnnotation(const Function *pFunction) {
  233. DXASSERT_NOMSG(m_FunctionAnnotations.count(pFunction));
  234. m_FunctionAnnotations.remove_if([pFunction](
  235. const std::pair<const Function *, std::unique_ptr<DxilFunctionAnnotation>>
  236. &I) { return pFunction == I.first; });
  237. }
  238. DxilTypeSystem::FunctionAnnotationMap &DxilTypeSystem::GetFunctionAnnotationMap() {
  239. return m_FunctionAnnotations;
  240. }
  241. StructType *DxilTypeSystem::GetSNormF32Type(unsigned NumComps) {
  242. return GetNormFloatType(CompType::getSNormF32(), NumComps);
  243. }
  244. StructType *DxilTypeSystem::GetUNormF32Type(unsigned NumComps) {
  245. return GetNormFloatType(CompType::getUNormF32(), NumComps);
  246. }
  247. StructType *DxilTypeSystem::GetNormFloatType(CompType CT, unsigned NumComps) {
  248. Type *pCompType = CT.GetLLVMType(m_pModule->getContext());
  249. DXASSERT_NOMSG(pCompType->isFloatTy());
  250. Type *pFieldType = pCompType;
  251. string TypeName;
  252. raw_string_ostream NameStream(TypeName);
  253. if (NumComps > 1) {
  254. (NameStream << "dx.types." << NumComps << "x" << CT.GetName()).flush();
  255. pFieldType = VectorType::get(pFieldType, NumComps);
  256. } else {
  257. (NameStream << "dx.types." << CT.GetName()).flush();
  258. }
  259. StructType *pStructType = m_pModule->getTypeByName(TypeName);
  260. if (pStructType == nullptr) {
  261. pStructType = StructType::create(m_pModule->getContext(), pFieldType, TypeName);
  262. DxilStructAnnotation &TA = *AddStructAnnotation(pStructType);
  263. DxilFieldAnnotation &FA = TA.GetFieldAnnotation(0);
  264. FA.SetCompType(CT.GetKind());
  265. DXASSERT_NOMSG(CT.IsSNorm() || CT.IsUNorm());
  266. }
  267. return pStructType;
  268. }
  269. void DxilTypeSystem::CopyTypeAnnotation(const llvm::Type *Ty,
  270. const DxilTypeSystem &src) {
  271. if (isa<PointerType>(Ty))
  272. Ty = Ty->getPointerElementType();
  273. while (isa<ArrayType>(Ty))
  274. Ty = Ty->getArrayElementType();
  275. // Only struct type has annotation.
  276. if (!isa<StructType>(Ty))
  277. return;
  278. const StructType *ST = cast<StructType>(Ty);
  279. // Already exist.
  280. if (GetStructAnnotation(ST))
  281. return;
  282. if (const DxilStructAnnotation *annot = src.GetStructAnnotation(ST)) {
  283. DxilStructAnnotation *dstAnnot = AddStructAnnotation(ST);
  284. // Copy the annotation.
  285. *dstAnnot = *annot;
  286. // Copy field type annotations.
  287. for (Type *Ty : ST->elements()) {
  288. CopyTypeAnnotation(Ty, src);
  289. }
  290. }
  291. }
  292. void DxilTypeSystem::CopyFunctionAnnotation(const llvm::Function *pDstFunction,
  293. const llvm::Function *pSrcFunction,
  294. const DxilTypeSystem &src) {
  295. const DxilFunctionAnnotation *annot = src.GetFunctionAnnotation(pSrcFunction);
  296. // Don't have annotation.
  297. if (!annot)
  298. return;
  299. // Already exist.
  300. if (GetFunctionAnnotation(pDstFunction))
  301. return;
  302. DxilFunctionAnnotation *dstAnnot = AddFunctionAnnotationWithFPFlag(pDstFunction, &src.GetFunctionAnnotation(pSrcFunction)->GetFlag());
  303. // Copy the annotation.
  304. *dstAnnot = *annot;
  305. // Clone ret type annotation.
  306. CopyTypeAnnotation(pDstFunction->getReturnType(), src);
  307. // Clone param type annotations.
  308. for (const Argument &arg : pDstFunction->args()) {
  309. CopyTypeAnnotation(arg.getType(), src);
  310. }
  311. }
  312. DXIL::SigPointKind SigPointFromInputQual(DxilParamInputQual Q, DXIL::ShaderKind SK, bool isPC) {
  313. DXASSERT(Q != DxilParamInputQual::Inout, "Inout not expected for SigPointFromInputQual");
  314. switch (SK) {
  315. case DXIL::ShaderKind::Vertex:
  316. switch (Q) {
  317. case DxilParamInputQual::In:
  318. return DXIL::SigPointKind::VSIn;
  319. case DxilParamInputQual::Out:
  320. return DXIL::SigPointKind::VSOut;
  321. default:
  322. break;
  323. }
  324. break;
  325. case DXIL::ShaderKind::Hull:
  326. switch (Q) {
  327. case DxilParamInputQual::In:
  328. if (isPC)
  329. return DXIL::SigPointKind::PCIn;
  330. else
  331. return DXIL::SigPointKind::HSIn;
  332. case DxilParamInputQual::Out:
  333. if (isPC)
  334. return DXIL::SigPointKind::PCOut;
  335. else
  336. return DXIL::SigPointKind::HSCPOut;
  337. case DxilParamInputQual::InputPatch:
  338. return DXIL::SigPointKind::HSCPIn;
  339. case DxilParamInputQual::OutputPatch:
  340. return DXIL::SigPointKind::HSCPOut;
  341. default:
  342. break;
  343. }
  344. break;
  345. case DXIL::ShaderKind::Domain:
  346. switch (Q) {
  347. case DxilParamInputQual::In:
  348. return DXIL::SigPointKind::DSIn;
  349. case DxilParamInputQual::Out:
  350. return DXIL::SigPointKind::DSOut;
  351. case DxilParamInputQual::InputPatch:
  352. case DxilParamInputQual::OutputPatch:
  353. return DXIL::SigPointKind::DSCPIn;
  354. default:
  355. break;
  356. }
  357. break;
  358. case DXIL::ShaderKind::Geometry:
  359. switch (Q) {
  360. case DxilParamInputQual::In:
  361. return DXIL::SigPointKind::GSIn;
  362. case DxilParamInputQual::InputPrimitive:
  363. return DXIL::SigPointKind::GSVIn;
  364. case DxilParamInputQual::OutStream0:
  365. case DxilParamInputQual::OutStream1:
  366. case DxilParamInputQual::OutStream2:
  367. case DxilParamInputQual::OutStream3:
  368. return DXIL::SigPointKind::GSOut;
  369. default:
  370. break;
  371. }
  372. break;
  373. case DXIL::ShaderKind::Pixel:
  374. switch (Q) {
  375. case DxilParamInputQual::In:
  376. return DXIL::SigPointKind::PSIn;
  377. case DxilParamInputQual::Out:
  378. return DXIL::SigPointKind::PSOut;
  379. default:
  380. break;
  381. }
  382. break;
  383. case DXIL::ShaderKind::Compute:
  384. switch (Q) {
  385. case DxilParamInputQual::In:
  386. return DXIL::SigPointKind::CSIn;
  387. default:
  388. break;
  389. }
  390. break;
  391. default:
  392. break;
  393. }
  394. return DXIL::SigPointKind::Invalid;
  395. }
  396. bool DxilTypeSystem::UseMinPrecision() {
  397. if (m_LowPrecisionMode == DXIL::LowPrecisionMode::Undefined) {
  398. if (&m_pModule->GetDxilModule()) {
  399. m_LowPrecisionMode = m_pModule->GetDxilModule().m_ShaderFlags.GetUseNativeLowPrecision() ?
  400. DXIL::LowPrecisionMode::UseNativeLowPrecision : DXIL::LowPrecisionMode::UseMinPrecision;
  401. }
  402. else if (&m_pModule->GetHLModule()) {
  403. m_LowPrecisionMode = m_pModule->GetHLModule().GetHLOptions().bUseMinPrecision ?
  404. DXIL::LowPrecisionMode::UseMinPrecision : DXIL::LowPrecisionMode::UseNativeLowPrecision;
  405. }
  406. else {
  407. DXASSERT(false, "otherwise module doesn't contain either HLModule or Dxil Module.");
  408. }
  409. }
  410. return m_LowPrecisionMode == DXIL::LowPrecisionMode::UseMinPrecision;
  411. }
  412. } // namespace hlsl