transformation_equation_instruction.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. // Copyright (c) 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_equation_instruction.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_descriptor.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationEquationInstruction::TransformationEquationInstruction(
  20. const spvtools::fuzz::protobufs::TransformationEquationInstruction& message)
  21. : message_(message) {}
  22. TransformationEquationInstruction::TransformationEquationInstruction(
  23. uint32_t fresh_id, SpvOp opcode, const std::vector<uint32_t>& in_operand_id,
  24. const protobufs::InstructionDescriptor& instruction_to_insert_before) {
  25. message_.set_fresh_id(fresh_id);
  26. message_.set_opcode(opcode);
  27. for (auto id : in_operand_id) {
  28. message_.add_in_operand_id(id);
  29. }
  30. *message_.mutable_instruction_to_insert_before() =
  31. instruction_to_insert_before;
  32. }
  33. bool TransformationEquationInstruction::IsApplicable(
  34. opt::IRContext* ir_context,
  35. const TransformationContext& transformation_context) const {
  36. // The result id must be fresh.
  37. if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) {
  38. return false;
  39. }
  40. // The instruction to insert before must exist.
  41. auto insert_before =
  42. FindInstruction(message_.instruction_to_insert_before(), ir_context);
  43. if (!insert_before) {
  44. return false;
  45. }
  46. // The input ids must all exist, not be OpUndef, not be irrelevant, and be
  47. // available before this instruction.
  48. for (auto id : message_.in_operand_id()) {
  49. auto inst = ir_context->get_def_use_mgr()->GetDef(id);
  50. if (!inst) {
  51. return false;
  52. }
  53. if (inst->opcode() == SpvOpUndef) {
  54. return false;
  55. }
  56. if (transformation_context.GetFactManager()->IdIsIrrelevant(id)) {
  57. return false;
  58. }
  59. if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before,
  60. id)) {
  61. return false;
  62. }
  63. }
  64. return MaybeGetResultTypeId(ir_context) != 0;
  65. }
  66. void TransformationEquationInstruction::Apply(
  67. opt::IRContext* ir_context,
  68. TransformationContext* transformation_context) const {
  69. fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id());
  70. opt::Instruction::OperandList in_operands;
  71. std::vector<uint32_t> rhs_id;
  72. for (auto id : message_.in_operand_id()) {
  73. in_operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
  74. rhs_id.push_back(id);
  75. }
  76. FindInstruction(message_.instruction_to_insert_before(), ir_context)
  77. ->InsertBefore(MakeUnique<opt::Instruction>(
  78. ir_context, static_cast<SpvOp>(message_.opcode()),
  79. MaybeGetResultTypeId(ir_context), message_.fresh_id(),
  80. std::move(in_operands)));
  81. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  82. // Add an equation fact as long as the result id is not irrelevant (it could
  83. // be if we are inserting into a dead block).
  84. if (!transformation_context->GetFactManager()->IdIsIrrelevant(
  85. message_.fresh_id())) {
  86. transformation_context->GetFactManager()->AddFactIdEquation(
  87. message_.fresh_id(), static_cast<SpvOp>(message_.opcode()), rhs_id);
  88. }
  89. }
  90. protobufs::Transformation TransformationEquationInstruction::ToMessage() const {
  91. protobufs::Transformation result;
  92. *result.mutable_equation_instruction() = message_;
  93. return result;
  94. }
  95. uint32_t TransformationEquationInstruction::MaybeGetResultTypeId(
  96. opt::IRContext* ir_context) const {
  97. auto opcode = static_cast<SpvOp>(message_.opcode());
  98. switch (opcode) {
  99. case SpvOpConvertUToF:
  100. case SpvOpConvertSToF: {
  101. if (message_.in_operand_id_size() != 1) {
  102. return 0;
  103. }
  104. const auto* type = ir_context->get_type_mgr()->GetType(
  105. fuzzerutil::GetTypeId(ir_context, message_.in_operand_id(0)));
  106. if (!type) {
  107. return 0;
  108. }
  109. if (const auto* vector = type->AsVector()) {
  110. if (!vector->element_type()->AsInteger()) {
  111. return 0;
  112. }
  113. if (auto element_type_id = fuzzerutil::MaybeGetFloatType(
  114. ir_context, vector->element_type()->AsInteger()->width())) {
  115. return fuzzerutil::MaybeGetVectorType(ir_context, element_type_id,
  116. vector->element_count());
  117. }
  118. return 0;
  119. } else {
  120. if (!type->AsInteger()) {
  121. return 0;
  122. }
  123. return fuzzerutil::MaybeGetFloatType(ir_context,
  124. type->AsInteger()->width());
  125. }
  126. }
  127. case SpvOpBitcast: {
  128. if (message_.in_operand_id_size() != 1) {
  129. return 0;
  130. }
  131. const auto* operand_inst =
  132. ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
  133. if (!operand_inst) {
  134. return 0;
  135. }
  136. const auto* operand_type =
  137. ir_context->get_type_mgr()->GetType(operand_inst->type_id());
  138. if (!operand_type) {
  139. return 0;
  140. }
  141. // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3539):
  142. // The only constraint on the types of OpBitcast's parameters is that
  143. // they must have the same number of bits. Consider improving the code
  144. // below to support this in full.
  145. if (const auto* vector = operand_type->AsVector()) {
  146. uint32_t component_type_id;
  147. if (const auto* int_type = vector->element_type()->AsInteger()) {
  148. component_type_id =
  149. fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
  150. } else if (const auto* float_type = vector->element_type()->AsFloat()) {
  151. component_type_id = fuzzerutil::MaybeGetIntegerType(
  152. ir_context, float_type->width(), true);
  153. if (component_type_id == 0 ||
  154. fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
  155. vector->element_count()) == 0) {
  156. component_type_id = fuzzerutil::MaybeGetIntegerType(
  157. ir_context, float_type->width(), false);
  158. }
  159. } else {
  160. assert(false && "Only vectors of numerical components are supported");
  161. return 0;
  162. }
  163. if (component_type_id == 0) {
  164. return 0;
  165. }
  166. return fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
  167. vector->element_count());
  168. } else if (const auto* int_type = operand_type->AsInteger()) {
  169. return fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
  170. } else if (const auto* float_type = operand_type->AsFloat()) {
  171. if (auto existing_id = fuzzerutil::MaybeGetIntegerType(
  172. ir_context, float_type->width(), true)) {
  173. return existing_id;
  174. }
  175. return fuzzerutil::MaybeGetIntegerType(ir_context, float_type->width(),
  176. false);
  177. } else {
  178. assert(false &&
  179. "Operand is not a scalar or a vector of numerical type");
  180. return 0;
  181. }
  182. }
  183. case SpvOpIAdd:
  184. case SpvOpISub: {
  185. if (message_.in_operand_id_size() != 2) {
  186. return 0;
  187. }
  188. uint32_t first_operand_width = 0;
  189. uint32_t first_operand_type_id = 0;
  190. for (uint32_t index = 0; index < 2; index++) {
  191. auto operand_inst = ir_context->get_def_use_mgr()->GetDef(
  192. message_.in_operand_id(index));
  193. if (!operand_inst || !operand_inst->type_id()) {
  194. return 0;
  195. }
  196. auto operand_type =
  197. ir_context->get_type_mgr()->GetType(operand_inst->type_id());
  198. if (!(operand_type->AsInteger() ||
  199. (operand_type->AsVector() &&
  200. operand_type->AsVector()->element_type()->AsInteger()))) {
  201. return 0;
  202. }
  203. uint32_t operand_width =
  204. operand_type->AsInteger()
  205. ? 1
  206. : operand_type->AsVector()->element_count();
  207. if (index == 0) {
  208. first_operand_width = operand_width;
  209. first_operand_type_id = operand_inst->type_id();
  210. } else {
  211. assert(first_operand_width != 0 &&
  212. "The first operand should have been processed.");
  213. if (operand_width != first_operand_width) {
  214. return 0;
  215. }
  216. }
  217. }
  218. assert(first_operand_type_id != 0 &&
  219. "A type must have been found for the first operand.");
  220. return first_operand_type_id;
  221. }
  222. case SpvOpLogicalNot: {
  223. if (message_.in_operand_id().size() != 1) {
  224. return 0;
  225. }
  226. auto operand_inst =
  227. ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
  228. if (!operand_inst || !operand_inst->type_id()) {
  229. return 0;
  230. }
  231. auto operand_type =
  232. ir_context->get_type_mgr()->GetType(operand_inst->type_id());
  233. if (!(operand_type->AsBool() ||
  234. (operand_type->AsVector() &&
  235. operand_type->AsVector()->element_type()->AsBool()))) {
  236. return 0;
  237. }
  238. return operand_inst->type_id();
  239. }
  240. case SpvOpSNegate: {
  241. if (message_.in_operand_id().size() != 1) {
  242. return 0;
  243. }
  244. auto operand_inst =
  245. ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
  246. if (!operand_inst || !operand_inst->type_id()) {
  247. return 0;
  248. }
  249. auto operand_type =
  250. ir_context->get_type_mgr()->GetType(operand_inst->type_id());
  251. if (!(operand_type->AsInteger() ||
  252. (operand_type->AsVector() &&
  253. operand_type->AsVector()->element_type()->AsInteger()))) {
  254. return 0;
  255. }
  256. return operand_inst->type_id();
  257. }
  258. default:
  259. assert(false && "Inappropriate opcode for equation instruction.");
  260. return 0;
  261. }
  262. }
  263. std::unordered_set<uint32_t> TransformationEquationInstruction::GetFreshIds()
  264. const {
  265. return {message_.fresh_id()};
  266. }
  267. } // namespace fuzz
  268. } // namespace spvtools