inst_bindless_check_pass.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. // Copyright (c) 2018 The Khronos Group Inc.
  2. // Copyright (c) 2018 Valve Corporation
  3. // Copyright (c) 2018 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "inst_bindless_check_pass.h"
  17. namespace {
  18. // Input Operand Indices
  19. static const int kSpvImageSampleImageIdInIdx = 0;
  20. static const int kSpvSampledImageImageIdInIdx = 0;
  21. static const int kSpvSampledImageSamplerIdInIdx = 1;
  22. static const int kSpvImageSampledImageIdInIdx = 0;
  23. static const int kSpvLoadPtrIdInIdx = 0;
  24. static const int kSpvAccessChainBaseIdInIdx = 0;
  25. static const int kSpvAccessChainIndex0IdInIdx = 1;
  26. static const int kSpvTypeArrayLengthIdInIdx = 1;
  27. static const int kSpvConstantValueInIdx = 0;
  28. static const int kSpvVariableStorageClassInIdx = 0;
  29. } // anonymous namespace
  30. // Avoid unused variable warning/error on Linux
  31. #ifndef NDEBUG
  32. #define USE_ASSERT(x) assert(x)
  33. #else
  34. #define USE_ASSERT(x) ((void)(x))
  35. #endif
  36. namespace spvtools {
  37. namespace opt {
  38. uint32_t InstBindlessCheckPass::GenDebugReadLength(
  39. uint32_t var_id, InstructionBuilder* builder) {
  40. uint32_t desc_set_idx =
  41. var2desc_set_[var_id] + kDebugInputBindlessOffsetLengths;
  42. uint32_t desc_set_idx_id = builder->GetUintConstantId(desc_set_idx);
  43. uint32_t binding_idx_id = builder->GetUintConstantId(var2binding_[var_id]);
  44. return GenDebugDirectRead({desc_set_idx_id, binding_idx_id}, builder);
  45. }
  46. uint32_t InstBindlessCheckPass::GenDebugReadInit(uint32_t var_id,
  47. uint32_t desc_idx_id,
  48. InstructionBuilder* builder) {
  49. uint32_t binding_idx_id = builder->GetUintConstantId(var2binding_[var_id]);
  50. uint32_t u_desc_idx_id = GenUintCastCode(desc_idx_id, builder);
  51. // If desc index checking is not enabled, we know the offset of initialization
  52. // entries is 1, so we can avoid loading this value and just add 1 to the
  53. // descriptor set.
  54. if (!desc_idx_enabled_) {
  55. uint32_t desc_set_idx_id =
  56. builder->GetUintConstantId(var2desc_set_[var_id] + 1);
  57. return GenDebugDirectRead({desc_set_idx_id, binding_idx_id, u_desc_idx_id},
  58. builder);
  59. } else {
  60. uint32_t desc_set_base_id =
  61. builder->GetUintConstantId(kDebugInputBindlessInitOffset);
  62. uint32_t desc_set_idx_id =
  63. builder->GetUintConstantId(var2desc_set_[var_id]);
  64. return GenDebugDirectRead(
  65. {desc_set_base_id, desc_set_idx_id, binding_idx_id, u_desc_idx_id},
  66. builder);
  67. }
  68. }
  69. uint32_t InstBindlessCheckPass::CloneOriginalReference(
  70. ref_analysis* ref, InstructionBuilder* builder) {
  71. // If original is image based, start by cloning descriptor load
  72. uint32_t new_image_id = 0;
  73. if (ref->desc_load_id != 0) {
  74. Instruction* desc_load_inst = get_def_use_mgr()->GetDef(ref->desc_load_id);
  75. Instruction* new_load_inst = builder->AddLoad(
  76. desc_load_inst->type_id(),
  77. desc_load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx));
  78. uid2offset_[new_load_inst->unique_id()] =
  79. uid2offset_[desc_load_inst->unique_id()];
  80. uint32_t new_load_id = new_load_inst->result_id();
  81. get_decoration_mgr()->CloneDecorations(desc_load_inst->result_id(),
  82. new_load_id);
  83. new_image_id = new_load_id;
  84. // Clone Image/SampledImage with new load, if needed
  85. if (ref->image_id != 0) {
  86. Instruction* image_inst = get_def_use_mgr()->GetDef(ref->image_id);
  87. if (image_inst->opcode() == SpvOp::SpvOpSampledImage) {
  88. Instruction* new_image_inst = builder->AddBinaryOp(
  89. image_inst->type_id(), SpvOpSampledImage, new_load_id,
  90. image_inst->GetSingleWordInOperand(kSpvSampledImageSamplerIdInIdx));
  91. uid2offset_[new_image_inst->unique_id()] =
  92. uid2offset_[image_inst->unique_id()];
  93. new_image_id = new_image_inst->result_id();
  94. } else {
  95. assert(image_inst->opcode() == SpvOp::SpvOpImage &&
  96. "expecting OpImage");
  97. Instruction* new_image_inst =
  98. builder->AddUnaryOp(image_inst->type_id(), SpvOpImage, new_load_id);
  99. uid2offset_[new_image_inst->unique_id()] =
  100. uid2offset_[image_inst->unique_id()];
  101. new_image_id = new_image_inst->result_id();
  102. }
  103. get_decoration_mgr()->CloneDecorations(ref->image_id, new_image_id);
  104. }
  105. }
  106. // Clone original reference
  107. std::unique_ptr<Instruction> new_ref_inst(ref->ref_inst->Clone(context()));
  108. uint32_t ref_result_id = ref->ref_inst->result_id();
  109. uint32_t new_ref_id = 0;
  110. if (ref_result_id != 0) {
  111. new_ref_id = TakeNextId();
  112. new_ref_inst->SetResultId(new_ref_id);
  113. }
  114. // Update new ref with new image if created
  115. if (new_image_id != 0)
  116. new_ref_inst->SetInOperand(kSpvImageSampleImageIdInIdx, {new_image_id});
  117. // Register new reference and add to new block
  118. Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
  119. uid2offset_[added_inst->unique_id()] =
  120. uid2offset_[ref->ref_inst->unique_id()];
  121. if (new_ref_id != 0)
  122. get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
  123. return new_ref_id;
  124. }
  125. uint32_t InstBindlessCheckPass::GetImageId(Instruction* inst) {
  126. switch (inst->opcode()) {
  127. case SpvOp::SpvOpImageSampleImplicitLod:
  128. case SpvOp::SpvOpImageSampleExplicitLod:
  129. case SpvOp::SpvOpImageSampleDrefImplicitLod:
  130. case SpvOp::SpvOpImageSampleDrefExplicitLod:
  131. case SpvOp::SpvOpImageSampleProjImplicitLod:
  132. case SpvOp::SpvOpImageSampleProjExplicitLod:
  133. case SpvOp::SpvOpImageSampleProjDrefImplicitLod:
  134. case SpvOp::SpvOpImageSampleProjDrefExplicitLod:
  135. case SpvOp::SpvOpImageGather:
  136. case SpvOp::SpvOpImageDrefGather:
  137. case SpvOp::SpvOpImageQueryLod:
  138. case SpvOp::SpvOpImageSparseSampleImplicitLod:
  139. case SpvOp::SpvOpImageSparseSampleExplicitLod:
  140. case SpvOp::SpvOpImageSparseSampleDrefImplicitLod:
  141. case SpvOp::SpvOpImageSparseSampleDrefExplicitLod:
  142. case SpvOp::SpvOpImageSparseSampleProjImplicitLod:
  143. case SpvOp::SpvOpImageSparseSampleProjExplicitLod:
  144. case SpvOp::SpvOpImageSparseSampleProjDrefImplicitLod:
  145. case SpvOp::SpvOpImageSparseSampleProjDrefExplicitLod:
  146. case SpvOp::SpvOpImageSparseGather:
  147. case SpvOp::SpvOpImageSparseDrefGather:
  148. case SpvOp::SpvOpImageFetch:
  149. case SpvOp::SpvOpImageRead:
  150. case SpvOp::SpvOpImageQueryFormat:
  151. case SpvOp::SpvOpImageQueryOrder:
  152. case SpvOp::SpvOpImageQuerySizeLod:
  153. case SpvOp::SpvOpImageQuerySize:
  154. case SpvOp::SpvOpImageQueryLevels:
  155. case SpvOp::SpvOpImageQuerySamples:
  156. case SpvOp::SpvOpImageSparseFetch:
  157. case SpvOp::SpvOpImageSparseRead:
  158. case SpvOp::SpvOpImageWrite:
  159. return inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
  160. default:
  161. break;
  162. }
  163. return 0;
  164. }
  165. Instruction* InstBindlessCheckPass::GetPointeeTypeInst(Instruction* ptr_inst) {
  166. uint32_t pte_ty_id = GetPointeeTypeId(ptr_inst);
  167. return get_def_use_mgr()->GetDef(pte_ty_id);
  168. }
  169. bool InstBindlessCheckPass::AnalyzeDescriptorReference(Instruction* ref_inst,
  170. ref_analysis* ref) {
  171. ref->ref_inst = ref_inst;
  172. if (ref_inst->opcode() == SpvOpLoad || ref_inst->opcode() == SpvOpStore) {
  173. ref->desc_load_id = 0;
  174. ref->ptr_id = ref_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
  175. Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
  176. if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return false;
  177. ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
  178. Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
  179. if (var_inst->opcode() != SpvOp::SpvOpVariable) return false;
  180. uint32_t storage_class =
  181. var_inst->GetSingleWordInOperand(kSpvVariableStorageClassInIdx);
  182. switch (storage_class) {
  183. case SpvStorageClassUniform:
  184. case SpvStorageClassUniformConstant:
  185. case SpvStorageClassStorageBuffer:
  186. break;
  187. default:
  188. return false;
  189. break;
  190. }
  191. Instruction* desc_type_inst = GetPointeeTypeInst(var_inst);
  192. switch (desc_type_inst->opcode()) {
  193. case SpvOpTypeArray:
  194. case SpvOpTypeRuntimeArray:
  195. // A load through a descriptor array will have at least 3 operands. We
  196. // do not want to instrument loads of descriptors here which are part of
  197. // an image-based reference.
  198. if (ptr_inst->NumInOperands() < 3) return false;
  199. ref->desc_idx_id =
  200. ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
  201. break;
  202. default:
  203. ref->desc_idx_id = 0;
  204. break;
  205. }
  206. return true;
  207. }
  208. // Reference is not load or store. If not an image-based reference, return.
  209. ref->image_id = GetImageId(ref_inst);
  210. if (ref->image_id == 0) return false;
  211. Instruction* image_inst = get_def_use_mgr()->GetDef(ref->image_id);
  212. Instruction* desc_load_inst = nullptr;
  213. if (image_inst->opcode() == SpvOp::SpvOpSampledImage) {
  214. ref->desc_load_id =
  215. image_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx);
  216. desc_load_inst = get_def_use_mgr()->GetDef(ref->desc_load_id);
  217. } else if (image_inst->opcode() == SpvOp::SpvOpImage) {
  218. ref->desc_load_id =
  219. image_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx);
  220. desc_load_inst = get_def_use_mgr()->GetDef(ref->desc_load_id);
  221. } else {
  222. ref->desc_load_id = ref->image_id;
  223. desc_load_inst = image_inst;
  224. ref->image_id = 0;
  225. }
  226. if (desc_load_inst->opcode() != SpvOp::SpvOpLoad) {
  227. // TODO(greg-lunarg): Handle additional possibilities?
  228. return false;
  229. }
  230. ref->ptr_id = desc_load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
  231. Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
  232. if (ptr_inst->opcode() == SpvOp::SpvOpVariable) {
  233. ref->desc_idx_id = 0;
  234. ref->var_id = ref->ptr_id;
  235. } else if (ptr_inst->opcode() == SpvOp::SpvOpAccessChain) {
  236. if (ptr_inst->NumInOperands() != 2) {
  237. assert(false && "unexpected bindless index number");
  238. return false;
  239. }
  240. ref->desc_idx_id =
  241. ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
  242. ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
  243. Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
  244. if (var_inst->opcode() != SpvOpVariable) {
  245. assert(false && "unexpected bindless base");
  246. return false;
  247. }
  248. } else {
  249. // TODO(greg-lunarg): Handle additional possibilities?
  250. return false;
  251. }
  252. return true;
  253. }
  254. uint32_t InstBindlessCheckPass::FindStride(uint32_t ty_id,
  255. uint32_t stride_deco) {
  256. uint32_t stride = 0xdeadbeef;
  257. bool found = !get_decoration_mgr()->WhileEachDecoration(
  258. ty_id, stride_deco, [&stride](const Instruction& deco_inst) {
  259. stride = deco_inst.GetSingleWordInOperand(2u);
  260. return false;
  261. });
  262. USE_ASSERT(found && "stride not found");
  263. return stride;
  264. }
  265. uint32_t InstBindlessCheckPass::ByteSize(uint32_t ty_id) {
  266. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  267. const analysis::Type* sz_ty = type_mgr->GetType(ty_id);
  268. if (sz_ty->kind() == analysis::Type::kPointer) {
  269. // Assuming PhysicalStorageBuffer pointer
  270. return 8;
  271. }
  272. uint32_t size = 1;
  273. if (sz_ty->kind() == analysis::Type::kMatrix) {
  274. const analysis::Matrix* m_ty = sz_ty->AsMatrix();
  275. size = m_ty->element_count() * size;
  276. uint32_t stride = FindStride(ty_id, SpvDecorationMatrixStride);
  277. if (stride != 0) return size * stride;
  278. sz_ty = m_ty->element_type();
  279. }
  280. if (sz_ty->kind() == analysis::Type::kVector) {
  281. const analysis::Vector* v_ty = sz_ty->AsVector();
  282. size = v_ty->element_count() * size;
  283. sz_ty = v_ty->element_type();
  284. }
  285. switch (sz_ty->kind()) {
  286. case analysis::Type::kFloat: {
  287. const analysis::Float* f_ty = sz_ty->AsFloat();
  288. size *= f_ty->width();
  289. } break;
  290. case analysis::Type::kInteger: {
  291. const analysis::Integer* i_ty = sz_ty->AsInteger();
  292. size *= i_ty->width();
  293. } break;
  294. default: { assert(false && "unexpected type"); } break;
  295. }
  296. size /= 8;
  297. return size;
  298. }
  299. uint32_t InstBindlessCheckPass::GenLastByteIdx(ref_analysis* ref,
  300. InstructionBuilder* builder) {
  301. // Find outermost buffer type and its access chain index
  302. Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
  303. Instruction* desc_ty_inst = GetPointeeTypeInst(var_inst);
  304. uint32_t buff_ty_id;
  305. uint32_t ac_in_idx = 1;
  306. switch (desc_ty_inst->opcode()) {
  307. case SpvOpTypeArray:
  308. case SpvOpTypeRuntimeArray:
  309. buff_ty_id = desc_ty_inst->GetSingleWordInOperand(0);
  310. ++ac_in_idx;
  311. break;
  312. default:
  313. assert(desc_ty_inst->opcode() == SpvOpTypeStruct &&
  314. "unexpected descriptor type");
  315. buff_ty_id = desc_ty_inst->result_id();
  316. break;
  317. }
  318. // Process remaining access chain indices
  319. Instruction* ac_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
  320. uint32_t curr_ty_id = buff_ty_id;
  321. uint32_t sum_id = 0;
  322. while (ac_in_idx < ac_inst->NumInOperands()) {
  323. uint32_t curr_idx_id = ac_inst->GetSingleWordInOperand(ac_in_idx);
  324. Instruction* curr_idx_inst = get_def_use_mgr()->GetDef(curr_idx_id);
  325. Instruction* curr_ty_inst = get_def_use_mgr()->GetDef(curr_ty_id);
  326. uint32_t curr_offset_id = 0;
  327. switch (curr_ty_inst->opcode()) {
  328. case SpvOpTypeArray:
  329. case SpvOpTypeRuntimeArray:
  330. case SpvOpTypeMatrix: {
  331. // Get array/matrix stride and multiply by current index
  332. uint32_t stride_deco = (curr_ty_inst->opcode() == SpvOpTypeMatrix)
  333. ? SpvDecorationMatrixStride
  334. : SpvDecorationArrayStride;
  335. uint32_t arr_stride = FindStride(curr_ty_id, stride_deco);
  336. uint32_t arr_stride_id = builder->GetUintConstantId(arr_stride);
  337. Instruction* curr_offset_inst = builder->AddBinaryOp(
  338. GetUintId(), SpvOpIMul, arr_stride_id, curr_idx_id);
  339. curr_offset_id = curr_offset_inst->result_id();
  340. // Get element type for next step
  341. curr_ty_id = curr_ty_inst->GetSingleWordInOperand(0);
  342. } break;
  343. case SpvOpTypeVector: {
  344. // Stride is size of component type
  345. uint32_t comp_ty_id = curr_ty_inst->GetSingleWordInOperand(0u);
  346. uint32_t vec_stride = ByteSize(comp_ty_id);
  347. uint32_t vec_stride_id = builder->GetUintConstantId(vec_stride);
  348. Instruction* curr_offset_inst = builder->AddBinaryOp(
  349. GetUintId(), SpvOpIMul, vec_stride_id, curr_idx_id);
  350. curr_offset_id = curr_offset_inst->result_id();
  351. // Get element type for next step
  352. curr_ty_id = comp_ty_id;
  353. } break;
  354. case SpvOpTypeStruct: {
  355. // Get buffer byte offset for the referenced member
  356. assert(curr_idx_inst->opcode() == SpvOpConstant &&
  357. "unexpected struct index");
  358. uint32_t member_idx = curr_idx_inst->GetSingleWordInOperand(0);
  359. uint32_t member_offset = 0xdeadbeef;
  360. bool found = !get_decoration_mgr()->WhileEachDecoration(
  361. curr_ty_id, SpvDecorationOffset,
  362. [&member_idx, &member_offset](const Instruction& deco_inst) {
  363. if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
  364. return true;
  365. member_offset = deco_inst.GetSingleWordInOperand(3u);
  366. return false;
  367. });
  368. USE_ASSERT(found && "member offset not found");
  369. curr_offset_id = builder->GetUintConstantId(member_offset);
  370. // Get element type for next step
  371. curr_ty_id = curr_ty_inst->GetSingleWordInOperand(member_idx);
  372. } break;
  373. default: { assert(false && "unexpected non-composite type"); } break;
  374. }
  375. if (sum_id == 0)
  376. sum_id = curr_offset_id;
  377. else {
  378. Instruction* sum_inst =
  379. builder->AddBinaryOp(GetUintId(), SpvOpIAdd, sum_id, curr_offset_id);
  380. sum_id = sum_inst->result_id();
  381. }
  382. ++ac_in_idx;
  383. }
  384. // Add in offset of last byte of referenced object
  385. uint32_t bsize = ByteSize(curr_ty_id);
  386. uint32_t last = bsize - 1;
  387. uint32_t last_id = builder->GetUintConstantId(last);
  388. Instruction* sum_inst =
  389. builder->AddBinaryOp(GetUintId(), SpvOpIAdd, sum_id, last_id);
  390. return sum_inst->result_id();
  391. }
  392. void InstBindlessCheckPass::GenCheckCode(
  393. uint32_t check_id, uint32_t error_id, uint32_t offset_id,
  394. uint32_t length_id, uint32_t stage_idx, ref_analysis* ref,
  395. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  396. BasicBlock* back_blk_ptr = &*new_blocks->back();
  397. InstructionBuilder builder(
  398. context(), back_blk_ptr,
  399. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  400. // Gen conditional branch on check_id. Valid branch generates original
  401. // reference. Invalid generates debug output and zero result (if needed).
  402. uint32_t merge_blk_id = TakeNextId();
  403. uint32_t valid_blk_id = TakeNextId();
  404. uint32_t invalid_blk_id = TakeNextId();
  405. std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
  406. std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
  407. std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
  408. (void)builder.AddConditionalBranch(check_id, valid_blk_id, invalid_blk_id,
  409. merge_blk_id, SpvSelectionControlMaskNone);
  410. // Gen valid bounds branch
  411. std::unique_ptr<BasicBlock> new_blk_ptr(
  412. new BasicBlock(std::move(valid_label)));
  413. builder.SetInsertPoint(&*new_blk_ptr);
  414. uint32_t new_ref_id = CloneOriginalReference(ref, &builder);
  415. (void)builder.AddBranch(merge_blk_id);
  416. new_blocks->push_back(std::move(new_blk_ptr));
  417. // Gen invalid block
  418. new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
  419. builder.SetInsertPoint(&*new_blk_ptr);
  420. uint32_t u_index_id = GenUintCastCode(ref->desc_idx_id, &builder);
  421. if (offset_id != 0)
  422. GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
  423. {error_id, u_index_id, offset_id, length_id}, &builder);
  424. else if (buffer_bounds_enabled_)
  425. // So all error modes will use same debug stream write function
  426. GenDebugStreamWrite(
  427. uid2offset_[ref->ref_inst->unique_id()], stage_idx,
  428. {error_id, u_index_id, length_id, builder.GetUintConstantId(0)},
  429. &builder);
  430. else
  431. GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
  432. {error_id, u_index_id, length_id}, &builder);
  433. // Remember last invalid block id
  434. uint32_t last_invalid_blk_id = new_blk_ptr->GetLabelInst()->result_id();
  435. // Gen zero for invalid reference
  436. uint32_t ref_type_id = ref->ref_inst->type_id();
  437. (void)builder.AddBranch(merge_blk_id);
  438. new_blocks->push_back(std::move(new_blk_ptr));
  439. // Gen merge block
  440. new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
  441. builder.SetInsertPoint(&*new_blk_ptr);
  442. // Gen phi of new reference and zero, if necessary, and replace the
  443. // result id of the original reference with that of the Phi. Kill original
  444. // reference.
  445. if (new_ref_id != 0) {
  446. Instruction* phi_inst = builder.AddPhi(
  447. ref_type_id, {new_ref_id, valid_blk_id, GetNullId(ref_type_id),
  448. last_invalid_blk_id});
  449. context()->ReplaceAllUsesWith(ref->ref_inst->result_id(),
  450. phi_inst->result_id());
  451. }
  452. new_blocks->push_back(std::move(new_blk_ptr));
  453. context()->KillInst(ref->ref_inst);
  454. }
  455. void InstBindlessCheckPass::GenDescIdxCheckCode(
  456. BasicBlock::iterator ref_inst_itr,
  457. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  458. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  459. // Look for reference through indexed descriptor. If found, analyze and
  460. // save components. If not, return.
  461. ref_analysis ref;
  462. if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
  463. Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
  464. if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return;
  465. // If index and bound both compile-time constants and index < bound,
  466. // return without changing
  467. Instruction* var_inst = get_def_use_mgr()->GetDef(ref.var_id);
  468. Instruction* desc_type_inst = GetPointeeTypeInst(var_inst);
  469. uint32_t length_id = 0;
  470. if (desc_type_inst->opcode() == SpvOpTypeArray) {
  471. length_id =
  472. desc_type_inst->GetSingleWordInOperand(kSpvTypeArrayLengthIdInIdx);
  473. Instruction* index_inst = get_def_use_mgr()->GetDef(ref.desc_idx_id);
  474. Instruction* length_inst = get_def_use_mgr()->GetDef(length_id);
  475. if (index_inst->opcode() == SpvOpConstant &&
  476. length_inst->opcode() == SpvOpConstant &&
  477. index_inst->GetSingleWordInOperand(kSpvConstantValueInIdx) <
  478. length_inst->GetSingleWordInOperand(kSpvConstantValueInIdx))
  479. return;
  480. } else if (!desc_idx_enabled_ ||
  481. desc_type_inst->opcode() != SpvOpTypeRuntimeArray) {
  482. return;
  483. }
  484. // Move original block's preceding instructions into first new block
  485. std::unique_ptr<BasicBlock> new_blk_ptr;
  486. MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
  487. InstructionBuilder builder(
  488. context(), &*new_blk_ptr,
  489. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  490. new_blocks->push_back(std::move(new_blk_ptr));
  491. uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessBounds);
  492. // If length id not yet set, descriptor array is runtime size so
  493. // generate load of length from stage's debug input buffer.
  494. if (length_id == 0) {
  495. assert(desc_type_inst->opcode() == SpvOpTypeRuntimeArray &&
  496. "unexpected bindless type");
  497. length_id = GenDebugReadLength(ref.var_id, &builder);
  498. }
  499. // Generate full runtime bounds test code with true branch
  500. // being full reference and false branch being debug output and zero
  501. // for the referenced value.
  502. Instruction* ult_inst = builder.AddBinaryOp(GetBoolId(), SpvOpULessThan,
  503. ref.desc_idx_id, length_id);
  504. GenCheckCode(ult_inst->result_id(), error_id, 0u, length_id, stage_idx, &ref,
  505. new_blocks);
  506. // Move original block's remaining code into remainder/merge block and add
  507. // to new blocks
  508. BasicBlock* back_blk_ptr = &*new_blocks->back();
  509. MovePostludeCode(ref_block_itr, back_blk_ptr);
  510. }
  511. void InstBindlessCheckPass::GenDescInitCheckCode(
  512. BasicBlock::iterator ref_inst_itr,
  513. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  514. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  515. // Look for reference through descriptor. If not, return.
  516. ref_analysis ref;
  517. if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
  518. // Determine if we can only do initialization check
  519. bool init_check = false;
  520. if (ref.desc_load_id != 0 || !buffer_bounds_enabled_) {
  521. init_check = true;
  522. } else {
  523. // For now, only do bounds check for non-aggregate types. Otherwise
  524. // just do descriptor initialization check.
  525. // TODO(greg-lunarg): Do bounds check for aggregate loads and stores
  526. Instruction* ref_ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
  527. Instruction* pte_type_inst = GetPointeeTypeInst(ref_ptr_inst);
  528. uint32_t pte_type_op = pte_type_inst->opcode();
  529. if (pte_type_op == SpvOpTypeArray || pte_type_op == SpvOpTypeRuntimeArray ||
  530. pte_type_op == SpvOpTypeStruct)
  531. init_check = true;
  532. }
  533. // If initialization check and not enabled, return
  534. if (init_check && !desc_init_enabled_) return;
  535. // Move original block's preceding instructions into first new block
  536. std::unique_ptr<BasicBlock> new_blk_ptr;
  537. MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
  538. InstructionBuilder builder(
  539. context(), &*new_blk_ptr,
  540. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  541. new_blocks->push_back(std::move(new_blk_ptr));
  542. // If initialization check, use reference value of zero.
  543. // Else use the index of the last byte referenced.
  544. uint32_t ref_id = init_check ? builder.GetUintConstantId(0u)
  545. : GenLastByteIdx(&ref, &builder);
  546. // Read initialization/bounds from debug input buffer. If index id not yet
  547. // set, binding is single descriptor, so set index to constant 0.
  548. if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
  549. uint32_t init_id = GenDebugReadInit(ref.var_id, ref.desc_idx_id, &builder);
  550. // Generate runtime initialization/bounds test code with true branch
  551. // being full reference and false branch being debug output and zero
  552. // for the referenced value.
  553. Instruction* ult_inst =
  554. builder.AddBinaryOp(GetBoolId(), SpvOpULessThan, ref_id, init_id);
  555. uint32_t error =
  556. init_check ? kInstErrorBindlessUninit : kInstErrorBindlessBuffOOB;
  557. uint32_t error_id = builder.GetUintConstantId(error);
  558. GenCheckCode(ult_inst->result_id(), error_id, init_check ? 0 : ref_id,
  559. init_check ? builder.GetUintConstantId(0u) : init_id, stage_idx,
  560. &ref, new_blocks);
  561. // Move original block's remaining code into remainder/merge block and add
  562. // to new blocks
  563. BasicBlock* back_blk_ptr = &*new_blocks->back();
  564. MovePostludeCode(ref_block_itr, back_blk_ptr);
  565. }
  566. void InstBindlessCheckPass::InitializeInstBindlessCheck() {
  567. // Initialize base class
  568. InitializeInstrument();
  569. // If runtime array length support enabled, create variable mappings. Length
  570. // support is always enabled if descriptor init check is enabled.
  571. if (desc_idx_enabled_ || buffer_bounds_enabled_)
  572. for (auto& anno : get_module()->annotations())
  573. if (anno.opcode() == SpvOpDecorate) {
  574. if (anno.GetSingleWordInOperand(1u) == SpvDecorationDescriptorSet)
  575. var2desc_set_[anno.GetSingleWordInOperand(0u)] =
  576. anno.GetSingleWordInOperand(2u);
  577. else if (anno.GetSingleWordInOperand(1u) == SpvDecorationBinding)
  578. var2binding_[anno.GetSingleWordInOperand(0u)] =
  579. anno.GetSingleWordInOperand(2u);
  580. }
  581. }
  582. Pass::Status InstBindlessCheckPass::ProcessImpl() {
  583. // Perform bindless bounds check on each entry point function in module
  584. InstProcessFunction pfn =
  585. [this](BasicBlock::iterator ref_inst_itr,
  586. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  587. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  588. return GenDescIdxCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
  589. new_blocks);
  590. };
  591. bool modified = InstProcessEntryPointCallTree(pfn);
  592. if (desc_init_enabled_ || buffer_bounds_enabled_) {
  593. // Perform descriptor initialization check on each entry point function in
  594. // module
  595. pfn = [this](BasicBlock::iterator ref_inst_itr,
  596. UptrVectorIterator<BasicBlock> ref_block_itr,
  597. uint32_t stage_idx,
  598. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  599. return GenDescInitCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
  600. new_blocks);
  601. };
  602. modified |= InstProcessEntryPointCallTree(pfn);
  603. }
  604. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  605. }
  606. Pass::Status InstBindlessCheckPass::Process() {
  607. InitializeInstBindlessCheck();
  608. return ProcessImpl();
  609. }
  610. } // namespace opt
  611. } // namespace spvtools