amd_ext_to_khr.cpp 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980
  1. // Copyright (c) 2019 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/amd_ext_to_khr.h"
  15. #include <set>
  16. #include <string>
  17. #include "ir_builder.h"
  18. #include "source/opt/ir_context.h"
  19. #include "spv-amd-shader-ballot.insts.inc"
  20. #include "type_manager.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. enum AmdShaderBallotExtOpcodes {
  25. AmdShaderBallotSwizzleInvocationsAMD = 1,
  26. AmdShaderBallotSwizzleInvocationsMaskedAMD = 2,
  27. AmdShaderBallotWriteInvocationAMD = 3,
  28. AmdShaderBallotMbcntAMD = 4
  29. };
  30. enum AmdShaderTrinaryMinMaxExtOpCodes {
  31. FMin3AMD = 1,
  32. UMin3AMD = 2,
  33. SMin3AMD = 3,
  34. FMax3AMD = 4,
  35. UMax3AMD = 5,
  36. SMax3AMD = 6,
  37. FMid3AMD = 7,
  38. UMid3AMD = 8,
  39. SMid3AMD = 9
  40. };
  41. enum AmdGcnShader { CubeFaceCoordAMD = 2, CubeFaceIndexAMD = 1, TimeAMD = 3 };
  42. analysis::Type* GetUIntType(IRContext* ctx) {
  43. analysis::Integer int_type(32, false);
  44. return ctx->get_type_mgr()->GetRegisteredType(&int_type);
  45. }
  46. // Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where
  47. // |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450
  48. // extended instruction set that corresponds to the trinary instruction being
  49. // replaced.
  50. template <GLSLstd450 opcode>
  51. bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst,
  52. const std::vector<const analysis::Constant*>&) {
  53. uint32_t glsl405_ext_inst_id =
  54. ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  55. if (glsl405_ext_inst_id == 0) {
  56. ctx->AddExtInstImport("GLSL.std.450");
  57. glsl405_ext_inst_id =
  58. ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  59. }
  60. InstructionBuilder ir_builder(
  61. ctx, inst,
  62. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  63. uint32_t op1 = inst->GetSingleWordInOperand(2);
  64. uint32_t op2 = inst->GetSingleWordInOperand(3);
  65. uint32_t op3 = inst->GetSingleWordInOperand(4);
  66. Instruction* temp = ir_builder.AddNaryExtendedInstruction(
  67. inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2});
  68. Instruction::OperandList new_operands;
  69. new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
  70. new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
  71. {static_cast<uint32_t>(opcode)}});
  72. new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}});
  73. new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}});
  74. inst->SetInOperands(std::move(new_operands));
  75. ctx->UpdateDefUse(inst);
  76. return true;
  77. }
  78. // Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c),
  79. // max(b,c)|. The three parameters are the opcode that correspond to the min,
  80. // max, and clamp operations for the type of the instruction being replaced.
  81. template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode>
  82. bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst,
  83. const std::vector<const analysis::Constant*>&) {
  84. uint32_t glsl405_ext_inst_id =
  85. ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  86. if (glsl405_ext_inst_id == 0) {
  87. ctx->AddExtInstImport("GLSL.std.450");
  88. glsl405_ext_inst_id =
  89. ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  90. }
  91. InstructionBuilder ir_builder(
  92. ctx, inst,
  93. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  94. uint32_t op1 = inst->GetSingleWordInOperand(2);
  95. uint32_t op2 = inst->GetSingleWordInOperand(3);
  96. uint32_t op3 = inst->GetSingleWordInOperand(4);
  97. Instruction* min = ir_builder.AddNaryExtendedInstruction(
  98. inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode),
  99. {op2, op3});
  100. Instruction* max = ir_builder.AddNaryExtendedInstruction(
  101. inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode),
  102. {op2, op3});
  103. Instruction::OperandList new_operands;
  104. new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
  105. new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
  106. {static_cast<uint32_t>(clamp_opcode)}});
  107. new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}});
  108. new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}});
  109. new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}});
  110. inst->SetInOperands(std::move(new_operands));
  111. ctx->UpdateDefUse(inst);
  112. return true;
  113. }
  114. // Returns a folding rule that will replace the opcode with |opcode| and add
  115. // the capabilities required. The folding rule assumes it is folding an
  116. // OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension.
  117. template <spv::Op new_opcode>
  118. bool ReplaceGroupNonuniformOperationOpCode(
  119. IRContext* ctx, Instruction* inst,
  120. const std::vector<const analysis::Constant*>&) {
  121. switch (new_opcode) {
  122. case spv::Op::OpGroupNonUniformIAdd:
  123. case spv::Op::OpGroupNonUniformFAdd:
  124. case spv::Op::OpGroupNonUniformUMin:
  125. case spv::Op::OpGroupNonUniformSMin:
  126. case spv::Op::OpGroupNonUniformFMin:
  127. case spv::Op::OpGroupNonUniformUMax:
  128. case spv::Op::OpGroupNonUniformSMax:
  129. case spv::Op::OpGroupNonUniformFMax:
  130. break;
  131. default:
  132. assert(
  133. false &&
  134. "Should be replacing with a group non uniform arithmetic operation.");
  135. }
  136. switch (inst->opcode()) {
  137. case spv::Op::OpGroupIAddNonUniformAMD:
  138. case spv::Op::OpGroupFAddNonUniformAMD:
  139. case spv::Op::OpGroupUMinNonUniformAMD:
  140. case spv::Op::OpGroupSMinNonUniformAMD:
  141. case spv::Op::OpGroupFMinNonUniformAMD:
  142. case spv::Op::OpGroupUMaxNonUniformAMD:
  143. case spv::Op::OpGroupSMaxNonUniformAMD:
  144. case spv::Op::OpGroupFMaxNonUniformAMD:
  145. break;
  146. default:
  147. assert(false &&
  148. "Should be replacing a group non uniform arithmetic operation.");
  149. }
  150. ctx->AddCapability(spv::Capability::GroupNonUniformArithmetic);
  151. inst->SetOpcode(new_opcode);
  152. return true;
  153. }
  154. // Returns a folding rule that will replace the SwizzleInvocationsAMD extended
  155. // instruction in the SPV_AMD_shader_ballot extension.
  156. //
  157. // The instruction
  158. //
  159. // %offset = OpConstantComposite %v3uint %x %y %z %w
  160. // %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset
  161. //
  162. // is replaced with
  163. //
  164. // potentially new constants and types
  165. //
  166. // clang-format off
  167. // %uint_max = OpConstant %uint 0xFFFFFFFF
  168. // %v4uint = OpTypeVector %uint 4
  169. // %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
  170. // %null = OpConstantNull %type
  171. // clang-format on
  172. //
  173. // and the following code in the function body
  174. //
  175. // clang-format off
  176. // %id = OpLoad %uint %SubgroupLocalInvocationId
  177. // %quad_idx = OpBitwiseAnd %uint %id %uint_3
  178. // %quad_ldr = OpBitwiseXor %uint %id %quad_idx
  179. // %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx
  180. // %target_inv = OpIAdd %uint %quad_ldr %my_offset
  181. // %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
  182. // %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
  183. // %result = OpSelect %type %is_active %shuffle %null
  184. // clang-format on
  185. //
  186. // Also adding the capabilities and builtins that are needed.
  187. bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst,
  188. const std::vector<const analysis::Constant*>&) {
  189. analysis::TypeManager* type_mgr = ctx->get_type_mgr();
  190. analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
  191. ctx->AddExtension("SPV_KHR_shader_ballot");
  192. ctx->AddCapability(spv::Capability::GroupNonUniformBallot);
  193. ctx->AddCapability(spv::Capability::GroupNonUniformShuffle);
  194. InstructionBuilder ir_builder(
  195. ctx, inst,
  196. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  197. uint32_t data_id = inst->GetSingleWordInOperand(2);
  198. uint32_t offset_id = inst->GetSingleWordInOperand(3);
  199. // Get the subgroup invocation id.
  200. uint32_t var_id = ctx->GetBuiltinInputVarId(
  201. uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
  202. assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
  203. Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
  204. Instruction* var_ptr_type =
  205. ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
  206. uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
  207. Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
  208. uint32_t quad_mask = ir_builder.GetUintConstantId(3);
  209. // This gives the offset in the group of 4 of this invocation.
  210. Instruction* quad_idx = ir_builder.AddBinaryOp(
  211. uint_type_id, spv::Op::OpBitwiseAnd, id->result_id(), quad_mask);
  212. // Get the invocation id of the first invocation in the group of 4.
  213. Instruction* quad_ldr =
  214. ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseXor,
  215. id->result_id(), quad_idx->result_id());
  216. // Get the offset of the target invocation from the offset vector.
  217. Instruction* my_offset =
  218. ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpVectorExtractDynamic,
  219. offset_id, quad_idx->result_id());
  220. // Determine the index of the invocation to read from.
  221. Instruction* target_inv =
  222. ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpIAdd,
  223. quad_ldr->result_id(), my_offset->result_id());
  224. // Do the group operations
  225. uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
  226. uint32_t subgroup_scope =
  227. ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
  228. const auto* ballot_value_const = const_mgr->GetConstant(
  229. type_mgr->GetUIntVectorType(4),
  230. {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
  231. Instruction* ballot_value =
  232. const_mgr->GetDefiningInstruction(ballot_value_const);
  233. Instruction* is_active = ir_builder.AddNaryOp(
  234. type_mgr->GetBoolTypeId(), spv::Op::OpGroupNonUniformBallotBitExtract,
  235. {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
  236. Instruction* shuffle =
  237. ir_builder.AddNaryOp(inst->type_id(), spv::Op::OpGroupNonUniformShuffle,
  238. {subgroup_scope, data_id, target_inv->result_id()});
  239. // Create the null constant to use in the select.
  240. const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
  241. std::vector<uint32_t>());
  242. Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
  243. // Build the select.
  244. inst->SetOpcode(spv::Op::OpSelect);
  245. Instruction::OperandList new_operands;
  246. new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
  247. new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
  248. new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
  249. inst->SetInOperands(std::move(new_operands));
  250. ctx->UpdateDefUse(inst);
  251. return true;
  252. }
  253. // Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
  254. // extended instruction in the SPV_AMD_shader_ballot extension.
  255. //
  256. // The instruction
  257. //
  258. // %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z
  259. // %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask
  260. //
  261. // is replaced with
  262. //
  263. // potentially new constants and types
  264. //
  265. // clang-format off
  266. // %uint_mask_extend = OpConstant %uint 0xFFFFFFE0
  267. // %uint_max = OpConstant %uint 0xFFFFFFFF
  268. // %v4uint = OpTypeVector %uint 4
  269. // %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
  270. // clang-format on
  271. //
  272. // and the following code in the function body
  273. //
  274. // clang-format off
  275. // %id = OpLoad %uint %SubgroupLocalInvocationId
  276. // %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend
  277. // %and = OpBitwiseAnd %uint %id %and_mask
  278. // %or = OpBitwiseOr %uint %and %uint_y
  279. // %target_inv = OpBitwiseXor %uint %or %uint_z
  280. // %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
  281. // %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
  282. // %result = OpSelect %type %is_active %shuffle %uint_0
  283. // clang-format on
  284. //
  285. // Also adding the capabilities and builtins that are needed.
  286. bool ReplaceSwizzleInvocationsMasked(
  287. IRContext* ctx, Instruction* inst,
  288. const std::vector<const analysis::Constant*>&) {
  289. analysis::TypeManager* type_mgr = ctx->get_type_mgr();
  290. analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
  291. analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
  292. ctx->AddCapability(spv::Capability::GroupNonUniformBallot);
  293. ctx->AddCapability(spv::Capability::GroupNonUniformShuffle);
  294. InstructionBuilder ir_builder(
  295. ctx, inst,
  296. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  297. // Get the operands to inst, and the components of the mask
  298. uint32_t data_id = inst->GetSingleWordInOperand(2);
  299. Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
  300. assert(mask_inst->opcode() == spv::Op::OpConstantComposite &&
  301. "The mask is suppose to be a vector constant.");
  302. assert(mask_inst->NumInOperands() == 3 &&
  303. "The mask is suppose to have 3 components.");
  304. uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
  305. uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
  306. uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
  307. // Get the subgroup invocation id.
  308. uint32_t var_id = ctx->GetBuiltinInputVarId(
  309. uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
  310. ctx->AddExtension("SPV_KHR_shader_ballot");
  311. assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
  312. Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
  313. Instruction* var_ptr_type =
  314. ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
  315. uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
  316. Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
  317. // Do the bitwise operations.
  318. uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
  319. Instruction* and_mask = ir_builder.AddBinaryOp(
  320. uint_type_id, spv::Op::OpBitwiseOr, uint_x, mask_extended);
  321. Instruction* and_result =
  322. ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseAnd,
  323. id->result_id(), and_mask->result_id());
  324. Instruction* or_result = ir_builder.AddBinaryOp(
  325. uint_type_id, spv::Op::OpBitwiseOr, and_result->result_id(), uint_y);
  326. Instruction* target_inv = ir_builder.AddBinaryOp(
  327. uint_type_id, spv::Op::OpBitwiseXor, or_result->result_id(), uint_z);
  328. // Do the group operations
  329. uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
  330. uint32_t subgroup_scope =
  331. ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
  332. const auto* ballot_value_const = const_mgr->GetConstant(
  333. type_mgr->GetUIntVectorType(4),
  334. {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
  335. Instruction* ballot_value =
  336. const_mgr->GetDefiningInstruction(ballot_value_const);
  337. Instruction* is_active = ir_builder.AddNaryOp(
  338. type_mgr->GetBoolTypeId(), spv::Op::OpGroupNonUniformBallotBitExtract,
  339. {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
  340. Instruction* shuffle =
  341. ir_builder.AddNaryOp(inst->type_id(), spv::Op::OpGroupNonUniformShuffle,
  342. {subgroup_scope, data_id, target_inv->result_id()});
  343. // Create the null constant to use in the select.
  344. const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
  345. std::vector<uint32_t>());
  346. Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
  347. // Build the select.
  348. inst->SetOpcode(spv::Op::OpSelect);
  349. Instruction::OperandList new_operands;
  350. new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
  351. new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
  352. new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
  353. inst->SetInOperands(std::move(new_operands));
  354. ctx->UpdateDefUse(inst);
  355. return true;
  356. }
  357. // Returns a folding rule that will replace the WriteInvocationAMD extended
  358. // instruction in the SPV_AMD_shader_ballot extension.
  359. //
  360. // The instruction
  361. //
  362. // clang-format off
  363. // %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index
  364. // clang-format on
  365. //
  366. // with
  367. //
  368. // %id = OpLoad %uint %SubgroupLocalInvocationId
  369. // %cmp = OpIEqual %bool %id %invocation_index
  370. // %result = OpSelect %type %cmp %write_value %input_value
  371. //
  372. // Also adding the capabilities and builtins that are needed.
  373. bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst,
  374. const std::vector<const analysis::Constant*>&) {
  375. uint32_t var_id = ctx->GetBuiltinInputVarId(
  376. uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
  377. ctx->AddCapability(spv::Capability::SubgroupBallotKHR);
  378. ctx->AddExtension("SPV_KHR_shader_ballot");
  379. assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
  380. Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
  381. Instruction* var_ptr_type =
  382. ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
  383. InstructionBuilder ir_builder(
  384. ctx, inst,
  385. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  386. Instruction* t =
  387. ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
  388. analysis::Bool bool_type;
  389. uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
  390. Instruction* cmp =
  391. ir_builder.AddBinaryOp(bool_type_id, spv::Op::OpIEqual, t->result_id(),
  392. inst->GetSingleWordInOperand(4));
  393. // Build a select.
  394. inst->SetOpcode(spv::Op::OpSelect);
  395. Instruction::OperandList new_operands;
  396. new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
  397. new_operands.push_back(inst->GetInOperand(3));
  398. new_operands.push_back(inst->GetInOperand(2));
  399. inst->SetInOperands(std::move(new_operands));
  400. ctx->UpdateDefUse(inst);
  401. return true;
  402. }
  403. // Returns a folding rule that will replace the MbcntAMD extended instruction in
  404. // the SPV_AMD_shader_ballot extension.
  405. //
  406. // The instruction
  407. //
  408. // %result = OpExtInst %uint %1 MbcntAMD %mask
  409. //
  410. // with
  411. //
  412. // Get SubgroupLtMask and convert the first 64-bits into a uint64_t because
  413. // AMD's shader compiler expects a 64-bit integer mask.
  414. //
  415. // %var = OpLoad %v4uint %SubgroupLtMaskKHR
  416. // %shuffle = OpVectorShuffle %v2uint %var %var 0 1
  417. // %cast = OpBitcast %ulong %shuffle
  418. //
  419. // Perform the mask and count the bits.
  420. //
  421. // %and = OpBitwiseAnd %ulong %cast %mask
  422. // %result = OpBitCount %uint %and
  423. //
  424. // Also adding the capabilities and builtins that are needed.
  425. bool ReplaceMbcnt(IRContext* context, Instruction* inst,
  426. const std::vector<const analysis::Constant*>&) {
  427. analysis::TypeManager* type_mgr = context->get_type_mgr();
  428. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  429. uint32_t var_id =
  430. context->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::SubgroupLtMask));
  431. assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
  432. context->AddCapability(spv::Capability::GroupNonUniformBallot);
  433. Instruction* var_inst = def_use_mgr->GetDef(var_id);
  434. Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
  435. Instruction* var_type =
  436. def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
  437. assert(var_type->opcode() == spv::Op::OpTypeVector &&
  438. "Variable is suppose to be a vector of 4 ints");
  439. // Get the type for the shuffle.
  440. analysis::Vector temp_type(GetUIntType(context), 2);
  441. const analysis::Type* shuffle_type =
  442. context->get_type_mgr()->GetRegisteredType(&temp_type);
  443. uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
  444. uint32_t mask_id = inst->GetSingleWordInOperand(2);
  445. Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
  446. // Testing with amd's shader compiler shows that a 64-bit mask is expected.
  447. assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
  448. assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
  449. InstructionBuilder ir_builder(
  450. context, inst,
  451. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  452. Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
  453. Instruction* shuffle = ir_builder.AddVectorShuffle(
  454. shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
  455. Instruction* bitcast = ir_builder.AddUnaryOp(
  456. mask_inst->type_id(), spv::Op::OpBitcast, shuffle->result_id());
  457. Instruction* t =
  458. ir_builder.AddBinaryOp(mask_inst->type_id(), spv::Op::OpBitwiseAnd,
  459. bitcast->result_id(), mask_id);
  460. inst->SetOpcode(spv::Op::OpBitCount);
  461. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
  462. context->UpdateDefUse(inst);
  463. return true;
  464. }
  465. // A folding rule that will replace the CubeFaceCoordAMD extended
  466. // instruction in the SPV_AMD_gcn_shader_ballot. Returns true if the folding is
  467. // successful.
  468. //
  469. // The instruction
  470. //
  471. // %result = OpExtInst %v2float %1 CubeFaceCoordAMD %input
  472. //
  473. // with
  474. //
  475. // %x = OpCompositeExtract %float %input 0
  476. // %y = OpCompositeExtract %float %input 1
  477. // %z = OpCompositeExtract %float %input 2
  478. // %nx = OpFNegate %float %x
  479. // %ny = OpFNegate %float %y
  480. // %nz = OpFNegate %float %z
  481. // %ax = OpExtInst %float %n_1 FAbs %x
  482. // %ay = OpExtInst %float %n_1 FAbs %y
  483. // %az = OpExtInst %float %n_1 FAbs %z
  484. // %amax_x_y = OpExtInst %float %n_1 FMax %ay %ax
  485. // %amax = OpExtInst %float %n_1 FMax %az %amax_x_y
  486. // %cubema = OpFMul %float %float_2 %amax
  487. // %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y
  488. // %not_is_z_max = OpLogicalNot %bool %is_z_max
  489. // %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax
  490. // %is_y_max = OpLogicalAnd %bool %not_is_z_max %y_gt_x
  491. // %is_z_neg = OpFOrdLessThan %bool %z %float_0
  492. // %cubesc_case_1 = OpSelect %float %is_z_neg %nx %x
  493. // %is_x_neg = OpFOrdLessThan %bool %x %float_0
  494. // %cubesc_case_2 = OpSelect %float %is_x_neg %z %nz
  495. // %sel = OpSelect %float %is_y_max %x %cubesc_case_2
  496. // %cubesc = OpSelect %float %is_z_max %cubesc_case_1 %sel
  497. // %is_y_neg = OpFOrdLessThan %bool %y %float_0
  498. // %cubetc_case_1 = OpSelect %float %is_y_neg %nz %z
  499. // %cubetc = OpSelect %float %is_y_max %cubetc_case_1 %ny
  500. // %cube = OpCompositeConstruct %v2float %cubesc %cubetc
  501. // %denom = OpCompositeConstruct %v2float %cubema %cubema
  502. // %div = OpFDiv %v2float %cube %denom
  503. // %result = OpFAdd %v2float %div %const
  504. //
  505. // Also adding the capabilities and builtins that are needed.
  506. bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
  507. const std::vector<const analysis::Constant*>&) {
  508. analysis::TypeManager* type_mgr = ctx->get_type_mgr();
  509. analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
  510. uint32_t float_type_id = type_mgr->GetFloatTypeId();
  511. const analysis::Type* v2_float_type = type_mgr->GetFloatVectorType(2);
  512. uint32_t v2_float_type_id = type_mgr->GetId(v2_float_type);
  513. uint32_t bool_id = type_mgr->GetBoolTypeId();
  514. InstructionBuilder ir_builder(
  515. ctx, inst,
  516. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  517. uint32_t input_id = inst->GetSingleWordInOperand(2);
  518. uint32_t glsl405_ext_inst_id =
  519. ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  520. if (glsl405_ext_inst_id == 0) {
  521. ctx->AddExtInstImport("GLSL.std.450");
  522. glsl405_ext_inst_id =
  523. ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  524. }
  525. // Get the constants that will be used.
  526. uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
  527. uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
  528. uint32_t f0_5_const_id = const_mgr->GetFloatConstId(0.5);
  529. const analysis::Constant* vec_const =
  530. const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id});
  531. uint32_t vec_const_id =
  532. const_mgr->GetDefiningInstruction(vec_const)->result_id();
  533. // Extract the input values.
  534. Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
  535. Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1});
  536. Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2});
  537. // Negate the input values.
  538. Instruction* nx =
  539. ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, x->result_id());
  540. Instruction* ny =
  541. ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, y->result_id());
  542. Instruction* nz =
  543. ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, z->result_id());
  544. // Get the abolsute values of the inputs.
  545. Instruction* ax = ir_builder.AddNaryExtendedInstruction(
  546. float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()});
  547. Instruction* ay = ir_builder.AddNaryExtendedInstruction(
  548. float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()});
  549. Instruction* az = ir_builder.AddNaryExtendedInstruction(
  550. float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()});
  551. // Find which values are negative. Used in later computations.
  552. Instruction* is_z_neg = ir_builder.AddBinaryOp(
  553. bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
  554. Instruction* is_y_neg = ir_builder.AddBinaryOp(
  555. bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
  556. Instruction* is_x_neg = ir_builder.AddBinaryOp(
  557. bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
  558. // Compute cubema
  559. Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
  560. float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
  561. {ax->result_id(), ay->result_id()});
  562. Instruction* amax = ir_builder.AddNaryExtendedInstruction(
  563. float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
  564. {az->result_id(), amax_x_y->result_id()});
  565. Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, spv::Op::OpFMul,
  566. f2_const_id, amax->result_id());
  567. // Do the comparisons needed for computing cubesc and cubetc.
  568. Instruction* is_z_max =
  569. ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
  570. az->result_id(), amax_x_y->result_id());
  571. Instruction* not_is_z_max = ir_builder.AddUnaryOp(
  572. bool_id, spv::Op::OpLogicalNot, is_z_max->result_id());
  573. Instruction* y_gr_x =
  574. ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
  575. ay->result_id(), ax->result_id());
  576. Instruction* is_y_max =
  577. ir_builder.AddBinaryOp(bool_id, spv::Op::OpLogicalAnd,
  578. not_is_z_max->result_id(), y_gr_x->result_id());
  579. // Select the correct value for cubesc.
  580. Instruction* cubesc_case_1 = ir_builder.AddSelect(
  581. float_type_id, is_z_neg->result_id(), nx->result_id(), x->result_id());
  582. Instruction* cubesc_case_2 = ir_builder.AddSelect(
  583. float_type_id, is_x_neg->result_id(), z->result_id(), nz->result_id());
  584. Instruction* sel =
  585. ir_builder.AddSelect(float_type_id, is_y_max->result_id(), x->result_id(),
  586. cubesc_case_2->result_id());
  587. Instruction* cubesc =
  588. ir_builder.AddSelect(float_type_id, is_z_max->result_id(),
  589. cubesc_case_1->result_id(), sel->result_id());
  590. // Select the correct value for cubetc.
  591. Instruction* cubetc_case_1 = ir_builder.AddSelect(
  592. float_type_id, is_y_neg->result_id(), nz->result_id(), z->result_id());
  593. Instruction* cubetc =
  594. ir_builder.AddSelect(float_type_id, is_y_max->result_id(),
  595. cubetc_case_1->result_id(), ny->result_id());
  596. // Do the division
  597. Instruction* cube = ir_builder.AddCompositeConstruct(
  598. v2_float_type_id, {cubesc->result_id(), cubetc->result_id()});
  599. Instruction* denom = ir_builder.AddCompositeConstruct(
  600. v2_float_type_id, {cubema->result_id(), cubema->result_id()});
  601. Instruction* div = ir_builder.AddBinaryOp(
  602. v2_float_type_id, spv::Op::OpFDiv, cube->result_id(), denom->result_id());
  603. // Get the final result by adding 0.5 to |div|.
  604. inst->SetOpcode(spv::Op::OpFAdd);
  605. Instruction::OperandList new_operands;
  606. new_operands.push_back({SPV_OPERAND_TYPE_ID, {div->result_id()}});
  607. new_operands.push_back({SPV_OPERAND_TYPE_ID, {vec_const_id}});
  608. inst->SetInOperands(std::move(new_operands));
  609. ctx->UpdateDefUse(inst);
  610. return true;
  611. }
  612. // A folding rule that will replace the CubeFaceIndexAMD extended
  613. // instruction in the SPV_AMD_gcn_shader_ballot. Returns true if the folding
  614. // is successful.
  615. //
  616. // The instruction
  617. //
  618. // %result = OpExtInst %float %1 CubeFaceIndexAMD %input
  619. //
  620. // with
  621. //
  622. // %x = OpCompositeExtract %float %input 0
  623. // %y = OpCompositeExtract %float %input 1
  624. // %z = OpCompositeExtract %float %input 2
  625. // %ax = OpExtInst %float %n_1 FAbs %x
  626. // %ay = OpExtInst %float %n_1 FAbs %y
  627. // %az = OpExtInst %float %n_1 FAbs %z
  628. // %is_z_neg = OpFOrdLessThan %bool %z %float_0
  629. // %is_y_neg = OpFOrdLessThan %bool %y %float_0
  630. // %is_x_neg = OpFOrdLessThan %bool %x %float_0
  631. // %amax_x_y = OpExtInst %float %n_1 FMax %ax %ay
  632. // %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y
  633. // %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax
  634. // %case_z = OpSelect %float %is_z_neg %float_5 %float4
  635. // %case_y = OpSelect %float %is_y_neg %float_3 %float2
  636. // %case_x = OpSelect %float %is_x_neg %float_1 %float0
  637. // %sel = OpSelect %float %y_gt_x %case_y %case_x
  638. // %result = OpSelect %float %is_z_max %case_z %sel
  639. //
  640. // Also adding the capabilities and builtins that are needed.
  641. bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
  642. const std::vector<const analysis::Constant*>&) {
  643. analysis::TypeManager* type_mgr = ctx->get_type_mgr();
  644. analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
  645. uint32_t float_type_id = type_mgr->GetFloatTypeId();
  646. uint32_t bool_id = type_mgr->GetBoolTypeId();
  647. InstructionBuilder ir_builder(
  648. ctx, inst,
  649. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  650. uint32_t input_id = inst->GetSingleWordInOperand(2);
  651. uint32_t glsl405_ext_inst_id =
  652. ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  653. if (glsl405_ext_inst_id == 0) {
  654. ctx->AddExtInstImport("GLSL.std.450");
  655. glsl405_ext_inst_id =
  656. ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  657. }
  658. // Get the constants that will be used.
  659. uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
  660. uint32_t f1_const_id = const_mgr->GetFloatConstId(1.0);
  661. uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
  662. uint32_t f3_const_id = const_mgr->GetFloatConstId(3.0);
  663. uint32_t f4_const_id = const_mgr->GetFloatConstId(4.0);
  664. uint32_t f5_const_id = const_mgr->GetFloatConstId(5.0);
  665. // Extract the input values.
  666. Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
  667. Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1});
  668. Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2});
  669. // Get the absolute values of the inputs.
  670. Instruction* ax = ir_builder.AddNaryExtendedInstruction(
  671. float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()});
  672. Instruction* ay = ir_builder.AddNaryExtendedInstruction(
  673. float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()});
  674. Instruction* az = ir_builder.AddNaryExtendedInstruction(
  675. float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()});
  676. // Find which values are negative. Used in later computations.
  677. Instruction* is_z_neg = ir_builder.AddBinaryOp(
  678. bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
  679. Instruction* is_y_neg = ir_builder.AddBinaryOp(
  680. bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
  681. Instruction* is_x_neg = ir_builder.AddBinaryOp(
  682. bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
  683. // Find the max value.
  684. Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
  685. float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
  686. {ax->result_id(), ay->result_id()});
  687. Instruction* is_z_max =
  688. ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
  689. az->result_id(), amax_x_y->result_id());
  690. Instruction* y_gr_x =
  691. ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
  692. ay->result_id(), ax->result_id());
  693. // Get the value for each case.
  694. Instruction* case_z = ir_builder.AddSelect(
  695. float_type_id, is_z_neg->result_id(), f5_const_id, f4_const_id);
  696. Instruction* case_y = ir_builder.AddSelect(
  697. float_type_id, is_y_neg->result_id(), f3_const_id, f2_const_id);
  698. Instruction* case_x = ir_builder.AddSelect(
  699. float_type_id, is_x_neg->result_id(), f1_const_id, f0_const_id);
  700. // Select the correct case.
  701. Instruction* sel =
  702. ir_builder.AddSelect(float_type_id, y_gr_x->result_id(),
  703. case_y->result_id(), case_x->result_id());
  704. // Get the final result by adding 0.5 to |div|.
  705. inst->SetOpcode(spv::Op::OpSelect);
  706. Instruction::OperandList new_operands;
  707. new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_z_max->result_id()}});
  708. new_operands.push_back({SPV_OPERAND_TYPE_ID, {case_z->result_id()}});
  709. new_operands.push_back({SPV_OPERAND_TYPE_ID, {sel->result_id()}});
  710. inst->SetInOperands(std::move(new_operands));
  711. ctx->UpdateDefUse(inst);
  712. return true;
  713. }
  714. // A folding rule that will replace the TimeAMD extended instruction in the
  715. // SPV_AMD_gcn_shader_ballot. It returns true if the folding is successful.
  716. // It returns False, otherwise.
  717. //
  718. // The instruction
  719. //
  720. // %result = OpExtInst %uint64 %1 TimeAMD
  721. //
  722. // with
  723. //
  724. // %result = OpReadClockKHR %uint64 %uint_3
  725. //
  726. // NOTE: TimeAMD uses subgroup scope (it is not a real time clock).
  727. bool ReplaceTimeAMD(IRContext* ctx, Instruction* inst,
  728. const std::vector<const analysis::Constant*>&) {
  729. InstructionBuilder ir_builder(
  730. ctx, inst,
  731. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  732. ctx->AddExtension("SPV_KHR_shader_clock");
  733. ctx->AddCapability(spv::Capability::ShaderClockKHR);
  734. inst->SetOpcode(spv::Op::OpReadClockKHR);
  735. Instruction::OperandList args;
  736. uint32_t subgroup_scope_id =
  737. ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
  738. args.push_back({SPV_OPERAND_TYPE_ID, {subgroup_scope_id}});
  739. inst->SetInOperands(std::move(args));
  740. ctx->UpdateDefUse(inst);
  741. return true;
  742. }
  743. class AmdExtFoldingRules : public FoldingRules {
  744. public:
  745. explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {}
  746. protected:
  747. virtual void AddFoldingRules() override {
  748. rules_[spv::Op::OpGroupIAddNonUniformAMD].push_back(
  749. ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformIAdd>);
  750. rules_[spv::Op::OpGroupFAddNonUniformAMD].push_back(
  751. ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFAdd>);
  752. rules_[spv::Op::OpGroupUMinNonUniformAMD].push_back(
  753. ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformUMin>);
  754. rules_[spv::Op::OpGroupSMinNonUniformAMD].push_back(
  755. ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformSMin>);
  756. rules_[spv::Op::OpGroupFMinNonUniformAMD].push_back(
  757. ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFMin>);
  758. rules_[spv::Op::OpGroupUMaxNonUniformAMD].push_back(
  759. ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformUMax>);
  760. rules_[spv::Op::OpGroupSMaxNonUniformAMD].push_back(
  761. ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformSMax>);
  762. rules_[spv::Op::OpGroupFMaxNonUniformAMD].push_back(
  763. ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFMax>);
  764. uint32_t extension_id =
  765. context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
  766. if (extension_id != 0) {
  767. ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}]
  768. .push_back(ReplaceSwizzleInvocations);
  769. ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
  770. .push_back(ReplaceSwizzleInvocationsMasked);
  771. ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
  772. ReplaceWriteInvocation);
  773. ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
  774. ReplaceMbcnt);
  775. }
  776. extension_id = context()->module()->GetExtInstImportId(
  777. "SPV_AMD_shader_trinary_minmax");
  778. if (extension_id != 0) {
  779. ext_rules_[{extension_id, FMin3AMD}].push_back(
  780. ReplaceTrinaryMinMax<GLSLstd450FMin>);
  781. ext_rules_[{extension_id, UMin3AMD}].push_back(
  782. ReplaceTrinaryMinMax<GLSLstd450UMin>);
  783. ext_rules_[{extension_id, SMin3AMD}].push_back(
  784. ReplaceTrinaryMinMax<GLSLstd450SMin>);
  785. ext_rules_[{extension_id, FMax3AMD}].push_back(
  786. ReplaceTrinaryMinMax<GLSLstd450FMax>);
  787. ext_rules_[{extension_id, UMax3AMD}].push_back(
  788. ReplaceTrinaryMinMax<GLSLstd450UMax>);
  789. ext_rules_[{extension_id, SMax3AMD}].push_back(
  790. ReplaceTrinaryMinMax<GLSLstd450SMax>);
  791. ext_rules_[{extension_id, FMid3AMD}].push_back(
  792. ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>);
  793. ext_rules_[{extension_id, UMid3AMD}].push_back(
  794. ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>);
  795. ext_rules_[{extension_id, SMid3AMD}].push_back(
  796. ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>);
  797. }
  798. extension_id =
  799. context()->module()->GetExtInstImportId("SPV_AMD_gcn_shader");
  800. if (extension_id != 0) {
  801. ext_rules_[{extension_id, CubeFaceCoordAMD}].push_back(
  802. ReplaceCubeFaceCoord);
  803. ext_rules_[{extension_id, CubeFaceIndexAMD}].push_back(
  804. ReplaceCubeFaceIndex);
  805. ext_rules_[{extension_id, TimeAMD}].push_back(ReplaceTimeAMD);
  806. }
  807. }
  808. };
  809. class AmdExtConstFoldingRules : public ConstantFoldingRules {
  810. public:
  811. AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {}
  812. protected:
  813. virtual void AddFoldingRules() override {}
  814. };
  815. } // namespace
  816. Pass::Status AmdExtensionToKhrPass::Process() {
  817. bool changed = false;
  818. // Traverse the body of the functions to replace instructions that require
  819. // the extensions.
  820. InstructionFolder folder(
  821. context(),
  822. std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())),
  823. MakeUnique<AmdExtConstFoldingRules>(context()));
  824. for (Function& func : *get_module()) {
  825. func.ForEachInst([&changed, &folder](Instruction* inst) {
  826. if (folder.FoldInstruction(inst)) {
  827. changed = true;
  828. }
  829. });
  830. }
  831. // Now that instruction that require the extensions have been removed, we can
  832. // remove the extension instructions.
  833. std::set<std::string> ext_to_remove = {"SPV_AMD_shader_ballot",
  834. "SPV_AMD_shader_trinary_minmax",
  835. "SPV_AMD_gcn_shader"};
  836. std::vector<Instruction*> to_be_killed;
  837. for (Instruction& inst : context()->module()->extensions()) {
  838. if (inst.opcode() == spv::Op::OpExtension) {
  839. if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
  840. to_be_killed.push_back(&inst);
  841. }
  842. }
  843. }
  844. for (Instruction& inst : context()->ext_inst_imports()) {
  845. if (inst.opcode() == spv::Op::OpExtInstImport) {
  846. if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
  847. to_be_killed.push_back(&inst);
  848. }
  849. }
  850. }
  851. for (Instruction* inst : to_be_killed) {
  852. context()->KillInst(inst);
  853. changed = true;
  854. }
  855. // The replacements that take place use instructions that are missing before
  856. // SPIR-V 1.3. If we changed something, we will have to make sure the version
  857. // is at least SPIR-V 1.3 to make sure those instruction can be used.
  858. if (changed) {
  859. uint32_t version = get_module()->version();
  860. if (version < 0x00010300 /*1.3*/) {
  861. get_module()->set_version(0x00010300);
  862. }
  863. }
  864. return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  865. }
  866. } // namespace opt
  867. } // namespace spvtools