convert_to_sampled_image_pass.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. // Copyright (c) 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/convert_to_sampled_image_pass.h"
  15. #include <cctype>
  16. #include <cstring>
  17. #include <tuple>
  18. #include "source/opt/ir_builder.h"
  19. #include "source/util/make_unique.h"
  20. #include "source/util/parse_number.h"
  21. namespace spvtools {
  22. namespace opt {
  23. using VectorOfDescriptorSetAndBindingPairs =
  24. std::vector<DescriptorSetAndBinding>;
  25. using DescriptorSetBindingToInstruction =
  26. ConvertToSampledImagePass::DescriptorSetBindingToInstruction;
  27. namespace {
  28. using utils::ParseNumber;
  29. // Returns true if the given char is ':', '\0' or considered as blank space
  30. // (i.e.: '\n', '\r', '\v', '\t', '\f' and ' ').
  31. bool IsSeparator(char ch) {
  32. return std::strchr(":\0", ch) || std::isspace(ch) != 0;
  33. }
  34. // Reads characters starting from |str| until it meets a separator. Parses a
  35. // number from the characters and stores it into |number|. Returns the pointer
  36. // to the separator if it succeeds. Otherwise, returns nullptr.
  37. const char* ParseNumberUntilSeparator(const char* str, uint32_t* number) {
  38. const char* number_begin = str;
  39. while (!IsSeparator(*str)) str++;
  40. const char* number_end = str;
  41. std::string number_in_str(number_begin, number_end - number_begin);
  42. if (!utils::ParseNumber(number_in_str.c_str(), number)) {
  43. // The descriptor set is not a valid uint32 number.
  44. return nullptr;
  45. }
  46. return str;
  47. }
  48. // Returns id of the image type used for the sampled image type of
  49. // |sampled_image|.
  50. uint32_t GetImageTypeOfSampledImage(analysis::TypeManager* type_mgr,
  51. Instruction* sampled_image) {
  52. auto* sampled_image_type =
  53. type_mgr->GetType(sampled_image->type_id())->AsSampledImage();
  54. return type_mgr->GetTypeInstruction(sampled_image_type->image_type());
  55. }
  56. // Finds the instruction whose id is |inst_id|. Follows the operand of
  57. // OpCopyObject recursively if the opcode of the instruction is OpCopyObject
  58. // and returns the first instruction that does not have OpCopyObject as opcode.
  59. Instruction* GetNonCopyObjectDef(analysis::DefUseManager* def_use_mgr,
  60. uint32_t inst_id) {
  61. Instruction* inst = def_use_mgr->GetDef(inst_id);
  62. while (inst->opcode() == SpvOpCopyObject) {
  63. inst_id = inst->GetSingleWordInOperand(0u);
  64. inst = def_use_mgr->GetDef(inst_id);
  65. }
  66. return inst;
  67. }
  68. } // namespace
  69. bool ConvertToSampledImagePass::GetDescriptorSetBinding(
  70. const Instruction& inst,
  71. DescriptorSetAndBinding* descriptor_set_binding) const {
  72. auto* decoration_manager = context()->get_decoration_mgr();
  73. bool found_descriptor_set_to_convert = false;
  74. bool found_binding_to_convert = false;
  75. for (auto decorate :
  76. decoration_manager->GetDecorationsFor(inst.result_id(), false)) {
  77. uint32_t decoration = decorate->GetSingleWordInOperand(1u);
  78. if (decoration == SpvDecorationDescriptorSet) {
  79. if (found_descriptor_set_to_convert) {
  80. assert(false && "A resource has two OpDecorate for the descriptor set");
  81. return false;
  82. }
  83. descriptor_set_binding->descriptor_set =
  84. decorate->GetSingleWordInOperand(2u);
  85. found_descriptor_set_to_convert = true;
  86. } else if (decoration == SpvDecorationBinding) {
  87. if (found_binding_to_convert) {
  88. assert(false && "A resource has two OpDecorate for the binding");
  89. return false;
  90. }
  91. descriptor_set_binding->binding = decorate->GetSingleWordInOperand(2u);
  92. found_binding_to_convert = true;
  93. }
  94. }
  95. return found_descriptor_set_to_convert && found_binding_to_convert;
  96. }
  97. bool ConvertToSampledImagePass::ShouldResourceBeConverted(
  98. const DescriptorSetAndBinding& descriptor_set_binding) const {
  99. return descriptor_set_binding_pairs_.find(descriptor_set_binding) !=
  100. descriptor_set_binding_pairs_.end();
  101. }
  102. const analysis::Type* ConvertToSampledImagePass::GetVariableType(
  103. const Instruction& variable) const {
  104. if (variable.opcode() != SpvOpVariable) return nullptr;
  105. auto* type = context()->get_type_mgr()->GetType(variable.type_id());
  106. auto* pointer_type = type->AsPointer();
  107. if (!pointer_type) return nullptr;
  108. return pointer_type->pointee_type();
  109. }
  110. SpvStorageClass ConvertToSampledImagePass::GetStorageClass(
  111. const Instruction& variable) const {
  112. assert(variable.opcode() == SpvOpVariable);
  113. auto* type = context()->get_type_mgr()->GetType(variable.type_id());
  114. auto* pointer_type = type->AsPointer();
  115. if (!pointer_type) return SpvStorageClassMax;
  116. return pointer_type->storage_class();
  117. }
  118. bool ConvertToSampledImagePass::CollectResourcesToConvert(
  119. DescriptorSetBindingToInstruction* descriptor_set_binding_pair_to_sampler,
  120. DescriptorSetBindingToInstruction* descriptor_set_binding_pair_to_image)
  121. const {
  122. for (auto& inst : context()->types_values()) {
  123. const auto* variable_type = GetVariableType(inst);
  124. if (variable_type == nullptr) continue;
  125. DescriptorSetAndBinding descriptor_set_binding;
  126. if (!GetDescriptorSetBinding(inst, &descriptor_set_binding)) continue;
  127. if (!ShouldResourceBeConverted(descriptor_set_binding)) {
  128. continue;
  129. }
  130. if (variable_type->AsImage()) {
  131. if (!descriptor_set_binding_pair_to_image
  132. ->insert({descriptor_set_binding, &inst})
  133. .second) {
  134. return false;
  135. }
  136. } else if (variable_type->AsSampler()) {
  137. if (!descriptor_set_binding_pair_to_sampler
  138. ->insert({descriptor_set_binding, &inst})
  139. .second) {
  140. return false;
  141. }
  142. }
  143. }
  144. return true;
  145. }
  146. Pass::Status ConvertToSampledImagePass::Process() {
  147. Status status = Status::SuccessWithoutChange;
  148. DescriptorSetBindingToInstruction descriptor_set_binding_pair_to_sampler,
  149. descriptor_set_binding_pair_to_image;
  150. if (!CollectResourcesToConvert(&descriptor_set_binding_pair_to_sampler,
  151. &descriptor_set_binding_pair_to_image)) {
  152. return Status::Failure;
  153. }
  154. for (auto& image : descriptor_set_binding_pair_to_image) {
  155. status = CombineStatus(
  156. status, UpdateImageVariableToSampledImage(image.second, image.first));
  157. if (status == Status::Failure) {
  158. return status;
  159. }
  160. }
  161. for (const auto& sampler : descriptor_set_binding_pair_to_sampler) {
  162. // Converting only a Sampler to Sampled Image is not allowed. It must have a
  163. // corresponding image to combine the sampler with.
  164. auto image_itr = descriptor_set_binding_pair_to_image.find(sampler.first);
  165. if (image_itr == descriptor_set_binding_pair_to_image.end() ||
  166. image_itr->second == nullptr) {
  167. return Status::Failure;
  168. }
  169. status = CombineStatus(
  170. status, CheckUsesOfSamplerVariable(sampler.second, image_itr->second));
  171. if (status == Status::Failure) {
  172. return status;
  173. }
  174. }
  175. return status;
  176. }
  177. void ConvertToSampledImagePass::FindUses(const Instruction* inst,
  178. std::vector<Instruction*>* uses,
  179. uint32_t user_opcode) const {
  180. auto* def_use_mgr = context()->get_def_use_mgr();
  181. def_use_mgr->ForEachUser(inst, [uses, user_opcode, this](Instruction* user) {
  182. if (user->opcode() == user_opcode) {
  183. uses->push_back(user);
  184. } else if (user->opcode() == SpvOpCopyObject) {
  185. FindUses(user, uses, user_opcode);
  186. }
  187. });
  188. }
  189. void ConvertToSampledImagePass::FindUsesOfImage(
  190. const Instruction* image, std::vector<Instruction*>* uses) const {
  191. auto* def_use_mgr = context()->get_def_use_mgr();
  192. def_use_mgr->ForEachUser(image, [uses, this](Instruction* user) {
  193. switch (user->opcode()) {
  194. case SpvOpImageFetch:
  195. case SpvOpImageRead:
  196. case SpvOpImageWrite:
  197. case SpvOpImageQueryFormat:
  198. case SpvOpImageQueryOrder:
  199. case SpvOpImageQuerySizeLod:
  200. case SpvOpImageQuerySize:
  201. case SpvOpImageQueryLevels:
  202. case SpvOpImageQuerySamples:
  203. case SpvOpImageSparseFetch:
  204. uses->push_back(user);
  205. default:
  206. break;
  207. }
  208. if (user->opcode() == SpvOpCopyObject) {
  209. FindUsesOfImage(user, uses);
  210. }
  211. });
  212. }
  213. Instruction* ConvertToSampledImagePass::CreateImageExtraction(
  214. Instruction* sampled_image) {
  215. InstructionBuilder builder(
  216. context(), sampled_image->NextNode(),
  217. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  218. return builder.AddUnaryOp(
  219. GetImageTypeOfSampledImage(context()->get_type_mgr(), sampled_image),
  220. SpvOpImage, sampled_image->result_id());
  221. }
  222. uint32_t ConvertToSampledImagePass::GetSampledImageTypeForImage(
  223. Instruction* image_variable) {
  224. const auto* variable_type = GetVariableType(*image_variable);
  225. if (variable_type == nullptr) return 0;
  226. const auto* image_type = variable_type->AsImage();
  227. if (image_type == nullptr) return 0;
  228. analysis::Image image_type_for_sampled_image(*image_type);
  229. analysis::SampledImage sampled_image_type(&image_type_for_sampled_image);
  230. return context()->get_type_mgr()->GetTypeInstruction(&sampled_image_type);
  231. }
  232. Instruction* ConvertToSampledImagePass::UpdateImageUses(
  233. Instruction* sampled_image_load) {
  234. std::vector<Instruction*> uses_of_load;
  235. FindUsesOfImage(sampled_image_load, &uses_of_load);
  236. if (uses_of_load.empty()) return nullptr;
  237. auto* extracted_image = CreateImageExtraction(sampled_image_load);
  238. for (auto* user : uses_of_load) {
  239. user->SetInOperand(0, {extracted_image->result_id()});
  240. context()->get_def_use_mgr()->AnalyzeInstUse(user);
  241. }
  242. return extracted_image;
  243. }
  244. bool ConvertToSampledImagePass::
  245. IsSamplerOfSampledImageDecoratedByDescriptorSetBinding(
  246. Instruction* sampled_image_inst,
  247. const DescriptorSetAndBinding& descriptor_set_binding) {
  248. auto* def_use_mgr = context()->get_def_use_mgr();
  249. uint32_t sampler_id = sampled_image_inst->GetSingleWordInOperand(1u);
  250. auto* sampler_load = def_use_mgr->GetDef(sampler_id);
  251. if (sampler_load->opcode() != SpvOpLoad) return false;
  252. auto* sampler = def_use_mgr->GetDef(sampler_load->GetSingleWordInOperand(0u));
  253. DescriptorSetAndBinding sampler_descriptor_set_binding;
  254. return GetDescriptorSetBinding(*sampler, &sampler_descriptor_set_binding) &&
  255. sampler_descriptor_set_binding == descriptor_set_binding;
  256. }
  257. void ConvertToSampledImagePass::UpdateSampledImageUses(
  258. Instruction* image_load, Instruction* image_extraction,
  259. const DescriptorSetAndBinding& image_descriptor_set_binding) {
  260. std::vector<Instruction*> sampled_image_users;
  261. FindUses(image_load, &sampled_image_users, SpvOpSampledImage);
  262. auto* def_use_mgr = context()->get_def_use_mgr();
  263. for (auto* sampled_image_inst : sampled_image_users) {
  264. if (IsSamplerOfSampledImageDecoratedByDescriptorSetBinding(
  265. sampled_image_inst, image_descriptor_set_binding)) {
  266. context()->ReplaceAllUsesWith(sampled_image_inst->result_id(),
  267. image_load->result_id());
  268. def_use_mgr->AnalyzeInstUse(image_load);
  269. context()->KillInst(sampled_image_inst);
  270. } else {
  271. if (!image_extraction)
  272. image_extraction = CreateImageExtraction(image_load);
  273. sampled_image_inst->SetInOperand(0, {image_extraction->result_id()});
  274. def_use_mgr->AnalyzeInstUse(sampled_image_inst);
  275. }
  276. }
  277. }
  278. void ConvertToSampledImagePass::MoveInstructionNextToType(Instruction* inst,
  279. uint32_t type_id) {
  280. auto* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
  281. inst->SetResultType(type_id);
  282. inst->RemoveFromList();
  283. inst->InsertAfter(type_inst);
  284. }
  285. bool ConvertToSampledImagePass::ConvertImageVariableToSampledImage(
  286. Instruction* image_variable, uint32_t sampled_image_type_id) {
  287. auto* sampled_image_type =
  288. context()->get_type_mgr()->GetType(sampled_image_type_id);
  289. if (sampled_image_type == nullptr) return false;
  290. auto storage_class = GetStorageClass(*image_variable);
  291. if (storage_class == SpvStorageClassMax) return false;
  292. analysis::Pointer sampled_image_pointer(sampled_image_type, storage_class);
  293. // Make sure |image_variable| is behind its type i.e., avoid the forward
  294. // reference.
  295. uint32_t type_id =
  296. context()->get_type_mgr()->GetTypeInstruction(&sampled_image_pointer);
  297. MoveInstructionNextToType(image_variable, type_id);
  298. return true;
  299. }
  300. Pass::Status ConvertToSampledImagePass::UpdateImageVariableToSampledImage(
  301. Instruction* image_variable,
  302. const DescriptorSetAndBinding& descriptor_set_binding) {
  303. std::vector<Instruction*> image_variable_loads;
  304. FindUses(image_variable, &image_variable_loads, SpvOpLoad);
  305. if (image_variable_loads.empty()) return Status::SuccessWithoutChange;
  306. const uint32_t sampled_image_type_id =
  307. GetSampledImageTypeForImage(image_variable);
  308. if (!sampled_image_type_id) return Status::Failure;
  309. for (auto* load : image_variable_loads) {
  310. load->SetResultType(sampled_image_type_id);
  311. auto* image_extraction = UpdateImageUses(load);
  312. UpdateSampledImageUses(load, image_extraction, descriptor_set_binding);
  313. }
  314. return ConvertImageVariableToSampledImage(image_variable,
  315. sampled_image_type_id)
  316. ? Status::SuccessWithChange
  317. : Status::Failure;
  318. }
  319. bool ConvertToSampledImagePass::DoesSampledImageReferenceImage(
  320. Instruction* sampled_image_inst, Instruction* image_variable) {
  321. if (sampled_image_inst->opcode() != SpvOpSampledImage) return false;
  322. auto* def_use_mgr = context()->get_def_use_mgr();
  323. auto* image_load = GetNonCopyObjectDef(
  324. def_use_mgr, sampled_image_inst->GetSingleWordInOperand(0u));
  325. if (image_load->opcode() != SpvOpLoad) return false;
  326. auto* image =
  327. GetNonCopyObjectDef(def_use_mgr, image_load->GetSingleWordInOperand(0u));
  328. return image->opcode() == SpvOpVariable &&
  329. image->result_id() == image_variable->result_id();
  330. }
  331. Pass::Status ConvertToSampledImagePass::CheckUsesOfSamplerVariable(
  332. const Instruction* sampler_variable,
  333. Instruction* image_to_be_combined_with) {
  334. if (image_to_be_combined_with == nullptr) return Status::Failure;
  335. std::vector<Instruction*> sampler_variable_loads;
  336. FindUses(sampler_variable, &sampler_variable_loads, SpvOpLoad);
  337. for (auto* load : sampler_variable_loads) {
  338. std::vector<Instruction*> sampled_image_users;
  339. FindUses(load, &sampled_image_users, SpvOpSampledImage);
  340. for (auto* sampled_image_inst : sampled_image_users) {
  341. if (!DoesSampledImageReferenceImage(sampled_image_inst,
  342. image_to_be_combined_with)) {
  343. return Status::Failure;
  344. }
  345. }
  346. }
  347. return Status::SuccessWithoutChange;
  348. }
  349. std::unique_ptr<VectorOfDescriptorSetAndBindingPairs>
  350. ConvertToSampledImagePass::ParseDescriptorSetBindingPairsString(
  351. const char* str) {
  352. if (!str) return nullptr;
  353. auto descriptor_set_binding_pairs =
  354. MakeUnique<VectorOfDescriptorSetAndBindingPairs>();
  355. while (std::isspace(*str)) str++; // skip leading spaces.
  356. // The parsing loop, break when points to the end.
  357. while (*str) {
  358. // Parse the descriptor set.
  359. uint32_t descriptor_set = 0;
  360. str = ParseNumberUntilSeparator(str, &descriptor_set);
  361. if (str == nullptr) return nullptr;
  362. // Find the ':', spaces between the descriptor set and the ':' are not
  363. // allowed.
  364. if (*str++ != ':') {
  365. // ':' not found
  366. return nullptr;
  367. }
  368. // Parse the binding.
  369. uint32_t binding = 0;
  370. str = ParseNumberUntilSeparator(str, &binding);
  371. if (str == nullptr) return nullptr;
  372. descriptor_set_binding_pairs->push_back({descriptor_set, binding});
  373. // Skip trailing spaces.
  374. while (std::isspace(*str)) str++;
  375. }
  376. return descriptor_set_binding_pairs;
  377. }
  378. } // namespace opt
  379. } // namespace spvtools