transformation_equation_instruction.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. // Copyright (c) 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_equation_instruction.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_descriptor.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationEquationInstruction::TransformationEquationInstruction(
  20. protobufs::TransformationEquationInstruction message)
  21. : message_(std::move(message)) {}
  22. TransformationEquationInstruction::TransformationEquationInstruction(
  23. uint32_t fresh_id, spv::Op opcode,
  24. const std::vector<uint32_t>& in_operand_id,
  25. const protobufs::InstructionDescriptor& instruction_to_insert_before) {
  26. message_.set_fresh_id(fresh_id);
  27. message_.set_opcode(uint32_t(opcode));
  28. for (auto id : in_operand_id) {
  29. message_.add_in_operand_id(id);
  30. }
  31. *message_.mutable_instruction_to_insert_before() =
  32. instruction_to_insert_before;
  33. }
  34. bool TransformationEquationInstruction::IsApplicable(
  35. opt::IRContext* ir_context,
  36. const TransformationContext& transformation_context) const {
  37. // The result id must be fresh.
  38. if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) {
  39. return false;
  40. }
  41. // The instruction to insert before must exist.
  42. auto insert_before =
  43. FindInstruction(message_.instruction_to_insert_before(), ir_context);
  44. if (!insert_before) {
  45. return false;
  46. }
  47. // The input ids must all exist, not be OpUndef, not be irrelevant, and be
  48. // available before this instruction.
  49. for (auto id : message_.in_operand_id()) {
  50. auto inst = ir_context->get_def_use_mgr()->GetDef(id);
  51. if (!inst) {
  52. return false;
  53. }
  54. if (inst->opcode() == spv::Op::OpUndef) {
  55. return false;
  56. }
  57. if (transformation_context.GetFactManager()->IdIsIrrelevant(id)) {
  58. return false;
  59. }
  60. if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before,
  61. id)) {
  62. return false;
  63. }
  64. }
  65. return MaybeGetResultTypeId(ir_context) != 0;
  66. }
  67. void TransformationEquationInstruction::Apply(
  68. opt::IRContext* ir_context,
  69. TransformationContext* transformation_context) const {
  70. fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id());
  71. opt::Instruction::OperandList in_operands;
  72. std::vector<uint32_t> rhs_id;
  73. for (auto id : message_.in_operand_id()) {
  74. in_operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
  75. rhs_id.push_back(id);
  76. }
  77. auto insert_before =
  78. FindInstruction(message_.instruction_to_insert_before(), ir_context);
  79. opt::Instruction* new_instruction =
  80. insert_before->InsertBefore(MakeUnique<opt::Instruction>(
  81. ir_context, static_cast<spv::Op>(message_.opcode()),
  82. MaybeGetResultTypeId(ir_context), message_.fresh_id(),
  83. std::move(in_operands)));
  84. ir_context->get_def_use_mgr()->AnalyzeInstDefUse(new_instruction);
  85. ir_context->set_instr_block(new_instruction,
  86. ir_context->get_instr_block(insert_before));
  87. // Add an equation fact as long as the result id is not irrelevant (it could
  88. // be if we are inserting into a dead block).
  89. if (!transformation_context->GetFactManager()->IdIsIrrelevant(
  90. message_.fresh_id())) {
  91. transformation_context->GetFactManager()->AddFactIdEquation(
  92. message_.fresh_id(), static_cast<spv::Op>(message_.opcode()), rhs_id);
  93. }
  94. }
  95. protobufs::Transformation TransformationEquationInstruction::ToMessage() const {
  96. protobufs::Transformation result;
  97. *result.mutable_equation_instruction() = message_;
  98. return result;
  99. }
  100. uint32_t TransformationEquationInstruction::MaybeGetResultTypeId(
  101. opt::IRContext* ir_context) const {
  102. auto opcode = static_cast<spv::Op>(message_.opcode());
  103. switch (opcode) {
  104. case spv::Op::OpConvertUToF:
  105. case spv::Op::OpConvertSToF: {
  106. if (message_.in_operand_id_size() != 1) {
  107. return 0;
  108. }
  109. const auto* type = ir_context->get_type_mgr()->GetType(
  110. fuzzerutil::GetTypeId(ir_context, message_.in_operand_id(0)));
  111. if (!type) {
  112. return 0;
  113. }
  114. if (const auto* vector = type->AsVector()) {
  115. if (!vector->element_type()->AsInteger()) {
  116. return 0;
  117. }
  118. if (auto element_type_id = fuzzerutil::MaybeGetFloatType(
  119. ir_context, vector->element_type()->AsInteger()->width())) {
  120. return fuzzerutil::MaybeGetVectorType(ir_context, element_type_id,
  121. vector->element_count());
  122. }
  123. return 0;
  124. } else {
  125. if (!type->AsInteger()) {
  126. return 0;
  127. }
  128. return fuzzerutil::MaybeGetFloatType(ir_context,
  129. type->AsInteger()->width());
  130. }
  131. }
  132. case spv::Op::OpBitcast: {
  133. if (message_.in_operand_id_size() != 1) {
  134. return 0;
  135. }
  136. const auto* operand_inst =
  137. ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
  138. if (!operand_inst) {
  139. return 0;
  140. }
  141. const auto* operand_type =
  142. ir_context->get_type_mgr()->GetType(operand_inst->type_id());
  143. if (!operand_type) {
  144. return 0;
  145. }
  146. // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3539):
  147. // The only constraint on the types of OpBitcast's parameters is that
  148. // they must have the same number of bits. Consider improving the code
  149. // below to support this in full.
  150. if (const auto* vector = operand_type->AsVector()) {
  151. uint32_t component_type_id;
  152. if (const auto* int_type = vector->element_type()->AsInteger()) {
  153. component_type_id =
  154. fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
  155. } else if (const auto* float_type = vector->element_type()->AsFloat()) {
  156. component_type_id = fuzzerutil::MaybeGetIntegerType(
  157. ir_context, float_type->width(), true);
  158. if (component_type_id == 0 ||
  159. fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
  160. vector->element_count()) == 0) {
  161. component_type_id = fuzzerutil::MaybeGetIntegerType(
  162. ir_context, float_type->width(), false);
  163. }
  164. } else {
  165. assert(false && "Only vectors of numerical components are supported");
  166. return 0;
  167. }
  168. if (component_type_id == 0) {
  169. return 0;
  170. }
  171. return fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
  172. vector->element_count());
  173. } else if (const auto* int_type = operand_type->AsInteger()) {
  174. return fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
  175. } else if (const auto* float_type = operand_type->AsFloat()) {
  176. if (auto existing_id = fuzzerutil::MaybeGetIntegerType(
  177. ir_context, float_type->width(), true)) {
  178. return existing_id;
  179. }
  180. return fuzzerutil::MaybeGetIntegerType(ir_context, float_type->width(),
  181. false);
  182. } else {
  183. assert(false &&
  184. "Operand is not a scalar or a vector of numerical type");
  185. return 0;
  186. }
  187. }
  188. case spv::Op::OpIAdd:
  189. case spv::Op::OpISub: {
  190. if (message_.in_operand_id_size() != 2) {
  191. return 0;
  192. }
  193. uint32_t first_operand_width = 0;
  194. uint32_t first_operand_type_id = 0;
  195. for (uint32_t index = 0; index < 2; index++) {
  196. auto operand_inst = ir_context->get_def_use_mgr()->GetDef(
  197. message_.in_operand_id(index));
  198. if (!operand_inst || !operand_inst->type_id()) {
  199. return 0;
  200. }
  201. auto operand_type =
  202. ir_context->get_type_mgr()->GetType(operand_inst->type_id());
  203. if (!(operand_type->AsInteger() ||
  204. (operand_type->AsVector() &&
  205. operand_type->AsVector()->element_type()->AsInteger()))) {
  206. return 0;
  207. }
  208. uint32_t operand_width =
  209. operand_type->AsInteger()
  210. ? 1
  211. : operand_type->AsVector()->element_count();
  212. if (index == 0) {
  213. first_operand_width = operand_width;
  214. first_operand_type_id = operand_inst->type_id();
  215. } else {
  216. assert(first_operand_width != 0 &&
  217. "The first operand should have been processed.");
  218. if (operand_width != first_operand_width) {
  219. return 0;
  220. }
  221. }
  222. }
  223. assert(first_operand_type_id != 0 &&
  224. "A type must have been found for the first operand.");
  225. return first_operand_type_id;
  226. }
  227. case spv::Op::OpLogicalNot: {
  228. if (message_.in_operand_id().size() != 1) {
  229. return 0;
  230. }
  231. auto operand_inst =
  232. ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
  233. if (!operand_inst || !operand_inst->type_id()) {
  234. return 0;
  235. }
  236. auto operand_type =
  237. ir_context->get_type_mgr()->GetType(operand_inst->type_id());
  238. if (!(operand_type->AsBool() ||
  239. (operand_type->AsVector() &&
  240. operand_type->AsVector()->element_type()->AsBool()))) {
  241. return 0;
  242. }
  243. return operand_inst->type_id();
  244. }
  245. case spv::Op::OpSNegate: {
  246. if (message_.in_operand_id().size() != 1) {
  247. return 0;
  248. }
  249. auto operand_inst =
  250. ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
  251. if (!operand_inst || !operand_inst->type_id()) {
  252. return 0;
  253. }
  254. auto operand_type =
  255. ir_context->get_type_mgr()->GetType(operand_inst->type_id());
  256. if (!(operand_type->AsInteger() ||
  257. (operand_type->AsVector() &&
  258. operand_type->AsVector()->element_type()->AsInteger()))) {
  259. return 0;
  260. }
  261. return operand_inst->type_id();
  262. }
  263. default:
  264. assert(false && "Inappropriate opcode for equation instruction.");
  265. return 0;
  266. }
  267. }
  268. std::unordered_set<uint32_t> TransformationEquationInstruction::GetFreshIds()
  269. const {
  270. return {message_.fresh_id()};
  271. }
  272. } // namespace fuzz
  273. } // namespace spvtools