transformation_replace_linear_algebra_instruction.cpp 48 KB


  1. // Copyright (c) 2020 André Perez Maselco
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_descriptor.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationReplaceLinearAlgebraInstruction::
  20. TransformationReplaceLinearAlgebraInstruction(
  21. protobufs::TransformationReplaceLinearAlgebraInstruction message)
  22. : message_(std::move(message)) {}
  23. TransformationReplaceLinearAlgebraInstruction::
  24. TransformationReplaceLinearAlgebraInstruction(
  25. const std::vector<uint32_t>& fresh_ids,
  26. const protobufs::InstructionDescriptor& instruction_descriptor) {
  27. for (auto fresh_id : fresh_ids) {
  28. message_.add_fresh_ids(fresh_id);
  29. }
  30. *message_.mutable_instruction_descriptor() = instruction_descriptor;
  31. }
  32. bool TransformationReplaceLinearAlgebraInstruction::IsApplicable(
  33. opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
  34. auto instruction =
  35. FindInstruction(message_.instruction_descriptor(), ir_context);
  36. // It must be a linear algebra instruction.
  37. if (!spvOpcodeIsLinearAlgebra(instruction->opcode())) {
  38. return false;
  39. }
  40. // |message_.fresh_ids.size| must be the exact number of fresh ids needed to
  41. // apply the transformation.
  42. if (static_cast<uint32_t>(message_.fresh_ids().size()) !=
  43. GetRequiredFreshIdCount(ir_context, instruction)) {
  44. return false;
  45. }
  46. // All ids in |message_.fresh_ids| must be fresh.
  47. for (uint32_t fresh_id : message_.fresh_ids()) {
  48. if (!fuzzerutil::IsFreshId(ir_context, fresh_id)) {
  49. return false;
  50. }
  51. }
  52. return true;
  53. }
  54. void TransformationReplaceLinearAlgebraInstruction::Apply(
  55. opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
  56. auto linear_algebra_instruction =
  57. FindInstruction(message_.instruction_descriptor(), ir_context);
  58. switch (linear_algebra_instruction->opcode()) {
  59. case spv::Op::OpTranspose:
  60. ReplaceOpTranspose(ir_context, linear_algebra_instruction);
  61. break;
  62. case spv::Op::OpVectorTimesScalar:
  63. ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction);
  64. break;
  65. case spv::Op::OpMatrixTimesScalar:
  66. ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction);
  67. break;
  68. case spv::Op::OpVectorTimesMatrix:
  69. ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction);
  70. break;
  71. case spv::Op::OpMatrixTimesVector:
  72. ReplaceOpMatrixTimesVector(ir_context, linear_algebra_instruction);
  73. break;
  74. case spv::Op::OpMatrixTimesMatrix:
  75. ReplaceOpMatrixTimesMatrix(ir_context, linear_algebra_instruction);
  76. break;
  77. case spv::Op::OpOuterProduct:
  78. ReplaceOpOuterProduct(ir_context, linear_algebra_instruction);
  79. break;
  80. case spv::Op::OpDot:
  81. ReplaceOpDot(ir_context, linear_algebra_instruction);
  82. break;
  83. default:
  84. assert(false && "Should be unreachable.");
  85. break;
  86. }
  87. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  88. }
  89. protobufs::Transformation
  90. TransformationReplaceLinearAlgebraInstruction::ToMessage() const {
  91. protobufs::Transformation result;
  92. *result.mutable_replace_linear_algebra_instruction() = message_;
  93. return result;
  94. }
  95. uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
  96. opt::IRContext* ir_context, opt::Instruction* instruction) {
  97. // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
  98. // Right now we only support certain operations.
  99. switch (instruction->opcode()) {
  100. case spv::Op::OpTranspose: {
  101. // For each matrix row, |2 * matrix_column_count| OpCompositeExtract and 1
  102. // OpCompositeConstruct will be inserted.
  103. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  104. instruction->GetSingleWordInOperand(0));
  105. uint32_t matrix_column_count =
  106. ir_context->get_type_mgr()
  107. ->GetType(matrix_instruction->type_id())
  108. ->AsMatrix()
  109. ->element_count();
  110. uint32_t matrix_row_count = ir_context->get_type_mgr()
  111. ->GetType(matrix_instruction->type_id())
  112. ->AsMatrix()
  113. ->element_type()
  114. ->AsVector()
  115. ->element_count();
  116. return matrix_row_count * (2 * matrix_column_count + 1);
  117. }
  118. case spv::Op::OpVectorTimesScalar:
  119. // For each vector component, 1 OpCompositeExtract and 1 OpFMul will be
  120. // inserted.
  121. return 2 *
  122. ir_context->get_type_mgr()
  123. ->GetType(ir_context->get_def_use_mgr()
  124. ->GetDef(instruction->GetSingleWordInOperand(0))
  125. ->type_id())
  126. ->AsVector()
  127. ->element_count();
  128. case spv::Op::OpMatrixTimesScalar: {
  129. // For each matrix column, |1 + column.size| OpCompositeExtract,
  130. // |column.size| OpFMul and 1 OpCompositeConstruct instructions will be
  131. // inserted.
  132. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  133. instruction->GetSingleWordInOperand(0));
  134. auto matrix_type =
  135. ir_context->get_type_mgr()->GetType(matrix_instruction->type_id());
  136. return 2 * matrix_type->AsMatrix()->element_count() *
  137. (1 + matrix_type->AsMatrix()
  138. ->element_type()
  139. ->AsVector()
  140. ->element_count());
  141. }
  142. case spv::Op::OpVectorTimesMatrix: {
  143. // For each vector component, 1 OpCompositeExtract instruction will be
  144. // inserted. For each matrix column, |1 + vector_component_count|
  145. // OpCompositeExtract, |vector_component_count| OpFMul and
  146. // |vector_component_count - 1| OpFAdd instructions will be inserted.
  147. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
  148. instruction->GetSingleWordInOperand(0));
  149. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  150. instruction->GetSingleWordInOperand(1));
  151. uint32_t vector_component_count =
  152. ir_context->get_type_mgr()
  153. ->GetType(vector_instruction->type_id())
  154. ->AsVector()
  155. ->element_count();
  156. uint32_t matrix_column_count =
  157. ir_context->get_type_mgr()
  158. ->GetType(matrix_instruction->type_id())
  159. ->AsMatrix()
  160. ->element_count();
  161. return vector_component_count * (3 * matrix_column_count + 1);
  162. }
  163. case spv::Op::OpMatrixTimesVector: {
  164. // For each matrix column, |1 + matrix_row_count| OpCompositeExtract
  165. // will be inserted. For each matrix row, |matrix_column_count| OpFMul and
  166. // |matrix_column_count - 1| OpFAdd instructions will be inserted. For
  167. // each vector component, 1 OpCompositeExtract instruction will be
  168. // inserted.
  169. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  170. instruction->GetSingleWordInOperand(0));
  171. uint32_t matrix_column_count =
  172. ir_context->get_type_mgr()
  173. ->GetType(matrix_instruction->type_id())
  174. ->AsMatrix()
  175. ->element_count();
  176. uint32_t matrix_row_count = ir_context->get_type_mgr()
  177. ->GetType(matrix_instruction->type_id())
  178. ->AsMatrix()
  179. ->element_type()
  180. ->AsVector()
  181. ->element_count();
  182. return 3 * matrix_column_count * matrix_row_count +
  183. 2 * matrix_column_count - matrix_row_count;
  184. }
  185. case spv::Op::OpMatrixTimesMatrix: {
  186. // For each matrix 2 column, 1 OpCompositeExtract, 1 OpCompositeConstruct,
  187. // |3 * matrix_1_row_count * matrix_1_column_count| OpCompositeExtract,
  188. // |matrix_1_row_count * matrix_1_column_count| OpFMul,
  189. // |matrix_1_row_count * (matrix_1_column_count - 1)| OpFAdd instructions
  190. // will be inserted.
  191. auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
  192. instruction->GetSingleWordInOperand(0));
  193. uint32_t matrix_1_column_count =
  194. ir_context->get_type_mgr()
  195. ->GetType(matrix_1_instruction->type_id())
  196. ->AsMatrix()
  197. ->element_count();
  198. uint32_t matrix_1_row_count =
  199. ir_context->get_type_mgr()
  200. ->GetType(matrix_1_instruction->type_id())
  201. ->AsMatrix()
  202. ->element_type()
  203. ->AsVector()
  204. ->element_count();
  205. auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
  206. instruction->GetSingleWordInOperand(1));
  207. uint32_t matrix_2_column_count =
  208. ir_context->get_type_mgr()
  209. ->GetType(matrix_2_instruction->type_id())
  210. ->AsMatrix()
  211. ->element_count();
  212. return matrix_2_column_count *
  213. (2 + matrix_1_row_count * (5 * matrix_1_column_count - 1));
  214. }
  215. case spv::Op::OpOuterProduct: {
  216. // For each |vector_2| component, |vector_1_component_count + 1|
  217. // OpCompositeExtract, |vector_1_component_count| OpFMul and 1
  218. // OpCompositeConstruct instructions will be inserted.
  219. auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
  220. instruction->GetSingleWordInOperand(0));
  221. auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
  222. instruction->GetSingleWordInOperand(1));
  223. uint32_t vector_1_component_count =
  224. ir_context->get_type_mgr()
  225. ->GetType(vector_1_instruction->type_id())
  226. ->AsVector()
  227. ->element_count();
  228. uint32_t vector_2_component_count =
  229. ir_context->get_type_mgr()
  230. ->GetType(vector_2_instruction->type_id())
  231. ->AsVector()
  232. ->element_count();
  233. return 2 * vector_2_component_count * (vector_1_component_count + 1);
  234. }
  235. case spv::Op::OpDot:
  236. // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
  237. // will be inserted. The first two OpFMul instructions will result the
  238. // first OpFAdd instruction to be inserted. For each remaining OpFMul, 1
  239. // OpFAdd will be inserted. The last OpFAdd instruction is got by changing
  240. // the OpDot instruction.
  241. return 4 * ir_context->get_type_mgr()
  242. ->GetType(
  243. ir_context->get_def_use_mgr()
  244. ->GetDef(instruction->GetSingleWordInOperand(0))
  245. ->type_id())
  246. ->AsVector()
  247. ->element_count() -
  248. 2;
  249. default:
  250. assert(false && "Unsupported linear algebra instruction.");
  251. return 0;
  252. }
  253. }
  254. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpTranspose(
  255. opt::IRContext* ir_context,
  256. opt::Instruction* linear_algebra_instruction) const {
  257. // Gets OpTranspose instruction information.
  258. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  259. linear_algebra_instruction->GetSingleWordInOperand(0));
  260. uint32_t matrix_column_count = ir_context->get_type_mgr()
  261. ->GetType(matrix_instruction->type_id())
  262. ->AsMatrix()
  263. ->element_count();
  264. auto matrix_column_type = ir_context->get_type_mgr()
  265. ->GetType(matrix_instruction->type_id())
  266. ->AsMatrix()
  267. ->element_type();
  268. auto matrix_column_component_type =
  269. matrix_column_type->AsVector()->element_type();
  270. uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
  271. auto resulting_matrix_column_type =
  272. ir_context->get_type_mgr()
  273. ->GetType(linear_algebra_instruction->type_id())
  274. ->AsMatrix()
  275. ->element_type();
  276. uint32_t fresh_id_index = 0;
  277. std::vector<uint32_t> result_column_ids(matrix_row_count);
  278. for (uint32_t i = 0; i < matrix_row_count; i++) {
  279. std::vector<uint32_t> column_component_ids(matrix_column_count);
  280. for (uint32_t j = 0; j < matrix_column_count; j++) {
  281. // Extracts the matrix column.
  282. uint32_t matrix_column_id = message_.fresh_ids(fresh_id_index++);
  283. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  284. ir_context, spv::Op::OpCompositeExtract,
  285. ir_context->get_type_mgr()->GetId(matrix_column_type),
  286. matrix_column_id,
  287. opt::Instruction::OperandList(
  288. {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
  289. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
  290. // Extracts the matrix column component.
  291. column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
  292. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  293. ir_context, spv::Op::OpCompositeExtract,
  294. ir_context->get_type_mgr()->GetId(matrix_column_component_type),
  295. column_component_ids[j],
  296. opt::Instruction::OperandList(
  297. {{SPV_OPERAND_TYPE_ID, {matrix_column_id}},
  298. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  299. }
  300. // Inserts the resulting matrix column.
  301. opt::Instruction::OperandList in_operands;
  302. for (auto& column_component_id : column_component_ids) {
  303. in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
  304. }
  305. result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
  306. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  307. ir_context, spv::Op::OpCompositeConstruct,
  308. ir_context->get_type_mgr()->GetId(resulting_matrix_column_type),
  309. result_column_ids[i], opt::Instruction::OperandList(in_operands)));
  310. }
  311. // The OpTranspose instruction is changed to an OpCompositeConstruct
  312. // instruction.
  313. linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
  314. linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
  315. for (uint32_t i = 1; i < result_column_ids.size(); i++) {
  316. linear_algebra_instruction->AddOperand(
  317. {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
  318. }
  319. fuzzerutil::UpdateModuleIdBound(
  320. ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
  321. }
  322. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar(
  323. opt::IRContext* ir_context,
  324. opt::Instruction* linear_algebra_instruction) const {
  325. // Gets OpVectorTimesScalar in operands.
  326. auto vector = ir_context->get_def_use_mgr()->GetDef(
  327. linear_algebra_instruction->GetSingleWordInOperand(0));
  328. auto scalar = ir_context->get_def_use_mgr()->GetDef(
  329. linear_algebra_instruction->GetSingleWordInOperand(1));
  330. uint32_t vector_component_count = ir_context->get_type_mgr()
  331. ->GetType(vector->type_id())
  332. ->AsVector()
  333. ->element_count();
  334. std::vector<uint32_t> float_multiplication_ids(vector_component_count);
  335. uint32_t fresh_id_index = 0;
  336. for (uint32_t i = 0; i < vector_component_count; i++) {
  337. // Extracts |vector| component.
  338. uint32_t vector_extract_id = message_.fresh_ids(fresh_id_index++);
  339. fuzzerutil::UpdateModuleIdBound(ir_context, vector_extract_id);
  340. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  341. ir_context, spv::Op::OpCompositeExtract, scalar->type_id(),
  342. vector_extract_id,
  343. opt::Instruction::OperandList(
  344. {{SPV_OPERAND_TYPE_ID, {vector->result_id()}},
  345. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  346. // Multiplies the |vector| component with the |scalar|.
  347. uint32_t float_multiplication_id = message_.fresh_ids(fresh_id_index++);
  348. float_multiplication_ids[i] = float_multiplication_id;
  349. fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_id);
  350. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  351. ir_context, spv::Op::OpFMul, scalar->type_id(), float_multiplication_id,
  352. opt::Instruction::OperandList(
  353. {{SPV_OPERAND_TYPE_ID, {vector_extract_id}},
  354. {SPV_OPERAND_TYPE_ID, {scalar->result_id()}}})));
  355. }
  356. // The OpVectorTimesScalar instruction is changed to an OpCompositeConstruct
  357. // instruction.
  358. linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
  359. linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
  360. linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
  361. for (uint32_t i = 2; i < float_multiplication_ids.size(); i++) {
  362. linear_algebra_instruction->AddOperand(
  363. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}});
  364. }
  365. }
  366. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar(
  367. opt::IRContext* ir_context,
  368. opt::Instruction* linear_algebra_instruction) const {
  369. // Gets OpMatrixTimesScalar in operands.
  370. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  371. linear_algebra_instruction->GetSingleWordInOperand(0));
  372. auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef(
  373. linear_algebra_instruction->GetSingleWordInOperand(1));
  374. // Gets matrix information.
  375. uint32_t matrix_column_count = ir_context->get_type_mgr()
  376. ->GetType(matrix_instruction->type_id())
  377. ->AsMatrix()
  378. ->element_count();
  379. auto matrix_column_type = ir_context->get_type_mgr()
  380. ->GetType(matrix_instruction->type_id())
  381. ->AsMatrix()
  382. ->element_type();
  383. uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count();
  384. std::vector<uint32_t> composite_construct_ids(matrix_column_count);
  385. uint32_t fresh_id_index = 0;
  386. for (uint32_t i = 0; i < matrix_column_count; i++) {
  387. // Extracts |matrix| column.
  388. uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
  389. fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id);
  390. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  391. ir_context, spv::Op::OpCompositeExtract,
  392. ir_context->get_type_mgr()->GetId(matrix_column_type),
  393. matrix_extract_id,
  394. opt::Instruction::OperandList(
  395. {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
  396. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  397. std::vector<uint32_t> float_multiplication_ids(matrix_column_size);
  398. for (uint32_t j = 0; j < matrix_column_size; j++) {
  399. // Extracts |column| component.
  400. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
  401. fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id);
  402. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  403. ir_context, spv::Op::OpCompositeExtract,
  404. scalar_instruction->type_id(), column_extract_id,
  405. opt::Instruction::OperandList(
  406. {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
  407. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
  408. // Multiplies the |column| component with the |scalar|.
  409. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
  410. fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]);
  411. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  412. ir_context, spv::Op::OpFMul, scalar_instruction->type_id(),
  413. float_multiplication_ids[j],
  414. opt::Instruction::OperandList(
  415. {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
  416. {SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}})));
  417. }
  418. // Constructs a new column multiplied by |scalar|.
  419. opt::Instruction::OperandList composite_construct_in_operands;
  420. for (uint32_t& float_multiplication_id : float_multiplication_ids) {
  421. composite_construct_in_operands.push_back(
  422. {SPV_OPERAND_TYPE_ID, {float_multiplication_id}});
  423. }
  424. composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++);
  425. fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]);
  426. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  427. ir_context, spv::Op::OpCompositeConstruct,
  428. ir_context->get_type_mgr()->GetId(matrix_column_type),
  429. composite_construct_ids[i], composite_construct_in_operands));
  430. }
  431. // The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct
  432. // instruction.
  433. linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
  434. linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]});
  435. linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]});
  436. for (uint32_t i = 2; i < composite_construct_ids.size(); i++) {
  437. linear_algebra_instruction->AddOperand(
  438. {SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}});
  439. }
  440. }
  441. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesMatrix(
  442. opt::IRContext* ir_context,
  443. opt::Instruction* linear_algebra_instruction) const {
  444. // Gets vector information.
  445. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
  446. linear_algebra_instruction->GetSingleWordInOperand(0));
  447. uint32_t vector_component_count = ir_context->get_type_mgr()
  448. ->GetType(vector_instruction->type_id())
  449. ->AsVector()
  450. ->element_count();
  451. auto vector_component_type = ir_context->get_type_mgr()
  452. ->GetType(vector_instruction->type_id())
  453. ->AsVector()
  454. ->element_type();
  455. // Extracts vector components.
  456. uint32_t fresh_id_index = 0;
  457. std::vector<uint32_t> vector_component_ids(vector_component_count);
  458. for (uint32_t i = 0; i < vector_component_count; i++) {
  459. vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
  460. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  461. ir_context, spv::Op::OpCompositeExtract,
  462. ir_context->get_type_mgr()->GetId(vector_component_type),
  463. vector_component_ids[i],
  464. opt::Instruction::OperandList(
  465. {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
  466. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  467. }
  468. // Gets matrix information.
  469. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  470. linear_algebra_instruction->GetSingleWordInOperand(1));
  471. uint32_t matrix_column_count = ir_context->get_type_mgr()
  472. ->GetType(matrix_instruction->type_id())
  473. ->AsMatrix()
  474. ->element_count();
  475. auto matrix_column_type = ir_context->get_type_mgr()
  476. ->GetType(matrix_instruction->type_id())
  477. ->AsMatrix()
  478. ->element_type();
  479. std::vector<uint32_t> result_component_ids(matrix_column_count);
  480. for (uint32_t i = 0; i < matrix_column_count; i++) {
  481. // Extracts matrix column.
  482. uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
  483. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  484. ir_context, spv::Op::OpCompositeExtract,
  485. ir_context->get_type_mgr()->GetId(matrix_column_type),
  486. matrix_extract_id,
  487. opt::Instruction::OperandList(
  488. {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
  489. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  490. std::vector<uint32_t> float_multiplication_ids(vector_component_count);
  491. for (uint32_t j = 0; j < vector_component_count; j++) {
  492. // Extracts column component.
  493. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
  494. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  495. ir_context, spv::Op::OpCompositeExtract,
  496. ir_context->get_type_mgr()->GetId(vector_component_type),
  497. column_extract_id,
  498. opt::Instruction::OperandList(
  499. {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
  500. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
  501. // Multiplies corresponding vector and column components.
  502. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
  503. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  504. ir_context, spv::Op::OpFMul,
  505. ir_context->get_type_mgr()->GetId(vector_component_type),
  506. float_multiplication_ids[j],
  507. opt::Instruction::OperandList(
  508. {{SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}},
  509. {SPV_OPERAND_TYPE_ID, {column_extract_id}}})));
  510. }
  511. // Adds the multiplication results.
  512. std::vector<uint32_t> float_add_ids;
  513. uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
  514. float_add_ids.push_back(float_add_id);
  515. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  516. ir_context, spv::Op::OpFAdd,
  517. ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
  518. opt::Instruction::OperandList(
  519. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
  520. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
  521. for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
  522. float_add_id = message_.fresh_ids(fresh_id_index++);
  523. float_add_ids.push_back(float_add_id);
  524. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  525. ir_context, spv::Op::OpFAdd,
  526. ir_context->get_type_mgr()->GetId(vector_component_type),
  527. float_add_id,
  528. opt::Instruction::OperandList(
  529. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
  530. {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
  531. }
  532. result_component_ids[i] = float_add_ids.back();
  533. }
  534. // The OpVectorTimesMatrix instruction is changed to an OpCompositeConstruct
  535. // instruction.
  536. linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
  537. linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
  538. linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
  539. for (uint32_t i = 2; i < result_component_ids.size(); i++) {
  540. linear_algebra_instruction->AddOperand(
  541. {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
  542. }
  543. fuzzerutil::UpdateModuleIdBound(
  544. ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
  545. }
  546. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesVector(
  547. opt::IRContext* ir_context,
  548. opt::Instruction* linear_algebra_instruction) const {
  549. // Gets matrix information.
  550. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
  551. linear_algebra_instruction->GetSingleWordInOperand(0));
  552. uint32_t matrix_column_count = ir_context->get_type_mgr()
  553. ->GetType(matrix_instruction->type_id())
  554. ->AsMatrix()
  555. ->element_count();
  556. auto matrix_column_type = ir_context->get_type_mgr()
  557. ->GetType(matrix_instruction->type_id())
  558. ->AsMatrix()
  559. ->element_type();
  560. uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
  561. // Extracts matrix columns.
  562. uint32_t fresh_id_index = 0;
  563. std::vector<uint32_t> matrix_column_ids(matrix_column_count);
  564. for (uint32_t i = 0; i < matrix_column_count; i++) {
  565. matrix_column_ids[i] = message_.fresh_ids(fresh_id_index++);
  566. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  567. ir_context, spv::Op::OpCompositeExtract,
  568. ir_context->get_type_mgr()->GetId(matrix_column_type),
  569. matrix_column_ids[i],
  570. opt::Instruction::OperandList(
  571. {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
  572. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  573. }
  574. // Gets vector information.
  575. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
  576. linear_algebra_instruction->GetSingleWordInOperand(1));
  577. auto vector_component_type = ir_context->get_type_mgr()
  578. ->GetType(vector_instruction->type_id())
  579. ->AsVector()
  580. ->element_type();
  581. // Extracts vector components.
  582. std::vector<uint32_t> vector_component_ids(matrix_column_count);
  583. for (uint32_t i = 0; i < matrix_column_count; i++) {
  584. vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
  585. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  586. ir_context, spv::Op::OpCompositeExtract,
  587. ir_context->get_type_mgr()->GetId(vector_component_type),
  588. vector_component_ids[i],
  589. opt::Instruction::OperandList(
  590. {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
  591. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  592. }
  593. std::vector<uint32_t> result_component_ids(matrix_row_count);
  594. for (uint32_t i = 0; i < matrix_row_count; i++) {
  595. std::vector<uint32_t> float_multiplication_ids(matrix_column_count);
  596. for (uint32_t j = 0; j < matrix_column_count; j++) {
  597. // Extracts column component.
  598. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
  599. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  600. ir_context, spv::Op::OpCompositeExtract,
  601. ir_context->get_type_mgr()->GetId(vector_component_type),
  602. column_extract_id,
  603. opt::Instruction::OperandList(
  604. {{SPV_OPERAND_TYPE_ID, {matrix_column_ids[j]}},
  605. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  606. // Multiplies corresponding vector and column components.
  607. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
  608. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  609. ir_context, spv::Op::OpFMul,
  610. ir_context->get_type_mgr()->GetId(vector_component_type),
  611. float_multiplication_ids[j],
  612. opt::Instruction::OperandList(
  613. {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
  614. {SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}})));
  615. }
  616. // Adds the multiplication results.
  617. std::vector<uint32_t> float_add_ids;
  618. uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
  619. float_add_ids.push_back(float_add_id);
  620. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  621. ir_context, spv::Op::OpFAdd,
  622. ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
  623. opt::Instruction::OperandList(
  624. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
  625. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
  626. for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
  627. float_add_id = message_.fresh_ids(fresh_id_index++);
  628. float_add_ids.push_back(float_add_id);
  629. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  630. ir_context, spv::Op::OpFAdd,
  631. ir_context->get_type_mgr()->GetId(vector_component_type),
  632. float_add_id,
  633. opt::Instruction::OperandList(
  634. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
  635. {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
  636. }
  637. result_component_ids[i] = float_add_ids.back();
  638. }
  639. // The OpMatrixTimesVector instruction is changed to an OpCompositeConstruct
  640. // instruction.
  641. linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
  642. linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
  643. linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
  644. for (uint32_t i = 2; i < result_component_ids.size(); i++) {
  645. linear_algebra_instruction->AddOperand(
  646. {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
  647. }
  648. fuzzerutil::UpdateModuleIdBound(
  649. ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
  650. }
  651. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesMatrix(
  652. opt::IRContext* ir_context,
  653. opt::Instruction* linear_algebra_instruction) const {
  654. // Gets matrix 1 information.
  655. auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
  656. linear_algebra_instruction->GetSingleWordInOperand(0));
  657. uint32_t matrix_1_column_count =
  658. ir_context->get_type_mgr()
  659. ->GetType(matrix_1_instruction->type_id())
  660. ->AsMatrix()
  661. ->element_count();
  662. auto matrix_1_column_type = ir_context->get_type_mgr()
  663. ->GetType(matrix_1_instruction->type_id())
  664. ->AsMatrix()
  665. ->element_type();
  666. auto matrix_1_column_component_type =
  667. matrix_1_column_type->AsVector()->element_type();
  668. uint32_t matrix_1_row_count =
  669. matrix_1_column_type->AsVector()->element_count();
  670. // Gets matrix 2 information.
  671. auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
  672. linear_algebra_instruction->GetSingleWordInOperand(1));
  673. uint32_t matrix_2_column_count =
  674. ir_context->get_type_mgr()
  675. ->GetType(matrix_2_instruction->type_id())
  676. ->AsMatrix()
  677. ->element_count();
  678. auto matrix_2_column_type = ir_context->get_type_mgr()
  679. ->GetType(matrix_2_instruction->type_id())
  680. ->AsMatrix()
  681. ->element_type();
  682. uint32_t fresh_id_index = 0;
  683. std::vector<uint32_t> result_column_ids(matrix_2_column_count);
  684. for (uint32_t i = 0; i < matrix_2_column_count; i++) {
  685. // Extracts matrix 2 column.
  686. uint32_t matrix_2_column_id = message_.fresh_ids(fresh_id_index++);
  687. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  688. ir_context, spv::Op::OpCompositeExtract,
  689. ir_context->get_type_mgr()->GetId(matrix_2_column_type),
  690. matrix_2_column_id,
  691. opt::Instruction::OperandList(
  692. {{SPV_OPERAND_TYPE_ID, {matrix_2_instruction->result_id()}},
  693. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  694. std::vector<uint32_t> column_component_ids(matrix_1_row_count);
  695. for (uint32_t j = 0; j < matrix_1_row_count; j++) {
  696. std::vector<uint32_t> float_multiplication_ids(matrix_1_column_count);
  697. for (uint32_t k = 0; k < matrix_1_column_count; k++) {
  698. // Extracts matrix 1 column.
  699. uint32_t matrix_1_column_id = message_.fresh_ids(fresh_id_index++);
  700. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  701. ir_context, spv::Op::OpCompositeExtract,
  702. ir_context->get_type_mgr()->GetId(matrix_1_column_type),
  703. matrix_1_column_id,
  704. opt::Instruction::OperandList(
  705. {{SPV_OPERAND_TYPE_ID, {matrix_1_instruction->result_id()}},
  706. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
  707. // Extracts matrix 1 column component.
  708. uint32_t matrix_1_column_component_id =
  709. message_.fresh_ids(fresh_id_index++);
  710. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  711. ir_context, spv::Op::OpCompositeExtract,
  712. ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
  713. matrix_1_column_component_id,
  714. opt::Instruction::OperandList(
  715. {{SPV_OPERAND_TYPE_ID, {matrix_1_column_id}},
  716. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
  717. // Extracts matrix 2 column component.
  718. uint32_t matrix_2_column_component_id =
  719. message_.fresh_ids(fresh_id_index++);
  720. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  721. ir_context, spv::Op::OpCompositeExtract,
  722. ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
  723. matrix_2_column_component_id,
  724. opt::Instruction::OperandList(
  725. {{SPV_OPERAND_TYPE_ID, {matrix_2_column_id}},
  726. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
  727. // Multiplies corresponding matrix 1 and matrix 2 column components.
  728. float_multiplication_ids[k] = message_.fresh_ids(fresh_id_index++);
  729. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  730. ir_context, spv::Op::OpFMul,
  731. ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
  732. float_multiplication_ids[k],
  733. opt::Instruction::OperandList(
  734. {{SPV_OPERAND_TYPE_ID, {matrix_1_column_component_id}},
  735. {SPV_OPERAND_TYPE_ID, {matrix_2_column_component_id}}})));
  736. }
  737. // Adds the multiplication results.
  738. std::vector<uint32_t> float_add_ids;
  739. uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
  740. float_add_ids.push_back(float_add_id);
  741. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  742. ir_context, spv::Op::OpFAdd,
  743. ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
  744. float_add_id,
  745. opt::Instruction::OperandList(
  746. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
  747. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
  748. for (uint32_t k = 2; k < float_multiplication_ids.size(); k++) {
  749. float_add_id = message_.fresh_ids(fresh_id_index++);
  750. float_add_ids.push_back(float_add_id);
  751. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  752. ir_context, spv::Op::OpFAdd,
  753. ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
  754. float_add_id,
  755. opt::Instruction::OperandList(
  756. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[k]}},
  757. {SPV_OPERAND_TYPE_ID, {float_add_ids[k - 2]}}})));
  758. }
  759. column_component_ids[j] = float_add_ids.back();
  760. }
  761. // Inserts the resulting matrix column.
  762. opt::Instruction::OperandList in_operands;
  763. for (auto& column_component_id : column_component_ids) {
  764. in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
  765. }
  766. result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
  767. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  768. ir_context, spv::Op::OpCompositeConstruct,
  769. ir_context->get_type_mgr()->GetId(matrix_1_column_type),
  770. result_column_ids[i], opt::Instruction::OperandList(in_operands)));
  771. }
  772. // The OpMatrixTimesMatrix instruction is changed to an OpCompositeConstruct
  773. // instruction.
  774. linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
  775. linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
  776. linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
  777. for (uint32_t i = 2; i < result_column_ids.size(); i++) {
  778. linear_algebra_instruction->AddOperand(
  779. {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
  780. }
  781. fuzzerutil::UpdateModuleIdBound(
  782. ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
  783. }
  784. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpOuterProduct(
  785. opt::IRContext* ir_context,
  786. opt::Instruction* linear_algebra_instruction) const {
  787. // Gets vector 1 information.
  788. auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
  789. linear_algebra_instruction->GetSingleWordInOperand(0));
  790. uint32_t vector_1_component_count =
  791. ir_context->get_type_mgr()
  792. ->GetType(vector_1_instruction->type_id())
  793. ->AsVector()
  794. ->element_count();
  795. auto vector_1_component_type = ir_context->get_type_mgr()
  796. ->GetType(vector_1_instruction->type_id())
  797. ->AsVector()
  798. ->element_type();
  799. // Gets vector 2 information.
  800. auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
  801. linear_algebra_instruction->GetSingleWordInOperand(1));
  802. uint32_t vector_2_component_count =
  803. ir_context->get_type_mgr()
  804. ->GetType(vector_2_instruction->type_id())
  805. ->AsVector()
  806. ->element_count();
  807. uint32_t fresh_id_index = 0;
  808. std::vector<uint32_t> result_column_ids(vector_2_component_count);
  809. for (uint32_t i = 0; i < vector_2_component_count; i++) {
  810. // Extracts |vector_2| component.
  811. uint32_t vector_2_component_id = message_.fresh_ids(fresh_id_index++);
  812. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  813. ir_context, spv::Op::OpCompositeExtract,
  814. ir_context->get_type_mgr()->GetId(vector_1_component_type),
  815. vector_2_component_id,
  816. opt::Instruction::OperandList(
  817. {{SPV_OPERAND_TYPE_ID, {vector_2_instruction->result_id()}},
  818. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  819. std::vector<uint32_t> column_component_ids(vector_1_component_count);
  820. for (uint32_t j = 0; j < vector_1_component_count; j++) {
  821. // Extracts |vector_1| component.
  822. uint32_t vector_1_component_id = message_.fresh_ids(fresh_id_index++);
  823. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  824. ir_context, spv::Op::OpCompositeExtract,
  825. ir_context->get_type_mgr()->GetId(vector_1_component_type),
  826. vector_1_component_id,
  827. opt::Instruction::OperandList(
  828. {{SPV_OPERAND_TYPE_ID, {vector_1_instruction->result_id()}},
  829. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
  830. // Multiplies |vector_1| and |vector_2| components.
  831. column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
  832. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  833. ir_context, spv::Op::OpFMul,
  834. ir_context->get_type_mgr()->GetId(vector_1_component_type),
  835. column_component_ids[j],
  836. opt::Instruction::OperandList(
  837. {{SPV_OPERAND_TYPE_ID, {vector_2_component_id}},
  838. {SPV_OPERAND_TYPE_ID, {vector_1_component_id}}})));
  839. }
  840. // Inserts the resulting matrix column.
  841. opt::Instruction::OperandList in_operands;
  842. for (auto& column_component_id : column_component_ids) {
  843. in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
  844. }
  845. result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
  846. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  847. ir_context, spv::Op::OpCompositeConstruct,
  848. vector_1_instruction->type_id(), result_column_ids[i], in_operands));
  849. }
  850. // The OpOuterProduct instruction is changed to an OpCompositeConstruct
  851. // instruction.
  852. linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
  853. linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
  854. linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
  855. for (uint32_t i = 2; i < result_column_ids.size(); i++) {
  856. linear_algebra_instruction->AddOperand(
  857. {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
  858. }
  859. fuzzerutil::UpdateModuleIdBound(
  860. ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
  861. }
  862. void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
  863. opt::IRContext* ir_context,
  864. opt::Instruction* linear_algebra_instruction) const {
  865. // Gets OpDot in operands.
  866. auto vector_1 = ir_context->get_def_use_mgr()->GetDef(
  867. linear_algebra_instruction->GetSingleWordInOperand(0));
  868. auto vector_2 = ir_context->get_def_use_mgr()->GetDef(
  869. linear_algebra_instruction->GetSingleWordInOperand(1));
  870. uint32_t vectors_component_count = ir_context->get_type_mgr()
  871. ->GetType(vector_1->type_id())
  872. ->AsVector()
  873. ->element_count();
  874. std::vector<uint32_t> float_multiplication_ids(vectors_component_count);
  875. uint32_t fresh_id_index = 0;
  876. for (uint32_t i = 0; i < vectors_component_count; i++) {
  877. // Extracts |vector_1| component.
  878. uint32_t vector_1_extract_id = message_.fresh_ids(fresh_id_index++);
  879. fuzzerutil::UpdateModuleIdBound(ir_context, vector_1_extract_id);
  880. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  881. ir_context, spv::Op::OpCompositeExtract,
  882. linear_algebra_instruction->type_id(), vector_1_extract_id,
  883. opt::Instruction::OperandList(
  884. {{SPV_OPERAND_TYPE_ID, {vector_1->result_id()}},
  885. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  886. // Extracts |vector_2| component.
  887. uint32_t vector_2_extract_id = message_.fresh_ids(fresh_id_index++);
  888. fuzzerutil::UpdateModuleIdBound(ir_context, vector_2_extract_id);
  889. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  890. ir_context, spv::Op::OpCompositeExtract,
  891. linear_algebra_instruction->type_id(), vector_2_extract_id,
  892. opt::Instruction::OperandList(
  893. {{SPV_OPERAND_TYPE_ID, {vector_2->result_id()}},
  894. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
  895. // Multiplies the pair of components.
  896. float_multiplication_ids[i] = message_.fresh_ids(fresh_id_index++);
  897. fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[i]);
  898. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  899. ir_context, spv::Op::OpFMul, linear_algebra_instruction->type_id(),
  900. float_multiplication_ids[i],
  901. opt::Instruction::OperandList(
  902. {{SPV_OPERAND_TYPE_ID, {vector_1_extract_id}},
  903. {SPV_OPERAND_TYPE_ID, {vector_2_extract_id}}})));
  904. }
  905. // If the vector has 2 components, then there will be 2 float multiplication
  906. // instructions.
  907. if (vectors_component_count == 2) {
  908. linear_algebra_instruction->SetOpcode(spv::Op::OpFAdd);
  909. linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
  910. linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
  911. } else {
  912. // The first OpFAdd instruction has as operands the first two OpFMul
  913. // instructions.
  914. std::vector<uint32_t> float_add_ids;
  915. uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
  916. float_add_ids.push_back(float_add_id);
  917. fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
  918. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  919. ir_context, spv::Op::OpFAdd, linear_algebra_instruction->type_id(),
  920. float_add_id,
  921. opt::Instruction::OperandList(
  922. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
  923. {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
  924. // The remaining OpFAdd instructions has as operands an OpFMul and an OpFAdd
  925. // instruction.
  926. for (uint32_t i = 2; i < float_multiplication_ids.size() - 1; i++) {
  927. float_add_id = message_.fresh_ids(fresh_id_index++);
  928. fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
  929. float_add_ids.push_back(float_add_id);
  930. linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
  931. ir_context, spv::Op::OpFAdd, linear_algebra_instruction->type_id(),
  932. float_add_id,
  933. opt::Instruction::OperandList(
  934. {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}},
  935. {SPV_OPERAND_TYPE_ID, {float_add_ids[i - 2]}}})));
  936. }
  937. // The last OpFAdd instruction is got by changing some of the OpDot
  938. // instruction attributes.
  939. linear_algebra_instruction->SetOpcode(spv::Op::OpFAdd);
  940. linear_algebra_instruction->SetInOperand(
  941. 0, {float_multiplication_ids[float_multiplication_ids.size() - 1]});
  942. linear_algebra_instruction->SetInOperand(
  943. 1, {float_add_ids[float_add_ids.size() - 1]});
  944. }
  945. }
  946. std::unordered_set<uint32_t>
  947. TransformationReplaceLinearAlgebraInstruction::GetFreshIds() const {
  948. std::unordered_set<uint32_t> result;
  949. for (auto id : message_.fresh_ids()) {
  950. result.insert(id);
  951. }
  952. return result;
  953. }
  954. } // namespace fuzz
  955. } // namespace spvtools