graphics_robust_access_pass.cpp 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // This pass injects code in a graphics shader to implement guarantees
  15. // satisfying Vulkan's robustBufferAcces rules. Robust access rules permit
  16. // an out-of-bounds access to be redirected to an access of the same type
  17. // (load, store, etc.) but within the same root object.
  18. //
  19. // We assume baseline functionality in Vulkan, i.e. the module uses
  20. // logical addressing mode, without VK_KHR_variable_pointers.
  21. //
  22. // - Logical addressing mode implies:
  23. // - Each root pointer (a pointer that exists other than by the
  24. // execution of a shader instruction) is the result of an OpVariable.
  25. //
  26. // - Instructions that result in pointers are:
  27. // OpVariable
  28. // OpAccessChain
  29. // OpInBoundsAccessChain
  30. // OpFunctionParameter
  31. // OpImageTexelPointer
  32. // OpCopyObject
  33. //
  34. // - Instructions that use a pointer are:
  35. // OpLoad
  36. // OpStore
  37. // OpAccessChain
  38. // OpInBoundsAccessChain
  39. // OpFunctionCall
  40. // OpImageTexelPointer
  41. // OpCopyMemory
  42. // OpCopyObject
  43. // all OpAtomic* instructions
  44. //
  45. // We classify pointer-users into:
  46. // - Accesses:
  47. // - OpLoad
  48. // - OpStore
  49. // - OpAtomic*
  50. // - OpCopyMemory
  51. //
  52. // - Address calculations:
  53. // - OpAccessChain
  54. // - OpInBoundsAccessChain
  55. //
  56. // - Pass-through:
  57. // - OpFunctionCall
  58. // - OpFunctionParameter
  59. // - OpCopyObject
  60. //
  61. // The strategy is:
  62. //
  63. // - Handle only logical addressing mode. In particular, don't handle a module
  64. // if it uses one of the variable-pointers capabilities.
  65. //
  66. // - Don't handle modules using capability RuntimeDescriptorArrayEXT. So the
  67. // only runtime arrays are those that are the last member in a
  68. // Block-decorated struct. This allows us to feasibly/easily compute the
  69. // length of the runtime array. See below.
  70. //
  71. // - The memory locations accessed by OpLoad, OpStore, OpCopyMemory, and
  72. // OpAtomic* are determined by their pointer parameter or parameters.
  73. // Pointers are always (correctly) typed and so the address and number of
  74. // consecutive locations are fully determined by the pointer.
  75. //
  76. // - A pointer value orginates as one of few cases:
  77. //
  78. // - OpVariable for an interface object or an array of them: image,
  79. // buffer (UBO or SSBO), sampler, sampled-image, push-constant, input
  80. // variable, output variable. The execution environment is responsible for
  81. // allocating the correct amount of storage for these, and for ensuring
  82. // each resource bound to such a variable is big enough to contain the
  83. // SPIR-V pointee type of the variable.
  84. //
  85. // - OpVariable for a non-interface object. These are variables in
  86. // Workgroup, Private, and Function storage classes. The compiler ensures
  87. // the underlying allocation is big enough to store the entire SPIR-V
  88. // pointee type of the variable.
  89. //
  90. // - An OpFunctionParameter. This always maps to a pointer parameter to an
  91. // OpFunctionCall.
  92. //
  93. // - In logical addressing mode, these are severely limited:
  94. // "Any pointer operand to an OpFunctionCall must be:
  95. // - a memory object declaration, or
  96. // - a pointer to an element in an array that is a memory object
  97. // declaration, where the element type is OpTypeSampler or OpTypeImage"
  98. //
  99. // - This has an important simplifying consequence:
  100. //
  101. // - When looking for a pointer to the structure containing a runtime
  102. // array, you begin with a pointer to the runtime array and trace
  103. // backward in the function. You never have to trace back beyond
  104. // your function call boundary. So you can't take a partial access
  105. // chain into an SSBO, then pass that pointer into a function. So
  106. // we don't resort to using fat pointers to compute array length.
  107. // We can trace back to a pointer to the containing structure,
  108. // and use that in an OpArrayLength instruction. (The structure type
  109. // gives us the member index of the runtime array.)
  110. //
  111. // - Otherwise, the pointer type fully encodes the range of valid
  112. // addresses. In particular, the type of a pointer to an aggregate
  113. // value fully encodes the range of indices when indexing into
  114. // that aggregate.
  115. //
  116. // - The pointer is the result of an access chain instruction. We clamp
  117. // indices contributing to address calculations. As noted above, the
  118. // valid ranges are either bound by the length of a runtime array, or
  119. // by the type of the base pointer. The length of a runtime array is
  120. // the result of an OpArrayLength instruction acting on the pointer of
  121. // the containing structure as noted above.
  122. //
  123. // - TODO(dneto): OpImageTexelPointer:
  124. // - Clamp coordinate to the image size returned by OpImageQuerySize
  125. // - If multi-sampled, clamp the sample index to the count returned by
  126. // OpImageQuerySamples.
  127. // - If not multi-sampled, set the sample index to 0.
  128. //
  129. // - Rely on the external validator to check that pointers are only
  130. // used by the instructions as above.
  131. //
  132. // - Handles OpTypeRuntimeArray
  133. // Track pointer back to original resource (pointer to struct), so we can
  134. // query the runtime array size.
  135. //
  136. #include "graphics_robust_access_pass.h"
  137. #include <algorithm>
  138. #include <cstring>
  139. #include <functional>
  140. #include <initializer_list>
  141. #include <utility>
  142. #include "constants.h"
  143. #include "def_use_manager.h"
  144. #include "function.h"
  145. #include "ir_context.h"
  146. #include "module.h"
  147. #include "pass.h"
  148. #include "source/diagnostic.h"
  149. #include "source/util/make_unique.h"
  150. #include "spirv-tools/libspirv.h"
  151. #include "spirv/unified1/GLSL.std.450.h"
  152. #include "spirv/unified1/spirv.h"
  153. #include "type_manager.h"
  154. #include "types.h"
  155. namespace spvtools {
  156. namespace opt {
  157. using opt::BasicBlock;
  158. using opt::Instruction;
  159. using opt::Operand;
  160. using spvtools::MakeUnique;
  161. GraphicsRobustAccessPass::GraphicsRobustAccessPass() : module_status_() {}
  162. Pass::Status GraphicsRobustAccessPass::Process() {
  163. module_status_ = PerModuleState();
  164. ProcessCurrentModule();
  165. auto result = module_status_.failed
  166. ? Status::Failure
  167. : (module_status_.modified ? Status::SuccessWithChange
  168. : Status::SuccessWithoutChange);
  169. return result;
  170. }
  171. spvtools::DiagnosticStream GraphicsRobustAccessPass::Fail() {
  172. module_status_.failed = true;
  173. // We don't really have a position, and we'll ignore the result.
  174. return std::move(
  175. spvtools::DiagnosticStream({}, consumer(), "", SPV_ERROR_INVALID_BINARY)
  176. << name() << ": ");
  177. }
  178. spv_result_t GraphicsRobustAccessPass::IsCompatibleModule() {
  179. auto* feature_mgr = context()->get_feature_mgr();
  180. if (!feature_mgr->HasCapability(SpvCapabilityShader))
  181. return Fail() << "Can only process Shader modules";
  182. if (feature_mgr->HasCapability(SpvCapabilityVariablePointers))
  183. return Fail() << "Can't process modules with VariablePointers capability";
  184. if (feature_mgr->HasCapability(SpvCapabilityVariablePointersStorageBuffer))
  185. return Fail() << "Can't process modules with VariablePointersStorageBuffer "
  186. "capability";
  187. if (feature_mgr->HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
  188. // These have a RuntimeArray outside of Block-decorated struct. There
  189. // is no way to compute the array length from within SPIR-V.
  190. return Fail() << "Can't process modules with RuntimeDescriptorArrayEXT "
  191. "capability";
  192. }
  193. {
  194. auto* inst = context()->module()->GetMemoryModel();
  195. const auto addressing_model = inst->GetSingleWordOperand(0);
  196. if (addressing_model != SpvAddressingModelLogical)
  197. return Fail() << "Addressing model must be Logical. Found "
  198. << inst->PrettyPrint();
  199. }
  200. return SPV_SUCCESS;
  201. }
  202. spv_result_t GraphicsRobustAccessPass::ProcessCurrentModule() {
  203. auto err = IsCompatibleModule();
  204. if (err != SPV_SUCCESS) return err;
  205. ProcessFunction fn = [this](opt::Function* f) { return ProcessAFunction(f); };
  206. module_status_.modified |= context()->ProcessReachableCallTree(fn);
  207. // Need something here. It's the price we pay for easier failure paths.
  208. return SPV_SUCCESS;
  209. }
  210. bool GraphicsRobustAccessPass::ProcessAFunction(opt::Function* function) {
  211. // Ensure that all pointers computed inside a function are within bounds.
  212. // Find the access chains in this block before trying to modify them.
  213. std::vector<Instruction*> access_chains;
  214. std::vector<Instruction*> image_texel_pointers;
  215. for (auto& block : *function) {
  216. for (auto& inst : block) {
  217. switch (inst.opcode()) {
  218. case SpvOpAccessChain:
  219. case SpvOpInBoundsAccessChain:
  220. access_chains.push_back(&inst);
  221. break;
  222. case SpvOpImageTexelPointer:
  223. image_texel_pointers.push_back(&inst);
  224. break;
  225. default:
  226. break;
  227. }
  228. }
  229. }
  230. for (auto* inst : access_chains) {
  231. ClampIndicesForAccessChain(inst);
  232. }
  233. for (auto* inst : image_texel_pointers) {
  234. if (SPV_SUCCESS != ClampCoordinateForImageTexelPointer(inst)) break;
  235. }
  236. return module_status_.modified;
  237. }
  238. void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
  239. Instruction* access_chain) {
  240. Instruction& inst = *access_chain;
  241. auto* constant_mgr = context()->get_constant_mgr();
  242. auto* def_use_mgr = context()->get_def_use_mgr();
  243. auto* type_mgr = context()->get_type_mgr();
  244. // Replaces one of the OpAccessChain index operands with a new value.
  245. // Updates def-use analysis.
  246. auto replace_index = [&inst, def_use_mgr](uint32_t operand_index,
  247. Instruction* new_value) {
  248. inst.SetOperand(operand_index, {new_value->result_id()});
  249. def_use_mgr->AnalyzeInstUse(&inst);
  250. };
  251. // Replaces one of the OpAccesssChain index operands with a clamped value.
  252. // Replace the operand at |operand_index| with the value computed from
  253. // unsigned_clamp(%old_value, %min_value, %max_value). It also analyzes
  254. // the new instruction and records that them module is modified.
  255. auto clamp_index = [&inst, this, &replace_index](
  256. uint32_t operand_index, Instruction* old_value,
  257. Instruction* min_value, Instruction* max_value) {
  258. auto* clamp_inst = MakeClampInst(old_value, min_value, max_value, &inst);
  259. replace_index(operand_index, clamp_inst);
  260. };
  261. // Ensures the specified index of access chain |inst| has a value that is
  262. // at most |count| - 1. If the index is already a constant value less than
  263. // |count| then no change is made.
  264. auto clamp_to_literal_count = [&inst, this, &constant_mgr, &type_mgr,
  265. &replace_index, &clamp_index](
  266. uint32_t operand_index, uint64_t count) {
  267. Instruction* index_inst =
  268. this->GetDef(inst.GetSingleWordOperand(operand_index));
  269. const auto* index_type =
  270. type_mgr->GetType(index_inst->type_id())->AsInteger();
  271. assert(index_type);
  272. if (count <= 1) {
  273. // Replace the index with 0.
  274. replace_index(operand_index, GetValueForType(0, index_type));
  275. return;
  276. }
  277. const auto index_width = index_type->width();
  278. // If the index is a constant then |index_constant| will not be a null
  279. // pointer. (If index is an |OpConstantNull| then it |index_constant| will
  280. // not be a null pointer.) Since access chain indices must be scalar
  281. // integers, this can't be a spec constant.
  282. if (auto* index_constant = constant_mgr->GetConstantFromInst(index_inst)) {
  283. auto* int_index_constant = index_constant->AsIntConstant();
  284. int64_t value = 0;
  285. // OpAccessChain indices are treated as signed. So get the signed
  286. // constant value here.
  287. if (index_width <= 32) {
  288. value = int64_t(int_index_constant->GetS32BitValue());
  289. } else if (index_width <= 64) {
  290. value = int_index_constant->GetS64BitValue();
  291. } else {
  292. this->Fail() << "Can't handle indices wider than 64 bits, found "
  293. "constant index with "
  294. << index_type->width() << "bits";
  295. return;
  296. }
  297. if (value < 0) {
  298. replace_index(operand_index, GetValueForType(0, index_type));
  299. } else if (uint64_t(value) < count) {
  300. // Nothing to do.
  301. return;
  302. } else {
  303. // Replace with count - 1.
  304. assert(count > 0); // Already took care of this case above.
  305. replace_index(operand_index, GetValueForType(count - 1, index_type));
  306. }
  307. } else {
  308. // Generate a clamp instruction.
  309. // Compute the bit width of a viable type to hold (count-1).
  310. const auto maxval = count - 1;
  311. const auto* maxval_type = index_type;
  312. // Look for a bit width, up to 64 bits wide, to fit maxval.
  313. uint32_t maxval_width = index_width;
  314. while ((maxval_width < 64) && (0 != (maxval >> maxval_width))) {
  315. maxval_width *= 2;
  316. }
  317. // Widen the index value if necessary
  318. if (maxval_width > index_width) {
  319. // Find the wider type. We only need this case if a constant (array)
  320. // bound is too big. This never requires us to *add* a capability
  321. // declaration for Int64 because the existence of the array bound would
  322. // already have required that declaration.
  323. index_inst = WidenInteger(index_type->IsSigned(), maxval_width,
  324. index_inst, &inst);
  325. maxval_type = type_mgr->GetType(index_inst->type_id())->AsInteger();
  326. }
  327. // Finally, clamp the index.
  328. clamp_index(operand_index, index_inst, GetValueForType(0, maxval_type),
  329. GetValueForType(maxval, maxval_type));
  330. }
  331. };
  332. // Ensures the specified index of access chain |inst| has a value that is at
  333. // most the value of |count_inst| minus 1, where |count_inst| is treated as an
  334. // unsigned integer.
  335. auto clamp_to_count = [&inst, this, &constant_mgr, &clamp_to_literal_count,
  336. &clamp_index, &type_mgr](uint32_t operand_index,
  337. Instruction* count_inst) {
  338. Instruction* index_inst =
  339. this->GetDef(inst.GetSingleWordOperand(operand_index));
  340. const auto* index_type =
  341. type_mgr->GetType(index_inst->type_id())->AsInteger();
  342. const auto* count_type =
  343. type_mgr->GetType(count_inst->type_id())->AsInteger();
  344. assert(index_type);
  345. if (const auto* count_constant =
  346. constant_mgr->GetConstantFromInst(count_inst)) {
  347. uint64_t value = 0;
  348. const auto width = count_constant->type()->AsInteger()->width();
  349. if (width <= 32) {
  350. value = count_constant->AsIntConstant()->GetU32BitValue();
  351. } else if (width <= 64) {
  352. value = count_constant->AsIntConstant()->GetU64BitValue();
  353. } else {
  354. this->Fail() << "Can't handle indices wider than 64 bits, found "
  355. "constant index with "
  356. << index_type->width() << "bits";
  357. return;
  358. }
  359. clamp_to_literal_count(operand_index, value);
  360. } else {
  361. // Widen them to the same width.
  362. const auto index_width = index_type->width();
  363. const auto count_width = count_type->width();
  364. const auto target_width = std::max(index_width, count_width);
  365. // UConvert requires the result type to have 0 signedness. So enforce
  366. // that here.
  367. auto* wider_type = index_width < count_width ? count_type : index_type;
  368. if (index_type->width() < target_width) {
  369. // Access chain indices are treated as signed integers.
  370. index_inst = WidenInteger(true, target_width, index_inst, &inst);
  371. } else if (count_type->width() < target_width) {
  372. // Assume type sizes are treated as unsigned.
  373. count_inst = WidenInteger(false, target_width, count_inst, &inst);
  374. }
  375. // Compute count - 1.
  376. // It doesn't matter if 1 is signed or unsigned.
  377. auto* one = GetValueForType(1, wider_type);
  378. auto* count_minus_1 = InsertInst(
  379. &inst, SpvOpISub, type_mgr->GetId(wider_type), TakeNextId(),
  380. {{SPV_OPERAND_TYPE_ID, {count_inst->result_id()}},
  381. {SPV_OPERAND_TYPE_ID, {one->result_id()}}});
  382. clamp_index(operand_index, index_inst, GetValueForType(0, wider_type),
  383. count_minus_1);
  384. }
  385. };
  386. const Instruction* base_inst = GetDef(inst.GetSingleWordInOperand(0));
  387. const Instruction* base_type = GetDef(base_inst->type_id());
  388. Instruction* pointee_type = GetDef(base_type->GetSingleWordInOperand(1));
  389. // Walk the indices from earliest to latest, replacing indices with a
  390. // clamped value, and updating the pointee_type. The order matters for
  391. // the case when we have to compute the length of a runtime array. In
  392. // that the algorithm relies on the fact that that the earlier indices
  393. // have already been clamped.
  394. const uint32_t num_operands = inst.NumOperands();
  395. for (uint32_t idx = 3; !module_status_.failed && idx < num_operands; ++idx) {
  396. const uint32_t index_id = inst.GetSingleWordOperand(idx);
  397. Instruction* index_inst = GetDef(index_id);
  398. switch (pointee_type->opcode()) {
  399. case SpvOpTypeMatrix: // Use column count
  400. case SpvOpTypeVector: // Use component count
  401. {
  402. const uint32_t count = pointee_type->GetSingleWordOperand(2);
  403. clamp_to_literal_count(idx, count);
  404. pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
  405. } break;
  406. case SpvOpTypeArray: {
  407. // The array length can be a spec constant, so go through the general
  408. // case.
  409. Instruction* array_len = GetDef(pointee_type->GetSingleWordOperand(2));
  410. clamp_to_count(idx, array_len);
  411. pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
  412. } break;
  413. case SpvOpTypeStruct: {
  414. // SPIR-V requires the index to be an OpConstant.
  415. // We need to know the index literal value so we can compute the next
  416. // pointee type.
  417. if (index_inst->opcode() != SpvOpConstant ||
  418. !constant_mgr->GetConstantFromInst(index_inst)
  419. ->type()
  420. ->AsInteger()) {
  421. Fail() << "Member index into struct is not a constant integer: "
  422. << index_inst->PrettyPrint(
  423. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
  424. << "\nin access chain: "
  425. << inst.PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  426. return;
  427. }
  428. const auto num_members = pointee_type->NumInOperands();
  429. const auto* index_constant =
  430. constant_mgr->GetConstantFromInst(index_inst);
  431. // Get the sign-extended value, since access index is always treated as
  432. // signed.
  433. const auto index_value = index_constant->GetSignExtendedValue();
  434. if (index_value < 0 || index_value >= num_members) {
  435. Fail() << "Member index " << index_value
  436. << " is out of bounds for struct type: "
  437. << pointee_type->PrettyPrint(
  438. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
  439. << "\nin access chain: "
  440. << inst.PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  441. return;
  442. }
  443. pointee_type = GetDef(pointee_type->GetSingleWordInOperand(
  444. static_cast<uint32_t>(index_value)));
  445. // No need to clamp this index. We just checked that it's valid.
  446. } break;
  447. case SpvOpTypeRuntimeArray: {
  448. auto* array_len = MakeRuntimeArrayLengthInst(&inst, idx);
  449. if (!array_len) { // We've already signaled an error.
  450. return;
  451. }
  452. clamp_to_count(idx, array_len);
  453. pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
  454. } break;
  455. default:
  456. Fail() << " Unhandled pointee type for access chain "
  457. << pointee_type->PrettyPrint(
  458. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  459. }
  460. }
  461. }
  462. uint32_t GraphicsRobustAccessPass::GetGlslInsts() {
  463. if (module_status_.glsl_insts_id == 0) {
  464. // This string serves double-duty as raw data for a string and for a vector
  465. // of 32-bit words
  466. const char glsl[] = "GLSL.std.450\0\0\0\0";
  467. const size_t glsl_str_byte_len = 16;
  468. // Use an existing import if we can.
  469. for (auto& inst : context()->module()->ext_inst_imports()) {
  470. const auto& name_words = inst.GetInOperand(0).words;
  471. if (0 == std::strncmp(reinterpret_cast<const char*>(name_words.data()),
  472. glsl, glsl_str_byte_len)) {
  473. module_status_.glsl_insts_id = inst.result_id();
  474. }
  475. }
  476. if (module_status_.glsl_insts_id == 0) {
  477. // Make a new import instruction.
  478. module_status_.glsl_insts_id = TakeNextId();
  479. std::vector<uint32_t> words(glsl_str_byte_len / sizeof(uint32_t));
  480. std::memcpy(words.data(), glsl, glsl_str_byte_len);
  481. auto import_inst = MakeUnique<Instruction>(
  482. context(), SpvOpExtInstImport, 0, module_status_.glsl_insts_id,
  483. std::initializer_list<Operand>{
  484. Operand{SPV_OPERAND_TYPE_LITERAL_STRING, std::move(words)}});
  485. Instruction* inst = import_inst.get();
  486. context()->module()->AddExtInstImport(std::move(import_inst));
  487. module_status_.modified = true;
  488. context()->AnalyzeDefUse(inst);
  489. // Reanalyze the feature list, since we added an extended instruction
  490. // set improt.
  491. context()->get_feature_mgr()->Analyze(context()->module());
  492. }
  493. }
  494. return module_status_.glsl_insts_id;
  495. }
  496. opt::Instruction* opt::GraphicsRobustAccessPass::GetValueForType(
  497. uint64_t value, const analysis::Integer* type) {
  498. auto* mgr = context()->get_constant_mgr();
  499. assert(type->width() <= 64);
  500. std::vector<uint32_t> words;
  501. words.push_back(uint32_t(value));
  502. if (type->width() > 32) {
  503. words.push_back(uint32_t(value >> 32u));
  504. }
  505. const auto* constant = mgr->GetConstant(type, words);
  506. return mgr->GetDefiningInstruction(
  507. constant, context()->get_type_mgr()->GetTypeInstruction(type));
  508. }
  509. opt::Instruction* opt::GraphicsRobustAccessPass::WidenInteger(
  510. bool sign_extend, uint32_t bit_width, Instruction* value,
  511. Instruction* before_inst) {
  512. analysis::Integer unsigned_type_for_query(bit_width, false);
  513. auto* type_mgr = context()->get_type_mgr();
  514. auto* unsigned_type = type_mgr->GetRegisteredType(&unsigned_type_for_query);
  515. auto type_id = context()->get_type_mgr()->GetId(unsigned_type);
  516. auto conversion_id = TakeNextId();
  517. auto* conversion = InsertInst(
  518. before_inst, (sign_extend ? SpvOpSConvert : SpvOpUConvert), type_id,
  519. conversion_id, {{SPV_OPERAND_TYPE_ID, {value->result_id()}}});
  520. return conversion;
  521. }
  522. Instruction* GraphicsRobustAccessPass::MakeClampInst(Instruction* x,
  523. Instruction* min,
  524. Instruction* max,
  525. Instruction* where) {
  526. // Get IDs of instructions we'll be referencing. Evaluate them before calling
  527. // the function so we force a deterministic ordering in case both of them need
  528. // to take a new ID.
  529. const uint32_t glsl_insts_id = GetGlslInsts();
  530. uint32_t clamp_id = TakeNextId();
  531. assert(x->type_id() == min->type_id());
  532. assert(x->type_id() == max->type_id());
  533. auto* clamp_inst = InsertInst(
  534. where, SpvOpExtInst, x->type_id(), clamp_id,
  535. {
  536. {SPV_OPERAND_TYPE_ID, {glsl_insts_id}},
  537. {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {GLSLstd450UClamp}},
  538. {SPV_OPERAND_TYPE_ID, {x->result_id()}},
  539. {SPV_OPERAND_TYPE_ID, {min->result_id()}},
  540. {SPV_OPERAND_TYPE_ID, {max->result_id()}},
  541. });
  542. return clamp_inst;
  543. }
  544. Instruction* GraphicsRobustAccessPass::MakeRuntimeArrayLengthInst(
  545. Instruction* access_chain, uint32_t operand_index) {
  546. // The Index parameter to the access chain at |operand_index| is indexing
  547. // *into* the runtime-array. To get the number of elements in the runtime
  548. // array we need a pointer to the Block-decorated struct that contains the
  549. // runtime array. So conceptually we have to go 2 steps backward in the
  550. // access chain. The two steps backward might forces us to traverse backward
  551. // across multiple dominating instructions.
  552. auto* type_mgr = context()->get_type_mgr();
  553. // How many access chain indices do we have to unwind to find the pointer
  554. // to the struct containing the runtime array?
  555. uint32_t steps_remaining = 2;
  556. // Find or create an instruction computing the pointer to the structure
  557. // containing the runtime array.
  558. // Walk backward through pointer address calculations until we either get
  559. // to exactly the right base pointer, or to an access chain instruction
  560. // that we can replicate but truncate to compute the address of the right
  561. // struct.
  562. Instruction* current_access_chain = access_chain;
  563. Instruction* pointer_to_containing_struct = nullptr;
  564. while (steps_remaining > 0) {
  565. switch (current_access_chain->opcode()) {
  566. case SpvOpCopyObject:
  567. // Whoops. Walk right through this one.
  568. current_access_chain =
  569. GetDef(current_access_chain->GetSingleWordInOperand(0));
  570. break;
  571. case SpvOpAccessChain:
  572. case SpvOpInBoundsAccessChain: {
  573. const int first_index_operand = 3;
  574. // How many indices in this access chain contribute to getting us
  575. // to an element in the runtime array?
  576. const auto num_contributing_indices =
  577. current_access_chain == access_chain
  578. ? operand_index - (first_index_operand - 1)
  579. : current_access_chain->NumInOperands() - 1 /* skip the base */;
  580. Instruction* base =
  581. GetDef(current_access_chain->GetSingleWordInOperand(0));
  582. if (num_contributing_indices == steps_remaining) {
  583. // The base pointer points to the structure.
  584. pointer_to_containing_struct = base;
  585. steps_remaining = 0;
  586. break;
  587. } else if (num_contributing_indices < steps_remaining) {
  588. // Peel off the index and keep going backward.
  589. steps_remaining -= num_contributing_indices;
  590. current_access_chain = base;
  591. } else {
  592. // This access chain has more indices than needed. Generate a new
  593. // access chain instruction, but truncating the list of indices.
  594. const int base_operand = 2;
  595. // We'll use the base pointer and the indices up to but not including
  596. // the one indexing into the runtime array.
  597. Instruction::OperandList ops;
  598. // Use the base pointer
  599. ops.push_back(current_access_chain->GetOperand(base_operand));
  600. const uint32_t num_indices_to_keep =
  601. num_contributing_indices - steps_remaining - 1;
  602. for (uint32_t i = 0; i <= num_indices_to_keep; i++) {
  603. ops.push_back(
  604. current_access_chain->GetOperand(first_index_operand + i));
  605. }
  606. // Compute the type of the result of the new access chain. Start at
  607. // the base and walk the indices in a forward direction.
  608. auto* constant_mgr = context()->get_constant_mgr();
  609. std::vector<uint32_t> indices_for_type;
  610. for (uint32_t i = 0; i < ops.size() - 1; i++) {
  611. uint32_t index_for_type_calculation = 0;
  612. Instruction* index =
  613. GetDef(current_access_chain->GetSingleWordOperand(
  614. first_index_operand + i));
  615. if (auto* index_constant =
  616. constant_mgr->GetConstantFromInst(index)) {
  617. // We only need 32 bits. For the type calculation, it's sufficient
  618. // to take the zero-extended value. It only matters for the struct
  619. // case, and struct member indices are unsigned.
  620. index_for_type_calculation =
  621. uint32_t(index_constant->GetZeroExtendedValue());
  622. } else {
  623. // Indexing into a variably-sized thing like an array. Use 0.
  624. index_for_type_calculation = 0;
  625. }
  626. indices_for_type.push_back(index_for_type_calculation);
  627. }
  628. auto* base_ptr_type = type_mgr->GetType(base->type_id())->AsPointer();
  629. auto* base_pointee_type = base_ptr_type->pointee_type();
  630. auto* new_access_chain_result_pointee_type =
  631. type_mgr->GetMemberType(base_pointee_type, indices_for_type);
  632. const uint32_t new_access_chain_type_id = type_mgr->FindPointerToType(
  633. type_mgr->GetId(new_access_chain_result_pointee_type),
  634. base_ptr_type->storage_class());
  635. // Create the instruction and insert it.
  636. const auto new_access_chain_id = TakeNextId();
  637. auto* new_access_chain =
  638. InsertInst(current_access_chain, current_access_chain->opcode(),
  639. new_access_chain_type_id, new_access_chain_id, ops);
  640. pointer_to_containing_struct = new_access_chain;
  641. steps_remaining = 0;
  642. break;
  643. }
  644. } break;
  645. default:
  646. Fail() << "Unhandled access chain in logical addressing mode passes "
  647. "through "
  648. << current_access_chain->PrettyPrint(
  649. SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET |
  650. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  651. return nullptr;
  652. }
  653. }
  654. assert(pointer_to_containing_struct);
  655. auto* pointee_type =
  656. type_mgr->GetType(pointer_to_containing_struct->type_id())
  657. ->AsPointer()
  658. ->pointee_type();
  659. auto* struct_type = pointee_type->AsStruct();
  660. const uint32_t member_index_of_runtime_array =
  661. uint32_t(struct_type->element_types().size() - 1);
  662. // Create the length-of-array instruction before the original access chain,
  663. // but after the generation of the pointer to the struct.
  664. const auto array_len_id = TakeNextId();
  665. analysis::Integer uint_type_for_query(32, false);
  666. auto* uint_type = type_mgr->GetRegisteredType(&uint_type_for_query);
  667. auto* array_len = InsertInst(
  668. access_chain, SpvOpArrayLength, type_mgr->GetId(uint_type), array_len_id,
  669. {{SPV_OPERAND_TYPE_ID, {pointer_to_containing_struct->result_id()}},
  670. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {member_index_of_runtime_array}}});
  671. return array_len;
  672. }
  673. spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
  674. opt::Instruction* image_texel_pointer) {
  675. // TODO(dneto): Write tests for this code.
  676. return SPV_SUCCESS;
  677. // Example:
  678. // %texel_ptr = OpImageTexelPointer %texel_ptr_type %image_ptr %coord
  679. // %sample
  680. //
  681. // We want to clamp %coord components between vector-0 and the result
  682. // of OpImageQuerySize acting on the underlying image. So insert:
  683. // %image = OpLoad %image_type %image_ptr
  684. // %query_size = OpImageQuerySize %query_size_type %image
  685. //
  686. // For a multi-sampled image, %sample is the sample index, and we need
  687. // to clamp it between zero and the number of samples in the image.
  688. // %sample_count = OpImageQuerySamples %uint %image
  689. // %max_sample_index = OpISub %uint %sample_count %uint_1
  690. // For non-multi-sampled images, the sample index must be constant zero.
  691. auto* def_use_mgr = context()->get_def_use_mgr();
  692. auto* type_mgr = context()->get_type_mgr();
  693. auto* constant_mgr = context()->get_constant_mgr();
  694. auto* image_ptr = GetDef(image_texel_pointer->GetSingleWordInOperand(0));
  695. auto* image_ptr_type = GetDef(image_ptr->type_id());
  696. auto image_type_id = image_ptr_type->GetSingleWordInOperand(1);
  697. auto* image_type = GetDef(image_type_id);
  698. auto* coord = GetDef(image_texel_pointer->GetSingleWordInOperand(1));
  699. auto* samples = GetDef(image_texel_pointer->GetSingleWordInOperand(2));
  700. // We will modify the module, at least by adding image query instructions.
  701. module_status_.modified = true;
  702. // Declare the ImageQuery capability if the module doesn't already have it.
  703. auto* feature_mgr = context()->get_feature_mgr();
  704. if (!feature_mgr->HasCapability(SpvCapabilityImageQuery)) {
  705. auto cap = MakeUnique<Instruction>(
  706. context(), SpvOpCapability, 0, 0,
  707. std::initializer_list<Operand>{
  708. {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityImageQuery}}});
  709. def_use_mgr->AnalyzeInstDefUse(cap.get());
  710. context()->AddCapability(std::move(cap));
  711. feature_mgr->Analyze(context()->module());
  712. }
  713. // OpImageTexelPointer is used to translate a coordinate and sample index
  714. // into an address for use with an atomic operation. That is, it may only
  715. // used with what Vulkan calls a "storage image"
  716. // (OpTypeImage parameter Sampled=2).
  717. // Note: A storage image never has a level-of-detail associated with it.
  718. // Constraints on the sample id:
  719. // - Only 2D images can be multi-sampled: OpTypeImage parameter MS=1
  720. // only if Dim=2D.
  721. // - Non-multi-sampled images (OpTypeImage parameter MS=0) must use
  722. // sample ID to a constant 0.
  723. // The coordinate is treated as unsigned, and should be clamped against the
  724. // image "size", returned by OpImageQuerySize. (Note: OpImageQuerySizeLod
  725. // is only usable with a sampled image, i.e. its image type has Sampled=1).
  726. // Determine the result type for the OpImageQuerySize.
  727. // For non-arrayed images:
  728. // non-Cube:
  729. // - Always the same as the coordinate type
  730. // Cube:
  731. // - Use all but the last component of the coordinate (which is the face
  732. // index from 0 to 5).
  733. // For arrayed images (in Vulkan the Dim is 1D, 2D, or Cube):
  734. // non-Cube:
  735. // - A vector with the components in the coordinate, and one more for
  736. // the layer index.
  737. // Cube:
  738. // - The same as the coordinate type: 3-element integer vector.
  739. // - The third component from the size query is the layer count.
  740. // - The third component in the texel pointer calculation is
  741. // 6 * layer + face, where 0 <= face < 6.
  742. // Cube: Use all but the last component of the coordinate (which is the face
  743. // index from 0 to 5).
  744. const auto dim = SpvDim(image_type->GetSingleWordInOperand(1));
  745. const bool arrayed = image_type->GetSingleWordInOperand(3) == 1;
  746. const bool multisampled = image_type->GetSingleWordInOperand(4) != 0;
  747. const auto query_num_components = [dim, arrayed, this]() -> int {
  748. const int arrayness_bonus = arrayed ? 1 : 0;
  749. int num_coords = 0;
  750. switch (dim) {
  751. case SpvDimBuffer:
  752. case SpvDim1D:
  753. num_coords = 1;
  754. break;
  755. case SpvDimCube:
  756. // For cube, we need bounds for x, y, but not face.
  757. case SpvDimRect:
  758. case SpvDim2D:
  759. num_coords = 2;
  760. break;
  761. case SpvDim3D:
  762. num_coords = 3;
  763. break;
  764. case SpvDimSubpassData:
  765. case SpvDimMax:
  766. return Fail() << "Invalid image dimension for OpImageTexelPointer: "
  767. << int(dim);
  768. break;
  769. }
  770. return num_coords + arrayness_bonus;
  771. }();
  772. const auto* coord_component_type = [type_mgr, coord]() {
  773. const analysis::Type* coord_type = type_mgr->GetType(coord->type_id());
  774. if (auto* vector_type = coord_type->AsVector()) {
  775. return vector_type->element_type()->AsInteger();
  776. }
  777. return coord_type->AsInteger();
  778. }();
  779. // For now, only handle 32-bit case for coordinates.
  780. if (!coord_component_type) {
  781. return Fail() << " Coordinates for OpImageTexelPointer are not integral: "
  782. << image_texel_pointer->PrettyPrint(
  783. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  784. }
  785. if (coord_component_type->width() != 32) {
  786. return Fail() << " Expected OpImageTexelPointer coordinate components to "
  787. "be 32-bits wide. They are "
  788. << coord_component_type->width() << " bits. "
  789. << image_texel_pointer->PrettyPrint(
  790. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  791. }
  792. const auto* query_size_type =
  793. [type_mgr, coord_component_type,
  794. query_num_components]() -> const analysis::Type* {
  795. if (query_num_components == 1) return coord_component_type;
  796. analysis::Vector proposed(coord_component_type, query_num_components);
  797. return type_mgr->GetRegisteredType(&proposed);
  798. }();
  799. const uint32_t image_id = TakeNextId();
  800. auto* image =
  801. InsertInst(image_texel_pointer, SpvOpLoad, image_type_id, image_id,
  802. {{SPV_OPERAND_TYPE_ID, {image_ptr->result_id()}}});
  803. const uint32_t query_size_id = TakeNextId();
  804. auto* query_size =
  805. InsertInst(image_texel_pointer, SpvOpImageQuerySize,
  806. type_mgr->GetTypeInstruction(query_size_type), query_size_id,
  807. {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
  808. auto* component_1 = constant_mgr->GetConstant(coord_component_type, {1});
  809. const uint32_t component_1_id =
  810. constant_mgr->GetDefiningInstruction(component_1)->result_id();
  811. auto* component_0 = constant_mgr->GetConstant(coord_component_type, {0});
  812. const uint32_t component_0_id =
  813. constant_mgr->GetDefiningInstruction(component_0)->result_id();
  814. // If the image is a cube array, then the last component of the queried
  815. // size is the layer count. In the query, we have to accomodate folding
  816. // in the face index ranging from 0 through 5. The inclusive upper bound
  817. // on the third coordinate therefore is multiplied by 6.
  818. auto* query_size_including_faces = query_size;
  819. if (arrayed && (dim == SpvDimCube)) {
  820. // Multiply the last coordinate by 6.
  821. auto* component_6 = constant_mgr->GetConstant(coord_component_type, {6});
  822. const uint32_t component_6_id =
  823. constant_mgr->GetDefiningInstruction(component_6)->result_id();
  824. assert(query_num_components == 3);
  825. auto* multiplicand = constant_mgr->GetConstant(
  826. query_size_type, {component_1_id, component_1_id, component_6_id});
  827. auto* multiplicand_inst =
  828. constant_mgr->GetDefiningInstruction(multiplicand);
  829. const auto query_size_including_faces_id = TakeNextId();
  830. query_size_including_faces = InsertInst(
  831. image_texel_pointer, SpvOpIMul,
  832. type_mgr->GetTypeInstruction(query_size_type),
  833. query_size_including_faces_id,
  834. {{SPV_OPERAND_TYPE_ID, {query_size_including_faces->result_id()}},
  835. {SPV_OPERAND_TYPE_ID, {multiplicand_inst->result_id()}}});
  836. }
  837. // Make a coordinate-type with all 1 components.
  838. auto* coordinate_1 =
  839. query_num_components == 1
  840. ? component_1
  841. : constant_mgr->GetConstant(
  842. query_size_type,
  843. std::vector<uint32_t>(query_num_components, component_1_id));
  844. // Make a coordinate-type with all 1 components.
  845. auto* coordinate_0 =
  846. query_num_components == 0
  847. ? component_0
  848. : constant_mgr->GetConstant(
  849. query_size_type,
  850. std::vector<uint32_t>(query_num_components, component_0_id));
  851. const uint32_t query_max_including_faces_id = TakeNextId();
  852. auto* query_max_including_faces = InsertInst(
  853. image_texel_pointer, SpvOpISub,
  854. type_mgr->GetTypeInstruction(query_size_type),
  855. query_max_including_faces_id,
  856. {{SPV_OPERAND_TYPE_ID, {query_size_including_faces->result_id()}},
  857. {SPV_OPERAND_TYPE_ID,
  858. {constant_mgr->GetDefiningInstruction(coordinate_1)->result_id()}}});
  859. // Clamp the coordinate
  860. auto* clamp_coord =
  861. MakeClampInst(coord, constant_mgr->GetDefiningInstruction(coordinate_0),
  862. query_max_including_faces, image_texel_pointer);
  863. image_texel_pointer->SetInOperand(1, {clamp_coord->result_id()});
  864. // Clamp the sample index
  865. if (multisampled) {
  866. // Get the sample count via OpImageQuerySamples
  867. const auto query_samples_id = TakeNextId();
  868. auto* query_samples = InsertInst(
  869. image_texel_pointer, SpvOpImageQuerySamples,
  870. constant_mgr->GetDefiningInstruction(component_0)->type_id(),
  871. query_samples_id, {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
  872. const auto max_samples_id = TakeNextId();
  873. auto* max_samples = InsertInst(image_texel_pointer, SpvOpImageQuerySamples,
  874. query_samples->type_id(), max_samples_id,
  875. {{SPV_OPERAND_TYPE_ID, {query_samples_id}},
  876. {SPV_OPERAND_TYPE_ID, {component_1_id}}});
  877. auto* clamp_samples = MakeClampInst(
  878. samples, constant_mgr->GetDefiningInstruction(coordinate_0),
  879. max_samples, image_texel_pointer);
  880. image_texel_pointer->SetInOperand(2, {clamp_samples->result_id()});
  881. } else {
  882. // Just replace it with 0. Don't even check what was there before.
  883. image_texel_pointer->SetInOperand(2, {component_0_id});
  884. }
  885. def_use_mgr->AnalyzeInstUse(image_texel_pointer);
  886. return SPV_SUCCESS;
  887. }
  888. opt::Instruction* GraphicsRobustAccessPass::InsertInst(
  889. opt::Instruction* where_inst, SpvOp opcode, uint32_t type_id,
  890. uint32_t result_id, const Instruction::OperandList& operands) {
  891. module_status_.modified = true;
  892. auto* result = where_inst->InsertBefore(
  893. MakeUnique<Instruction>(context(), opcode, type_id, result_id, operands));
  894. context()->get_def_use_mgr()->AnalyzeInstDefUse(result);
  895. auto* basic_block = context()->get_instr_block(where_inst);
  896. context()->set_instr_block(result, basic_block);
  897. return result;
  898. }
  899. } // namespace opt
  900. } // namespace spvtools