name_mapper.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. // Copyright (c) 2016 Google Inc.
  2. // Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
  3. // reserved.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "source/name_mapper.h"
  17. #include <algorithm>
  18. #include <cassert>
  19. #include <iterator>
  20. #include <sstream>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <unordered_set>
  24. #include "source/binary.h"
  25. #include "source/latest_version_spirv_header.h"
  26. #include "source/parsed_operand.h"
  27. #include "source/table2.h"
  28. #include "source/to_string.h"
  29. #include "spirv-tools/libspirv.h"
  30. namespace spvtools {
  31. NameMapper GetTrivialNameMapper() {
  32. return [](uint32_t i) { return spvtools::to_string(i); };
  33. }
  34. FriendlyNameMapper::FriendlyNameMapper(const spv_const_context context,
  35. const uint32_t* code,
  36. const size_t wordCount)
  37. : grammar_(AssemblyGrammar(context)) {
  38. spv_diagnostic diag = nullptr;
  39. // We don't care if the parse fails.
  40. spvBinaryParse(context, this, code, wordCount, nullptr,
  41. ParseInstructionForwarder, &diag);
  42. spvDiagnosticDestroy(diag);
  43. }
  44. std::string FriendlyNameMapper::NameForId(uint32_t id) {
  45. auto iter = name_for_id_.find(id);
  46. if (iter == name_for_id_.end()) {
  47. // It must have been an invalid module, so just return a trivial mapping.
  48. // We don't care about uniqueness.
  49. return to_string(id);
  50. } else {
  51. return iter->second;
  52. }
  53. }
  54. std::string FriendlyNameMapper::Sanitize(const std::string& suggested_name) {
  55. if (suggested_name.empty()) return "_";
  56. // Otherwise, replace invalid characters by '_'.
  57. std::string result;
  58. std::string valid =
  59. "abcdefghijklmnopqrstuvwxyz"
  60. "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
  61. "_0123456789";
  62. std::transform(suggested_name.begin(), suggested_name.end(),
  63. std::back_inserter(result), [&valid](const char c) {
  64. return (std::string::npos == valid.find(c)) ? '_' : c;
  65. });
  66. return result;
  67. }
  68. void FriendlyNameMapper::SaveName(uint32_t id,
  69. const std::string& suggested_name) {
  70. if (name_for_id_.find(id) != name_for_id_.end()) return;
  71. const std::string sanitized_suggested_name = Sanitize(suggested_name);
  72. std::string name = sanitized_suggested_name;
  73. auto inserted = used_names_.insert(name);
  74. if (!inserted.second) {
  75. const std::string base_name = sanitized_suggested_name + "_";
  76. for (uint32_t index = 0; !inserted.second; ++index) {
  77. name = base_name + to_string(index);
  78. inserted = used_names_.insert(name);
  79. }
  80. }
  81. name_for_id_[id] = name;
  82. }
  83. void FriendlyNameMapper::SaveBuiltInName(uint32_t target_id,
  84. uint32_t built_in) {
  85. #define GLCASE(name) \
  86. case spv::BuiltIn::name: \
  87. SaveName(target_id, "gl_" #name); \
  88. return;
  89. #define GLCASE2(name, suggested) \
  90. case spv::BuiltIn::name: \
  91. SaveName(target_id, "gl_" #suggested); \
  92. return;
  93. #define CASE(name) \
  94. case spv::BuiltIn::name: \
  95. SaveName(target_id, #name); \
  96. return;
  97. switch (spv::BuiltIn(built_in)) {
  98. GLCASE(Position)
  99. GLCASE(PointSize)
  100. GLCASE(ClipDistance)
  101. GLCASE(CullDistance)
  102. GLCASE2(VertexId, VertexID)
  103. GLCASE2(InstanceId, InstanceID)
  104. GLCASE2(PrimitiveId, PrimitiveID)
  105. GLCASE2(InvocationId, InvocationID)
  106. GLCASE(Layer)
  107. GLCASE(ViewportIndex)
  108. GLCASE(TessLevelOuter)
  109. GLCASE(TessLevelInner)
  110. GLCASE(TessCoord)
  111. GLCASE(PatchVertices)
  112. GLCASE(FragCoord)
  113. GLCASE(PointCoord)
  114. GLCASE(FrontFacing)
  115. GLCASE2(SampleId, SampleID)
  116. GLCASE(SamplePosition)
  117. GLCASE(SampleMask)
  118. GLCASE(FragDepth)
  119. GLCASE(HelperInvocation)
  120. GLCASE2(NumWorkgroups, NumWorkGroups)
  121. GLCASE2(WorkgroupSize, WorkGroupSize)
  122. GLCASE2(WorkgroupId, WorkGroupID)
  123. GLCASE2(LocalInvocationId, LocalInvocationID)
  124. GLCASE2(GlobalInvocationId, GlobalInvocationID)
  125. GLCASE(LocalInvocationIndex)
  126. CASE(WorkDim)
  127. CASE(GlobalSize)
  128. CASE(EnqueuedWorkgroupSize)
  129. CASE(GlobalOffset)
  130. CASE(GlobalLinearId)
  131. CASE(SubgroupSize)
  132. CASE(SubgroupMaxSize)
  133. CASE(NumSubgroups)
  134. CASE(NumEnqueuedSubgroups)
  135. CASE(SubgroupId)
  136. CASE(SubgroupLocalInvocationId)
  137. GLCASE(VertexIndex)
  138. GLCASE(InstanceIndex)
  139. GLCASE(BaseInstance)
  140. CASE(SubgroupEqMaskKHR)
  141. CASE(SubgroupGeMaskKHR)
  142. CASE(SubgroupGtMaskKHR)
  143. CASE(SubgroupLeMaskKHR)
  144. CASE(SubgroupLtMaskKHR)
  145. default:
  146. break;
  147. }
  148. #undef GLCASE
  149. #undef GLCASE2
  150. #undef CASE
  151. }
  152. spv_result_t FriendlyNameMapper::ParseInstruction(
  153. const spv_parsed_instruction_t& inst) {
  154. const auto result_id = inst.result_id;
  155. switch (spv::Op(inst.opcode)) {
  156. case spv::Op::OpName:
  157. SaveName(inst.words[1], spvDecodeLiteralStringOperand(inst, 1));
  158. break;
  159. case spv::Op::OpDecorate:
  160. // Decorations come after OpName. So OpName will take precedence over
  161. // decorations.
  162. //
  163. // In theory, we should also handle OpGroupDecorate. But that's unlikely
  164. // to occur.
  165. if (spv::Decoration(inst.words[2]) == spv::Decoration::BuiltIn) {
  166. assert(inst.num_words > 3);
  167. SaveBuiltInName(inst.words[1], inst.words[3]);
  168. }
  169. break;
  170. case spv::Op::OpTypeVoid:
  171. SaveName(result_id, "void");
  172. break;
  173. case spv::Op::OpTypeBool:
  174. SaveName(result_id, "bool");
  175. break;
  176. case spv::Op::OpTypeInt: {
  177. std::string signedness;
  178. std::string root;
  179. const auto bit_width = inst.words[2];
  180. switch (bit_width) {
  181. case 8:
  182. root = "char";
  183. break;
  184. case 16:
  185. root = "short";
  186. break;
  187. case 32:
  188. root = "int";
  189. break;
  190. case 64:
  191. root = "long";
  192. break;
  193. default:
  194. root = to_string(bit_width);
  195. signedness = "i";
  196. break;
  197. }
  198. if (0 == inst.words[3]) signedness = "u";
  199. SaveName(result_id, signedness + root);
  200. } break;
  201. case spv::Op::OpTypeFloat: {
  202. const auto bit_width = inst.words[2];
  203. if (inst.num_words > 3) {
  204. if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::BFloat16KHR) {
  205. SaveName(result_id, "bfloat16");
  206. break;
  207. }
  208. if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::Float8E4M3EXT) {
  209. SaveName(result_id, "fp8e4m3");
  210. break;
  211. }
  212. if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::Float8E5M2EXT) {
  213. SaveName(result_id, "fp8e5m2");
  214. break;
  215. }
  216. }
  217. switch (bit_width) {
  218. case 16:
  219. SaveName(result_id, "half");
  220. break;
  221. case 32:
  222. SaveName(result_id, "float");
  223. break;
  224. case 64:
  225. SaveName(result_id, "double");
  226. break;
  227. default:
  228. SaveName(result_id, std::string("fp") + to_string(bit_width));
  229. break;
  230. }
  231. } break;
  232. case spv::Op::OpTypeVector:
  233. SaveName(result_id, std::string("v") + to_string(inst.words[3]) +
  234. NameForId(inst.words[2]));
  235. break;
  236. case spv::Op::OpTypeMatrix:
  237. SaveName(result_id, std::string("mat") + to_string(inst.words[3]) +
  238. NameForId(inst.words[2]));
  239. break;
  240. case spv::Op::OpTypeArray:
  241. SaveName(result_id, std::string("_arr_") + NameForId(inst.words[2]) +
  242. "_" + NameForId(inst.words[3]));
  243. break;
  244. case spv::Op::OpTypeRuntimeArray:
  245. SaveName(result_id,
  246. std::string("_runtimearr_") + NameForId(inst.words[2]));
  247. break;
  248. case spv::Op::OpTypeNodePayloadArrayAMDX:
  249. SaveName(result_id,
  250. std::string("_payloadarr_") + NameForId(inst.words[2]));
  251. break;
  252. case spv::Op::OpTypePointer:
  253. SaveName(result_id, std::string("_ptr_") +
  254. NameForEnumOperand(SPV_OPERAND_TYPE_STORAGE_CLASS,
  255. inst.words[2]) +
  256. "_" + NameForId(inst.words[3]));
  257. break;
  258. case spv::Op::OpTypeUntypedPointerKHR:
  259. SaveName(result_id, std::string("_ptr_") +
  260. NameForEnumOperand(SPV_OPERAND_TYPE_STORAGE_CLASS,
  261. inst.words[2]));
  262. break;
  263. case spv::Op::OpTypePipe:
  264. SaveName(result_id,
  265. std::string("Pipe") +
  266. NameForEnumOperand(SPV_OPERAND_TYPE_ACCESS_QUALIFIER,
  267. inst.words[2]));
  268. break;
  269. case spv::Op::OpTypeEvent:
  270. SaveName(result_id, "Event");
  271. break;
  272. case spv::Op::OpTypeDeviceEvent:
  273. SaveName(result_id, "DeviceEvent");
  274. break;
  275. case spv::Op::OpTypeReserveId:
  276. SaveName(result_id, "ReserveId");
  277. break;
  278. case spv::Op::OpTypeQueue:
  279. SaveName(result_id, "Queue");
  280. break;
  281. case spv::Op::OpTypeOpaque:
  282. SaveName(result_id, std::string("Opaque_") +
  283. Sanitize(spvDecodeLiteralStringOperand(inst, 1)));
  284. break;
  285. case spv::Op::OpTypePipeStorage:
  286. SaveName(result_id, "PipeStorage");
  287. break;
  288. case spv::Op::OpTypeNamedBarrier:
  289. SaveName(result_id, "NamedBarrier");
  290. break;
  291. case spv::Op::OpTypeStruct:
  292. // Structs are mapped rather simplisitically. Just indicate that they
  293. // are a struct and then give the raw Id number.
  294. SaveName(result_id, std::string("_struct_") + to_string(result_id));
  295. break;
  296. case spv::Op::OpConstantTrue:
  297. SaveName(result_id, "true");
  298. break;
  299. case spv::Op::OpConstantFalse:
  300. SaveName(result_id, "false");
  301. break;
  302. case spv::Op::OpConstant: {
  303. std::ostringstream value;
  304. EmitNumericLiteral(&value, inst, inst.operands[2]);
  305. auto value_str = value.str();
  306. // Use 'n' to signify negative. Other invalid characters will be mapped
  307. // to underscore.
  308. for (auto& c : value_str)
  309. if (c == '-') c = 'n';
  310. SaveName(result_id, NameForId(inst.type_id) + "_" + value_str);
  311. } break;
  312. default:
  313. // If this instruction otherwise defines an Id, then save a mapping for
  314. // it. This is needed to ensure uniqueness in there is an OpName with
  315. // string something like "1" that might collide with this result_id.
  316. // We should only do this if a name hasn't already been registered by some
  317. // previous forward reference.
  318. if (result_id && name_for_id_.find(result_id) == name_for_id_.end())
  319. SaveName(result_id, to_string(result_id));
  320. break;
  321. }
  322. return SPV_SUCCESS;
  323. }
  324. std::string FriendlyNameMapper::NameForEnumOperand(spv_operand_type_t type,
  325. uint32_t word) {
  326. const spvtools::OperandDesc* desc = nullptr;
  327. if (SPV_SUCCESS == spvtools::LookupOperand(type, word, &desc)) {
  328. return desc->name().data();
  329. } else {
  330. // Invalid input. Just give something.
  331. return std::string("StorageClass") + to_string(word);
  332. }
  333. }
  334. } // namespace spvtools