fix_storage_class.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "fix_storage_class.h"
  15. #include <set>
  16. #include "source/opt/instruction.h"
  17. #include "source/opt/ir_context.h"
  18. namespace spvtools {
  19. namespace opt {
  20. Pass::Status FixStorageClass::Process() {
  21. bool modified = false;
  22. get_module()->ForEachInst([this, &modified](Instruction* inst) {
  23. if (inst->opcode() == SpvOpVariable) {
  24. std::set<uint32_t> seen;
  25. std::vector<std::pair<Instruction*, uint32_t>> uses;
  26. get_def_use_mgr()->ForEachUse(inst,
  27. [&uses](Instruction* use, uint32_t op_idx) {
  28. uses.push_back({use, op_idx});
  29. });
  30. for (auto& use : uses) {
  31. modified |= PropagateStorageClass(
  32. use.first,
  33. static_cast<SpvStorageClass>(inst->GetSingleWordInOperand(0)),
  34. &seen);
  35. assert(seen.empty() && "Seen was not properly reset.");
  36. modified |=
  37. PropagateType(use.first, inst->type_id(), use.second, &seen);
  38. assert(seen.empty() && "Seen was not properly reset.");
  39. }
  40. }
  41. });
  42. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  43. }
  44. bool FixStorageClass::PropagateStorageClass(Instruction* inst,
  45. SpvStorageClass storage_class,
  46. std::set<uint32_t>* seen) {
  47. if (!IsPointerResultType(inst)) {
  48. return false;
  49. }
  50. if (IsPointerToStorageClass(inst, storage_class)) {
  51. if (inst->opcode() == SpvOpPhi) {
  52. if (!seen->insert(inst->result_id()).second) {
  53. return false;
  54. }
  55. }
  56. bool modified = false;
  57. std::vector<Instruction*> uses;
  58. get_def_use_mgr()->ForEachUser(
  59. inst, [&uses](Instruction* use) { uses.push_back(use); });
  60. for (Instruction* use : uses) {
  61. modified |= PropagateStorageClass(use, storage_class, seen);
  62. }
  63. if (inst->opcode() == SpvOpPhi) {
  64. seen->erase(inst->result_id());
  65. }
  66. return modified;
  67. }
  68. switch (inst->opcode()) {
  69. case SpvOpAccessChain:
  70. case SpvOpPtrAccessChain:
  71. case SpvOpInBoundsAccessChain:
  72. case SpvOpCopyObject:
  73. case SpvOpPhi:
  74. case SpvOpSelect:
  75. FixInstructionStorageClass(inst, storage_class, seen);
  76. return true;
  77. case SpvOpFunctionCall:
  78. // We cannot be sure of the actual connection between the storage class
  79. // of the parameter and the storage class of the result, so we should not
  80. // do anything. If the result type needs to be fixed, the function call
  81. // should be inlined.
  82. return false;
  83. case SpvOpImageTexelPointer:
  84. case SpvOpLoad:
  85. case SpvOpStore:
  86. case SpvOpCopyMemory:
  87. case SpvOpCopyMemorySized:
  88. case SpvOpVariable:
  89. case SpvOpBitcast:
  90. // Nothing to change for these opcode. The result type is the same
  91. // regardless of the storage class of the operand.
  92. return false;
  93. default:
  94. assert(false &&
  95. "Not expecting instruction to have a pointer result type.");
  96. return false;
  97. }
  98. }
  99. void FixStorageClass::FixInstructionStorageClass(Instruction* inst,
  100. SpvStorageClass storage_class,
  101. std::set<uint32_t>* seen) {
  102. assert(IsPointerResultType(inst) &&
  103. "The result type of the instruction must be a pointer.");
  104. ChangeResultStorageClass(inst, storage_class);
  105. std::vector<Instruction*> uses;
  106. get_def_use_mgr()->ForEachUser(
  107. inst, [&uses](Instruction* use) { uses.push_back(use); });
  108. for (Instruction* use : uses) {
  109. PropagateStorageClass(use, storage_class, seen);
  110. }
  111. }
  112. void FixStorageClass::ChangeResultStorageClass(
  113. Instruction* inst, SpvStorageClass storage_class) const {
  114. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  115. Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
  116. assert(result_type_inst->opcode() == SpvOpTypePointer);
  117. uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
  118. uint32_t new_result_type_id =
  119. type_mgr->FindPointerToType(pointee_type_id, storage_class);
  120. inst->SetResultType(new_result_type_id);
  121. context()->UpdateDefUse(inst);
  122. }
  123. bool FixStorageClass::IsPointerResultType(Instruction* inst) {
  124. if (inst->type_id() == 0) {
  125. return false;
  126. }
  127. const analysis::Type* ret_type =
  128. context()->get_type_mgr()->GetType(inst->type_id());
  129. return ret_type->AsPointer() != nullptr;
  130. }
  131. bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
  132. SpvStorageClass storage_class) {
  133. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  134. analysis::Type* pType = type_mgr->GetType(inst->type_id());
  135. const analysis::Pointer* result_type = pType->AsPointer();
  136. if (result_type == nullptr) {
  137. return false;
  138. }
  139. return (result_type->storage_class() == storage_class);
  140. }
  141. bool FixStorageClass::ChangeResultType(Instruction* inst,
  142. uint32_t new_type_id) {
  143. if (inst->type_id() == new_type_id) {
  144. return false;
  145. }
  146. context()->ForgetUses(inst);
  147. inst->SetResultType(new_type_id);
  148. context()->AnalyzeUses(inst);
  149. return true;
  150. }
  151. bool FixStorageClass::PropagateType(Instruction* inst, uint32_t type_id,
  152. uint32_t op_idx, std::set<uint32_t>* seen) {
  153. assert(type_id != 0 && "Not given a valid type in PropagateType");
  154. bool modified = false;
  155. // If the type of operand |op_idx| forces the result type of |inst| to a
  156. // particular type, then we want find that type.
  157. uint32_t new_type_id = 0;
  158. switch (inst->opcode()) {
  159. case SpvOpAccessChain:
  160. case SpvOpPtrAccessChain:
  161. case SpvOpInBoundsAccessChain:
  162. case SpvOpInBoundsPtrAccessChain:
  163. if (op_idx == 2) {
  164. new_type_id = WalkAccessChainType(inst, type_id);
  165. }
  166. break;
  167. case SpvOpCopyObject:
  168. new_type_id = type_id;
  169. break;
  170. case SpvOpPhi:
  171. if (seen->insert(inst->result_id()).second) {
  172. new_type_id = type_id;
  173. }
  174. break;
  175. case SpvOpSelect:
  176. if (op_idx > 2) {
  177. new_type_id = type_id;
  178. }
  179. break;
  180. case SpvOpFunctionCall:
  181. // We cannot be sure of the actual connection between the type
  182. // of the parameter and the type of the result, so we should not
  183. // do anything. If the result type needs to be fixed, the function call
  184. // should be inlined.
  185. return false;
  186. case SpvOpLoad: {
  187. Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
  188. new_type_id = type_inst->GetSingleWordInOperand(1);
  189. break;
  190. }
  191. case SpvOpStore: {
  192. uint32_t obj_id = inst->GetSingleWordInOperand(1);
  193. Instruction* obj_inst = get_def_use_mgr()->GetDef(obj_id);
  194. uint32_t obj_type_id = obj_inst->type_id();
  195. uint32_t ptr_id = inst->GetSingleWordInOperand(0);
  196. Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id);
  197. uint32_t pointee_type_id = GetPointeeTypeId(ptr_inst);
  198. if (obj_type_id != pointee_type_id) {
  199. if (context()->get_type_mgr()->GetType(obj_type_id)->AsImage() &&
  200. context()->get_type_mgr()->GetType(pointee_type_id)->AsImage()) {
  201. // When storing an image, allow the type mismatch
  202. // and let the later legalization passes eliminate the OpStore.
  203. // This is to support assigning an image to a variable,
  204. // where the assigned image does not have a pre-defined
  205. // image format.
  206. return false;
  207. }
  208. uint32_t copy_id = GenerateCopy(obj_inst, pointee_type_id, inst);
  209. inst->SetInOperand(1, {copy_id});
  210. context()->UpdateDefUse(inst);
  211. }
  212. } break;
  213. case SpvOpCopyMemory:
  214. case SpvOpCopyMemorySized:
  215. // TODO: May need to expand the copy as we do with the stores.
  216. break;
  217. case SpvOpCompositeConstruct:
  218. case SpvOpCompositeExtract:
  219. case SpvOpCompositeInsert:
  220. // TODO: DXC does not seem to generate code that will require changes to
  221. // these opcode. The can be implemented when they come up.
  222. break;
  223. case SpvOpImageTexelPointer:
  224. case SpvOpBitcast:
  225. // Nothing to change for these opcode. The result type is the same
  226. // regardless of the type of the operand.
  227. return false;
  228. default:
  229. // I expect the remaining instructions to act on types that are guaranteed
  230. // to be unique, so no change will be necessary.
  231. break;
  232. }
  233. // If the operand forces the result type, then make sure the result type
  234. // matches, and update the uses of |inst|. We do not have to check the uses
  235. // of |inst| in the result type is not forced because we are only looking for
  236. // issue that come from mismatches between function formal and actual
  237. // parameters after the function has been inlined. These parameters are
  238. // pointers. Once the type no longer depends on the type of the parameter,
  239. // then the types should have be correct.
  240. if (new_type_id != 0) {
  241. modified = ChangeResultType(inst, new_type_id);
  242. std::vector<std::pair<Instruction*, uint32_t>> uses;
  243. get_def_use_mgr()->ForEachUse(inst,
  244. [&uses](Instruction* use, uint32_t idx) {
  245. uses.push_back({use, idx});
  246. });
  247. for (auto& use : uses) {
  248. PropagateType(use.first, new_type_id, use.second, seen);
  249. }
  250. if (inst->opcode() == SpvOpPhi) {
  251. seen->erase(inst->result_id());
  252. }
  253. }
  254. return modified;
  255. }
  256. uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
  257. uint32_t start_idx = 0;
  258. switch (inst->opcode()) {
  259. case SpvOpAccessChain:
  260. case SpvOpInBoundsAccessChain:
  261. start_idx = 1;
  262. break;
  263. case SpvOpPtrAccessChain:
  264. case SpvOpInBoundsPtrAccessChain:
  265. start_idx = 2;
  266. break;
  267. default:
  268. assert(false);
  269. break;
  270. }
  271. Instruction* orig_type_inst = get_def_use_mgr()->GetDef(id);
  272. assert(orig_type_inst->opcode() == SpvOpTypePointer);
  273. id = orig_type_inst->GetSingleWordInOperand(1);
  274. for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) {
  275. Instruction* type_inst = get_def_use_mgr()->GetDef(id);
  276. switch (type_inst->opcode()) {
  277. case SpvOpTypeArray:
  278. case SpvOpTypeRuntimeArray:
  279. case SpvOpTypeMatrix:
  280. case SpvOpTypeVector:
  281. id = type_inst->GetSingleWordInOperand(0);
  282. break;
  283. case SpvOpTypeStruct: {
  284. const analysis::Constant* index_const =
  285. context()->get_constant_mgr()->FindDeclaredConstant(
  286. inst->GetSingleWordInOperand(i));
  287. uint32_t index = index_const->GetU32();
  288. id = type_inst->GetSingleWordInOperand(index);
  289. break;
  290. }
  291. default:
  292. break;
  293. }
  294. assert(id != 0 &&
  295. "Tried to extract from an object where it cannot be done.");
  296. }
  297. return context()->get_type_mgr()->FindPointerToType(
  298. id,
  299. static_cast<SpvStorageClass>(orig_type_inst->GetSingleWordInOperand(0)));
  300. }
  301. // namespace opt
  302. } // namespace opt
  303. } // namespace spvtools