transformation_set_loop_control.cpp 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_set_loop_control.h"
  15. namespace spvtools {
  16. namespace fuzz {
  17. TransformationSetLoopControl::TransformationSetLoopControl(
  18. protobufs::TransformationSetLoopControl message)
  19. : message_(std::move(message)) {}
  20. TransformationSetLoopControl::TransformationSetLoopControl(
  21. uint32_t block_id, uint32_t loop_control, uint32_t peel_count,
  22. uint32_t partial_count) {
  23. message_.set_block_id(block_id);
  24. message_.set_loop_control(loop_control);
  25. message_.set_peel_count(peel_count);
  26. message_.set_partial_count(partial_count);
  27. }
  28. bool TransformationSetLoopControl::IsApplicable(
  29. opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
  30. // |message_.block_id| must identify a block that ends with OpLoopMerge.
  31. auto block = ir_context->get_instr_block(message_.block_id());
  32. if (!block) {
  33. return false;
  34. }
  35. auto merge_inst = block->GetMergeInst();
  36. if (!merge_inst || merge_inst->opcode() != spv::Op::OpLoopMerge) {
  37. return false;
  38. }
  39. // We assert that the transformation does not try to set any meaningless bits
  40. // of the loop control mask.
  41. uint32_t all_loop_control_mask_bits_set = uint32_t(
  42. spv::LoopControlMask::Unroll | spv::LoopControlMask::DontUnroll |
  43. spv::LoopControlMask::DependencyInfinite |
  44. spv::LoopControlMask::DependencyLength |
  45. spv::LoopControlMask::MinIterations |
  46. spv::LoopControlMask::MaxIterations |
  47. spv::LoopControlMask::IterationMultiple |
  48. spv::LoopControlMask::PeelCount | spv::LoopControlMask::PartialCount);
  49. // The variable is only used in an assertion; the following keeps release-mode
  50. // compilers happy.
  51. (void)(all_loop_control_mask_bits_set);
  52. // No additional bits should be set.
  53. assert(!(message_.loop_control() & ~all_loop_control_mask_bits_set));
  54. // Grab the loop control mask currently associated with the OpLoopMerge
  55. // instruction.
  56. auto existing_loop_control_mask =
  57. merge_inst->GetSingleWordInOperand(kLoopControlMaskInOperandIndex);
  58. // Check that there is no attempt to set one of the loop controls that
  59. // requires guarantees to hold.
  60. for (spv::LoopControlMask mask : {spv::LoopControlMask::DependencyInfinite,
  61. spv::LoopControlMask::DependencyLength,
  62. spv::LoopControlMask::MinIterations,
  63. spv::LoopControlMask::MaxIterations,
  64. spv::LoopControlMask::IterationMultiple}) {
  65. // We have a problem if this loop control bit was not set in the original
  66. // loop control mask but is set by the transformation.
  67. if (LoopControlBitIsAddedByTransformation(mask,
  68. existing_loop_control_mask)) {
  69. return false;
  70. }
  71. }
  72. // Check that PeelCount and PartialCount are supported if used.
  73. if ((message_.loop_control() & uint32_t(spv::LoopControlMask::PeelCount)) &&
  74. !PeelCountIsSupported(ir_context)) {
  75. return false;
  76. }
  77. if ((message_.loop_control() &
  78. uint32_t(spv::LoopControlMask::PartialCount)) &&
  79. !PartialCountIsSupported(ir_context)) {
  80. return false;
  81. }
  82. if (message_.peel_count() > 0 &&
  83. !(message_.loop_control() & uint32_t(spv::LoopControlMask::PeelCount))) {
  84. // Peel count provided, but peel count mask bit not set.
  85. return false;
  86. }
  87. if (message_.partial_count() > 0 &&
  88. !(message_.loop_control() &
  89. uint32_t(spv::LoopControlMask::PartialCount))) {
  90. // Partial count provided, but partial count mask bit not set.
  91. return false;
  92. }
  93. // We must not set both 'don't unroll' and one of 'peel count' or 'partial
  94. // count'.
  95. return !(
  96. (message_.loop_control() & uint32_t(spv::LoopControlMask::DontUnroll)) &&
  97. (message_.loop_control() & uint32_t(spv::LoopControlMask::PeelCount |
  98. spv::LoopControlMask::PartialCount)));
  99. }
  100. void TransformationSetLoopControl::Apply(
  101. opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
  102. // Grab the loop merge instruction and its associated loop control mask.
  103. auto merge_inst =
  104. ir_context->get_instr_block(message_.block_id())->GetMergeInst();
  105. auto existing_loop_control_mask =
  106. merge_inst->GetSingleWordInOperand(kLoopControlMaskInOperandIndex);
  107. // We are going to replace the OpLoopMerge's operands with this list.
  108. opt::Instruction::OperandList new_operands;
  109. // We add the existing merge block and continue target ids.
  110. new_operands.push_back(merge_inst->GetInOperand(0));
  111. new_operands.push_back(merge_inst->GetInOperand(1));
  112. // We use the loop control mask from the transformation.
  113. new_operands.push_back(
  114. {SPV_OPERAND_TYPE_LOOP_CONTROL, {message_.loop_control()}});
  115. // It remains to determine what literals to provide, in association with
  116. // the new loop control mask.
  117. //
  118. // For the loop controls that require guarantees to hold about the number
  119. // of loop iterations, we need to keep, from the original OpLoopMerge, any
  120. // literals associated with loop control bits that are still set.
  121. uint32_t literal_index = 0; // Indexes into the literals from the original
  122. // instruction.
  123. for (spv::LoopControlMask mask : {spv::LoopControlMask::DependencyLength,
  124. spv::LoopControlMask::MinIterations,
  125. spv::LoopControlMask::MaxIterations,
  126. spv::LoopControlMask::IterationMultiple}) {
  127. // Check whether the bit was set in the original loop control mask.
  128. if (existing_loop_control_mask & uint32_t(mask)) {
  129. // Check whether the bit is set in the new loop control mask.
  130. if (message_.loop_control() & uint32_t(mask)) {
  131. // Add the associated literal to our sequence of replacement operands.
  132. new_operands.push_back(
  133. {SPV_OPERAND_TYPE_LITERAL_INTEGER,
  134. {merge_inst->GetSingleWordInOperand(
  135. kLoopControlFirstLiteralInOperandIndex + literal_index)}});
  136. }
  137. // Increment our index into the original loop control mask's literals,
  138. // whether or not the bit was set in the new mask.
  139. literal_index++;
  140. }
  141. }
  142. // If PeelCount is set in the new mask, |message_.peel_count| provides the
  143. // associated peel count.
  144. if (message_.loop_control() & uint32_t(spv::LoopControlMask::PeelCount)) {
  145. new_operands.push_back(
  146. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {message_.peel_count()}});
  147. }
  148. // Similar, but for PartialCount.
  149. if (message_.loop_control() & uint32_t(spv::LoopControlMask::PartialCount)) {
  150. new_operands.push_back(
  151. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {message_.partial_count()}});
  152. }
  153. // Replace the input operands of the OpLoopMerge with the new operands we have
  154. // accumulated.
  155. merge_inst->SetInOperands(std::move(new_operands));
  156. }
  157. protobufs::Transformation TransformationSetLoopControl::ToMessage() const {
  158. protobufs::Transformation result;
  159. *result.mutable_set_loop_control() = message_;
  160. return result;
  161. }
  162. bool TransformationSetLoopControl::LoopControlBitIsAddedByTransformation(
  163. spv::LoopControlMask loop_control_single_bit_mask,
  164. uint32_t existing_loop_control_mask) const {
  165. return !(uint32_t(loop_control_single_bit_mask) &
  166. existing_loop_control_mask) &&
  167. (uint32_t(loop_control_single_bit_mask) & message_.loop_control());
  168. }
  169. bool TransformationSetLoopControl::PartialCountIsSupported(
  170. opt::IRContext* ir_context) {
  171. // TODO(afd): We capture the environments for which this loop control is
  172. // definitely not supported. The check should be refined on demand for other
  173. // target environments.
  174. switch (ir_context->grammar().target_env()) {
  175. case SPV_ENV_UNIVERSAL_1_0:
  176. case SPV_ENV_UNIVERSAL_1_1:
  177. case SPV_ENV_UNIVERSAL_1_2:
  178. case SPV_ENV_UNIVERSAL_1_3:
  179. case SPV_ENV_VULKAN_1_0:
  180. case SPV_ENV_VULKAN_1_1:
  181. return false;
  182. default:
  183. return true;
  184. }
  185. }
  186. bool TransformationSetLoopControl::PeelCountIsSupported(
  187. opt::IRContext* ir_context) {
  188. // TODO(afd): We capture the environments for which this loop control is
  189. // definitely not supported. The check should be refined on demand for other
  190. // target environments.
  191. switch (ir_context->grammar().target_env()) {
  192. case SPV_ENV_UNIVERSAL_1_0:
  193. case SPV_ENV_UNIVERSAL_1_1:
  194. case SPV_ENV_UNIVERSAL_1_2:
  195. case SPV_ENV_UNIVERSAL_1_3:
  196. case SPV_ENV_VULKAN_1_0:
  197. case SPV_ENV_VULKAN_1_1:
  198. return false;
  199. default:
  200. return true;
  201. }
  202. }
  203. std::unordered_set<uint32_t> TransformationSetLoopControl::GetFreshIds() const {
  204. return std::unordered_set<uint32_t>();
  205. }
  206. } // namespace fuzz
  207. } // namespace spvtools