transformation_adjust_branch_weights_test.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. // Copyright (c) 2020 André Perez Maselco
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_adjust_branch_weights.h"
  15. #include "gtest/gtest.h"
  16. #include "source/fuzz/fuzzer_util.h"
  17. #include "source/fuzz/instruction_descriptor.h"
  18. #include "test/fuzz/fuzz_test_util.h"
  19. namespace spvtools {
  20. namespace fuzz {
  21. namespace {
  22. TEST(TransformationAdjustBranchWeightsTest, IsApplicableTest) {
  23. std::string shader = R"(
  24. OpCapability Shader
  25. %1 = OpExtInstImport "GLSL.std.450"
  26. OpMemoryModel Logical GLSL450
  27. OpEntryPoint Fragment %4 "main" %51 %27
  28. OpExecutionMode %4 OriginUpperLeft
  29. OpSource ESSL 310
  30. OpName %4 "main"
  31. OpName %25 "buf"
  32. OpMemberName %25 0 "value"
  33. OpName %27 ""
  34. OpName %51 "color"
  35. OpMemberDecorate %25 0 Offset 0
  36. OpDecorate %25 Block
  37. OpDecorate %27 DescriptorSet 0
  38. OpDecorate %27 Binding 0
  39. OpDecorate %51 Location 0
  40. %2 = OpTypeVoid
  41. %3 = OpTypeFunction %2
  42. %6 = OpTypeFloat 32
  43. %7 = OpTypeVector %6 4
  44. %150 = OpTypeVector %6 2
  45. %10 = OpConstant %6 0.300000012
  46. %11 = OpConstant %6 0.400000006
  47. %12 = OpConstant %6 0.5
  48. %13 = OpConstant %6 1
  49. %14 = OpConstantComposite %7 %10 %11 %12 %13
  50. %15 = OpTypeInt 32 1
  51. %18 = OpConstant %15 0
  52. %25 = OpTypeStruct %6
  53. %26 = OpTypePointer Uniform %25
  54. %27 = OpVariable %26 Uniform
  55. %28 = OpTypePointer Uniform %6
  56. %32 = OpTypeBool
  57. %103 = OpConstantTrue %32
  58. %34 = OpConstant %6 0.100000001
  59. %48 = OpConstant %15 1
  60. %50 = OpTypePointer Output %7
  61. %51 = OpVariable %50 Output
  62. %100 = OpTypePointer Function %6
  63. %4 = OpFunction %2 None %3
  64. %5 = OpLabel
  65. %101 = OpVariable %100 Function
  66. %102 = OpVariable %100 Function
  67. OpBranch %19
  68. %19 = OpLabel
  69. %60 = OpPhi %7 %14 %5 %58 %20
  70. %59 = OpPhi %15 %18 %5 %49 %20
  71. %29 = OpAccessChain %28 %27 %18
  72. %30 = OpLoad %6 %29
  73. %31 = OpConvertFToS %15 %30
  74. %33 = OpSLessThan %32 %59 %31
  75. OpLoopMerge %21 %20 None
  76. OpBranchConditional %33 %20 %21 1 2
  77. %20 = OpLabel
  78. %39 = OpCompositeExtract %6 %60 0
  79. %40 = OpFAdd %6 %39 %34
  80. %55 = OpCompositeInsert %7 %40 %60 0
  81. %44 = OpCompositeExtract %6 %60 1
  82. %45 = OpFSub %6 %44 %34
  83. %58 = OpCompositeInsert %7 %45 %55 1
  84. %49 = OpIAdd %15 %59 %48
  85. OpBranch %19
  86. %21 = OpLabel
  87. OpStore %51 %60
  88. OpSelectionMerge %105 None
  89. OpBranchConditional %103 %104 %105
  90. %104 = OpLabel
  91. OpBranch %105
  92. %105 = OpLabel
  93. OpReturn
  94. OpFunctionEnd
  95. )";
  96. const auto env = SPV_ENV_UNIVERSAL_1_5;
  97. const auto consumer = nullptr;
  98. const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
  99. spvtools::ValidatorOptions validator_options;
  100. ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
  101. kConsoleMessageConsumer));
  102. TransformationContext transformation_context(
  103. MakeUnique<FactManager>(context.get()), validator_options);
  104. // Tests OpBranchConditional instruction with weights.
  105. auto instruction_descriptor =
  106. MakeInstructionDescriptor(33, spv::Op::OpBranchConditional, 0);
  107. auto transformation =
  108. TransformationAdjustBranchWeights(instruction_descriptor, {0, 1});
  109. ASSERT_TRUE(
  110. transformation.IsApplicable(context.get(), transformation_context));
  111. // Tests the two branch weights equal to 0.
  112. instruction_descriptor =
  113. MakeInstructionDescriptor(33, spv::Op::OpBranchConditional, 0);
  114. transformation =
  115. TransformationAdjustBranchWeights(instruction_descriptor, {0, 0});
  116. #ifndef NDEBUG
  117. ASSERT_DEATH(
  118. transformation.IsApplicable(context.get(), transformation_context),
  119. "At least one weight must be non-zero");
  120. #endif
  121. // Tests 32-bit unsigned integer overflow.
  122. instruction_descriptor =
  123. MakeInstructionDescriptor(33, spv::Op::OpBranchConditional, 0);
  124. transformation = TransformationAdjustBranchWeights(instruction_descriptor,
  125. {UINT32_MAX, 0});
  126. ASSERT_TRUE(
  127. transformation.IsApplicable(context.get(), transformation_context));
  128. instruction_descriptor =
  129. MakeInstructionDescriptor(33, spv::Op::OpBranchConditional, 0);
  130. transformation = TransformationAdjustBranchWeights(instruction_descriptor,
  131. {1, UINT32_MAX});
  132. #ifndef NDEBUG
  133. ASSERT_DEATH(
  134. transformation.IsApplicable(context.get(), transformation_context),
  135. "The sum of the two weights must not be greater than UINT32_MAX");
  136. #endif
  137. // Tests OpBranchConditional instruction with no weights.
  138. instruction_descriptor =
  139. MakeInstructionDescriptor(21, spv::Op::OpBranchConditional, 0);
  140. transformation =
  141. TransformationAdjustBranchWeights(instruction_descriptor, {0, 1});
  142. ASSERT_TRUE(
  143. transformation.IsApplicable(context.get(), transformation_context));
  144. // Tests non-OpBranchConditional instructions.
  145. instruction_descriptor = MakeInstructionDescriptor(2, spv::Op::OpTypeVoid, 0);
  146. transformation =
  147. TransformationAdjustBranchWeights(instruction_descriptor, {5, 6});
  148. ASSERT_FALSE(
  149. transformation.IsApplicable(context.get(), transformation_context));
  150. instruction_descriptor = MakeInstructionDescriptor(20, spv::Op::OpLabel, 0);
  151. transformation =
  152. TransformationAdjustBranchWeights(instruction_descriptor, {1, 2});
  153. ASSERT_FALSE(
  154. transformation.IsApplicable(context.get(), transformation_context));
  155. instruction_descriptor = MakeInstructionDescriptor(49, spv::Op::OpIAdd, 0);
  156. transformation =
  157. TransformationAdjustBranchWeights(instruction_descriptor, {1, 2});
  158. ASSERT_FALSE(
  159. transformation.IsApplicable(context.get(), transformation_context));
  160. }
  161. TEST(TransformationAdjustBranchWeightsTest, ApplyTest) {
  162. std::string shader = R"(
  163. OpCapability Shader
  164. %1 = OpExtInstImport "GLSL.std.450"
  165. OpMemoryModel Logical GLSL450
  166. OpEntryPoint Fragment %4 "main" %51 %27
  167. OpExecutionMode %4 OriginUpperLeft
  168. OpSource ESSL 310
  169. OpName %4 "main"
  170. OpName %25 "buf"
  171. OpMemberName %25 0 "value"
  172. OpName %27 ""
  173. OpName %51 "color"
  174. OpMemberDecorate %25 0 Offset 0
  175. OpDecorate %25 Block
  176. OpDecorate %27 DescriptorSet 0
  177. OpDecorate %27 Binding 0
  178. OpDecorate %51 Location 0
  179. %2 = OpTypeVoid
  180. %3 = OpTypeFunction %2
  181. %6 = OpTypeFloat 32
  182. %7 = OpTypeVector %6 4
  183. %150 = OpTypeVector %6 2
  184. %10 = OpConstant %6 0.300000012
  185. %11 = OpConstant %6 0.400000006
  186. %12 = OpConstant %6 0.5
  187. %13 = OpConstant %6 1
  188. %14 = OpConstantComposite %7 %10 %11 %12 %13
  189. %15 = OpTypeInt 32 1
  190. %18 = OpConstant %15 0
  191. %25 = OpTypeStruct %6
  192. %26 = OpTypePointer Uniform %25
  193. %27 = OpVariable %26 Uniform
  194. %28 = OpTypePointer Uniform %6
  195. %32 = OpTypeBool
  196. %103 = OpConstantTrue %32
  197. %34 = OpConstant %6 0.100000001
  198. %48 = OpConstant %15 1
  199. %50 = OpTypePointer Output %7
  200. %51 = OpVariable %50 Output
  201. %100 = OpTypePointer Function %6
  202. %4 = OpFunction %2 None %3
  203. %5 = OpLabel
  204. %101 = OpVariable %100 Function
  205. %102 = OpVariable %100 Function
  206. OpBranch %19
  207. %19 = OpLabel
  208. %60 = OpPhi %7 %14 %5 %58 %20
  209. %59 = OpPhi %15 %18 %5 %49 %20
  210. %29 = OpAccessChain %28 %27 %18
  211. %30 = OpLoad %6 %29
  212. %31 = OpConvertFToS %15 %30
  213. %33 = OpSLessThan %32 %59 %31
  214. OpLoopMerge %21 %20 None
  215. OpBranchConditional %33 %20 %21 1 2
  216. %20 = OpLabel
  217. %39 = OpCompositeExtract %6 %60 0
  218. %40 = OpFAdd %6 %39 %34
  219. %55 = OpCompositeInsert %7 %40 %60 0
  220. %44 = OpCompositeExtract %6 %60 1
  221. %45 = OpFSub %6 %44 %34
  222. %58 = OpCompositeInsert %7 %45 %55 1
  223. %49 = OpIAdd %15 %59 %48
  224. OpBranch %19
  225. %21 = OpLabel
  226. OpStore %51 %60
  227. OpSelectionMerge %105 None
  228. OpBranchConditional %103 %104 %105
  229. %104 = OpLabel
  230. OpBranch %105
  231. %105 = OpLabel
  232. OpReturn
  233. OpFunctionEnd
  234. )";
  235. const auto env = SPV_ENV_UNIVERSAL_1_5;
  236. const auto consumer = nullptr;
  237. const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
  238. spvtools::ValidatorOptions validator_options;
  239. ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
  240. kConsoleMessageConsumer));
  241. TransformationContext transformation_context(
  242. MakeUnique<FactManager>(context.get()), validator_options);
  243. auto instruction_descriptor =
  244. MakeInstructionDescriptor(33, spv::Op::OpBranchConditional, 0);
  245. auto transformation =
  246. TransformationAdjustBranchWeights(instruction_descriptor, {5, 6});
  247. ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
  248. instruction_descriptor =
  249. MakeInstructionDescriptor(21, spv::Op::OpBranchConditional, 0);
  250. transformation =
  251. TransformationAdjustBranchWeights(instruction_descriptor, {7, 8});
  252. ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
  253. std::string variant_shader = R"(
  254. OpCapability Shader
  255. %1 = OpExtInstImport "GLSL.std.450"
  256. OpMemoryModel Logical GLSL450
  257. OpEntryPoint Fragment %4 "main" %51 %27
  258. OpExecutionMode %4 OriginUpperLeft
  259. OpSource ESSL 310
  260. OpName %4 "main"
  261. OpName %25 "buf"
  262. OpMemberName %25 0 "value"
  263. OpName %27 ""
  264. OpName %51 "color"
  265. OpMemberDecorate %25 0 Offset 0
  266. OpDecorate %25 Block
  267. OpDecorate %27 DescriptorSet 0
  268. OpDecorate %27 Binding 0
  269. OpDecorate %51 Location 0
  270. %2 = OpTypeVoid
  271. %3 = OpTypeFunction %2
  272. %6 = OpTypeFloat 32
  273. %7 = OpTypeVector %6 4
  274. %150 = OpTypeVector %6 2
  275. %10 = OpConstant %6 0.300000012
  276. %11 = OpConstant %6 0.400000006
  277. %12 = OpConstant %6 0.5
  278. %13 = OpConstant %6 1
  279. %14 = OpConstantComposite %7 %10 %11 %12 %13
  280. %15 = OpTypeInt 32 1
  281. %18 = OpConstant %15 0
  282. %25 = OpTypeStruct %6
  283. %26 = OpTypePointer Uniform %25
  284. %27 = OpVariable %26 Uniform
  285. %28 = OpTypePointer Uniform %6
  286. %32 = OpTypeBool
  287. %103 = OpConstantTrue %32
  288. %34 = OpConstant %6 0.100000001
  289. %48 = OpConstant %15 1
  290. %50 = OpTypePointer Output %7
  291. %51 = OpVariable %50 Output
  292. %100 = OpTypePointer Function %6
  293. %4 = OpFunction %2 None %3
  294. %5 = OpLabel
  295. %101 = OpVariable %100 Function
  296. %102 = OpVariable %100 Function
  297. OpBranch %19
  298. %19 = OpLabel
  299. %60 = OpPhi %7 %14 %5 %58 %20
  300. %59 = OpPhi %15 %18 %5 %49 %20
  301. %29 = OpAccessChain %28 %27 %18
  302. %30 = OpLoad %6 %29
  303. %31 = OpConvertFToS %15 %30
  304. %33 = OpSLessThan %32 %59 %31
  305. OpLoopMerge %21 %20 None
  306. OpBranchConditional %33 %20 %21 5 6
  307. %20 = OpLabel
  308. %39 = OpCompositeExtract %6 %60 0
  309. %40 = OpFAdd %6 %39 %34
  310. %55 = OpCompositeInsert %7 %40 %60 0
  311. %44 = OpCompositeExtract %6 %60 1
  312. %45 = OpFSub %6 %44 %34
  313. %58 = OpCompositeInsert %7 %45 %55 1
  314. %49 = OpIAdd %15 %59 %48
  315. OpBranch %19
  316. %21 = OpLabel
  317. OpStore %51 %60
  318. OpSelectionMerge %105 None
  319. OpBranchConditional %103 %104 %105 7 8
  320. %104 = OpLabel
  321. OpBranch %105
  322. %105 = OpLabel
  323. OpReturn
  324. OpFunctionEnd
  325. )";
  326. ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
  327. }
  328. } // namespace
  329. } // namespace fuzz
  330. } // namespace spvtools