transformation_adjust_branch_weights.cpp 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. // Copyright (c) 2020 André Perez Maselco
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_adjust_branch_weights.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_descriptor.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. namespace {
  20. const uint32_t kBranchWeightForTrueLabelIndex = 3;
  21. const uint32_t kBranchWeightForFalseLabelIndex = 4;
  22. } // namespace
  23. TransformationAdjustBranchWeights::TransformationAdjustBranchWeights(
  24. protobufs::TransformationAdjustBranchWeights message)
  25. : message_(std::move(message)) {}
  26. TransformationAdjustBranchWeights::TransformationAdjustBranchWeights(
  27. const protobufs::InstructionDescriptor& instruction_descriptor,
  28. const std::pair<uint32_t, uint32_t>& branch_weights) {
  29. *message_.mutable_instruction_descriptor() = instruction_descriptor;
  30. message_.mutable_branch_weights()->set_first(branch_weights.first);
  31. message_.mutable_branch_weights()->set_second(branch_weights.second);
  32. }
  33. bool TransformationAdjustBranchWeights::IsApplicable(
  34. opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
  35. auto instruction =
  36. FindInstruction(message_.instruction_descriptor(), ir_context);
  37. if (instruction == nullptr) {
  38. return false;
  39. }
  40. spv::Op opcode = static_cast<spv::Op>(
  41. message_.instruction_descriptor().target_instruction_opcode());
  42. assert(instruction->opcode() == opcode &&
  43. "The located instruction must have the same opcode as in the "
  44. "descriptor.");
  45. // Must be an OpBranchConditional instruction.
  46. if (opcode != spv::Op::OpBranchConditional) {
  47. return false;
  48. }
  49. assert((message_.branch_weights().first() != 0 ||
  50. message_.branch_weights().second() != 0) &&
  51. "At least one weight must be non-zero.");
  52. assert(message_.branch_weights().first() <=
  53. UINT32_MAX - message_.branch_weights().second() &&
  54. "The sum of the two weights must not be greater than UINT32_MAX.");
  55. return true;
  56. }
  57. void TransformationAdjustBranchWeights::Apply(
  58. opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
  59. auto instruction =
  60. FindInstruction(message_.instruction_descriptor(), ir_context);
  61. if (instruction->HasBranchWeights()) {
  62. instruction->SetOperand(kBranchWeightForTrueLabelIndex,
  63. {message_.branch_weights().first()});
  64. instruction->SetOperand(kBranchWeightForFalseLabelIndex,
  65. {message_.branch_weights().second()});
  66. } else {
  67. instruction->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER,
  68. {message_.branch_weights().first()}});
  69. instruction->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER,
  70. {message_.branch_weights().second()}});
  71. }
  72. }
  73. protobufs::Transformation TransformationAdjustBranchWeights::ToMessage() const {
  74. protobufs::Transformation result;
  75. *result.mutable_adjust_branch_weights() = message_;
  76. return result;
  77. }
  78. std::unordered_set<uint32_t> TransformationAdjustBranchWeights::GetFreshIds()
  79. const {
  80. return std::unordered_set<uint32_t>();
  81. }
  82. } // namespace fuzz
  83. } // namespace spvtools