DxilShaderFlags.cpp 14 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilShaderFlags.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/DxilContainer.h"
  10. #include "dxc/HLSL/DxilModule.h"
  11. #include "dxc/HLSL/DxilShaderFlags.h"
  12. #include "dxc/HLSL/DxilOperations.h"
  13. #include "dxc/HLSL/DxilResource.h"
  14. #include "dxc/Support/Global.h"
  15. #include "llvm/IR/LLVMContext.h"
  16. #include "llvm/IR/Instructions.h"
  17. #include "llvm/IR/Constants.h"
  18. #include "llvm/Support/Casting.h"
  19. using namespace hlsl;
  20. using namespace llvm;
  21. ShaderFlags::ShaderFlags():
  22. m_bDisableOptimizations(false)
  23. , m_bDisableMathRefactoring(false)
  24. , m_bEnableDoublePrecision(false)
  25. , m_bForceEarlyDepthStencil(false)
  26. , m_bEnableRawAndStructuredBuffers(false)
  27. , m_bLowPrecisionPresent(false)
  28. , m_bEnableDoubleExtensions(false)
  29. , m_bEnableMSAD(false)
  30. , m_bAllResourcesBound(false)
  31. , m_bViewportAndRTArrayIndex(false)
  32. , m_bInnerCoverage(false)
  33. , m_bStencilRef(false)
  34. , m_bTiledResources(false)
  35. , m_bUAVLoadAdditionalFormats(false)
  36. , m_bLevel9ComparisonFiltering(false)
  37. , m_bCSRawAndStructuredViaShader4X(false)
  38. , m_b64UAVs(false)
  39. , m_UAVsAtEveryStage(false)
  40. , m_bROVS(false)
  41. , m_bWaveOps(false)
  42. , m_bInt64Ops(false)
  43. , m_bViewID(false)
  44. , m_bBarycentrics(false)
  45. , m_bUseNativeLowPrecision(false)
  46. , m_align0(0)
  47. , m_align1(0)
  48. {}
  49. uint64_t ShaderFlags::GetFeatureInfo() const {
  50. uint64_t Flags = 0;
  51. Flags |= m_bEnableDoublePrecision ? hlsl::ShaderFeatureInfo_Doubles : 0;
  52. Flags |= m_bLowPrecisionPresent && !m_bUseNativeLowPrecision ? hlsl::ShaderFeatureInfo_MinimumPrecision: 0;
  53. Flags |= m_bLowPrecisionPresent && m_bUseNativeLowPrecision ? hlsl::ShaderFeatureInfo_NativeLowPrecision : 0;
  54. Flags |= m_bEnableDoubleExtensions ? hlsl::ShaderFeatureInfo_11_1_DoubleExtensions : 0;
  55. Flags |= m_bWaveOps ? hlsl::ShaderFeatureInfo_WaveOps : 0;
  56. Flags |= m_bInt64Ops ? hlsl::ShaderFeatureInfo_Int64Ops : 0;
  57. Flags |= m_bROVS ? hlsl::ShaderFeatureInfo_ROVs : 0;
  58. Flags |= m_bViewportAndRTArrayIndex ? hlsl::ShaderFeatureInfo_ViewportAndRTArrayIndexFromAnyShaderFeedingRasterizer : 0;
  59. Flags |= m_bInnerCoverage ? hlsl::ShaderFeatureInfo_InnerCoverage : 0;
  60. Flags |= m_bStencilRef ? hlsl::ShaderFeatureInfo_StencilRef : 0;
  61. Flags |= m_bTiledResources ? hlsl::ShaderFeatureInfo_TiledResources : 0;
  62. Flags |= m_bEnableMSAD ? hlsl::ShaderFeatureInfo_11_1_ShaderExtensions : 0;
  63. Flags |= m_bCSRawAndStructuredViaShader4X ? hlsl::ShaderFeatureInfo_ComputeShadersPlusRawAndStructuredBuffersViaShader4X : 0;
  64. Flags |= m_UAVsAtEveryStage ? hlsl::ShaderFeatureInfo_UAVsAtEveryStage : 0;
  65. Flags |= m_b64UAVs ? hlsl::ShaderFeatureInfo_64UAVs : 0;
  66. Flags |= m_bLevel9ComparisonFiltering ? hlsl::ShaderFeatureInfo_LEVEL9ComparisonFiltering : 0;
  67. Flags |= m_bUAVLoadAdditionalFormats ? hlsl::ShaderFeatureInfo_TypedUAVLoadAdditionalFormats : 0;
  68. Flags |= m_bViewID ? hlsl::ShaderFeatureInfo_ViewID : 0;
  69. Flags |= m_bBarycentrics ? hlsl::ShaderFeatureInfo_Barycentrics : 0;
  70. return Flags;
  71. }
  72. uint64_t ShaderFlags::GetShaderFlagsRaw() const {
  73. union Cast {
  74. Cast(const ShaderFlags &flags) {
  75. shaderFlags = flags;
  76. }
  77. ShaderFlags shaderFlags;
  78. uint64_t rawData;
  79. };
  80. static_assert(sizeof(uint64_t) == sizeof(ShaderFlags),
  81. "size must match to make sure no undefined bits when cast");
  82. Cast rawCast(*this);
  83. return rawCast.rawData;
  84. }
  85. void ShaderFlags::SetShaderFlagsRaw(uint64_t data) {
  86. union Cast {
  87. Cast(uint64_t data) {
  88. rawData = data;
  89. }
  90. ShaderFlags shaderFlags;
  91. uint64_t rawData;
  92. };
  93. Cast rawCast(data);
  94. *this = rawCast.shaderFlags;
  95. }
  96. uint64_t ShaderFlags::GetShaderFlagsRawForCollection() {
  97. // This should be all the flags that can be set by DxilModule::CollectShaderFlags.
  98. ShaderFlags Flags;
  99. Flags.SetEnableDoublePrecision(true);
  100. Flags.SetInt64Ops(true);
  101. Flags.SetLowPrecisionPresent(true);
  102. Flags.SetEnableDoubleExtensions(true);
  103. Flags.SetWaveOps(true);
  104. Flags.SetTiledResources(true);
  105. Flags.SetEnableMSAD(true);
  106. Flags.SetUAVLoadAdditionalFormats(true);
  107. Flags.SetStencilRef(true);
  108. Flags.SetInnerCoverage(true);
  109. Flags.SetViewportAndRTArrayIndex(true);
  110. Flags.Set64UAVs(true);
  111. Flags.SetUAVsAtEveryStage(true);
  112. Flags.SetEnableRawAndStructuredBuffers(true);
  113. Flags.SetCSRawAndStructuredViaShader4X(true);
  114. Flags.SetViewID(true);
  115. Flags.SetBarycentrics(true);
  116. return Flags.GetShaderFlagsRaw();
  117. }
  118. unsigned ShaderFlags::GetGlobalFlags() const {
  119. unsigned Flags = 0;
  120. Flags |= m_bDisableOptimizations ? DXIL::kDisableOptimizations : 0;
  121. Flags |= m_bDisableMathRefactoring ? DXIL::kDisableMathRefactoring : 0;
  122. Flags |= m_bEnableDoublePrecision ? DXIL::kEnableDoublePrecision : 0;
  123. Flags |= m_bForceEarlyDepthStencil ? DXIL::kForceEarlyDepthStencil : 0;
  124. Flags |= m_bEnableRawAndStructuredBuffers ? DXIL::kEnableRawAndStructuredBuffers : 0;
  125. Flags |= m_bLowPrecisionPresent && !m_bUseNativeLowPrecision? DXIL::kEnableMinPrecision : 0;
  126. Flags |= m_bEnableDoubleExtensions ? DXIL::kEnableDoubleExtensions : 0;
  127. Flags |= m_bEnableMSAD ? DXIL::kEnableMSAD : 0;
  128. Flags |= m_bAllResourcesBound ? DXIL::kAllResourcesBound : 0;
  129. return Flags;
  130. }
  131. // Given a CreateHandle call, returns arbitrary ConstantInt rangeID
  132. // Note: HLSL is currently assuming that rangeID is a constant value, but this code is assuming
  133. // that it can be either constant, phi node, or select instruction
  134. static ConstantInt *GetArbitraryConstantRangeID(CallInst *handleCall) {
  135. Value *rangeID =
  136. handleCall->getArgOperand(DXIL::OperandIndex::kCreateHandleResIDOpIdx);
  137. ConstantInt *ConstantRangeID = dyn_cast<ConstantInt>(rangeID);
  138. while (ConstantRangeID == nullptr) {
  139. if (ConstantInt *CI = dyn_cast<ConstantInt>(rangeID)) {
  140. ConstantRangeID = CI;
  141. } else if (PHINode *PN = dyn_cast<PHINode>(rangeID)) {
  142. rangeID = PN->getIncomingValue(0);
  143. } else if (SelectInst *SI = dyn_cast<SelectInst>(rangeID)) {
  144. rangeID = SI->getTrueValue();
  145. } else {
  146. return nullptr;
  147. }
  148. }
  149. return ConstantRangeID;
  150. }
  151. static bool IsResourceSingleComponent(llvm::Type *Ty) {
  152. if (llvm::ArrayType *arrType = llvm::dyn_cast<llvm::ArrayType>(Ty)) {
  153. if (arrType->getArrayNumElements() > 1) {
  154. return false;
  155. }
  156. return IsResourceSingleComponent(arrType->getArrayElementType());
  157. } else if (llvm::StructType *structType =
  158. llvm::dyn_cast<llvm::StructType>(Ty)) {
  159. if (structType->getStructNumElements() > 1) {
  160. return false;
  161. }
  162. return IsResourceSingleComponent(structType->getStructElementType(0));
  163. } else if (llvm::VectorType *vectorType =
  164. llvm::dyn_cast<llvm::VectorType>(Ty)) {
  165. if (vectorType->getNumElements() > 1) {
  166. return false;
  167. }
  168. return IsResourceSingleComponent(vectorType->getVectorElementType());
  169. }
  170. return true;
  171. }
  172. // Given a handle type, find an arbitrary call instructions to create handle
  173. static CallInst *FindCallToCreateHandle(Value *handleType) {
  174. Value *curVal = handleType;
  175. CallInst *CI = dyn_cast<CallInst>(handleType);
  176. while (CI == nullptr) {
  177. if (PHINode *PN = dyn_cast<PHINode>(curVal)) {
  178. curVal = PN->getIncomingValue(0);
  179. }
  180. else if (SelectInst *SI = dyn_cast<SelectInst>(curVal)) {
  181. curVal = SI->getTrueValue();
  182. }
  183. else {
  184. return nullptr;
  185. }
  186. CI = dyn_cast<CallInst>(curVal);
  187. }
  188. return CI;
  189. }
  190. ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
  191. const hlsl::DxilModule *M) {
  192. ShaderFlags flag;
  193. // Module level options
  194. flag.SetUseNativeLowPrecision(!M->GetUseMinPrecision());
  195. flag.SetDisableOptimizations(M->GetDisableOptimization());
  196. flag.SetAllResourcesBound(M->GetAllResourcesBound());
  197. bool hasDouble = false;
  198. // ddiv dfma drcp d2i d2u i2d u2d.
  199. // fma has dxil op. Others should check IR instruction div/cast.
  200. bool hasDoubleExtension = false;
  201. bool has64Int = false;
  202. bool has16 = false;
  203. bool hasWaveOps = false;
  204. bool hasCheckAccessFully = false;
  205. bool hasMSAD = false;
  206. bool hasInnerCoverage = false;
  207. bool hasViewID = false;
  208. bool hasMulticomponentUAVLoads = false;
  209. // Try to maintain compatibility with a v1.0 validator if that's what we have.
  210. uint32_t valMajor, valMinor;
  211. M->GetValidatorVersion(valMajor, valMinor);
  212. bool hasMulticomponentUAVLoadsBackCompat = valMajor == 1 && valMinor == 0;
  213. Type *int16Ty = Type::getInt16Ty(F->getContext());
  214. Type *int64Ty = Type::getInt64Ty(F->getContext());
  215. for (const BasicBlock &BB : F->getBasicBlockList()) {
  216. for (const Instruction &I : BB.getInstList()) {
  217. // Skip none dxil function call.
  218. if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
  219. if (!OP::IsDxilOpFunc(CI->getCalledFunction()))
  220. continue;
  221. }
  222. Type *Ty = I.getType();
  223. bool isDouble = Ty->isDoubleTy();
  224. bool isHalf = Ty->isHalfTy();
  225. bool isInt16 = Ty == int16Ty;
  226. bool isInt64 = Ty == int64Ty;
  227. if (isa<ExtractElementInst>(&I) ||
  228. isa<InsertElementInst>(&I))
  229. continue;
  230. for (Value *operand : I.operands()) {
  231. Type *Ty = operand->getType();
  232. isDouble |= Ty->isDoubleTy();
  233. isHalf |= Ty->isHalfTy();
  234. isInt16 |= Ty == int16Ty;
  235. isInt64 |= Ty == int64Ty;
  236. }
  237. if (isDouble) {
  238. hasDouble = true;
  239. switch (I.getOpcode()) {
  240. case Instruction::FDiv:
  241. case Instruction::UIToFP:
  242. case Instruction::SIToFP:
  243. case Instruction::FPToUI:
  244. case Instruction::FPToSI:
  245. hasDoubleExtension = true;
  246. break;
  247. }
  248. }
  249. has16 |= isHalf;
  250. has16 |= isInt16;
  251. has64Int |= isInt64;
  252. if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
  253. if (!OP::IsDxilOpFunc(CI->getCalledFunction()))
  254. continue;
  255. Value *opcodeArg = CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx);
  256. ConstantInt *opcodeConst = dyn_cast<ConstantInt>(opcodeArg);
  257. DXASSERT(opcodeConst, "DXIL opcode arg must be immediate");
  258. unsigned opcode = opcodeConst->getLimitedValue();
  259. DXASSERT(opcode < static_cast<unsigned>(DXIL::OpCode::NumOpCodes),
  260. "invalid DXIL opcode");
  261. DXIL::OpCode dxilOp = static_cast<DXIL::OpCode>(opcode);
  262. if (hlsl::OP::IsDxilOpWave(dxilOp))
  263. hasWaveOps = true;
  264. switch (dxilOp) {
  265. case DXIL::OpCode::CheckAccessFullyMapped:
  266. hasCheckAccessFully = true;
  267. break;
  268. case DXIL::OpCode::Msad:
  269. hasMSAD = true;
  270. break;
  271. case DXIL::OpCode::BufferLoad:
  272. case DXIL::OpCode::TextureLoad: {
  273. if (hasMulticomponentUAVLoads) continue;
  274. // This is the old-style computation (overestimating requirements).
  275. Value *resHandle = CI->getArgOperand(DXIL::OperandIndex::kBufferStoreHandleOpIdx);
  276. CallInst *handleCall = FindCallToCreateHandle(resHandle);
  277. // Check if this is a library handle or general create handle
  278. if (handleCall) {
  279. ConstantInt *HandleOpCodeConst = cast<ConstantInt>(
  280. handleCall->getArgOperand(DXIL::OperandIndex::kOpcodeIdx));
  281. DXIL::OpCode handleOp = static_cast<DXIL::OpCode>(HandleOpCodeConst->getLimitedValue());
  282. if (handleOp == DXIL::OpCode::CreateHandle) {
  283. if (ConstantInt *resClassArg =
  284. dyn_cast<ConstantInt>(handleCall->getArgOperand(
  285. DXIL::OperandIndex::kCreateHandleResClassOpIdx))) {
  286. DXIL::ResourceClass resClass = static_cast<DXIL::ResourceClass>(
  287. resClassArg->getLimitedValue());
  288. if (resClass == DXIL::ResourceClass::UAV) {
  289. // Validator 1.0 assumes that all uav load is multi component load.
  290. if (hasMulticomponentUAVLoadsBackCompat) {
  291. hasMulticomponentUAVLoads = true;
  292. continue;
  293. }
  294. else {
  295. ConstantInt *rangeID = GetArbitraryConstantRangeID(handleCall);
  296. if (rangeID) {
  297. DxilResource resource = M->GetUAV(rangeID->getLimitedValue());
  298. if ((resource.IsTypedBuffer() ||
  299. resource.IsAnyTexture()) &&
  300. !IsResourceSingleComponent(resource.GetRetType())) {
  301. hasMulticomponentUAVLoads = true;
  302. }
  303. }
  304. }
  305. }
  306. }
  307. else {
  308. DXASSERT(false, "Resource class must be constant.");
  309. }
  310. }
  311. else if (handleOp == DXIL::OpCode::CreateHandleForLib) {
  312. // If library handle, find DxilResource by checking the name
  313. if (LoadInst *LI = dyn_cast<LoadInst>(handleCall->getArgOperand(
  314. DXIL::OperandIndex::
  315. kCreateHandleForLibResOpIdx))) {
  316. Value *resType = LI->getOperand(0);
  317. for (auto &&res : M->GetUAVs()) {
  318. if (res->GetGlobalSymbol() == resType) {
  319. if ((res->IsTypedBuffer() || res->IsAnyTexture()) &&
  320. !IsResourceSingleComponent(res->GetRetType())) {
  321. hasMulticomponentUAVLoads = true;
  322. }
  323. }
  324. }
  325. }
  326. }
  327. }
  328. } break;
  329. case DXIL::OpCode::Fma:
  330. hasDoubleExtension |= isDouble;
  331. break;
  332. case DXIL::OpCode::InnerCoverage:
  333. hasInnerCoverage = true;
  334. break;
  335. case DXIL::OpCode::ViewID:
  336. hasViewID = true;
  337. break;
  338. default:
  339. // Normal opcodes.
  340. break;
  341. }
  342. }
  343. }
  344. }
  345. flag.SetEnableDoublePrecision(hasDouble);
  346. flag.SetInnerCoverage(hasInnerCoverage);
  347. flag.SetInt64Ops(has64Int);
  348. flag.SetLowPrecisionPresent(has16);
  349. flag.SetEnableDoubleExtensions(hasDoubleExtension);
  350. flag.SetWaveOps(hasWaveOps);
  351. flag.SetTiledResources(hasCheckAccessFully);
  352. flag.SetEnableMSAD(hasMSAD);
  353. flag.SetUAVLoadAdditionalFormats(hasMulticomponentUAVLoads);
  354. flag.SetViewID(hasViewID);
  355. return flag;
  356. }
  357. void ShaderFlags::CombineShaderFlags(const ShaderFlags &other) {
  358. SetShaderFlagsRaw(GetShaderFlagsRaw() | other.GetShaderFlagsRaw());
  359. }