set_spec_constant_default_value_pass.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. // Copyright (c) 2016 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/set_spec_constant_default_value_pass.h"
  15. #include <algorithm>
  16. #include <cctype>
  17. #include <cstring>
  18. #include <tuple>
  19. #include <vector>
  20. #include "source/opt/def_use_manager.h"
  21. #include "source/opt/ir_context.h"
  22. #include "source/opt/type_manager.h"
  23. #include "source/opt/types.h"
  24. #include "source/util/make_unique.h"
  25. #include "source/util/parse_number.h"
  26. #include "spirv-tools/libspirv.h"
  27. namespace spvtools {
  28. namespace opt {
  29. namespace {
  30. using utils::EncodeNumberStatus;
  31. using utils::NumberType;
  32. using utils::ParseAndEncodeNumber;
  33. using utils::ParseNumber;
  34. // Given a numeric value in a null-terminated c string and the expected type of
  35. // the value, parses the string and encodes it in a vector of words. If the
  36. // value is a scalar integer or floating point value, encodes the value in
  37. // SPIR-V encoding format. If the value is 'false' or 'true', returns a vector
  38. // with single word with value 0 or 1 respectively. Returns the vector
  39. // containing the encoded value on success. Otherwise returns an empty vector.
  40. std::vector<uint32_t> ParseDefaultValueStr(const char* text,
  41. const analysis::Type* type) {
  42. std::vector<uint32_t> result;
  43. if (!strcmp(text, "true") && type->AsBool()) {
  44. result.push_back(1u);
  45. } else if (!strcmp(text, "false") && type->AsBool()) {
  46. result.push_back(0u);
  47. } else {
  48. NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT};
  49. if (const auto* IT = type->AsInteger()) {
  50. number_type.bitwidth = IT->width();
  51. number_type.kind =
  52. IT->IsSigned() ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
  53. } else if (const auto* FT = type->AsFloat()) {
  54. number_type.bitwidth = FT->width();
  55. number_type.kind = SPV_NUMBER_FLOATING;
  56. } else {
  57. // Does not handle types other then boolean, integer or float. Returns
  58. // empty vector.
  59. result.clear();
  60. return result;
  61. }
  62. EncodeNumberStatus rc = ParseAndEncodeNumber(
  63. text, number_type, [&result](uint32_t word) { result.push_back(word); },
  64. nullptr);
  65. // Clear the result vector on failure.
  66. if (rc != EncodeNumberStatus::kSuccess) {
  67. result.clear();
  68. }
  69. }
  70. return result;
  71. }
  72. // Given a bit pattern and a type, checks if the bit pattern is compatible
  73. // with the type. If so, returns the bit pattern, otherwise returns an empty
  74. // bit pattern. If the given bit pattern is empty, returns an empty bit
  75. // pattern. If the given type represents a SPIR-V Boolean type, the bit pattern
  76. // to be returned is determined with the following standard:
  77. // If any words in the input bit pattern are non zero, returns a bit pattern
  78. // with 0x1, which represents a 'true'.
  79. // If all words in the bit pattern are zero, returns a bit pattern with 0x0,
  80. // which represents a 'false'.
  81. // For integer and floating point types narrower than 32 bits, the upper bits
  82. // in the input bit pattern are ignored. Instead the upper bits are set
  83. // according to SPIR-V literal requirements: sign extend a signed integer, and
  84. // otherwise set the upper bits to zero.
  85. std::vector<uint32_t> ParseDefaultValueBitPattern(
  86. const std::vector<uint32_t>& input_bit_pattern,
  87. const analysis::Type* type) {
  88. std::vector<uint32_t> result;
  89. if (type->AsBool()) {
  90. if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(),
  91. [](uint32_t i) { return i != 0; })) {
  92. result.push_back(1u);
  93. } else {
  94. result.push_back(0u);
  95. }
  96. return result;
  97. } else if (const auto* IT = type->AsInteger()) {
  98. const auto width = IT->width();
  99. assert(width > 0);
  100. const auto adjusted_width = std::max(32u, width);
  101. if (adjusted_width == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
  102. result = std::vector<uint32_t>(input_bit_pattern);
  103. if (width < 32) {
  104. const uint32_t high_active_bit = (1u << width) >> 1;
  105. if (IT->IsSigned() && (high_active_bit & result[0])) {
  106. // Sign extend. This overwrites the sign bit again, but that's ok.
  107. result[0] = result[0] | ~(high_active_bit - 1);
  108. } else {
  109. // Upper bits must be zero.
  110. result[0] = result[0] & ((1u << width) - 1);
  111. }
  112. }
  113. return result;
  114. }
  115. } else if (const auto* FT = type->AsFloat()) {
  116. const auto width = FT->width();
  117. const auto adjusted_width = std::max(32u, width);
  118. if (adjusted_width == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
  119. result = std::vector<uint32_t>(input_bit_pattern);
  120. if (width < 32) {
  121. // Upper bits must be zero.
  122. result[0] = result[0] & ((1u << width) - 1);
  123. }
  124. return result;
  125. }
  126. }
  127. result.clear();
  128. return result;
  129. }
  130. // Returns true if the given instruction's result id could have a SpecId
  131. // decoration.
  132. bool CanHaveSpecIdDecoration(const Instruction& inst) {
  133. switch (inst.opcode()) {
  134. case SpvOp::SpvOpSpecConstant:
  135. case SpvOp::SpvOpSpecConstantFalse:
  136. case SpvOp::SpvOpSpecConstantTrue:
  137. return true;
  138. default:
  139. return false;
  140. }
  141. }
  142. // Given a decoration group defining instruction that is decorated with SpecId
  143. // decoration, finds the spec constant defining instruction which is the real
  144. // target of the SpecId decoration. Returns the spec constant defining
  145. // instruction if such an instruction is found, otherwise returns a nullptr.
  146. Instruction* GetSpecIdTargetFromDecorationGroup(
  147. const Instruction& decoration_group_defining_inst,
  148. analysis::DefUseManager* def_use_mgr) {
  149. // Find the OpGroupDecorate instruction which consumes the given decoration
  150. // group. Note that the given decoration group has SpecId decoration, which
  151. // is unique for different spec constants. So the decoration group cannot be
  152. // consumed by different OpGroupDecorate instructions. Therefore we only need
  153. // the first OpGroupDecoration instruction that uses the given decoration
  154. // group.
  155. Instruction* group_decorate_inst = nullptr;
  156. if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst,
  157. [&group_decorate_inst](Instruction* user) {
  158. if (user->opcode() ==
  159. SpvOp::SpvOpGroupDecorate) {
  160. group_decorate_inst = user;
  161. return false;
  162. }
  163. return true;
  164. }))
  165. return nullptr;
  166. // Scan through the target ids of the OpGroupDecorate instruction. There
  167. // should be only one spec constant target consumes the SpecId decoration.
  168. // If multiple target ids are presented in the OpGroupDecorate instruction,
  169. // they must be the same one that defined by an eligible spec constant
  170. // instruction. If the OpGroupDecorate instruction has different target ids
  171. // or a target id is not defined by an eligible spec cosntant instruction,
  172. // returns a nullptr.
  173. Instruction* target_inst = nullptr;
  174. for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) {
  175. // All the operands of a OpGroupDecorate instruction should be of type
  176. // SPV_OPERAND_TYPE_ID.
  177. uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i);
  178. Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id);
  179. if (!candidate_inst) {
  180. continue;
  181. }
  182. if (!target_inst) {
  183. // If the spec constant target has not been found yet, check if the
  184. // candidate instruction is the target.
  185. if (CanHaveSpecIdDecoration(*candidate_inst)) {
  186. target_inst = candidate_inst;
  187. } else {
  188. // Spec id decoration should not be applied on other instructions.
  189. // TODO(qining): Emit an error message in the invalid case once the
  190. // error handling is done.
  191. return nullptr;
  192. }
  193. } else {
  194. // If the spec constant target has been found, check if the candidate
  195. // instruction is the same one as the target. The module is invalid if
  196. // the candidate instruction is different with the found target.
  197. // TODO(qining): Emit an error messaage in the invalid case once the
  198. // error handling is done.
  199. if (candidate_inst != target_inst) return nullptr;
  200. }
  201. }
  202. return target_inst;
  203. }
  204. } // namespace
  205. Pass::Status SetSpecConstantDefaultValuePass::Process() {
  206. // The operand index of decoration target in an OpDecorate instruction.
  207. const uint32_t kTargetIdOperandIndex = 0;
  208. // The operand index of the decoration literal in an OpDecorate instruction.
  209. const uint32_t kDecorationOperandIndex = 1;
  210. // The operand index of Spec id literal value in an OpDecorate SpecId
  211. // instruction.
  212. const uint32_t kSpecIdLiteralOperandIndex = 2;
  213. // The number of operands in an OpDecorate SpecId instruction.
  214. const uint32_t kOpDecorateSpecIdNumOperands = 3;
  215. // The in-operand index of the default value in a OpSpecConstant instruction.
  216. const uint32_t kOpSpecConstantLiteralInOperandIndex = 0;
  217. bool modified = false;
  218. // Scan through all the annotation instructions to find 'OpDecorate SpecId'
  219. // instructions. Then extract the decoration target of those instructions.
  220. // The decoration targets should be spec constant defining instructions with
  221. // opcode: OpSpecConstant{|True|False}. The spec id of those spec constants
  222. // will be used to look up their new default values in the mapping from
  223. // spec id to new default value strings. Once a new default value string
  224. // is found for a spec id, the string will be parsed according to the target
  225. // spec constant type. The parsed value will be used to replace the original
  226. // default value of the target spec constant.
  227. for (Instruction& inst : context()->annotations()) {
  228. // Only process 'OpDecorate SpecId' instructions
  229. if (inst.opcode() != SpvOp::SpvOpDecorate) continue;
  230. if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue;
  231. if (inst.GetSingleWordInOperand(kDecorationOperandIndex) !=
  232. uint32_t(SpvDecoration::SpvDecorationSpecId)) {
  233. continue;
  234. }
  235. // 'inst' is an OpDecorate SpecId instruction.
  236. uint32_t spec_id = inst.GetSingleWordOperand(kSpecIdLiteralOperandIndex);
  237. uint32_t target_id = inst.GetSingleWordOperand(kTargetIdOperandIndex);
  238. // Find the spec constant defining instruction. Note that the
  239. // target_id might be a decoration group id.
  240. Instruction* spec_inst = nullptr;
  241. if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) {
  242. if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) {
  243. spec_inst =
  244. GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr());
  245. } else {
  246. spec_inst = target_inst;
  247. }
  248. } else {
  249. continue;
  250. }
  251. if (!spec_inst) continue;
  252. // Get the default value bit pattern for this spec id.
  253. std::vector<uint32_t> bit_pattern;
  254. if (spec_id_to_value_str_.size() != 0) {
  255. // Search for the new string-form default value for this spec id.
  256. auto iter = spec_id_to_value_str_.find(spec_id);
  257. if (iter == spec_id_to_value_str_.end()) {
  258. continue;
  259. }
  260. // Gets the string of the default value and parses it to bit pattern
  261. // with the type of the spec constant.
  262. const std::string& default_value_str = iter->second;
  263. bit_pattern = ParseDefaultValueStr(
  264. default_value_str.c_str(),
  265. context()->get_type_mgr()->GetType(spec_inst->type_id()));
  266. } else {
  267. // Search for the new bit-pattern-form default value for this spec id.
  268. auto iter = spec_id_to_value_bit_pattern_.find(spec_id);
  269. if (iter == spec_id_to_value_bit_pattern_.end()) {
  270. continue;
  271. }
  272. // Gets the bit-pattern of the default value from the map directly.
  273. bit_pattern = ParseDefaultValueBitPattern(
  274. iter->second,
  275. context()->get_type_mgr()->GetType(spec_inst->type_id()));
  276. }
  277. if (bit_pattern.empty()) continue;
  278. // Update the operand bit patterns of the spec constant defining
  279. // instruction.
  280. switch (spec_inst->opcode()) {
  281. case SpvOp::SpvOpSpecConstant:
  282. // If the new value is the same with the original value, no
  283. // need to do anything. Otherwise update the operand words.
  284. if (spec_inst->GetInOperand(kOpSpecConstantLiteralInOperandIndex)
  285. .words != bit_pattern) {
  286. spec_inst->SetInOperand(kOpSpecConstantLiteralInOperandIndex,
  287. std::move(bit_pattern));
  288. modified = true;
  289. }
  290. break;
  291. case SpvOp::SpvOpSpecConstantTrue:
  292. // If the new value is also 'true', no need to change anything.
  293. // Otherwise, set the opcode to OpSpecConstantFalse;
  294. if (!static_cast<bool>(bit_pattern.front())) {
  295. spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantFalse);
  296. modified = true;
  297. }
  298. break;
  299. case SpvOp::SpvOpSpecConstantFalse:
  300. // If the new value is also 'false', no need to change anything.
  301. // Otherwise, set the opcode to OpSpecConstantTrue;
  302. if (static_cast<bool>(bit_pattern.front())) {
  303. spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantTrue);
  304. modified = true;
  305. }
  306. break;
  307. default:
  308. break;
  309. }
  310. // No need to update the DefUse manager, as this pass does not change any
  311. // ids.
  312. }
  313. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  314. }
  315. // Returns true if the given char is ':', '\0' or considered as blank space
  316. // (i.e.: '\n', '\r', '\v', '\t', '\f' and ' ').
  317. bool IsSeparator(char ch) {
  318. return std::strchr(":\0", ch) || std::isspace(ch) != 0;
  319. }
  320. std::unique_ptr<SetSpecConstantDefaultValuePass::SpecIdToValueStrMap>
  321. SetSpecConstantDefaultValuePass::ParseDefaultValuesString(const char* str) {
  322. if (!str) return nullptr;
  323. auto spec_id_to_value = MakeUnique<SpecIdToValueStrMap>();
  324. // The parsing loop, break when points to the end.
  325. while (*str) {
  326. // Find the spec id.
  327. while (std::isspace(*str)) str++; // skip leading spaces.
  328. const char* entry_begin = str;
  329. while (!IsSeparator(*str)) str++;
  330. const char* entry_end = str;
  331. std::string spec_id_str(entry_begin, entry_end - entry_begin);
  332. uint32_t spec_id = 0;
  333. if (!ParseNumber(spec_id_str.c_str(), &spec_id)) {
  334. // The spec id is not a valid uint32 number.
  335. return nullptr;
  336. }
  337. auto iter = spec_id_to_value->find(spec_id);
  338. if (iter != spec_id_to_value->end()) {
  339. // Same spec id has been defined before
  340. return nullptr;
  341. }
  342. // Find the ':', spaces between the spec id and the ':' are not allowed.
  343. if (*str++ != ':') {
  344. // ':' not found
  345. return nullptr;
  346. }
  347. // Find the value string
  348. const char* val_begin = str;
  349. while (!IsSeparator(*str)) str++;
  350. const char* val_end = str;
  351. if (val_end == val_begin) {
  352. // Value string is empty.
  353. return nullptr;
  354. }
  355. // Update the mapping with spec id and value string.
  356. (*spec_id_to_value)[spec_id] = std::string(val_begin, val_end - val_begin);
  357. // Skip trailing spaces.
  358. while (std::isspace(*str)) str++;
  359. }
  360. return spec_id_to_value;
  361. }
  362. } // namespace opt
  363. } // namespace spvtools