generate_webgpu_initializers_pass.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. // Copyright (c) 2019 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/generate_webgpu_initializers_pass.h"
  15. #include "source/opt/ir_context.h"
  16. namespace spvtools {
  17. namespace opt {
  18. using inst_iterator = InstructionList::iterator;
  19. namespace {
  20. bool NeedsWebGPUInitializer(Instruction* inst) {
  21. if (inst->opcode() != SpvOpVariable) return false;
  22. auto storage_class = inst->GetSingleWordOperand(2);
  23. if (storage_class != SpvStorageClassOutput &&
  24. storage_class != SpvStorageClassPrivate &&
  25. storage_class != SpvStorageClassFunction) {
  26. return false;
  27. }
  28. if (inst->NumOperands() > 3) return false;
  29. return true;
  30. }
  31. } // namespace
  32. Pass::Status GenerateWebGPUInitializersPass::Process() {
  33. auto* module = context()->module();
  34. bool changed = false;
  35. // Handle global/module scoped variables
  36. for (auto iter = module->types_values_begin();
  37. iter != module->types_values_end(); ++iter) {
  38. Instruction* inst = &(*iter);
  39. if (inst->opcode() == SpvOpConstantNull) {
  40. null_constant_type_map_[inst->type_id()] = inst;
  41. seen_null_constants_.insert(inst);
  42. continue;
  43. }
  44. if (!NeedsWebGPUInitializer(inst)) continue;
  45. changed = true;
  46. auto* constant_inst = GetNullConstantForVariable(inst);
  47. if (!constant_inst) return Status::Failure;
  48. if (seen_null_constants_.find(constant_inst) ==
  49. seen_null_constants_.end()) {
  50. constant_inst->InsertBefore(inst);
  51. null_constant_type_map_[inst->type_id()] = inst;
  52. seen_null_constants_.insert(inst);
  53. }
  54. AddNullInitializerToVariable(constant_inst, inst);
  55. }
  56. // Handle local/function scoped variables
  57. for (auto func = module->begin(); func != module->end(); ++func) {
  58. auto block = func->entry().get();
  59. for (auto iter = block->begin();
  60. iter != block->end() && iter->opcode() == SpvOpVariable; ++iter) {
  61. Instruction* inst = &(*iter);
  62. if (!NeedsWebGPUInitializer(inst)) continue;
  63. changed = true;
  64. auto* constant_inst = GetNullConstantForVariable(inst);
  65. if (!constant_inst) return Status::Failure;
  66. AddNullInitializerToVariable(constant_inst, inst);
  67. }
  68. }
  69. return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  70. }
  71. Instruction* GenerateWebGPUInitializersPass::GetNullConstantForVariable(
  72. Instruction* variable_inst) {
  73. auto constant_mgr = context()->get_constant_mgr();
  74. auto* def_use_mgr = get_def_use_mgr();
  75. auto* ptr_inst = def_use_mgr->GetDef(variable_inst->type_id());
  76. auto type_id = ptr_inst->GetInOperand(1).words[0];
  77. if (null_constant_type_map_.find(type_id) == null_constant_type_map_.end()) {
  78. auto* constant_type = context()->get_type_mgr()->GetType(type_id);
  79. auto* constant = constant_mgr->GetConstant(constant_type, {});
  80. return constant_mgr->GetDefiningInstruction(constant, type_id);
  81. } else {
  82. return null_constant_type_map_[type_id];
  83. }
  84. }
  85. void GenerateWebGPUInitializersPass::AddNullInitializerToVariable(
  86. Instruction* constant_inst, Instruction* variable_inst) {
  87. auto constant_id = constant_inst->result_id();
  88. variable_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {constant_id}));
  89. get_def_use_mgr()->AnalyzeInstUse(variable_inst);
  90. }
  91. } // namespace opt
  92. } // namespace spvtools