inst_bindless_check_pass.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. // Copyright (c) 2018 The Khronos Group Inc.
  2. // Copyright (c) 2018 Valve Corporation
  3. // Copyright (c) 2018 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "inst_bindless_check_pass.h"
  17. namespace {
  18. // Input Operand Indices
  19. static const int kSpvImageSampleImageIdInIdx = 0;
  20. static const int kSpvSampledImageImageIdInIdx = 0;
  21. static const int kSpvSampledImageSamplerIdInIdx = 1;
  22. static const int kSpvImageSampledImageIdInIdx = 0;
  23. static const int kSpvLoadPtrIdInIdx = 0;
  24. static const int kSpvAccessChainBaseIdInIdx = 0;
  25. static const int kSpvAccessChainIndex0IdInIdx = 1;
  26. static const int kSpvTypePointerTypeIdInIdx = 1;
  27. static const int kSpvTypeArrayLengthIdInIdx = 1;
  28. static const int kSpvConstantValueInIdx = 0;
  29. } // anonymous namespace
  30. namespace spvtools {
  31. namespace opt {
  32. uint32_t InstBindlessCheckPass::GenDebugReadLength(
  33. uint32_t var_id, InstructionBuilder* builder) {
  34. uint32_t desc_set_idx =
  35. var2desc_set_[var_id] + kDebugInputBindlessOffsetLengths;
  36. uint32_t desc_set_idx_id = builder->GetUintConstantId(desc_set_idx);
  37. uint32_t binding_idx_id = builder->GetUintConstantId(var2binding_[var_id]);
  38. return GenDebugDirectRead({desc_set_idx_id, binding_idx_id}, builder);
  39. }
  40. uint32_t InstBindlessCheckPass::GenDebugReadInit(uint32_t var_id,
  41. uint32_t desc_idx_id,
  42. InstructionBuilder* builder) {
  43. uint32_t desc_set_base_id =
  44. builder->GetUintConstantId(kDebugInputBindlessInitOffset);
  45. uint32_t desc_set_idx_id = builder->GetUintConstantId(var2desc_set_[var_id]);
  46. uint32_t binding_idx_id = builder->GetUintConstantId(var2binding_[var_id]);
  47. uint32_t u_desc_idx_id = GenUintCastCode(desc_idx_id, builder);
  48. return GenDebugDirectRead(
  49. {desc_set_base_id, desc_set_idx_id, binding_idx_id, u_desc_idx_id},
  50. builder);
  51. }
  52. uint32_t InstBindlessCheckPass::CloneOriginalReference(
  53. ref_analysis* ref, InstructionBuilder* builder) {
  54. // Clone descriptor load
  55. Instruction* load_inst = get_def_use_mgr()->GetDef(ref->load_id);
  56. Instruction* new_load_inst =
  57. builder->AddLoad(load_inst->type_id(),
  58. load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx));
  59. uid2offset_[new_load_inst->unique_id()] = uid2offset_[load_inst->unique_id()];
  60. uint32_t new_load_id = new_load_inst->result_id();
  61. get_decoration_mgr()->CloneDecorations(load_inst->result_id(), new_load_id);
  62. uint32_t new_image_id = new_load_id;
  63. // Clone Image/SampledImage with new load, if needed
  64. if (ref->image_id != 0) {
  65. Instruction* image_inst = get_def_use_mgr()->GetDef(ref->image_id);
  66. if (image_inst->opcode() == SpvOp::SpvOpSampledImage) {
  67. Instruction* new_image_inst = builder->AddBinaryOp(
  68. image_inst->type_id(), SpvOpSampledImage, new_load_id,
  69. image_inst->GetSingleWordInOperand(kSpvSampledImageSamplerIdInIdx));
  70. uid2offset_[new_image_inst->unique_id()] =
  71. uid2offset_[image_inst->unique_id()];
  72. new_image_id = new_image_inst->result_id();
  73. } else {
  74. assert(image_inst->opcode() == SpvOp::SpvOpImage && "expecting OpImage");
  75. Instruction* new_image_inst =
  76. builder->AddUnaryOp(image_inst->type_id(), SpvOpImage, new_load_id);
  77. uid2offset_[new_image_inst->unique_id()] =
  78. uid2offset_[image_inst->unique_id()];
  79. new_image_id = new_image_inst->result_id();
  80. }
  81. get_decoration_mgr()->CloneDecorations(ref->image_id, new_image_id);
  82. }
  83. // Clone original reference using new image code
  84. std::unique_ptr<Instruction> new_ref_inst(ref->ref_inst->Clone(context()));
  85. uint32_t ref_result_id = ref->ref_inst->result_id();
  86. uint32_t new_ref_id = 0;
  87. if (ref_result_id != 0) {
  88. new_ref_id = TakeNextId();
  89. new_ref_inst->SetResultId(new_ref_id);
  90. }
  91. new_ref_inst->SetInOperand(kSpvImageSampleImageIdInIdx, {new_image_id});
  92. // Register new reference and add to new block
  93. Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
  94. uid2offset_[added_inst->unique_id()] =
  95. uid2offset_[ref->ref_inst->unique_id()];
  96. if (new_ref_id != 0)
  97. get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
  98. return new_ref_id;
  99. }
  100. uint32_t InstBindlessCheckPass::GetDescriptorValueId(Instruction* inst) {
  101. switch (inst->opcode()) {
  102. case SpvOp::SpvOpImageSampleImplicitLod:
  103. case SpvOp::SpvOpImageSampleExplicitLod:
  104. case SpvOp::SpvOpImageSampleDrefImplicitLod:
  105. case SpvOp::SpvOpImageSampleDrefExplicitLod:
  106. case SpvOp::SpvOpImageSampleProjImplicitLod:
  107. case SpvOp::SpvOpImageSampleProjExplicitLod:
  108. case SpvOp::SpvOpImageSampleProjDrefImplicitLod:
  109. case SpvOp::SpvOpImageSampleProjDrefExplicitLod:
  110. case SpvOp::SpvOpImageGather:
  111. case SpvOp::SpvOpImageDrefGather:
  112. case SpvOp::SpvOpImageQueryLod:
  113. case SpvOp::SpvOpImageSparseSampleImplicitLod:
  114. case SpvOp::SpvOpImageSparseSampleExplicitLod:
  115. case SpvOp::SpvOpImageSparseSampleDrefImplicitLod:
  116. case SpvOp::SpvOpImageSparseSampleDrefExplicitLod:
  117. case SpvOp::SpvOpImageSparseSampleProjImplicitLod:
  118. case SpvOp::SpvOpImageSparseSampleProjExplicitLod:
  119. case SpvOp::SpvOpImageSparseSampleProjDrefImplicitLod:
  120. case SpvOp::SpvOpImageSparseSampleProjDrefExplicitLod:
  121. case SpvOp::SpvOpImageSparseGather:
  122. case SpvOp::SpvOpImageSparseDrefGather:
  123. case SpvOp::SpvOpImageFetch:
  124. case SpvOp::SpvOpImageRead:
  125. case SpvOp::SpvOpImageQueryFormat:
  126. case SpvOp::SpvOpImageQueryOrder:
  127. case SpvOp::SpvOpImageQuerySizeLod:
  128. case SpvOp::SpvOpImageQuerySize:
  129. case SpvOp::SpvOpImageQueryLevels:
  130. case SpvOp::SpvOpImageQuerySamples:
  131. case SpvOp::SpvOpImageSparseFetch:
  132. case SpvOp::SpvOpImageSparseRead:
  133. case SpvOp::SpvOpImageWrite:
  134. return inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
  135. default:
  136. break;
  137. }
  138. return 0;
  139. }
  140. bool InstBindlessCheckPass::AnalyzeDescriptorReference(Instruction* ref_inst,
  141. ref_analysis* ref) {
  142. ref->image_id = GetDescriptorValueId(ref_inst);
  143. if (ref->image_id == 0) return false;
  144. Instruction* image_inst = get_def_use_mgr()->GetDef(ref->image_id);
  145. if (image_inst->opcode() == SpvOp::SpvOpSampledImage) {
  146. ref->load_id =
  147. image_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx);
  148. } else if (image_inst->opcode() == SpvOp::SpvOpImage) {
  149. ref->load_id =
  150. image_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx);
  151. } else {
  152. ref->load_id = ref->image_id;
  153. ref->image_id = 0;
  154. }
  155. Instruction* load_inst = get_def_use_mgr()->GetDef(ref->load_id);
  156. if (load_inst->opcode() != SpvOp::SpvOpLoad) {
  157. // TODO(greg-lunarg): Handle additional possibilities?
  158. return false;
  159. }
  160. ref->ptr_id = load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
  161. Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
  162. if (ptr_inst->opcode() == SpvOp::SpvOpVariable) {
  163. ref->index_id = 0;
  164. ref->var_id = ref->ptr_id;
  165. } else if (ptr_inst->opcode() == SpvOp::SpvOpAccessChain) {
  166. if (ptr_inst->NumInOperands() != 2) {
  167. assert(false && "unexpected bindless index number");
  168. return false;
  169. }
  170. ref->index_id =
  171. ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
  172. ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
  173. Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
  174. if (var_inst->opcode() != SpvOpVariable) {
  175. assert(false && "unexpected bindless base");
  176. return false;
  177. }
  178. } else {
  179. // TODO(greg-lunarg): Handle additional possibilities?
  180. return false;
  181. }
  182. ref->ref_inst = ref_inst;
  183. return true;
  184. }
  185. void InstBindlessCheckPass::GenCheckCode(
  186. uint32_t check_id, uint32_t error_id, uint32_t length_id,
  187. uint32_t stage_idx, ref_analysis* ref,
  188. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  189. BasicBlock* back_blk_ptr = &*new_blocks->back();
  190. InstructionBuilder builder(
  191. context(), back_blk_ptr,
  192. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  193. // Gen conditional branch on check_id. Valid branch generates original
  194. // reference. Invalid generates debug output and zero result (if needed).
  195. uint32_t merge_blk_id = TakeNextId();
  196. uint32_t valid_blk_id = TakeNextId();
  197. uint32_t invalid_blk_id = TakeNextId();
  198. std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
  199. std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
  200. std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
  201. (void)builder.AddConditionalBranch(check_id, valid_blk_id, invalid_blk_id,
  202. merge_blk_id, SpvSelectionControlMaskNone);
  203. // Gen valid bounds branch
  204. std::unique_ptr<BasicBlock> new_blk_ptr(
  205. new BasicBlock(std::move(valid_label)));
  206. builder.SetInsertPoint(&*new_blk_ptr);
  207. uint32_t new_ref_id = CloneOriginalReference(ref, &builder);
  208. (void)builder.AddBranch(merge_blk_id);
  209. new_blocks->push_back(std::move(new_blk_ptr));
  210. // Gen invalid block
  211. new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
  212. builder.SetInsertPoint(&*new_blk_ptr);
  213. uint32_t u_index_id = GenUintCastCode(ref->index_id, &builder);
  214. GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
  215. {error_id, u_index_id, length_id}, &builder);
  216. // Remember last invalid block id
  217. uint32_t last_invalid_blk_id = new_blk_ptr->GetLabelInst()->result_id();
  218. // Gen zero for invalid reference
  219. uint32_t ref_type_id = ref->ref_inst->type_id();
  220. (void)builder.AddBranch(merge_blk_id);
  221. new_blocks->push_back(std::move(new_blk_ptr));
  222. // Gen merge block
  223. new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
  224. builder.SetInsertPoint(&*new_blk_ptr);
  225. // Gen phi of new reference and zero, if necessary, and replace the
  226. // result id of the original reference with that of the Phi. Kill original
  227. // reference.
  228. if (new_ref_id != 0) {
  229. Instruction* phi_inst = builder.AddPhi(
  230. ref_type_id, {new_ref_id, valid_blk_id, builder.GetNullId(ref_type_id),
  231. last_invalid_blk_id});
  232. context()->ReplaceAllUsesWith(ref->ref_inst->result_id(),
  233. phi_inst->result_id());
  234. }
  235. new_blocks->push_back(std::move(new_blk_ptr));
  236. context()->KillInst(ref->ref_inst);
  237. }
  238. void InstBindlessCheckPass::GenBoundsCheckCode(
  239. BasicBlock::iterator ref_inst_itr,
  240. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  241. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  242. // Look for reference through indexed descriptor. If found, analyze and
  243. // save components. If not, return.
  244. ref_analysis ref;
  245. if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
  246. Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
  247. if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return;
  248. // If index and bound both compile-time constants and index < bound,
  249. // return without changing
  250. Instruction* var_inst = get_def_use_mgr()->GetDef(ref.var_id);
  251. uint32_t var_type_id = var_inst->type_id();
  252. Instruction* var_type_inst = get_def_use_mgr()->GetDef(var_type_id);
  253. uint32_t desc_type_id =
  254. var_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
  255. Instruction* desc_type_inst = get_def_use_mgr()->GetDef(desc_type_id);
  256. uint32_t length_id = 0;
  257. if (desc_type_inst->opcode() == SpvOpTypeArray) {
  258. length_id =
  259. desc_type_inst->GetSingleWordInOperand(kSpvTypeArrayLengthIdInIdx);
  260. Instruction* index_inst = get_def_use_mgr()->GetDef(ref.index_id);
  261. Instruction* length_inst = get_def_use_mgr()->GetDef(length_id);
  262. if (index_inst->opcode() == SpvOpConstant &&
  263. length_inst->opcode() == SpvOpConstant &&
  264. index_inst->GetSingleWordInOperand(kSpvConstantValueInIdx) <
  265. length_inst->GetSingleWordInOperand(kSpvConstantValueInIdx))
  266. return;
  267. } else if (!input_length_enabled_ ||
  268. desc_type_inst->opcode() != SpvOpTypeRuntimeArray) {
  269. return;
  270. }
  271. // Move original block's preceding instructions into first new block
  272. std::unique_ptr<BasicBlock> new_blk_ptr;
  273. MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
  274. InstructionBuilder builder(
  275. context(), &*new_blk_ptr,
  276. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  277. new_blocks->push_back(std::move(new_blk_ptr));
  278. uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessBounds);
  279. // If length id not yet set, descriptor array is runtime size so
  280. // generate load of length from stage's debug input buffer.
  281. if (length_id == 0) {
  282. assert(desc_type_inst->opcode() == SpvOpTypeRuntimeArray &&
  283. "unexpected bindless type");
  284. length_id = GenDebugReadLength(ref.var_id, &builder);
  285. }
  286. // Generate full runtime bounds test code with true branch
  287. // being full reference and false branch being debug output and zero
  288. // for the referenced value.
  289. Instruction* ult_inst =
  290. builder.AddBinaryOp(GetBoolId(), SpvOpULessThan, ref.index_id, length_id);
  291. GenCheckCode(ult_inst->result_id(), error_id, length_id, stage_idx, &ref,
  292. new_blocks);
  293. // Move original block's remaining code into remainder/merge block and add
  294. // to new blocks
  295. BasicBlock* back_blk_ptr = &*new_blocks->back();
  296. MovePostludeCode(ref_block_itr, back_blk_ptr);
  297. }
  298. void InstBindlessCheckPass::GenInitCheckCode(
  299. BasicBlock::iterator ref_inst_itr,
  300. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  301. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  302. // Look for reference through descriptor. If not, return.
  303. ref_analysis ref;
  304. if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
  305. // Move original block's preceding instructions into first new block
  306. std::unique_ptr<BasicBlock> new_blk_ptr;
  307. MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
  308. InstructionBuilder builder(
  309. context(), &*new_blk_ptr,
  310. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  311. new_blocks->push_back(std::move(new_blk_ptr));
  312. // Read initialization status from debug input buffer. If index id not yet
  313. // set, binding is single descriptor, so set index to constant 0.
  314. uint32_t zero_id = builder.GetUintConstantId(0u);
  315. if (ref.index_id == 0) ref.index_id = zero_id;
  316. uint32_t init_id = GenDebugReadInit(ref.var_id, ref.index_id, &builder);
  317. // Generate full runtime non-zero init test code with true branch
  318. // being full reference and false branch being debug output and zero
  319. // for the referenced value.
  320. Instruction* uneq_inst =
  321. builder.AddBinaryOp(GetBoolId(), SpvOpINotEqual, init_id, zero_id);
  322. uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessUninit);
  323. GenCheckCode(uneq_inst->result_id(), error_id, zero_id, stage_idx, &ref,
  324. new_blocks);
  325. // Move original block's remaining code into remainder/merge block and add
  326. // to new blocks
  327. BasicBlock* back_blk_ptr = &*new_blocks->back();
  328. MovePostludeCode(ref_block_itr, back_blk_ptr);
  329. }
  330. void InstBindlessCheckPass::InitializeInstBindlessCheck() {
  331. // Initialize base class
  332. InitializeInstrument();
  333. // Look for related extensions
  334. ext_descriptor_indexing_defined_ = false;
  335. for (auto& ei : get_module()->extensions()) {
  336. const char* ext_name =
  337. reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
  338. if (strcmp(ext_name, "SPV_EXT_descriptor_indexing") == 0) {
  339. ext_descriptor_indexing_defined_ = true;
  340. break;
  341. }
  342. }
  343. // If descriptor indexing extension and runtime array length support enabled,
  344. // create variable mappings. Length support is always enabled if descriptor
  345. // init check is enabled.
  346. if (ext_descriptor_indexing_defined_ && input_length_enabled_)
  347. for (auto& anno : get_module()->annotations())
  348. if (anno.opcode() == SpvOpDecorate) {
  349. if (anno.GetSingleWordInOperand(1u) == SpvDecorationDescriptorSet)
  350. var2desc_set_[anno.GetSingleWordInOperand(0u)] =
  351. anno.GetSingleWordInOperand(2u);
  352. else if (anno.GetSingleWordInOperand(1u) == SpvDecorationBinding)
  353. var2binding_[anno.GetSingleWordInOperand(0u)] =
  354. anno.GetSingleWordInOperand(2u);
  355. }
  356. }
  357. Pass::Status InstBindlessCheckPass::ProcessImpl() {
  358. // Perform bindless bounds check on each entry point function in module
  359. InstProcessFunction pfn =
  360. [this](BasicBlock::iterator ref_inst_itr,
  361. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  362. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  363. return GenBoundsCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
  364. new_blocks);
  365. };
  366. bool modified = InstProcessEntryPointCallTree(pfn);
  367. if (ext_descriptor_indexing_defined_ && input_init_enabled_) {
  368. // Perform descriptor initialization check on each entry point function in
  369. // module
  370. pfn = [this](BasicBlock::iterator ref_inst_itr,
  371. UptrVectorIterator<BasicBlock> ref_block_itr,
  372. uint32_t stage_idx,
  373. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  374. return GenInitCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
  375. new_blocks);
  376. };
  377. modified |= InstProcessEntryPointCallTree(pfn);
  378. }
  379. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  380. }
  381. Pass::Status InstBindlessCheckPass::Process() {
  382. InitializeInstBindlessCheck();
  383. return ProcessImpl();
  384. }
  385. } // namespace opt
  386. } // namespace spvtools