|
|
@@ -0,0 +1,539 @@
|
|
|
+// Copyright (c) 2019 Google LLC.
|
|
|
+//
|
|
|
+// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+// you may not use this file except in compliance with the License.
|
|
|
+// You may obtain a copy of the License at
|
|
|
+//
|
|
|
+// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+//
|
|
|
+// Unless required by applicable law or agreed to in writing, software
|
|
|
+// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+// See the License for the specific language governing permissions and
|
|
|
+// limitations under the License.
|
|
|
+
|
|
|
+#include "source/opt/amd_ext_to_khr.h"
|
|
|
+
|
|
|
+#include "ir_builder.h"
|
|
|
+#include "source/opt/ir_context.h"
|
|
|
+#include "spv-amd-shader-ballot.insts.inc"
|
|
|
+#include "type_manager.h"
|
|
|
+
|
|
|
+namespace spvtools {
|
|
|
+namespace opt {
|
|
|
+
|
|
|
+namespace {
|
|
|
+
|
|
|
+enum ExtOpcodes {
|
|
|
+ AmdShaderBallotSwizzleInvocationsAMD = 1,
|
|
|
+ AmdShaderBallotSwizzleInvocationsMaskedAMD = 2,
|
|
|
+ AmdShaderBallotWriteInvocationAMD = 3,
|
|
|
+ AmdShaderBallotMbcntAMD = 4
|
|
|
+};
|
|
|
+
|
|
|
+analysis::Type* GetUIntType(IRContext* ctx) {
|
|
|
+ analysis::Integer int_type(32, false);
|
|
|
+ return ctx->get_type_mgr()->GetRegisteredType(&int_type);
|
|
|
+}
|
|
|
+
|
|
|
+// Returns a folding rule that will replace the opcode with |opcode| and add
|
|
|
+// the capabilities required. The folding rule assumes it is folding an
|
|
|
+// OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension.
|
|
|
+FoldingRule ReplaceGroupNonuniformOperationOpCode(SpvOp new_opcode) {
|
|
|
+ switch (new_opcode) {
|
|
|
+ case SpvOpGroupNonUniformIAdd:
|
|
|
+ case SpvOpGroupNonUniformFAdd:
|
|
|
+ case SpvOpGroupNonUniformUMin:
|
|
|
+ case SpvOpGroupNonUniformSMin:
|
|
|
+ case SpvOpGroupNonUniformFMin:
|
|
|
+ case SpvOpGroupNonUniformUMax:
|
|
|
+ case SpvOpGroupNonUniformSMax:
|
|
|
+ case SpvOpGroupNonUniformFMax:
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ assert(
|
|
|
+ false &&
|
|
|
+ "Should be replacing with a group non uniform arithmetic operation.");
|
|
|
+ }
|
|
|
+
|
|
|
+ return [new_opcode](IRContext* ctx, Instruction* inst,
|
|
|
+ const std::vector<const analysis::Constant*>&) {
|
|
|
+ switch (inst->opcode()) {
|
|
|
+ case SpvOpGroupIAddNonUniformAMD:
|
|
|
+ case SpvOpGroupFAddNonUniformAMD:
|
|
|
+ case SpvOpGroupUMinNonUniformAMD:
|
|
|
+ case SpvOpGroupSMinNonUniformAMD:
|
|
|
+ case SpvOpGroupFMinNonUniformAMD:
|
|
|
+ case SpvOpGroupUMaxNonUniformAMD:
|
|
|
+ case SpvOpGroupSMaxNonUniformAMD:
|
|
|
+ case SpvOpGroupFMaxNonUniformAMD:
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ assert(false &&
|
|
|
+ "Should be replacing a group non uniform arithmetic operation.");
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic);
|
|
|
+ inst->SetOpcode(new_opcode);
|
|
|
+ return true;
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+// Returns a folding rule that will replace the SwizzleInvocationsAMD extended
|
|
|
+// instruction in the SPV_AMD_shader_ballot extension.
|
|
|
+//
|
|
|
+// The instruction
|
|
|
+//
|
|
|
+// %offset = OpConstantComposite %v3uint %x %y %z %w
|
|
|
+// %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset
|
|
|
+//
|
|
|
+// is replaced with
|
|
|
+//
|
|
|
+// potentially new constants and types
|
|
|
+//
|
|
|
+// clang-format off
|
|
|
+// %uint_max = OpConstant %uint 0xFFFFFFFF
|
|
|
+// %v4uint = OpTypeVector %uint 4
|
|
|
+// %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
|
|
|
+// %null = OpConstantNull %type
|
|
|
+// clang-format on
|
|
|
+//
|
|
|
+// and the following code in the function body
|
|
|
+//
|
|
|
+// clang-format off
|
|
|
+// %id = OpLoad %uint %SubgroupLocalInvocationId
|
|
|
+// %quad_idx = OpBitwiseAnd %uint %id %uint_3
|
|
|
+// %quad_ldr = OpBitwiseXor %uint %id %quad_idx
|
|
|
+// %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx
|
|
|
+// %target_inv = OpIAdd %uint %quad_ldr %my_offset
|
|
|
+// %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
|
|
|
+// %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
|
|
|
+// %result = OpSelect %type %is_active %shuffle %null
|
|
|
+// clang-format on
|
|
|
+//
|
|
|
+// Also adding the capabilities and builtins that are needed.
|
|
|
+FoldingRule ReplaceSwizzleInvocations() {
|
|
|
+ return [](IRContext* ctx, Instruction* inst,
|
|
|
+ const std::vector<const analysis::Constant*>&) {
|
|
|
+ analysis::TypeManager* type_mgr = ctx->get_type_mgr();
|
|
|
+ analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
|
|
|
+
|
|
|
+ ctx->AddExtension("SPV_KHR_shader_ballot");
|
|
|
+ ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
|
|
|
+ ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
|
|
|
+
|
|
|
+ InstructionBuilder ir_builder(
|
|
|
+ ctx, inst,
|
|
|
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
|
|
+
|
|
|
+ uint32_t data_id = inst->GetSingleWordInOperand(2);
|
|
|
+ uint32_t offset_id = inst->GetSingleWordInOperand(3);
|
|
|
+
|
|
|
+ // Get the subgroup invocation id.
|
|
|
+ uint32_t var_id =
|
|
|
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
|
|
|
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
|
|
|
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
|
|
|
+ Instruction* var_ptr_type =
|
|
|
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
|
|
|
+ uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
|
|
|
+
|
|
|
+ Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
|
|
|
+
|
|
|
+ uint32_t quad_mask = ir_builder.GetUintConstantId(3);
|
|
|
+
|
|
|
+ // This gives the offset in the group of 4 of this invocation.
|
|
|
+ Instruction* quad_idx = ir_builder.AddBinaryOp(
|
|
|
+ uint_type_id, SpvOpBitwiseAnd, id->result_id(), quad_mask);
|
|
|
+
|
|
|
+ // Get the invocation id of the first invocation in the group of 4.
|
|
|
+ Instruction* quad_ldr = ir_builder.AddBinaryOp(
|
|
|
+ uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id());
|
|
|
+
|
|
|
+ // Get the offset of the target invocation from the offset vector.
|
|
|
+ Instruction* my_offset =
|
|
|
+ ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic,
|
|
|
+ offset_id, quad_idx->result_id());
|
|
|
+
|
|
|
+ // Determine the index of the invocation to read from.
|
|
|
+ Instruction* target_inv = ir_builder.AddBinaryOp(
|
|
|
+ uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id());
|
|
|
+
|
|
|
+ // Do the group operations
|
|
|
+ uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
|
|
|
+ uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
|
|
|
+ const auto* ballot_value_const = const_mgr->GetConstant(
|
|
|
+ type_mgr->GetUIntVectorType(4),
|
|
|
+ {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
|
|
|
+ Instruction* ballot_value =
|
|
|
+ const_mgr->GetDefiningInstruction(ballot_value_const);
|
|
|
+ Instruction* is_active = ir_builder.AddNaryOp(
|
|
|
+ type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
|
|
|
+ {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
|
|
|
+ Instruction* shuffle = ir_builder.AddNaryOp(
|
|
|
+ inst->type_id(), SpvOpGroupNonUniformShuffle,
|
|
|
+ {subgroup_scope, data_id, target_inv->result_id()});
|
|
|
+
|
|
|
+ // Create the null constant to use in the select.
|
|
|
+ const auto* null = const_mgr->GetConstant(
|
|
|
+ type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
|
|
|
+ Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
|
|
|
+
|
|
|
+ // Build the select.
|
|
|
+ inst->SetOpcode(SpvOpSelect);
|
|
|
+ Instruction::OperandList new_operands;
|
|
|
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
|
|
|
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
|
|
|
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
|
|
|
+
|
|
|
+ inst->SetInOperands(std::move(new_operands));
|
|
|
+ ctx->UpdateDefUse(inst);
|
|
|
+ return true;
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+// Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
|
|
|
+// extended instruction in the SPV_AMD_shader_ballot extension.
|
|
|
+//
|
|
|
+// The instruction
|
|
|
+//
|
|
|
+// %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z
|
|
|
+// %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask
|
|
|
+//
|
|
|
+// is replaced with
|
|
|
+//
|
|
|
+// potentially new constants and types
|
|
|
+//
|
|
|
+// clang-format off
|
|
|
+// %uint_mask_extend = OpConstant %uint 0xFFFFFFE0
|
|
|
+// %uint_max = OpConstant %uint 0xFFFFFFFF
|
|
|
+// %v4uint = OpTypeVector %uint 4
|
|
|
+// %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
|
|
|
+// clang-format on
|
|
|
+//
|
|
|
+// and the following code in the function body
|
|
|
+//
|
|
|
+// clang-format off
|
|
|
+// %id = OpLoad %uint %SubgroupLocalInvocationId
|
|
|
+// %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend
|
|
|
+// %and = OpBitwiseAnd %uint %id %and_mask
|
|
|
+// %or = OpBitwiseOr %uint %and %uint_y
|
|
|
+// %target_inv = OpBitwiseXor %uint %or %uint_z
|
|
|
+// %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
|
|
|
+// %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
|
|
|
+// %result = OpSelect %type %is_active %shuffle %uint_0
|
|
|
+// clang-format on
|
|
|
+//
|
|
|
+// Also adding the capabilities and builtins that are needed.
|
|
|
+FoldingRule ReplaceSwizzleInvocationsMasked() {
|
|
|
+ return [](IRContext* ctx, Instruction* inst,
|
|
|
+ const std::vector<const analysis::Constant*>&) {
|
|
|
+ analysis::TypeManager* type_mgr = ctx->get_type_mgr();
|
|
|
+ analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
|
|
|
+ analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
|
|
|
+
|
|
|
+ // ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
|
|
|
+ ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
|
|
|
+ ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
|
|
|
+
|
|
|
+ InstructionBuilder ir_builder(
|
|
|
+ ctx, inst,
|
|
|
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
|
|
+
|
|
|
+ // Get the operands to inst, and the components of the mask
|
|
|
+ uint32_t data_id = inst->GetSingleWordInOperand(2);
|
|
|
+
|
|
|
+ Instruction* mask_inst =
|
|
|
+ def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
|
|
|
+ assert(mask_inst->opcode() == SpvOpConstantComposite &&
|
|
|
+ "The mask is suppose to be a vector constant.");
|
|
|
+ assert(mask_inst->NumInOperands() == 3 &&
|
|
|
+ "The mask is suppose to have 3 components.");
|
|
|
+
|
|
|
+ uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
|
|
|
+ uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
|
|
|
+ uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
|
|
|
+
|
|
|
+ // Get the subgroup invocation id.
|
|
|
+ uint32_t var_id =
|
|
|
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
|
|
|
+ ctx->AddExtension("SPV_KHR_shader_ballot");
|
|
|
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
|
|
|
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
|
|
|
+ Instruction* var_ptr_type =
|
|
|
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
|
|
|
+ uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
|
|
|
+
|
|
|
+ Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
|
|
|
+
|
|
|
+ // Do the bitwise operations.
|
|
|
+ uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
|
|
|
+ Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr,
|
|
|
+ uint_x, mask_extended);
|
|
|
+ Instruction* and_result = ir_builder.AddBinaryOp(
|
|
|
+ uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id());
|
|
|
+ Instruction* or_result = ir_builder.AddBinaryOp(
|
|
|
+ uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y);
|
|
|
+ Instruction* target_inv = ir_builder.AddBinaryOp(
|
|
|
+ uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z);
|
|
|
+
|
|
|
+ // Do the group operations
|
|
|
+ uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
|
|
|
+ uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
|
|
|
+ const auto* ballot_value_const = const_mgr->GetConstant(
|
|
|
+ type_mgr->GetUIntVectorType(4),
|
|
|
+ {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
|
|
|
+ Instruction* ballot_value =
|
|
|
+ const_mgr->GetDefiningInstruction(ballot_value_const);
|
|
|
+ Instruction* is_active = ir_builder.AddNaryOp(
|
|
|
+ type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
|
|
|
+ {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
|
|
|
+ Instruction* shuffle = ir_builder.AddNaryOp(
|
|
|
+ inst->type_id(), SpvOpGroupNonUniformShuffle,
|
|
|
+ {subgroup_scope, data_id, target_inv->result_id()});
|
|
|
+
|
|
|
+ // Create the null constant to use in the select.
|
|
|
+ const auto* null = const_mgr->GetConstant(
|
|
|
+ type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
|
|
|
+ Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
|
|
|
+
|
|
|
+ // Build the select.
|
|
|
+ inst->SetOpcode(SpvOpSelect);
|
|
|
+ Instruction::OperandList new_operands;
|
|
|
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
|
|
|
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
|
|
|
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
|
|
|
+
|
|
|
+ inst->SetInOperands(std::move(new_operands));
|
|
|
+ ctx->UpdateDefUse(inst);
|
|
|
+ return true;
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+// Returns a folding rule that will replace the WriteInvocationAMD extended
|
|
|
+// instruction in the SPV_AMD_shader_ballot extension.
|
|
|
+//
|
|
|
+// The instruction
|
|
|
+//
|
|
|
+// clang-format off
|
|
|
+// %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index
|
|
|
+// clang-format on
|
|
|
+//
|
|
|
+// with
|
|
|
+//
|
|
|
+// %id = OpLoad %uint %SubgroupLocalInvocationId
|
|
|
+// %cmp = OpIEqual %bool %id %invocation_index
|
|
|
+// %result = OpSelect %type %cmp %write_value %input_value
|
|
|
+//
|
|
|
+// Also adding the capabilities and builtins that are needed.
|
|
|
+FoldingRule ReplaceWriteInvocation() {
|
|
|
+ return [](IRContext* ctx, Instruction* inst,
|
|
|
+ const std::vector<const analysis::Constant*>&) {
|
|
|
+ uint32_t var_id =
|
|
|
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
|
|
|
+ ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
|
|
|
+ ctx->AddExtension("SPV_KHR_shader_ballot");
|
|
|
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
|
|
|
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
|
|
|
+ Instruction* var_ptr_type =
|
|
|
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
|
|
|
+
|
|
|
+ InstructionBuilder ir_builder(
|
|
|
+ ctx, inst,
|
|
|
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
|
|
+ Instruction* t =
|
|
|
+ ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
|
|
|
+ analysis::Bool bool_type;
|
|
|
+ uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
|
|
|
+ Instruction* cmp =
|
|
|
+ ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(),
|
|
|
+ inst->GetSingleWordInOperand(4));
|
|
|
+
|
|
|
+ // Build a select.
|
|
|
+ inst->SetOpcode(SpvOpSelect);
|
|
|
+ Instruction::OperandList new_operands;
|
|
|
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
|
|
|
+ new_operands.push_back(inst->GetInOperand(3));
|
|
|
+ new_operands.push_back(inst->GetInOperand(2));
|
|
|
+
|
|
|
+ inst->SetInOperands(std::move(new_operands));
|
|
|
+ ctx->UpdateDefUse(inst);
|
|
|
+ return true;
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+// Returns a folding rule that will replace the MbcntAMD extended instruction in
|
|
|
+// the SPV_AMD_shader_ballot extension.
|
|
|
+//
|
|
|
+// The instruction
|
|
|
+//
|
|
|
+// %result = OpExtInst %uint %1 MbcntAMD %mask
|
|
|
+//
|
|
|
+// with
|
|
|
+//
|
|
|
+// Get SubgroupLtMask and convert the first 64-bits into a uint64_t because
|
|
|
+// AMD's shader compiler expects a 64-bit integer mask.
|
|
|
+//
|
|
|
+// %var = OpLoad %v4uint %SubgroupLtMaskKHR
|
|
|
+// %shuffle = OpVectorShuffle %v2uint %var %var 0 1
|
|
|
+// %cast = OpBitcast %ulong %shuffle
|
|
|
+//
|
|
|
+// Perform the mask and count the bits.
|
|
|
+//
|
|
|
+// %and = OpBitwiseAnd %ulong %cast %mask
|
|
|
+// %result = OpBitCount %uint %and
|
|
|
+//
|
|
|
+// Also adding the capabilities and builtins that are needed.
|
|
|
+FoldingRule ReplaceMbcnt() {
|
|
|
+ return [](IRContext* context, Instruction* inst,
|
|
|
+ const std::vector<const analysis::Constant*>&) {
|
|
|
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
|
|
|
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
|
|
+
|
|
|
+ uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask);
|
|
|
+ assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
|
|
|
+ context->AddCapability(SpvCapabilityGroupNonUniformBallot);
|
|
|
+ Instruction* var_inst = def_use_mgr->GetDef(var_id);
|
|
|
+ Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
|
|
|
+ Instruction* var_type =
|
|
|
+ def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
|
|
|
+ assert(var_type->opcode() == SpvOpTypeVector &&
|
|
|
+ "Variable is suppose to be a vector of 4 ints");
|
|
|
+
|
|
|
+ // Get the type for the shuffle.
|
|
|
+ analysis::Vector temp_type(GetUIntType(context), 2);
|
|
|
+ const analysis::Type* shuffle_type =
|
|
|
+ context->get_type_mgr()->GetRegisteredType(&temp_type);
|
|
|
+ uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
|
|
|
+
|
|
|
+ uint32_t mask_id = inst->GetSingleWordInOperand(2);
|
|
|
+ Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
|
|
|
+
|
|
|
+ // Testing with amd's shader compiler shows that a 64-bit mask is expected.
|
|
|
+ assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
|
|
|
+ assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
|
|
|
+
|
|
|
+ InstructionBuilder ir_builder(
|
|
|
+ context, inst,
|
|
|
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
|
|
|
+ Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
|
|
|
+ Instruction* shuffle = ir_builder.AddVectorShuffle(
|
|
|
+ shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
|
|
|
+ Instruction* bitcast = ir_builder.AddUnaryOp(
|
|
|
+ mask_inst->type_id(), SpvOpBitcast, shuffle->result_id());
|
|
|
+ Instruction* t = ir_builder.AddBinaryOp(
|
|
|
+ mask_inst->type_id(), SpvOpBitwiseAnd, bitcast->result_id(), mask_id);
|
|
|
+
|
|
|
+ inst->SetOpcode(SpvOpBitCount);
|
|
|
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
|
|
|
+ context->UpdateDefUse(inst);
|
|
|
+ return true;
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+class AmdExtFoldingRules : public FoldingRules {
|
|
|
+ public:
|
|
|
+ explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {}
|
|
|
+
|
|
|
+ protected:
|
|
|
+ virtual void AddFoldingRules() override {
|
|
|
+ rules_[SpvOpGroupIAddNonUniformAMD].push_back(
|
|
|
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformIAdd));
|
|
|
+ rules_[SpvOpGroupFAddNonUniformAMD].push_back(
|
|
|
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFAdd));
|
|
|
+ rules_[SpvOpGroupUMinNonUniformAMD].push_back(
|
|
|
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformUMin));
|
|
|
+ rules_[SpvOpGroupSMinNonUniformAMD].push_back(
|
|
|
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformSMin));
|
|
|
+ rules_[SpvOpGroupFMinNonUniformAMD].push_back(
|
|
|
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFMin));
|
|
|
+ rules_[SpvOpGroupUMaxNonUniformAMD].push_back(
|
|
|
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformUMax));
|
|
|
+ rules_[SpvOpGroupSMaxNonUniformAMD].push_back(
|
|
|
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformSMax));
|
|
|
+ rules_[SpvOpGroupFMaxNonUniformAMD].push_back(
|
|
|
+ ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFMax));
|
|
|
+
|
|
|
+ uint32_t extension_id =
|
|
|
+ context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
|
|
|
+
|
|
|
+ ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}].push_back(
|
|
|
+ ReplaceSwizzleInvocations());
|
|
|
+ ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
|
|
|
+ .push_back(ReplaceSwizzleInvocationsMasked());
|
|
|
+ ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
|
|
|
+ ReplaceWriteInvocation());
|
|
|
+ ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
|
|
|
+ ReplaceMbcnt());
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+class AmdExtConstFoldingRules : public ConstantFoldingRules {
|
|
|
+ public:
|
|
|
+ AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {}
|
|
|
+
|
|
|
+ protected:
|
|
|
+ virtual void AddFoldingRules() override {}
|
|
|
+};
|
|
|
+
|
|
|
+} // namespace
|
|
|
+
|
|
|
+Pass::Status AmdExtensionToKhrPass::Process() {
|
|
|
+ bool changed = false;
|
|
|
+
|
|
|
+ // Traverse the body of the functions to replace instructions that require
|
|
|
+ // the extensions.
|
|
|
+ InstructionFolder folder(
|
|
|
+ context(),
|
|
|
+ std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())),
|
|
|
+ MakeUnique<AmdExtConstFoldingRules>(context()));
|
|
|
+ for (Function& func : *get_module()) {
|
|
|
+ func.ForEachInst([&changed, &folder](Instruction* inst) {
|
|
|
+ if (folder.FoldInstruction(inst)) {
|
|
|
+ changed = true;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now that instruction that require the extensions have been removed, we can
|
|
|
+ // remove the extension instructions.
|
|
|
+ std::vector<Instruction*> to_be_killed;
|
|
|
+ for (Instruction& inst : context()->module()->extensions()) {
|
|
|
+ if (inst.opcode() == SpvOpExtension) {
|
|
|
+ if (!strcmp("SPV_AMD_shader_ballot",
|
|
|
+ reinterpret_cast<const char*>(
|
|
|
+ &(inst.GetInOperand(0).words[0])))) {
|
|
|
+ to_be_killed.push_back(&inst);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for (Instruction& inst : context()->ext_inst_imports()) {
|
|
|
+ if (inst.opcode() == SpvOpExtInstImport) {
|
|
|
+ if (!strcmp("SPV_AMD_shader_ballot",
|
|
|
+ reinterpret_cast<const char*>(
|
|
|
+ &(inst.GetInOperand(0).words[0])))) {
|
|
|
+ to_be_killed.push_back(&inst);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for (Instruction* inst : to_be_killed) {
|
|
|
+ context()->KillInst(inst);
|
|
|
+ changed = true;
|
|
|
+ }
|
|
|
+
|
|
|
+ // The replacements that take place use instructions that are missing before
|
|
|
+ // SPIR-V 1.3. If we changed something, we will have to make sure the version
|
|
|
+ // is at least SPIR-V 1.3 to make sure those instruction can be used.
|
|
|
+ if (changed) {
|
|
|
+ uint32_t version = get_module()->version();
|
|
|
+ if (version < 0x00010300 /*1.3*/) {
|
|
|
+ get_module()->set_version(0x00010300);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
|
|
|
+}
|
|
|
+
|
|
|
+} // namespace opt
|
|
|
+} // namespace spvtools
|