convert_to_sampled_image_pass.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. // Copyright (c) 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/convert_to_sampled_image_pass.h"
  15. #include <cctype>
  16. #include <cstring>
  17. #include "source/opt/ir_builder.h"
  18. #include "source/util/make_unique.h"
  19. #include "source/util/parse_number.h"
  20. namespace spvtools {
  21. namespace opt {
  22. using VectorOfDescriptorSetAndBindingPairs =
  23. std::vector<DescriptorSetAndBinding>;
  24. using DescriptorSetBindingToInstruction =
  25. ConvertToSampledImagePass::DescriptorSetBindingToInstruction;
  26. namespace {
  27. using utils::ParseNumber;
  28. // Returns true if the given char is ':', '\0' or considered as blank space
  29. // (i.e.: '\n', '\r', '\v', '\t', '\f' and ' ').
  30. bool IsSeparator(char ch) {
  31. return std::strchr(":\0", ch) || std::isspace(ch) != 0;
  32. }
  33. // Reads characters starting from |str| until it meets a separator. Parses a
  34. // number from the characters and stores it into |number|. Returns the pointer
  35. // to the separator if it succeeds. Otherwise, returns nullptr.
  36. const char* ParseNumberUntilSeparator(const char* str, uint32_t* number) {
  37. const char* number_begin = str;
  38. while (!IsSeparator(*str)) str++;
  39. const char* number_end = str;
  40. std::string number_in_str(number_begin, number_end - number_begin);
  41. if (!utils::ParseNumber(number_in_str.c_str(), number)) {
  42. // The descriptor set is not a valid uint32 number.
  43. return nullptr;
  44. }
  45. return str;
  46. }
  47. // Returns id of the image type used for the sampled image type of
  48. // |sampled_image|.
  49. uint32_t GetImageTypeOfSampledImage(analysis::TypeManager* type_mgr,
  50. Instruction* sampled_image) {
  51. auto* sampled_image_type =
  52. type_mgr->GetType(sampled_image->type_id())->AsSampledImage();
  53. return type_mgr->GetTypeInstruction(sampled_image_type->image_type());
  54. }
  55. // Finds the instruction whose id is |inst_id|. Follows the operand of
  56. // OpCopyObject recursively if the opcode of the instruction is OpCopyObject
  57. // and returns the first instruction that does not have OpCopyObject as opcode.
  58. Instruction* GetNonCopyObjectDef(analysis::DefUseManager* def_use_mgr,
  59. uint32_t inst_id) {
  60. Instruction* inst = def_use_mgr->GetDef(inst_id);
  61. while (inst->opcode() == spv::Op::OpCopyObject) {
  62. inst_id = inst->GetSingleWordInOperand(0u);
  63. inst = def_use_mgr->GetDef(inst_id);
  64. }
  65. return inst;
  66. }
  67. } // namespace
  68. bool ConvertToSampledImagePass::GetDescriptorSetBinding(
  69. const Instruction& inst,
  70. DescriptorSetAndBinding* descriptor_set_binding) const {
  71. auto* decoration_manager = context()->get_decoration_mgr();
  72. bool found_descriptor_set_to_convert = false;
  73. bool found_binding_to_convert = false;
  74. for (auto decorate :
  75. decoration_manager->GetDecorationsFor(inst.result_id(), false)) {
  76. spv::Decoration decoration =
  77. spv::Decoration(decorate->GetSingleWordInOperand(1u));
  78. if (decoration == spv::Decoration::DescriptorSet) {
  79. if (found_descriptor_set_to_convert) {
  80. assert(false && "A resource has two OpDecorate for the descriptor set");
  81. return false;
  82. }
  83. descriptor_set_binding->descriptor_set =
  84. decorate->GetSingleWordInOperand(2u);
  85. found_descriptor_set_to_convert = true;
  86. } else if (decoration == spv::Decoration::Binding) {
  87. if (found_binding_to_convert) {
  88. assert(false && "A resource has two OpDecorate for the binding");
  89. return false;
  90. }
  91. descriptor_set_binding->binding = decorate->GetSingleWordInOperand(2u);
  92. found_binding_to_convert = true;
  93. }
  94. }
  95. return found_descriptor_set_to_convert && found_binding_to_convert;
  96. }
  97. bool ConvertToSampledImagePass::ShouldResourceBeConverted(
  98. const DescriptorSetAndBinding& descriptor_set_binding) const {
  99. return descriptor_set_binding_pairs_.find(descriptor_set_binding) !=
  100. descriptor_set_binding_pairs_.end();
  101. }
  102. const analysis::Type* ConvertToSampledImagePass::GetVariableType(
  103. const Instruction& variable) const {
  104. if (variable.opcode() != spv::Op::OpVariable) return nullptr;
  105. auto* type = context()->get_type_mgr()->GetType(variable.type_id());
  106. auto* pointer_type = type->AsPointer();
  107. if (!pointer_type) return nullptr;
  108. return pointer_type->pointee_type();
  109. }
  110. spv::StorageClass ConvertToSampledImagePass::GetStorageClass(
  111. const Instruction& variable) const {
  112. assert(variable.opcode() == spv::Op::OpVariable);
  113. auto* type = context()->get_type_mgr()->GetType(variable.type_id());
  114. auto* pointer_type = type->AsPointer();
  115. if (!pointer_type) return spv::StorageClass::Max;
  116. return pointer_type->storage_class();
  117. }
  118. bool ConvertToSampledImagePass::CollectResourcesToConvert(
  119. DescriptorSetBindingToInstruction* descriptor_set_binding_pair_to_sampler,
  120. DescriptorSetBindingToInstruction* descriptor_set_binding_pair_to_image)
  121. const {
  122. for (auto& inst : context()->types_values()) {
  123. const auto* variable_type = GetVariableType(inst);
  124. if (variable_type == nullptr) continue;
  125. DescriptorSetAndBinding descriptor_set_binding;
  126. if (!GetDescriptorSetBinding(inst, &descriptor_set_binding)) continue;
  127. if (!ShouldResourceBeConverted(descriptor_set_binding)) {
  128. continue;
  129. }
  130. if (variable_type->AsImage()) {
  131. if (!descriptor_set_binding_pair_to_image
  132. ->insert({descriptor_set_binding, &inst})
  133. .second) {
  134. return false;
  135. }
  136. } else if (variable_type->AsSampler()) {
  137. if (!descriptor_set_binding_pair_to_sampler
  138. ->insert({descriptor_set_binding, &inst})
  139. .second) {
  140. return false;
  141. }
  142. }
  143. }
  144. return true;
  145. }
  146. Pass::Status ConvertToSampledImagePass::Process() {
  147. Status status = Status::SuccessWithoutChange;
  148. DescriptorSetBindingToInstruction descriptor_set_binding_pair_to_sampler,
  149. descriptor_set_binding_pair_to_image;
  150. if (!CollectResourcesToConvert(&descriptor_set_binding_pair_to_sampler,
  151. &descriptor_set_binding_pair_to_image)) {
  152. return Status::Failure;
  153. }
  154. for (auto& image : descriptor_set_binding_pair_to_image) {
  155. status = CombineStatus(
  156. status, UpdateImageVariableToSampledImage(image.second, image.first));
  157. if (status == Status::Failure) {
  158. return status;
  159. }
  160. }
  161. for (const auto& sampler : descriptor_set_binding_pair_to_sampler) {
  162. // Converting only a Sampler to Sampled Image is not allowed. It must have a
  163. // corresponding image to combine the sampler with.
  164. auto image_itr = descriptor_set_binding_pair_to_image.find(sampler.first);
  165. if (image_itr == descriptor_set_binding_pair_to_image.end() ||
  166. image_itr->second == nullptr) {
  167. return Status::Failure;
  168. }
  169. status = CombineStatus(
  170. status, CheckUsesOfSamplerVariable(sampler.second, image_itr->second));
  171. if (status == Status::Failure) {
  172. return status;
  173. }
  174. }
  175. return status;
  176. }
  177. void ConvertToSampledImagePass::FindUses(const Instruction* inst,
  178. std::vector<Instruction*>* uses,
  179. spv::Op user_opcode) const {
  180. auto* def_use_mgr = context()->get_def_use_mgr();
  181. def_use_mgr->ForEachUser(inst, [uses, user_opcode, this](Instruction* user) {
  182. if (user->opcode() == user_opcode) {
  183. uses->push_back(user);
  184. } else if (user->opcode() == spv::Op::OpCopyObject) {
  185. FindUses(user, uses, user_opcode);
  186. }
  187. });
  188. }
  189. void ConvertToSampledImagePass::FindUsesOfImage(
  190. const Instruction* image, std::vector<Instruction*>* uses) const {
  191. auto* def_use_mgr = context()->get_def_use_mgr();
  192. def_use_mgr->ForEachUser(image, [uses, this](Instruction* user) {
  193. switch (user->opcode()) {
  194. case spv::Op::OpImageFetch:
  195. case spv::Op::OpImageRead:
  196. case spv::Op::OpImageWrite:
  197. case spv::Op::OpImageQueryFormat:
  198. case spv::Op::OpImageQueryOrder:
  199. case spv::Op::OpImageQuerySizeLod:
  200. case spv::Op::OpImageQuerySize:
  201. case spv::Op::OpImageQueryLevels:
  202. case spv::Op::OpImageQuerySamples:
  203. case spv::Op::OpImageSparseFetch:
  204. uses->push_back(user);
  205. default:
  206. break;
  207. }
  208. if (user->opcode() == spv::Op::OpCopyObject) {
  209. FindUsesOfImage(user, uses);
  210. }
  211. });
  212. }
  213. Instruction* ConvertToSampledImagePass::CreateImageExtraction(
  214. Instruction* sampled_image) {
  215. InstructionBuilder builder(
  216. context(), sampled_image->NextNode(),
  217. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  218. return builder.AddUnaryOp(
  219. GetImageTypeOfSampledImage(context()->get_type_mgr(), sampled_image),
  220. spv::Op::OpImage, sampled_image->result_id());
  221. }
  222. uint32_t ConvertToSampledImagePass::GetSampledImageTypeForImage(
  223. Instruction* image_variable) {
  224. const auto* variable_type = GetVariableType(*image_variable);
  225. if (variable_type == nullptr) return 0;
  226. const auto* image_type = variable_type->AsImage();
  227. if (image_type == nullptr) return 0;
  228. analysis::Image image_type_for_sampled_image(*image_type);
  229. analysis::SampledImage sampled_image_type(&image_type_for_sampled_image);
  230. return context()->get_type_mgr()->GetTypeInstruction(&sampled_image_type);
  231. }
  232. Instruction* ConvertToSampledImagePass::UpdateImageUses(
  233. Instruction* sampled_image_load) {
  234. std::vector<Instruction*> uses_of_load;
  235. FindUsesOfImage(sampled_image_load, &uses_of_load);
  236. if (uses_of_load.empty()) return nullptr;
  237. auto* extracted_image = CreateImageExtraction(sampled_image_load);
  238. for (auto* user : uses_of_load) {
  239. user->SetInOperand(0, {extracted_image->result_id()});
  240. context()->get_def_use_mgr()->AnalyzeInstUse(user);
  241. }
  242. return extracted_image;
  243. }
  244. bool ConvertToSampledImagePass::
  245. IsSamplerOfSampledImageDecoratedByDescriptorSetBinding(
  246. Instruction* sampled_image_inst,
  247. const DescriptorSetAndBinding& descriptor_set_binding) {
  248. auto* def_use_mgr = context()->get_def_use_mgr();
  249. uint32_t sampler_id = sampled_image_inst->GetSingleWordInOperand(1u);
  250. auto* sampler_load = def_use_mgr->GetDef(sampler_id);
  251. if (sampler_load->opcode() != spv::Op::OpLoad) return false;
  252. auto* sampler = def_use_mgr->GetDef(sampler_load->GetSingleWordInOperand(0u));
  253. DescriptorSetAndBinding sampler_descriptor_set_binding;
  254. return GetDescriptorSetBinding(*sampler, &sampler_descriptor_set_binding) &&
  255. sampler_descriptor_set_binding == descriptor_set_binding;
  256. }
  257. void ConvertToSampledImagePass::UpdateSampledImageUses(
  258. Instruction* image_load, Instruction* image_extraction,
  259. const DescriptorSetAndBinding& image_descriptor_set_binding) {
  260. std::vector<Instruction*> sampled_image_users;
  261. FindUses(image_load, &sampled_image_users, spv::Op::OpSampledImage);
  262. auto* def_use_mgr = context()->get_def_use_mgr();
  263. for (auto* sampled_image_inst : sampled_image_users) {
  264. if (IsSamplerOfSampledImageDecoratedByDescriptorSetBinding(
  265. sampled_image_inst, image_descriptor_set_binding)) {
  266. context()->ReplaceAllUsesWith(sampled_image_inst->result_id(),
  267. image_load->result_id());
  268. def_use_mgr->AnalyzeInstUse(image_load);
  269. context()->KillInst(sampled_image_inst);
  270. } else {
  271. if (!image_extraction)
  272. image_extraction = CreateImageExtraction(image_load);
  273. sampled_image_inst->SetInOperand(0, {image_extraction->result_id()});
  274. def_use_mgr->AnalyzeInstUse(sampled_image_inst);
  275. }
  276. }
  277. }
  278. void ConvertToSampledImagePass::MoveInstructionNextToType(Instruction* inst,
  279. uint32_t type_id) {
  280. auto* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
  281. inst->SetResultType(type_id);
  282. inst->RemoveFromList();
  283. inst->InsertAfter(type_inst);
  284. }
  285. bool ConvertToSampledImagePass::ConvertImageVariableToSampledImage(
  286. Instruction* image_variable, uint32_t sampled_image_type_id) {
  287. auto* sampled_image_type =
  288. context()->get_type_mgr()->GetType(sampled_image_type_id);
  289. if (sampled_image_type == nullptr) return false;
  290. auto storage_class = GetStorageClass(*image_variable);
  291. if (storage_class == spv::StorageClass::Max) return false;
  292. // Make sure |image_variable| is behind its type i.e., avoid the forward
  293. // reference.
  294. uint32_t type_id = context()->get_type_mgr()->FindPointerToType(
  295. sampled_image_type_id, storage_class);
  296. MoveInstructionNextToType(image_variable, type_id);
  297. return true;
  298. }
  299. Pass::Status ConvertToSampledImagePass::UpdateImageVariableToSampledImage(
  300. Instruction* image_variable,
  301. const DescriptorSetAndBinding& descriptor_set_binding) {
  302. std::vector<Instruction*> image_variable_loads;
  303. FindUses(image_variable, &image_variable_loads, spv::Op::OpLoad);
  304. if (image_variable_loads.empty()) return Status::SuccessWithoutChange;
  305. const uint32_t sampled_image_type_id =
  306. GetSampledImageTypeForImage(image_variable);
  307. if (!sampled_image_type_id) return Status::Failure;
  308. for (auto* load : image_variable_loads) {
  309. load->SetResultType(sampled_image_type_id);
  310. auto* image_extraction = UpdateImageUses(load);
  311. UpdateSampledImageUses(load, image_extraction, descriptor_set_binding);
  312. }
  313. return ConvertImageVariableToSampledImage(image_variable,
  314. sampled_image_type_id)
  315. ? Status::SuccessWithChange
  316. : Status::Failure;
  317. }
  318. bool ConvertToSampledImagePass::DoesSampledImageReferenceImage(
  319. Instruction* sampled_image_inst, Instruction* image_variable) {
  320. if (sampled_image_inst->opcode() != spv::Op::OpSampledImage) return false;
  321. auto* def_use_mgr = context()->get_def_use_mgr();
  322. auto* image_load = GetNonCopyObjectDef(
  323. def_use_mgr, sampled_image_inst->GetSingleWordInOperand(0u));
  324. if (image_load->opcode() != spv::Op::OpLoad) return false;
  325. auto* image =
  326. GetNonCopyObjectDef(def_use_mgr, image_load->GetSingleWordInOperand(0u));
  327. return image->opcode() == spv::Op::OpVariable &&
  328. image->result_id() == image_variable->result_id();
  329. }
  330. Pass::Status ConvertToSampledImagePass::CheckUsesOfSamplerVariable(
  331. const Instruction* sampler_variable,
  332. Instruction* image_to_be_combined_with) {
  333. if (image_to_be_combined_with == nullptr) return Status::Failure;
  334. std::vector<Instruction*> sampler_variable_loads;
  335. FindUses(sampler_variable, &sampler_variable_loads, spv::Op::OpLoad);
  336. for (auto* load : sampler_variable_loads) {
  337. std::vector<Instruction*> sampled_image_users;
  338. FindUses(load, &sampled_image_users, spv::Op::OpSampledImage);
  339. for (auto* sampled_image_inst : sampled_image_users) {
  340. if (!DoesSampledImageReferenceImage(sampled_image_inst,
  341. image_to_be_combined_with)) {
  342. return Status::Failure;
  343. }
  344. }
  345. }
  346. return Status::SuccessWithoutChange;
  347. }
  348. std::unique_ptr<VectorOfDescriptorSetAndBindingPairs>
  349. ConvertToSampledImagePass::ParseDescriptorSetBindingPairsString(
  350. const char* str) {
  351. if (!str) return nullptr;
  352. auto descriptor_set_binding_pairs =
  353. MakeUnique<VectorOfDescriptorSetAndBindingPairs>();
  354. while (std::isspace(*str)) str++; // skip leading spaces.
  355. // The parsing loop, break when points to the end.
  356. while (*str) {
  357. // Parse the descriptor set.
  358. uint32_t descriptor_set = 0;
  359. str = ParseNumberUntilSeparator(str, &descriptor_set);
  360. if (str == nullptr) return nullptr;
  361. // Find the ':', spaces between the descriptor set and the ':' are not
  362. // allowed.
  363. if (*str++ != ':') {
  364. // ':' not found
  365. return nullptr;
  366. }
  367. // Parse the binding.
  368. uint32_t binding = 0;
  369. str = ParseNumberUntilSeparator(str, &binding);
  370. if (str == nullptr) return nullptr;
  371. descriptor_set_binding_pairs->push_back({descriptor_set, binding});
  372. // Skip trailing spaces.
  373. while (std::isspace(*str)) str++;
  374. }
  375. return descriptor_set_binding_pairs;
  376. }
  377. } // namespace opt
  378. } // namespace spvtools