transformation_replace_linear_algebra_instruction.cpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. // Copyright (c) 2020 André Perez Maselco
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_descriptor.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationReplaceLinearAlgebraInstruction::
  20. TransformationReplaceLinearAlgebraInstruction(
  21. const spvtools::fuzz::protobufs::
  22. TransformationReplaceLinearAlgebraInstruction& message)
  23. : message_(message) {}
  24. TransformationReplaceLinearAlgebraInstruction::
  25. TransformationReplaceLinearAlgebraInstruction(
  26. const std::vector<uint32_t>& fresh_ids,
  27. const protobufs::InstructionDescriptor& instruction_descriptor) {
  28. for (auto fresh_id : fresh_ids) {
  29. message_.add_fresh_ids(fresh_id);
  30. }
  31. *message_.mutable_instruction_descriptor() = instruction_descriptor;
  32. }
  33. bool TransformationReplaceLinearAlgebraInstruction::IsApplicable(
  34. opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
  35. auto instruction =
  36. FindInstruction(message_.instruction_descriptor(), ir_context);
  37. // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
  38. // Right now we only support certain operations. When this issue is addressed
  39. // the following conditional can use the function |spvOpcodeIsLinearAlgebra|.
  40. // It must be a supported linear algebra instruction.
  41. if (instruction->opcode() != SpvOpVectorTimesScalar &&
  42. instruction->opcode() != SpvOpMatrixTimesScalar &&
  43. instruction->opcode() != SpvOpVectorTimesMatrix &&
  44. instruction->opcode() != SpvOpMatrixTimesVector &&
  45. instruction->opcode() != SpvOpDot) {
  46. return false;
  47. }
  48. // |message_.fresh_ids.size| must be the exact number of fresh ids needed to
  49. // apply the transformation.
  50. if (static_cast<uint32_t>(message_.fresh_ids().size()) !=
  51. GetRequiredFreshIdCount(ir_context, instruction)) {
  52. return false;
  53. }
  54. // All ids in |message_.fresh_ids| must be fresh.
  55. for (uint32_t fresh_id : message_.fresh_ids()) {
  56. if (!fuzzerutil::IsFreshId(ir_context, fresh_id)) {
  57. return false;
  58. }
  59. }
  60. return true;
  61. }
  62. void TransformationReplaceLinearAlgebraInstruction::Apply(
  63. opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
  64. auto linear_algebra_instruction =
  65. FindInstruction(message_.instruction_descriptor(), ir_context);
  66. switch (linear_algebra_instruction->opcode()) {
  67. case SpvOpVectorTimesScalar:
  68. ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction);
  69. break;
  70. case SpvOpMatrixTimesScalar:
  71. ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction);
  72. break;
  73. case SpvOpVectorTimesMatrix:
  74. ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction);
  75. break;
  76. case SpvOpMatrixTimesVector:
  77. ReplaceOpMatrixTimesVector(ir_context, linear_algebra_instruction);
  78. break;
  79. case SpvOpDot:
  80. ReplaceOpDot(ir_context, linear_algebra_instruction);
  81. break;
  82. default:
  83. assert(false && "Should be unreachable.");
  84. break;
  85. }
  86. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  87. }
  88. protobufs::Transformation
  89. TransformationReplaceLinearAlgebraInstruction::ToMessage() const {
  90. protobufs::Transformation result;
  91. *result.mutable_replace_linear_algebra_instruction() = message_;
  92. return result;
  93. }
  94. uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
  95. opt::IRContext* ir_context, opt::Instruction* instruction) {
  96. // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
  97. // Right now we only support certain operations.
  98. switch (instruction->opcode()) {
  99. case SpvOpVectorTimesScalar:
  100. // For each vector component, 1 OpCompositeExtract and 1 OpFMul will be
  101. // inserted.
  102. return 2 *
  103. ir_context->get_type_mgr()
  104. ->GetType(ir_context->get_def_use_mgr()
  105. ->GetDef(instruction->GetSingleWordInOperand(0))
  106. ->type_id())
  107. ->AsVector()
  108. ->element_count();
  109. case SpvOpMatrixTimesScalar: {
  110. // For each matrix column, |1 + column.size| OpCompositeExtract,
  111. // |column.size| OpFMul and 1 OpCompositeConstruct instructions will be
  112. // inserted.
  113. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  114. instruction->GetSingleWordInOperand(0));
  115. auto matrix_type =
  116. ir_context->get_type_mgr()->GetType(matrix_instruction->type_id());
  117. return 2 * matrix_type->AsMatrix()->element_count() *
  118. (1 + matrix_type->AsMatrix()
  119. ->element_type()
  120. ->AsVector()
  121. ->element_count());
  122. }
  123. case SpvOpVectorTimesMatrix: {
  124. // For each vector component, 1 OpCompositeExtract instruction will be
  125. // inserted. For each matrix column, |1 + vector_component_count|
  126. // OpCompositeExtract, |vector_component_count| OpFMul and
  127. // |vector_component_count - 1| OpFAdd instructions will be inserted.
  128. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
  129. instruction->GetSingleWordInOperand(0));
  130. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  131. instruction->GetSingleWordInOperand(1));
  132. uint32_t vector_component_count =
  133. ir_context->get_type_mgr()
  134. ->GetType(vector_instruction->type_id())
  135. ->AsVector()
  136. ->element_count();
  137. uint32_t matrix_column_count =
  138. ir_context->get_type_mgr()
  139. ->GetType(matrix_instruction->type_id())
  140. ->AsMatrix()
  141. ->element_count();
  142. return vector_component_count * (3 * matrix_column_count + 1);
  143. }
  144. case SpvOpMatrixTimesVector: {
  145. // For each matrix column, |1 + matrix_row_count| OpCompositeExtract
  146. // will be inserted. For each matrix row, |matrix_column_count| OpFMul and
  147. // |matrix_column_count - 1| OpFAdd instructions will be inserted. For
  148. // each vector component, 1 OpCompositeExtract instruction will be
  149. // inserted.
  150. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  151. instruction->GetSingleWordInOperand(0));
  152. uint32_t matrix_column_count =
  153. ir_context->get_type_mgr()
  154. ->GetType(matrix_instruction->type_id())
  155. ->AsMatrix()
  156. ->element_count();
  157. uint32_t matrix_row_count = ir_context->get_type_mgr()
  158. ->GetType(matrix_instruction->type_id())
  159. ->AsMatrix()
  160. ->element_type()
  161. ->AsVector()
  162. ->element_count();
  163. return 3 * matrix_column_count * matrix_row_count +
  164. 2 * matrix_column_count - matrix_row_count;
  165. }
  166. case SpvOpDot:
  167. // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
  168. // will be inserted. The first two OpFMul instructions will result the
  169. // first OpFAdd instruction to be inserted. For each remaining OpFMul, 1
  170. // OpFAdd will be inserted. The last OpFAdd instruction is got by changing
  171. // the OpDot instruction.
  172. return 4 * ir_context->get_type_mgr()
  173. ->GetType(
  174. ir_context->get_def_use_mgr()
  175. ->GetDef(instruction->GetSingleWordInOperand(0))
  176. ->type_id())
  177. ->AsVector()
  178. ->element_count() -
  179. 2;
  180. default:
  181. assert(false && "Unsupported linear algebra instruction.");
  182. return 0;
  183. }
  184. }
  185. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar(
  186. opt::IRContext* ir_context,
  187. opt::Instruction* linear_algebra_instruction) const {
  188. // Gets OpVectorTimesScalar in operands.
  189. auto vector = ir_context->get_def_use_mgr()->GetDef(
  190. linear_algebra_instruction->GetSingleWordInOperand(0));
  191. auto scalar = ir_context->get_def_use_mgr()->GetDef(
  192. linear_algebra_instruction->GetSingleWordInOperand(1));
  193. uint32_t vector_component_count = ir_context->get_type_mgr()
  194. ->GetType(vector->type_id())
  195. ->AsVector()
  196. ->element_count();
  197. std::vector<uint32_t> float_multiplication_ids(vector_component_count);
  198. uint32_t fresh_id_index = 0;
  199. for (uint32_t i = 0; i < vector_component_count; i++) {
  200. // Extracts |vector| component.
  201. uint32_t vector_extract_id = message_.fresh_ids(fresh_id_index++);
  202. fuzzerutil::UpdateModuleIdBound(ir_context, vector_extract_id);
  203. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  204. ir_context, SpvOpCompositeExtract, scalar->type_id(), vector_extract_id,
  205. opt::Instruction::OperandList(
  206. {{SPV_OPERAND_TYPE_ID, {vector->result_id()}},
  207. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  208. // Multiplies the |vector| component with the |scalar|.
  209. uint32_t float_multiplication_id = message_.fresh_ids(fresh_id_index++);
  210. float_multiplication_ids[i] = float_multiplication_id;
  211. fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_id);
  212. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  213. ir_context, SpvOpFMul, scalar->type_id(), float_multiplication_id,
  214. opt::Instruction::OperandList(
  215. {{SPV_OPERAND_TYPE_ID, {vector_extract_id}},
  216. {SPV_OPERAND_TYPE_ID, {scalar->result_id()}}})));
  217. }
  218. // The OpVectorTimesScalar instruction is changed to an OpCompositeConstruct
  219. // instruction.
  220. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
  221. linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
  222. linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
  223. for (uint32_t i = 2; i < float_multiplication_ids.size(); i++) {
  224. linear_algebra_instruction->AddOperand(
  225. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}});
  226. }
  227. }
  228. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar(
  229. opt::IRContext* ir_context,
  230. opt::Instruction* linear_algebra_instruction) const {
  231. // Gets OpMatrixTimesScalar in operands.
  232. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  233. linear_algebra_instruction->GetSingleWordInOperand(0));
  234. auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef(
  235. linear_algebra_instruction->GetSingleWordInOperand(1));
  236. // Gets matrix information.
  237. uint32_t matrix_column_count = ir_context->get_type_mgr()
  238. ->GetType(matrix_instruction->type_id())
  239. ->AsMatrix()
  240. ->element_count();
  241. auto matrix_column_type = ir_context->get_type_mgr()
  242. ->GetType(matrix_instruction->type_id())
  243. ->AsMatrix()
  244. ->element_type();
  245. uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count();
  246. std::vector<uint32_t> composite_construct_ids(matrix_column_count);
  247. uint32_t fresh_id_index = 0;
  248. for (uint32_t i = 0; i < matrix_column_count; i++) {
  249. // Extracts |matrix| column.
  250. uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
  251. fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id);
  252. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  253. ir_context, SpvOpCompositeExtract,
  254. ir_context->get_type_mgr()->GetId(matrix_column_type),
  255. matrix_extract_id,
  256. opt::Instruction::OperandList(
  257. {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
  258. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  259. std::vector<uint32_t> float_multiplication_ids(matrix_column_size);
  260. for (uint32_t j = 0; j < matrix_column_size; j++) {
  261. // Extracts |column| component.
  262. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
  263. fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id);
  264. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  265. ir_context, SpvOpCompositeExtract, scalar_instruction->type_id(),
  266. column_extract_id,
  267. opt::Instruction::OperandList(
  268. {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
  269. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
  270. // Multiplies the |column| component with the |scalar|.
  271. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
  272. fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]);
  273. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  274. ir_context, SpvOpFMul, scalar_instruction->type_id(),
  275. float_multiplication_ids[j],
  276. opt::Instruction::OperandList(
  277. {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
  278. {SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}})));
  279. }
  280. // Constructs a new column multiplied by |scalar|.
  281. opt::Instruction::OperandList composite_construct_in_operands;
  282. for (uint32_t& float_multiplication_id : float_multiplication_ids) {
  283. composite_construct_in_operands.push_back(
  284. {SPV_OPERAND_TYPE_ID, {float_multiplication_id}});
  285. }
  286. composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++);
  287. fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]);
  288. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  289. ir_context, SpvOpCompositeConstruct,
  290. ir_context->get_type_mgr()->GetId(matrix_column_type),
  291. composite_construct_ids[i], composite_construct_in_operands));
  292. }
  293. // The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct
  294. // instruction.
  295. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
  296. linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]});
  297. linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]});
  298. for (uint32_t i = 2; i < composite_construct_ids.size(); i++) {
  299. linear_algebra_instruction->AddOperand(
  300. {SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}});
  301. }
  302. }
  303. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesMatrix(
  304. opt::IRContext* ir_context,
  305. opt::Instruction* linear_algebra_instruction) const {
  306. // Gets vector information.
  307. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
  308. linear_algebra_instruction->GetSingleWordInOperand(0));
  309. uint32_t vector_component_count = ir_context->get_type_mgr()
  310. ->GetType(vector_instruction->type_id())
  311. ->AsVector()
  312. ->element_count();
  313. auto vector_component_type = ir_context->get_type_mgr()
  314. ->GetType(vector_instruction->type_id())
  315. ->AsVector()
  316. ->element_type();
  317. // Extracts vector components.
  318. uint32_t fresh_id_index = 0;
  319. std::vector<uint32_t> vector_component_ids(vector_component_count);
  320. for (uint32_t i = 0; i < vector_component_count; i++) {
  321. vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
  322. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  323. ir_context, SpvOpCompositeExtract,
  324. ir_context->get_type_mgr()->GetId(vector_component_type),
  325. vector_component_ids[i],
  326. opt::Instruction::OperandList(
  327. {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
  328. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  329. }
  330. // Gets matrix information.
  331. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  332. linear_algebra_instruction->GetSingleWordInOperand(1));
  333. uint32_t matrix_column_count = ir_context->get_type_mgr()
  334. ->GetType(matrix_instruction->type_id())
  335. ->AsMatrix()
  336. ->element_count();
  337. auto matrix_column_type = ir_context->get_type_mgr()
  338. ->GetType(matrix_instruction->type_id())
  339. ->AsMatrix()
  340. ->element_type();
  341. std::vector<uint32_t> result_component_ids(matrix_column_count);
  342. for (uint32_t i = 0; i < matrix_column_count; i++) {
  343. // Extracts matrix column.
  344. uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
  345. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  346. ir_context, SpvOpCompositeExtract,
  347. ir_context->get_type_mgr()->GetId(matrix_column_type),
  348. matrix_extract_id,
  349. opt::Instruction::OperandList(
  350. {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
  351. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  352. std::vector<uint32_t> float_multiplication_ids(vector_component_count);
  353. for (uint32_t j = 0; j < vector_component_count; j++) {
  354. // Extracts column component.
  355. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
  356. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  357. ir_context, SpvOpCompositeExtract,
  358. ir_context->get_type_mgr()->GetId(vector_component_type),
  359. column_extract_id,
  360. opt::Instruction::OperandList(
  361. {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
  362. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
  363. // Multiplies corresponding vector and column components.
  364. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
  365. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  366. ir_context, SpvOpFMul,
  367. ir_context->get_type_mgr()->GetId(vector_component_type),
  368. float_multiplication_ids[j],
  369. opt::Instruction::OperandList(
  370. {{SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}},
  371. {SPV_OPERAND_TYPE_ID, {column_extract_id}}})));
  372. }
  373. // Adds the multiplication results.
  374. std::vector<uint32_t> float_add_ids;
  375. uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
  376. float_add_ids.push_back(float_add_id);
  377. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  378. ir_context, SpvOpFAdd,
  379. ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
  380. opt::Instruction::OperandList(
  381. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
  382. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
  383. for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
  384. float_add_id = message_.fresh_ids(fresh_id_index++);
  385. float_add_ids.push_back(float_add_id);
  386. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  387. ir_context, SpvOpFAdd,
  388. ir_context->get_type_mgr()->GetId(vector_component_type),
  389. float_add_id,
  390. opt::Instruction::OperandList(
  391. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
  392. {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
  393. }
  394. result_component_ids[i] = float_add_ids.back();
  395. }
  396. // The OpVectorTimesMatrix instruction is changed to an OpCompositeConstruct
  397. // instruction.
  398. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
  399. linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
  400. linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
  401. for (uint32_t i = 2; i < result_component_ids.size(); i++) {
  402. linear_algebra_instruction->AddOperand(
  403. {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
  404. }
  405. fuzzerutil::UpdateModuleIdBound(
  406. ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
  407. }
  408. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesVector(
  409. opt::IRContext* ir_context,
  410. opt::Instruction* linear_algebra_instruction) const {
  411. // Gets matrix information.
  412. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  413. linear_algebra_instruction->GetSingleWordInOperand(0));
  414. uint32_t matrix_column_count = ir_context->get_type_mgr()
  415. ->GetType(matrix_instruction->type_id())
  416. ->AsMatrix()
  417. ->element_count();
  418. auto matrix_column_type = ir_context->get_type_mgr()
  419. ->GetType(matrix_instruction->type_id())
  420. ->AsMatrix()
  421. ->element_type();
  422. uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
  423. // Extracts matrix columns.
  424. uint32_t fresh_id_index = 0;
  425. std::vector<uint32_t> matrix_column_ids(matrix_column_count);
  426. for (uint32_t i = 0; i < matrix_column_count; i++) {
  427. matrix_column_ids[i] = message_.fresh_ids(fresh_id_index++);
  428. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  429. ir_context, SpvOpCompositeExtract,
  430. ir_context->get_type_mgr()->GetId(matrix_column_type),
  431. matrix_column_ids[i],
  432. opt::Instruction::OperandList(
  433. {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
  434. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  435. }
  436. // Gets vector information.
  437. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
  438. linear_algebra_instruction->GetSingleWordInOperand(1));
  439. auto vector_component_type = ir_context->get_type_mgr()
  440. ->GetType(vector_instruction->type_id())
  441. ->AsVector()
  442. ->element_type();
  443. // Extracts vector components.
  444. std::vector<uint32_t> vector_component_ids(matrix_column_count);
  445. for (uint32_t i = 0; i < matrix_column_count; i++) {
  446. vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
  447. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  448. ir_context, SpvOpCompositeExtract,
  449. ir_context->get_type_mgr()->GetId(vector_component_type),
  450. vector_component_ids[i],
  451. opt::Instruction::OperandList(
  452. {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
  453. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  454. }
  455. std::vector<uint32_t> result_component_ids(matrix_row_count);
  456. for (uint32_t i = 0; i < matrix_row_count; i++) {
  457. std::vector<uint32_t> float_multiplication_ids(matrix_column_count);
  458. for (uint32_t j = 0; j < matrix_column_count; j++) {
  459. // Extracts column component.
  460. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
  461. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  462. ir_context, SpvOpCompositeExtract,
  463. ir_context->get_type_mgr()->GetId(vector_component_type),
  464. column_extract_id,
  465. opt::Instruction::OperandList(
  466. {{SPV_OPERAND_TYPE_ID, {matrix_column_ids[j]}},
  467. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  468. // Multiplies corresponding vector and column components.
  469. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
  470. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  471. ir_context, SpvOpFMul,
  472. ir_context->get_type_mgr()->GetId(vector_component_type),
  473. float_multiplication_ids[j],
  474. opt::Instruction::OperandList(
  475. {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
  476. {SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}})));
  477. }
  478. // Adds the multiplication results.
  479. std::vector<uint32_t> float_add_ids;
  480. uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
  481. float_add_ids.push_back(float_add_id);
  482. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  483. ir_context, SpvOpFAdd,
  484. ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
  485. opt::Instruction::OperandList(
  486. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
  487. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
  488. for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
  489. float_add_id = message_.fresh_ids(fresh_id_index++);
  490. float_add_ids.push_back(float_add_id);
  491. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  492. ir_context, SpvOpFAdd,
  493. ir_context->get_type_mgr()->GetId(vector_component_type),
  494. float_add_id,
  495. opt::Instruction::OperandList(
  496. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
  497. {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
  498. }
  499. result_component_ids[i] = float_add_ids.back();
  500. }
  501. // The OpMatrixTimesVector instruction is changed to an OpCompositeConstruct
  502. // instruction.
  503. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
  504. linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
  505. linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
  506. for (uint32_t i = 2; i < result_component_ids.size(); i++) {
  507. linear_algebra_instruction->AddOperand(
  508. {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
  509. }
  510. fuzzerutil::UpdateModuleIdBound(
  511. ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
  512. }
  513. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
  514. opt::IRContext* ir_context,
  515. opt::Instruction* linear_algebra_instruction) const {
  516. // Gets OpDot in operands.
  517. auto vector_1 = ir_context->get_def_use_mgr()->GetDef(
  518. linear_algebra_instruction->GetSingleWordInOperand(0));
  519. auto vector_2 = ir_context->get_def_use_mgr()->GetDef(
  520. linear_algebra_instruction->GetSingleWordInOperand(1));
  521. uint32_t vectors_component_count = ir_context->get_type_mgr()
  522. ->GetType(vector_1->type_id())
  523. ->AsVector()
  524. ->element_count();
  525. std::vector<uint32_t> float_multiplication_ids(vectors_component_count);
  526. uint32_t fresh_id_index = 0;
  527. for (uint32_t i = 0; i < vectors_component_count; i++) {
  528. // Extracts |vector_1| component.
  529. uint32_t vector_1_extract_id = message_.fresh_ids(fresh_id_index++);
  530. fuzzerutil::UpdateModuleIdBound(ir_context, vector_1_extract_id);
  531. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  532. ir_context, SpvOpCompositeExtract,
  533. linear_algebra_instruction->type_id(), vector_1_extract_id,
  534. opt::Instruction::OperandList(
  535. {{SPV_OPERAND_TYPE_ID, {vector_1->result_id()}},
  536. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  537. // Extracts |vector_2| component.
  538. uint32_t vector_2_extract_id = message_.fresh_ids(fresh_id_index++);
  539. fuzzerutil::UpdateModuleIdBound(ir_context, vector_2_extract_id);
  540. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  541. ir_context, SpvOpCompositeExtract,
  542. linear_algebra_instruction->type_id(), vector_2_extract_id,
  543. opt::Instruction::OperandList(
  544. {{SPV_OPERAND_TYPE_ID, {vector_2->result_id()}},
  545. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  546. // Multiplies the pair of components.
  547. float_multiplication_ids[i] = message_.fresh_ids(fresh_id_index++);
  548. fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[i]);
  549. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  550. ir_context, SpvOpFMul, linear_algebra_instruction->type_id(),
  551. float_multiplication_ids[i],
  552. opt::Instruction::OperandList(
  553. {{SPV_OPERAND_TYPE_ID, {vector_1_extract_id}},
  554. {SPV_OPERAND_TYPE_ID, {vector_2_extract_id}}})));
  555. }
  556. // If the vector has 2 components, then there will be 2 float multiplication
  557. // instructions.
  558. if (vectors_component_count == 2) {
  559. linear_algebra_instruction->SetOpcode(SpvOpFAdd);
  560. linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
  561. linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
  562. } else {
  563. // The first OpFAdd instruction has as operands the first two OpFMul
  564. // instructions.
  565. std::vector<uint32_t> float_add_ids;
  566. uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
  567. float_add_ids.push_back(float_add_id);
  568. fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
  569. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  570. ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(),
  571. float_add_id,
  572. opt::Instruction::OperandList(
  573. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
  574. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
  575. // The remaining OpFAdd instructions has as operands an OpFMul and an OpFAdd
  576. // instruction.
  577. for (uint32_t i = 2; i < float_multiplication_ids.size() - 1; i++) {
  578. float_add_id = message_.fresh_ids(fresh_id_index++);
  579. fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
  580. float_add_ids.push_back(float_add_id);
  581. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  582. ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(),
  583. float_add_id,
  584. opt::Instruction::OperandList(
  585. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}},
  586. {SPV_OPERAND_TYPE_ID, {float_add_ids[i - 2]}}})));
  587. }
  588. // The last OpFAdd instruction is got by changing some of the OpDot
  589. // instruction attributes.
  590. linear_algebra_instruction->SetOpcode(SpvOpFAdd);
  591. linear_algebra_instruction->SetInOperand(
  592. 0, {float_multiplication_ids[float_multiplication_ids.size() - 1]});
  593. linear_algebra_instruction->SetInOperand(
  594. 1, {float_add_ids[float_add_ids.size() - 1]});
  595. }
  596. }
  597. } // namespace fuzz
  598. } // namespace spvtools