inst_buff_addr_check_pass.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. // Copyright (c) 2019 The Khronos Group Inc.
  2. // Copyright (c) 2019 Valve Corporation
  3. // Copyright (c) 2019 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "inst_buff_addr_check_pass.h"
  17. namespace spvtools {
  18. namespace opt {
  19. uint32_t InstBuffAddrCheckPass::CloneOriginalReference(
  20. Instruction* ref_inst, InstructionBuilder* builder) {
  21. // Clone original ref with new result id (if load)
  22. assert(
  23. (ref_inst->opcode() == SpvOpLoad || ref_inst->opcode() == SpvOpStore) &&
  24. "unexpected ref");
  25. std::unique_ptr<Instruction> new_ref_inst(ref_inst->Clone(context()));
  26. uint32_t ref_result_id = ref_inst->result_id();
  27. uint32_t new_ref_id = 0;
  28. if (ref_result_id != 0) {
  29. new_ref_id = TakeNextId();
  30. new_ref_inst->SetResultId(new_ref_id);
  31. }
  32. // Register new reference and add to new block
  33. Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
  34. uid2offset_[added_inst->unique_id()] = uid2offset_[ref_inst->unique_id()];
  35. if (new_ref_id != 0)
  36. get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
  37. return new_ref_id;
  38. }
  39. bool InstBuffAddrCheckPass::IsPhysicalBuffAddrReference(Instruction* ref_inst) {
  40. if (ref_inst->opcode() != SpvOpLoad && ref_inst->opcode() != SpvOpStore)
  41. return false;
  42. uint32_t ptr_id = ref_inst->GetSingleWordInOperand(0);
  43. analysis::DefUseManager* du_mgr = get_def_use_mgr();
  44. Instruction* ptr_inst = du_mgr->GetDef(ptr_id);
  45. if (ptr_inst->opcode() != SpvOpAccessChain) return false;
  46. uint32_t ptr_ty_id = ptr_inst->type_id();
  47. Instruction* ptr_ty_inst = du_mgr->GetDef(ptr_ty_id);
  48. if (ptr_ty_inst->GetSingleWordInOperand(0) !=
  49. SpvStorageClassPhysicalStorageBufferEXT)
  50. return false;
  51. return true;
  52. }
  53. // TODO(greg-lunarg): Refactor with InstBindlessCheckPass::GenCheckCode() ??
  54. void InstBuffAddrCheckPass::GenCheckCode(
  55. uint32_t check_id, uint32_t error_id, uint32_t ref_uptr_id,
  56. uint32_t stage_idx, Instruction* ref_inst,
  57. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  58. BasicBlock* back_blk_ptr = &*new_blocks->back();
  59. InstructionBuilder builder(
  60. context(), back_blk_ptr,
  61. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  62. // Gen conditional branch on check_id. Valid branch generates original
  63. // reference. Invalid generates debug output and zero result (if needed).
  64. uint32_t merge_blk_id = TakeNextId();
  65. uint32_t valid_blk_id = TakeNextId();
  66. uint32_t invalid_blk_id = TakeNextId();
  67. std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
  68. std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
  69. std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
  70. (void)builder.AddConditionalBranch(check_id, valid_blk_id, invalid_blk_id,
  71. merge_blk_id, SpvSelectionControlMaskNone);
  72. // Gen valid branch
  73. std::unique_ptr<BasicBlock> new_blk_ptr(
  74. new BasicBlock(std::move(valid_label)));
  75. builder.SetInsertPoint(&*new_blk_ptr);
  76. uint32_t new_ref_id = CloneOriginalReference(ref_inst, &builder);
  77. (void)builder.AddBranch(merge_blk_id);
  78. new_blocks->push_back(std::move(new_blk_ptr));
  79. // Gen invalid block
  80. new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
  81. builder.SetInsertPoint(&*new_blk_ptr);
  82. // Convert uptr from uint64 to 2 uint32
  83. Instruction* lo_uptr_inst =
  84. builder.AddUnaryOp(GetUintId(), SpvOpUConvert, ref_uptr_id);
  85. Instruction* rshift_uptr_inst =
  86. builder.AddBinaryOp(GetUint64Id(), SpvOpShiftRightLogical, ref_uptr_id,
  87. builder.GetUintConstantId(32));
  88. Instruction* hi_uptr_inst = builder.AddUnaryOp(GetUintId(), SpvOpUConvert,
  89. rshift_uptr_inst->result_id());
  90. GenDebugStreamWrite(
  91. uid2offset_[ref_inst->unique_id()], stage_idx,
  92. {error_id, lo_uptr_inst->result_id(), hi_uptr_inst->result_id()},
  93. &builder);
  94. // Gen zero for invalid load. If pointer type, need to convert uint64
  95. // zero to pointer; cannot create ConstantNull of pointer type.
  96. uint32_t null_id = 0;
  97. if (new_ref_id != 0) {
  98. uint32_t ref_type_id = ref_inst->type_id();
  99. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  100. analysis::Type* ref_type = type_mgr->GetType(ref_type_id);
  101. if (ref_type->AsPointer() != nullptr) {
  102. uint32_t null_u64_id = GetNullId(GetUint64Id());
  103. Instruction* null_ptr_inst =
  104. builder.AddUnaryOp(ref_type_id, SpvOpConvertUToPtr, null_u64_id);
  105. null_id = null_ptr_inst->result_id();
  106. } else {
  107. null_id = GetNullId(ref_type_id);
  108. }
  109. }
  110. (void)builder.AddBranch(merge_blk_id);
  111. new_blocks->push_back(std::move(new_blk_ptr));
  112. // Gen merge block
  113. new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
  114. builder.SetInsertPoint(&*new_blk_ptr);
  115. // Gen phi of new reference and zero, if necessary, and replace the
  116. // result id of the original reference with that of the Phi. Kill original
  117. // reference.
  118. if (new_ref_id != 0) {
  119. Instruction* phi_inst =
  120. builder.AddPhi(ref_inst->type_id(),
  121. {new_ref_id, valid_blk_id, null_id, invalid_blk_id});
  122. context()->ReplaceAllUsesWith(ref_inst->result_id(), phi_inst->result_id());
  123. }
  124. new_blocks->push_back(std::move(new_blk_ptr));
  125. context()->KillInst(ref_inst);
  126. }
  127. uint32_t InstBuffAddrCheckPass::GetTypeAlignment(uint32_t type_id) {
  128. Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
  129. switch (type_inst->opcode()) {
  130. case SpvOpTypeFloat:
  131. case SpvOpTypeInt:
  132. case SpvOpTypeVector:
  133. return GetTypeLength(type_id);
  134. case SpvOpTypeMatrix:
  135. return GetTypeAlignment(type_inst->GetSingleWordInOperand(0));
  136. case SpvOpTypeArray:
  137. case SpvOpTypeRuntimeArray:
  138. return GetTypeAlignment(type_inst->GetSingleWordInOperand(0));
  139. case SpvOpTypeStruct: {
  140. uint32_t max = 0;
  141. type_inst->ForEachInId([&max, this](const uint32_t* iid) {
  142. uint32_t alignment = GetTypeAlignment(*iid);
  143. max = (alignment > max) ? alignment : max;
  144. });
  145. return max;
  146. }
  147. case SpvOpTypePointer:
  148. assert(type_inst->GetSingleWordInOperand(0) ==
  149. SpvStorageClassPhysicalStorageBufferEXT &&
  150. "unexpected pointer type");
  151. return 8u;
  152. default:
  153. assert(false && "unexpected type");
  154. return 0;
  155. }
  156. }
  157. uint32_t InstBuffAddrCheckPass::GetTypeLength(uint32_t type_id) {
  158. Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
  159. switch (type_inst->opcode()) {
  160. case SpvOpTypeFloat:
  161. case SpvOpTypeInt:
  162. return type_inst->GetSingleWordInOperand(0) / 8u;
  163. case SpvOpTypeVector: {
  164. uint32_t raw_cnt = type_inst->GetSingleWordInOperand(1);
  165. uint32_t adj_cnt = (raw_cnt == 3u) ? 4u : raw_cnt;
  166. return adj_cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
  167. }
  168. case SpvOpTypeMatrix:
  169. return type_inst->GetSingleWordInOperand(1) *
  170. GetTypeLength(type_inst->GetSingleWordInOperand(0));
  171. case SpvOpTypePointer:
  172. assert(type_inst->GetSingleWordInOperand(0) ==
  173. SpvStorageClassPhysicalStorageBufferEXT &&
  174. "unexpected pointer type");
  175. return 8u;
  176. case SpvOpTypeArray: {
  177. uint32_t const_id = type_inst->GetSingleWordInOperand(1);
  178. Instruction* const_inst = get_def_use_mgr()->GetDef(const_id);
  179. uint32_t cnt = const_inst->GetSingleWordInOperand(0);
  180. return cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
  181. }
  182. case SpvOpTypeStruct: {
  183. uint32_t len = 0;
  184. type_inst->ForEachInId([&len, this](const uint32_t* iid) {
  185. // Align struct length
  186. uint32_t alignment = GetTypeAlignment(*iid);
  187. uint32_t mod = len % alignment;
  188. uint32_t diff = (mod != 0) ? alignment - mod : 0;
  189. len += diff;
  190. // Increment struct length by component length
  191. uint32_t comp_len = GetTypeLength(*iid);
  192. len += comp_len;
  193. });
  194. return len;
  195. }
  196. case SpvOpTypeRuntimeArray:
  197. default:
  198. assert(false && "unexpected type");
  199. return 0;
  200. }
  201. }
  202. void InstBuffAddrCheckPass::AddParam(uint32_t type_id,
  203. std::vector<uint32_t>* param_vec,
  204. std::unique_ptr<Function>* input_func) {
  205. uint32_t pid = TakeNextId();
  206. param_vec->push_back(pid);
  207. std::unique_ptr<Instruction> param_inst(new Instruction(
  208. get_module()->context(), SpvOpFunctionParameter, type_id, pid, {}));
  209. get_def_use_mgr()->AnalyzeInstDefUse(&*param_inst);
  210. (*input_func)->AddParameter(std::move(param_inst));
  211. }
  212. uint32_t InstBuffAddrCheckPass::GetSearchAndTestFuncId() {
  213. if (search_test_func_id_ == 0) {
  214. // Generate function "bool search_and_test(uint64_t ref_ptr, uint32_t len)"
  215. // which searches input buffer for buffer which most likely contains the
  216. // pointer value |ref_ptr| and verifies that the entire reference of
  217. // length |len| bytes is contained in the buffer.
  218. search_test_func_id_ = TakeNextId();
  219. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  220. std::vector<const analysis::Type*> param_types = {
  221. type_mgr->GetType(GetUint64Id()), type_mgr->GetType(GetUintId())};
  222. analysis::Function func_ty(type_mgr->GetType(GetBoolId()), param_types);
  223. analysis::Type* reg_func_ty = type_mgr->GetRegisteredType(&func_ty);
  224. std::unique_ptr<Instruction> func_inst(
  225. new Instruction(get_module()->context(), SpvOpFunction, GetBoolId(),
  226. search_test_func_id_,
  227. {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
  228. {SpvFunctionControlMaskNone}},
  229. {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
  230. {type_mgr->GetTypeInstruction(reg_func_ty)}}}));
  231. get_def_use_mgr()->AnalyzeInstDefUse(&*func_inst);
  232. std::unique_ptr<Function> input_func =
  233. MakeUnique<Function>(std::move(func_inst));
  234. std::vector<uint32_t> param_vec;
  235. // Add ref_ptr and length parameters
  236. AddParam(GetUint64Id(), &param_vec, &input_func);
  237. AddParam(GetUintId(), &param_vec, &input_func);
  238. // Empty first block.
  239. uint32_t first_blk_id = TakeNextId();
  240. std::unique_ptr<Instruction> first_blk_label(NewLabel(first_blk_id));
  241. std::unique_ptr<BasicBlock> first_blk_ptr =
  242. MakeUnique<BasicBlock>(std::move(first_blk_label));
  243. InstructionBuilder builder(
  244. context(), &*first_blk_ptr,
  245. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  246. uint32_t hdr_blk_id = TakeNextId();
  247. // Branch to search loop header
  248. std::unique_ptr<Instruction> hdr_blk_label(NewLabel(hdr_blk_id));
  249. (void)builder.AddInstruction(MakeUnique<Instruction>(
  250. context(), SpvOpBranch, 0, 0,
  251. std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {hdr_blk_id}}}));
  252. input_func->AddBasicBlock(std::move(first_blk_ptr));
  253. // Linear search loop header block
  254. // TODO(greg-lunarg): Implement binary search
  255. std::unique_ptr<BasicBlock> hdr_blk_ptr =
  256. MakeUnique<BasicBlock>(std::move(hdr_blk_label));
  257. builder.SetInsertPoint(&*hdr_blk_ptr);
  258. // Phi for search index. Starts with 1.
  259. uint32_t cont_blk_id = TakeNextId();
  260. std::unique_ptr<Instruction> cont_blk_label(NewLabel(cont_blk_id));
  261. // Deal with def-use cycle caused by search loop index computation.
  262. // Create Add and Phi instructions first, then do Def analysis on Add.
  263. // Add Phi and Add instructions and do Use analysis later.
  264. uint32_t idx_phi_id = TakeNextId();
  265. uint32_t idx_inc_id = TakeNextId();
  266. std::unique_ptr<Instruction> idx_inc_inst(new Instruction(
  267. context(), SpvOpIAdd, GetUintId(), idx_inc_id,
  268. {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {idx_phi_id}},
  269. {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
  270. {builder.GetUintConstantId(1u)}}}));
  271. std::unique_ptr<Instruction> idx_phi_inst(new Instruction(
  272. context(), SpvOpPhi, GetUintId(), idx_phi_id,
  273. {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
  274. {builder.GetUintConstantId(1u)}},
  275. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {first_blk_id}},
  276. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {idx_inc_id}},
  277. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cont_blk_id}}}));
  278. get_def_use_mgr()->AnalyzeInstDef(&*idx_inc_inst);
  279. // Add (previously created) search index phi
  280. (void)builder.AddInstruction(std::move(idx_phi_inst));
  281. // LoopMerge
  282. uint32_t bound_test_blk_id = TakeNextId();
  283. std::unique_ptr<Instruction> bound_test_blk_label(
  284. NewLabel(bound_test_blk_id));
  285. (void)builder.AddInstruction(MakeUnique<Instruction>(
  286. context(), SpvOpLoopMerge, 0, 0,
  287. std::initializer_list<Operand>{
  288. {SPV_OPERAND_TYPE_ID, {bound_test_blk_id}},
  289. {SPV_OPERAND_TYPE_ID, {cont_blk_id}},
  290. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {SpvLoopControlMaskNone}}}));
  291. // Branch to continue/work block
  292. (void)builder.AddInstruction(MakeUnique<Instruction>(
  293. context(), SpvOpBranch, 0, 0,
  294. std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {cont_blk_id}}}));
  295. input_func->AddBasicBlock(std::move(hdr_blk_ptr));
  296. // Continue/Work Block. Read next buffer pointer and break if greater
  297. // than ref_ptr arg.
  298. std::unique_ptr<BasicBlock> cont_blk_ptr =
  299. MakeUnique<BasicBlock>(std::move(cont_blk_label));
  300. builder.SetInsertPoint(&*cont_blk_ptr);
  301. // Add (previously created) search index increment now.
  302. (void)builder.AddInstruction(std::move(idx_inc_inst));
  303. // Load next buffer address from debug input buffer
  304. uint32_t ibuf_id = GetInputBufferId();
  305. uint32_t ibuf_ptr_id = GetInputBufferPtrId();
  306. Instruction* uptr_ac_inst = builder.AddTernaryOp(
  307. ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
  308. builder.GetUintConstantId(kDebugInputDataOffset), idx_inc_id);
  309. uint32_t ibuf_type_id = GetInputBufferTypeId();
  310. Instruction* uptr_load_inst =
  311. builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, uptr_ac_inst->result_id());
  312. // If loaded address greater than ref_ptr arg, break, else branch back to
  313. // loop header
  314. Instruction* uptr_test_inst =
  315. builder.AddBinaryOp(GetBoolId(), SpvOpUGreaterThan,
  316. uptr_load_inst->result_id(), param_vec[0]);
  317. (void)builder.AddConditionalBranch(uptr_test_inst->result_id(),
  318. bound_test_blk_id, hdr_blk_id,
  319. kInvalidId, SpvSelectionControlMaskNone);
  320. input_func->AddBasicBlock(std::move(cont_blk_ptr));
  321. // Bounds test block. Read length of selected buffer and test that
  322. // all len arg bytes are in buffer.
  323. std::unique_ptr<BasicBlock> bound_test_blk_ptr =
  324. MakeUnique<BasicBlock>(std::move(bound_test_blk_label));
  325. builder.SetInsertPoint(&*bound_test_blk_ptr);
  326. // Decrement index to point to previous/candidate buffer address
  327. Instruction* cand_idx_inst = builder.AddBinaryOp(
  328. GetUintId(), SpvOpISub, idx_inc_id, builder.GetUintConstantId(1u));
  329. // Load candidate buffer address
  330. Instruction* cand_ac_inst =
  331. builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
  332. builder.GetUintConstantId(kDebugInputDataOffset),
  333. cand_idx_inst->result_id());
  334. Instruction* cand_load_inst =
  335. builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, cand_ac_inst->result_id());
  336. // Compute offset of ref_ptr from candidate buffer address
  337. Instruction* offset_inst = builder.AddBinaryOp(
  338. ibuf_type_id, SpvOpISub, param_vec[0], cand_load_inst->result_id());
  339. // Convert ref length to uint64
  340. Instruction* ref_len_64_inst =
  341. builder.AddUnaryOp(ibuf_type_id, SpvOpUConvert, param_vec[1]);
  342. // Add ref length to ref offset to compute end of reference
  343. Instruction* ref_end_inst =
  344. builder.AddBinaryOp(ibuf_type_id, SpvOpIAdd, offset_inst->result_id(),
  345. ref_len_64_inst->result_id());
  346. // Load starting index of lengths in input buffer and convert to uint32
  347. Instruction* len_start_ac_inst =
  348. builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
  349. builder.GetUintConstantId(kDebugInputDataOffset),
  350. builder.GetUintConstantId(0u));
  351. Instruction* len_start_load_inst = builder.AddUnaryOp(
  352. ibuf_type_id, SpvOpLoad, len_start_ac_inst->result_id());
  353. Instruction* len_start_32_inst = builder.AddUnaryOp(
  354. GetUintId(), SpvOpUConvert, len_start_load_inst->result_id());
  355. // Decrement search index to get candidate buffer length index
  356. Instruction* cand_len_idx_inst =
  357. builder.AddBinaryOp(GetUintId(), SpvOpISub, cand_idx_inst->result_id(),
  358. builder.GetUintConstantId(1u));
  359. // Add candidate length index to start index
  360. Instruction* len_idx_inst = builder.AddBinaryOp(
  361. GetUintId(), SpvOpIAdd, cand_len_idx_inst->result_id(),
  362. len_start_32_inst->result_id());
  363. // Load candidate buffer length
  364. Instruction* len_ac_inst =
  365. builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
  366. builder.GetUintConstantId(kDebugInputDataOffset),
  367. len_idx_inst->result_id());
  368. Instruction* len_load_inst =
  369. builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, len_ac_inst->result_id());
  370. // Test if reference end within candidate buffer length
  371. Instruction* len_test_inst = builder.AddBinaryOp(
  372. GetBoolId(), SpvOpULessThanEqual, ref_end_inst->result_id(),
  373. len_load_inst->result_id());
  374. // Return test result
  375. (void)builder.AddInstruction(MakeUnique<Instruction>(
  376. context(), SpvOpReturnValue, 0, 0,
  377. std::initializer_list<Operand>{
  378. {SPV_OPERAND_TYPE_ID, {len_test_inst->result_id()}}}));
  379. // Close block
  380. input_func->AddBasicBlock(std::move(bound_test_blk_ptr));
  381. // Close function and add function to module
  382. std::unique_ptr<Instruction> func_end_inst(
  383. new Instruction(get_module()->context(), SpvOpFunctionEnd, 0, 0, {}));
  384. get_def_use_mgr()->AnalyzeInstDefUse(&*func_end_inst);
  385. input_func->SetFunctionEnd(std::move(func_end_inst));
  386. context()->AddFunction(std::move(input_func));
  387. }
  388. return search_test_func_id_;
  389. }
  390. uint32_t InstBuffAddrCheckPass::GenSearchAndTest(Instruction* ref_inst,
  391. InstructionBuilder* builder,
  392. uint32_t* ref_uptr_id) {
  393. // Enable Int64 if necessary
  394. if (!get_feature_mgr()->HasCapability(SpvCapabilityInt64)) {
  395. std::unique_ptr<Instruction> cap_int64_inst(new Instruction(
  396. context(), SpvOpCapability, 0, 0,
  397. std::initializer_list<Operand>{
  398. {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityInt64}}}));
  399. get_def_use_mgr()->AnalyzeInstDefUse(&*cap_int64_inst);
  400. context()->AddCapability(std::move(cap_int64_inst));
  401. }
  402. // Convert reference pointer to uint64
  403. uint32_t ref_ptr_id = ref_inst->GetSingleWordInOperand(0);
  404. Instruction* ref_uptr_inst =
  405. builder->AddUnaryOp(GetUint64Id(), SpvOpConvertPtrToU, ref_ptr_id);
  406. *ref_uptr_id = ref_uptr_inst->result_id();
  407. // Compute reference length in bytes
  408. analysis::DefUseManager* du_mgr = get_def_use_mgr();
  409. Instruction* ref_ptr_inst = du_mgr->GetDef(ref_ptr_id);
  410. uint32_t ref_ptr_ty_id = ref_ptr_inst->type_id();
  411. Instruction* ref_ptr_ty_inst = du_mgr->GetDef(ref_ptr_ty_id);
  412. uint32_t ref_len = GetTypeLength(ref_ptr_ty_inst->GetSingleWordInOperand(1));
  413. uint32_t ref_len_id = builder->GetUintConstantId(ref_len);
  414. // Gen call to search and test function
  415. const std::vector<uint32_t> args = {GetSearchAndTestFuncId(), *ref_uptr_id,
  416. ref_len_id};
  417. Instruction* call_inst =
  418. builder->AddNaryOp(GetBoolId(), SpvOpFunctionCall, args);
  419. uint32_t retval = call_inst->result_id();
  420. return retval;
  421. }
  422. void InstBuffAddrCheckPass::GenBuffAddrCheckCode(
  423. BasicBlock::iterator ref_inst_itr,
  424. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  425. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  426. // Look for reference through indexed descriptor. If found, analyze and
  427. // save components. If not, return.
  428. Instruction* ref_inst = &*ref_inst_itr;
  429. if (!IsPhysicalBuffAddrReference(ref_inst)) return;
  430. // Move original block's preceding instructions into first new block
  431. std::unique_ptr<BasicBlock> new_blk_ptr;
  432. MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
  433. InstructionBuilder builder(
  434. context(), &*new_blk_ptr,
  435. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  436. new_blocks->push_back(std::move(new_blk_ptr));
  437. uint32_t error_id = builder.GetUintConstantId(kInstErrorBuffAddrUnallocRef);
  438. // Generate code to do search and test if all bytes of reference
  439. // are within a listed buffer. Return reference pointer converted to uint64.
  440. uint32_t ref_uptr_id;
  441. uint32_t valid_id = GenSearchAndTest(ref_inst, &builder, &ref_uptr_id);
  442. // Generate test of search results with true branch
  443. // being full reference and false branch being debug output and zero
  444. // for the referenced value.
  445. GenCheckCode(valid_id, error_id, ref_uptr_id, stage_idx, ref_inst,
  446. new_blocks);
  447. // Move original block's remaining code into remainder/merge block and add
  448. // to new blocks
  449. BasicBlock* back_blk_ptr = &*new_blocks->back();
  450. MovePostludeCode(ref_block_itr, back_blk_ptr);
  451. }
  452. void InstBuffAddrCheckPass::InitInstBuffAddrCheck() {
  453. // Initialize base class
  454. InitializeInstrument();
  455. // Initialize class
  456. search_test_func_id_ = 0;
  457. }
  458. Pass::Status InstBuffAddrCheckPass::ProcessImpl() {
  459. // Perform bindless bounds check on each entry point function in module
  460. InstProcessFunction pfn =
  461. [this](BasicBlock::iterator ref_inst_itr,
  462. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  463. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  464. return GenBuffAddrCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
  465. new_blocks);
  466. };
  467. bool modified = InstProcessEntryPointCallTree(pfn);
  468. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  469. }
  470. Pass::Status InstBuffAddrCheckPass::Process() {
  471. if (!get_feature_mgr()->HasCapability(
  472. SpvCapabilityPhysicalStorageBufferAddressesEXT))
  473. return Status::SuccessWithoutChange;
  474. InitInstBuffAddrCheck();
  475. return ProcessImpl();
  476. }
  477. } // namespace opt
  478. } // namespace spvtools