fix_storage_class.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. // Copyright (c) 2019 Google LLC
  2. // Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
  3. // reserved.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "fix_storage_class.h"
  17. #include <set>
  18. #include "source/opt/instruction.h"
  19. #include "source/opt/ir_context.h"
  20. namespace spvtools {
  21. namespace opt {
  22. Pass::Status FixStorageClass::Process() {
  23. bool modified = false;
  24. get_module()->ForEachInst([this, &modified](Instruction* inst) {
  25. if (inst->opcode() == spv::Op::OpVariable) {
  26. std::set<uint32_t> seen;
  27. std::vector<std::pair<Instruction*, uint32_t>> uses;
  28. get_def_use_mgr()->ForEachUse(inst,
  29. [&uses](Instruction* use, uint32_t op_idx) {
  30. uses.push_back({use, op_idx});
  31. });
  32. for (auto& use : uses) {
  33. modified |= PropagateStorageClass(
  34. use.first,
  35. static_cast<spv::StorageClass>(inst->GetSingleWordInOperand(0)),
  36. &seen);
  37. assert(seen.empty() && "Seen was not properly reset.");
  38. modified |=
  39. PropagateType(use.first, inst->type_id(), use.second, &seen);
  40. assert(seen.empty() && "Seen was not properly reset.");
  41. }
  42. }
  43. });
  44. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  45. }
  46. bool FixStorageClass::PropagateStorageClass(Instruction* inst,
  47. spv::StorageClass storage_class,
  48. std::set<uint32_t>* seen) {
  49. if (!IsPointerResultType(inst)) {
  50. return false;
  51. }
  52. if (IsPointerToStorageClass(inst, storage_class)) {
  53. if (inst->opcode() == spv::Op::OpPhi) {
  54. if (!seen->insert(inst->result_id()).second) {
  55. return false;
  56. }
  57. }
  58. bool modified = false;
  59. std::vector<Instruction*> uses;
  60. get_def_use_mgr()->ForEachUser(
  61. inst, [&uses](Instruction* use) { uses.push_back(use); });
  62. for (Instruction* use : uses) {
  63. modified |= PropagateStorageClass(use, storage_class, seen);
  64. }
  65. if (inst->opcode() == spv::Op::OpPhi) {
  66. seen->erase(inst->result_id());
  67. }
  68. return modified;
  69. }
  70. switch (inst->opcode()) {
  71. case spv::Op::OpAccessChain:
  72. case spv::Op::OpPtrAccessChain:
  73. case spv::Op::OpInBoundsAccessChain:
  74. case spv::Op::OpCopyObject:
  75. case spv::Op::OpPhi:
  76. case spv::Op::OpSelect:
  77. FixInstructionStorageClass(inst, storage_class, seen);
  78. return true;
  79. case spv::Op::OpFunctionCall:
  80. // We cannot be sure of the actual connection between the storage class
  81. // of the parameter and the storage class of the result, so we should not
  82. // do anything. If the result type needs to be fixed, the function call
  83. // should be inlined.
  84. return false;
  85. case spv::Op::OpImageTexelPointer:
  86. case spv::Op::OpLoad:
  87. case spv::Op::OpStore:
  88. case spv::Op::OpCopyMemory:
  89. case spv::Op::OpCopyMemorySized:
  90. case spv::Op::OpVariable:
  91. case spv::Op::OpBitcast:
  92. case spv::Op::OpAllocateNodePayloadsAMDX:
  93. // Nothing to change for these opcode. The result type is the same
  94. // regardless of the storage class of the operand.
  95. return false;
  96. default:
  97. assert(false &&
  98. "Not expecting instruction to have a pointer result type.");
  99. return false;
  100. }
  101. }
  102. void FixStorageClass::FixInstructionStorageClass(
  103. Instruction* inst, spv::StorageClass storage_class,
  104. std::set<uint32_t>* seen) {
  105. assert(IsPointerResultType(inst) &&
  106. "The result type of the instruction must be a pointer.");
  107. ChangeResultStorageClass(inst, storage_class);
  108. std::vector<Instruction*> uses;
  109. get_def_use_mgr()->ForEachUser(
  110. inst, [&uses](Instruction* use) { uses.push_back(use); });
  111. for (Instruction* use : uses) {
  112. PropagateStorageClass(use, storage_class, seen);
  113. }
  114. }
  115. void FixStorageClass::ChangeResultStorageClass(
  116. Instruction* inst, spv::StorageClass storage_class) const {
  117. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  118. Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
  119. assert(result_type_inst->opcode() == spv::Op::OpTypePointer);
  120. uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
  121. uint32_t new_result_type_id =
  122. type_mgr->FindPointerToType(pointee_type_id, storage_class);
  123. inst->SetResultType(new_result_type_id);
  124. context()->UpdateDefUse(inst);
  125. }
  126. bool FixStorageClass::IsPointerResultType(Instruction* inst) {
  127. if (inst->type_id() == 0) {
  128. return false;
  129. }
  130. Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id());
  131. return type_def->opcode() == spv::Op::OpTypePointer;
  132. }
  133. bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
  134. spv::StorageClass storage_class) {
  135. if (inst->type_id() == 0) {
  136. return false;
  137. }
  138. Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id());
  139. if (type_def->opcode() != spv::Op::OpTypePointer) {
  140. return false;
  141. }
  142. const uint32_t kPointerTypeStorageClassIndex = 0;
  143. spv::StorageClass pointer_storage_class = static_cast<spv::StorageClass>(
  144. type_def->GetSingleWordInOperand(kPointerTypeStorageClassIndex));
  145. return pointer_storage_class == storage_class;
  146. }
  147. bool FixStorageClass::ChangeResultType(Instruction* inst,
  148. uint32_t new_type_id) {
  149. if (inst->type_id() == new_type_id) {
  150. return false;
  151. }
  152. context()->ForgetUses(inst);
  153. inst->SetResultType(new_type_id);
  154. context()->AnalyzeUses(inst);
  155. return true;
  156. }
  157. bool FixStorageClass::PropagateType(Instruction* inst, uint32_t type_id,
  158. uint32_t op_idx, std::set<uint32_t>* seen) {
  159. assert(type_id != 0 && "Not given a valid type in PropagateType");
  160. bool modified = false;
  161. // If the type of operand |op_idx| forces the result type of |inst| to a
  162. // particular type, then we want find that type.
  163. uint32_t new_type_id = 0;
  164. switch (inst->opcode()) {
  165. case spv::Op::OpAccessChain:
  166. case spv::Op::OpPtrAccessChain:
  167. case spv::Op::OpInBoundsAccessChain:
  168. case spv::Op::OpInBoundsPtrAccessChain:
  169. if (op_idx == 2) {
  170. new_type_id = WalkAccessChainType(inst, type_id);
  171. }
  172. break;
  173. case spv::Op::OpCopyObject:
  174. new_type_id = type_id;
  175. break;
  176. case spv::Op::OpPhi:
  177. if (seen->insert(inst->result_id()).second) {
  178. new_type_id = type_id;
  179. }
  180. break;
  181. case spv::Op::OpSelect:
  182. if (op_idx > 2) {
  183. new_type_id = type_id;
  184. }
  185. break;
  186. case spv::Op::OpFunctionCall:
  187. // We cannot be sure of the actual connection between the type
  188. // of the parameter and the type of the result, so we should not
  189. // do anything. If the result type needs to be fixed, the function call
  190. // should be inlined.
  191. return false;
  192. case spv::Op::OpLoad: {
  193. Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
  194. new_type_id = type_inst->GetSingleWordInOperand(1);
  195. break;
  196. }
  197. case spv::Op::OpStore: {
  198. uint32_t obj_id = inst->GetSingleWordInOperand(1);
  199. Instruction* obj_inst = get_def_use_mgr()->GetDef(obj_id);
  200. uint32_t obj_type_id = obj_inst->type_id();
  201. uint32_t ptr_id = inst->GetSingleWordInOperand(0);
  202. Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id);
  203. uint32_t pointee_type_id = GetPointeeTypeId(ptr_inst);
  204. if (obj_type_id != pointee_type_id) {
  205. if (context()->get_type_mgr()->GetType(obj_type_id)->AsImage() &&
  206. context()->get_type_mgr()->GetType(pointee_type_id)->AsImage()) {
  207. // When storing an image, allow the type mismatch
  208. // and let the later legalization passes eliminate the OpStore.
  209. // This is to support assigning an image to a variable,
  210. // where the assigned image does not have a pre-defined
  211. // image format.
  212. return false;
  213. }
  214. uint32_t copy_id = GenerateCopy(obj_inst, pointee_type_id, inst);
  215. if (copy_id == 0) {
  216. return false;
  217. }
  218. inst->SetInOperand(1, {copy_id});
  219. context()->UpdateDefUse(inst);
  220. }
  221. } break;
  222. case spv::Op::OpCopyMemory:
  223. case spv::Op::OpCopyMemorySized:
  224. // TODO: May need to expand the copy as we do with the stores.
  225. break;
  226. case spv::Op::OpCompositeConstruct:
  227. case spv::Op::OpCompositeExtract:
  228. case spv::Op::OpCompositeInsert:
  229. // TODO: DXC does not seem to generate code that will require changes to
  230. // these opcode. The can be implemented when they come up.
  231. break;
  232. case spv::Op::OpImageTexelPointer:
  233. case spv::Op::OpBitcast:
  234. // Nothing to change for these opcode. The result type is the same
  235. // regardless of the type of the operand.
  236. return false;
  237. default:
  238. // I expect the remaining instructions to act on types that are guaranteed
  239. // to be unique, so no change will be necessary.
  240. break;
  241. }
  242. // If the operand forces the result type, then make sure the result type
  243. // matches, and update the uses of |inst|. We do not have to check the uses
  244. // of |inst| in the result type is not forced because we are only looking for
  245. // issue that come from mismatches between function formal and actual
  246. // parameters after the function has been inlined. These parameters are
  247. // pointers. Once the type no longer depends on the type of the parameter,
  248. // then the types should have be correct.
  249. if (new_type_id != 0) {
  250. modified = ChangeResultType(inst, new_type_id);
  251. std::vector<std::pair<Instruction*, uint32_t>> uses;
  252. get_def_use_mgr()->ForEachUse(inst,
  253. [&uses](Instruction* use, uint32_t idx) {
  254. uses.push_back({use, idx});
  255. });
  256. for (auto& use : uses) {
  257. PropagateType(use.first, new_type_id, use.second, seen);
  258. }
  259. if (inst->opcode() == spv::Op::OpPhi) {
  260. seen->erase(inst->result_id());
  261. }
  262. }
  263. return modified;
  264. }
  265. uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
  266. uint32_t start_idx = 0;
  267. switch (inst->opcode()) {
  268. case spv::Op::OpAccessChain:
  269. case spv::Op::OpInBoundsAccessChain:
  270. start_idx = 1;
  271. break;
  272. case spv::Op::OpPtrAccessChain:
  273. case spv::Op::OpInBoundsPtrAccessChain:
  274. start_idx = 2;
  275. break;
  276. default:
  277. assert(false);
  278. break;
  279. }
  280. Instruction* id_type_inst = get_def_use_mgr()->GetDef(id);
  281. assert(id_type_inst->opcode() == spv::Op::OpTypePointer);
  282. id = id_type_inst->GetSingleWordInOperand(1);
  283. spv::StorageClass input_storage_class =
  284. static_cast<spv::StorageClass>(id_type_inst->GetSingleWordInOperand(0));
  285. for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) {
  286. Instruction* type_inst = get_def_use_mgr()->GetDef(id);
  287. switch (type_inst->opcode()) {
  288. case spv::Op::OpTypeArray:
  289. case spv::Op::OpTypeRuntimeArray:
  290. case spv::Op::OpTypeNodePayloadArrayAMDX:
  291. case spv::Op::OpTypeMatrix:
  292. case spv::Op::OpTypeVector:
  293. case spv::Op::OpTypeCooperativeMatrixKHR:
  294. id = type_inst->GetSingleWordInOperand(0);
  295. break;
  296. case spv::Op::OpTypeStruct: {
  297. const analysis::Constant* index_const =
  298. context()->get_constant_mgr()->FindDeclaredConstant(
  299. inst->GetSingleWordInOperand(i));
  300. // It is highly unlikely that any type would have more fields than could
  301. // be indexed by a 32-bit integer, and GetSingleWordInOperand only takes
  302. // a 32-bit value, so we would not be able to handle it anyway. But the
  303. // specification does allow any scalar integer type, treated as signed,
  304. // so we simply downcast the index to 32-bits.
  305. uint32_t index =
  306. static_cast<uint32_t>(index_const->GetSignExtendedValue());
  307. id = type_inst->GetSingleWordInOperand(index);
  308. break;
  309. }
  310. default:
  311. break;
  312. }
  313. assert(id != 0 &&
  314. "Tried to extract from an object where it cannot be done.");
  315. }
  316. Instruction* orig_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
  317. spv::StorageClass orig_storage_class =
  318. static_cast<spv::StorageClass>(orig_type_inst->GetSingleWordInOperand(0));
  319. assert(orig_type_inst->opcode() == spv::Op::OpTypePointer);
  320. if (orig_type_inst->GetSingleWordInOperand(1) == id &&
  321. input_storage_class == orig_storage_class) {
  322. // The existing type is correct. Avoid the search for the type. Note that if
  323. // there is a duplicate type, the search below could return a different type
  324. // forcing more changes to the code than necessary.
  325. return inst->type_id();
  326. }
  327. return context()->get_type_mgr()->FindPointerToType(id, input_storage_class);
  328. }
  329. // namespace opt
  330. } // namespace opt
  331. } // namespace spvtools