val_ray_query_test.cpp 23 KB


  1. // Copyright (c) 2022 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Tests ray query instructions from SPV_KHR_ray_query.
  15. #include <sstream>
  16. #include <string>
  17. #include "gmock/gmock.h"
  18. #include "spirv-tools/libspirv.h"
  19. #include "test/val/val_fixtures.h"
  20. namespace spvtools {
  21. namespace val {
  22. namespace {
  23. using ::testing::HasSubstr;
  24. using ::testing::Values;
  25. using ValidateRayQuery = spvtest::ValidateBase<bool>;
  26. std::string GenerateShaderCode(
  27. const std::string& body,
  28. const std::string& capabilities_and_extensions = "",
  29. const std::string& declarations = "") {
  30. std::ostringstream ss;
  31. ss << R"(
  32. OpCapability Shader
  33. OpCapability Int64
  34. OpCapability Float64
  35. OpCapability RayQueryKHR
  36. OpExtension "SPV_KHR_ray_query"
  37. )";
  38. ss << capabilities_and_extensions;
  39. ss << R"(
  40. OpMemoryModel Logical GLSL450
  41. OpEntryPoint GLCompute %main "main"
  42. OpExecutionMode %main LocalSize 1 1 1
  43. OpDecorate %top_level_as DescriptorSet 0
  44. OpDecorate %top_level_as Binding 0
  45. %void = OpTypeVoid
  46. %func = OpTypeFunction %void
  47. %bool = OpTypeBool
  48. %f32 = OpTypeFloat 32
  49. %f64 = OpTypeFloat 64
  50. %u32 = OpTypeInt 32 0
  51. %s32 = OpTypeInt 32 1
  52. %u64 = OpTypeInt 64 0
  53. %s64 = OpTypeInt 64 1
  54. %type_rq = OpTypeRayQueryKHR
  55. %type_as = OpTypeAccelerationStructureKHR
  56. %s32vec2 = OpTypeVector %s32 2
  57. %u32vec2 = OpTypeVector %u32 2
  58. %f32vec2 = OpTypeVector %f32 2
  59. %u32vec3 = OpTypeVector %u32 3
  60. %s32vec3 = OpTypeVector %s32 3
  61. %f32vec3 = OpTypeVector %f32 3
  62. %u32vec4 = OpTypeVector %u32 4
  63. %s32vec4 = OpTypeVector %s32 4
  64. %f32vec4 = OpTypeVector %f32 4
  65. %mat4x3 = OpTypeMatrix %f32vec3 4
  66. %f32_0 = OpConstant %f32 0
  67. %f64_0 = OpConstant %f64 0
  68. %s32_0 = OpConstant %s32 0
  69. %u32_0 = OpConstant %u32 0
  70. %u64_0 = OpConstant %u64 0
  71. %u32vec3_0 = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
  72. %f32vec3_0 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0
  73. %f32vec4_0 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0
  74. %ptr_rq = OpTypePointer Private %type_rq
  75. %ray_query = OpVariable %ptr_rq Private
  76. %ptr_as = OpTypePointer UniformConstant %type_as
  77. %top_level_as = OpVariable %ptr_as UniformConstant
  78. %ptr_function_u32 = OpTypePointer Function %u32
  79. %ptr_function_f32 = OpTypePointer Function %f32
  80. %ptr_function_f32vec3 = OpTypePointer Function %f32vec3
  81. )";
  82. ss << declarations;
  83. ss << R"(
  84. %main = OpFunction %void None %func
  85. %main_entry = OpLabel
  86. )";
  87. ss << body;
  88. ss << R"(
  89. OpReturn
  90. OpFunctionEnd)";
  91. return ss.str();
  92. }
  93. std::string RayQueryResult(std::string opcode) {
  94. if (opcode.compare("OpRayQueryProceedKHR") == 0 ||
  95. opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
  96. opcode.compare("OpRayQueryGetRayTMinKHR") == 0 ||
  97. opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 ||
  98. opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 ||
  99. opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
  100. opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
  101. opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
  102. "OffsetKHR") == 0 ||
  103. opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
  104. opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 ||
  105. opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 ||
  106. opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
  107. opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0 ||
  108. opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
  109. opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
  110. opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 ||
  111. opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0 ||
  112. opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
  113. opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
  114. return "%result =";
  115. }
  116. return "";
  117. }
  118. std::string RayQueryResultType(std::string opcode, bool valid) {
  119. if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
  120. opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 ||
  121. opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
  122. opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
  123. opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
  124. "OffsetKHR") == 0 ||
  125. opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
  126. opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0) {
  127. return valid ? "%u32" : "%f64";
  128. }
  129. if (opcode.compare("OpRayQueryGetRayTMinKHR") == 0 ||
  130. opcode.compare("OpRayQueryGetIntersectionTKHR") == 0) {
  131. return valid ? "%f32" : "%f64";
  132. }
  133. if (opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0) {
  134. return valid ? "%f32vec2" : "%f64";
  135. }
  136. if (opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
  137. opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
  138. opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 ||
  139. opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0) {
  140. return valid ? "%f32vec3" : "%f64";
  141. }
  142. if (opcode.compare("OpRayQueryProceedKHR") == 0 ||
  143. opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
  144. opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0) {
  145. return valid ? "%bool" : "%f64";
  146. }
  147. if (opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
  148. opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
  149. return valid ? "%mat4x3" : "%f64";
  150. }
  151. return "";
  152. }
  153. std::string RayQueryIntersection(std::string opcode, bool valid) {
  154. if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
  155. opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 ||
  156. opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
  157. opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
  158. opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
  159. "OffsetKHR") == 0 ||
  160. opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
  161. opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 ||
  162. opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 ||
  163. opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
  164. opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
  165. opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
  166. opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
  167. opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
  168. return valid ? "%s32_0" : "%f32_0";
  169. }
  170. return "";
  171. }
  172. using RayQueryCommon = spvtest::ValidateBase<std::string>;
  173. TEST_P(RayQueryCommon, Success) {
  174. std::string opcode = GetParam();
  175. std::ostringstream ss;
  176. ss << RayQueryResult(opcode);
  177. ss << " " << opcode << " ";
  178. ss << RayQueryResultType(opcode, true);
  179. ss << " %ray_query ";
  180. ss << RayQueryIntersection(opcode, true);
  181. CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
  182. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  183. }
  184. TEST_P(RayQueryCommon, BadQuery) {
  185. std::string opcode = GetParam();
  186. std::ostringstream ss;
  187. ss << RayQueryResult(opcode);
  188. ss << " " << opcode << " ";
  189. ss << RayQueryResultType(opcode, true);
  190. ss << " %top_level_as ";
  191. ss << RayQueryIntersection(opcode, true);
  192. CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
  193. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  194. EXPECT_THAT(getDiagnosticString(),
  195. HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
  196. }
  197. TEST_P(RayQueryCommon, BadResult) {
  198. std::string opcode = GetParam();
  199. std::string result_type = RayQueryResultType(opcode, false);
  200. if (!result_type.empty()) {
  201. std::ostringstream ss;
  202. ss << RayQueryResult(opcode);
  203. ss << " " << opcode << " ";
  204. ss << result_type;
  205. ss << " %ray_query ";
  206. ss << RayQueryIntersection(opcode, true);
  207. CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
  208. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  209. std::string correct_result_type = RayQueryResultType(opcode, true);
  210. if (correct_result_type.compare("%u32") == 0) {
  211. EXPECT_THAT(
  212. getDiagnosticString(),
  213. HasSubstr("expected Result Type to be 32-bit int scalar type"));
  214. } else if (correct_result_type.compare("%f32") == 0) {
  215. EXPECT_THAT(
  216. getDiagnosticString(),
  217. HasSubstr("expected Result Type to be 32-bit float scalar type"));
  218. } else if (correct_result_type.compare("%f32vec2") == 0) {
  219. EXPECT_THAT(getDiagnosticString(),
  220. HasSubstr("expected Result Type to be 32-bit float "
  221. "2-component vector type"));
  222. } else if (correct_result_type.compare("%f32vec3") == 0) {
  223. EXPECT_THAT(getDiagnosticString(),
  224. HasSubstr("expected Result Type to be 32-bit float "
  225. "3-component vector type"));
  226. } else if (correct_result_type.compare("%bool") == 0) {
  227. EXPECT_THAT(getDiagnosticString(),
  228. HasSubstr("expected Result Type to be bool scalar type"));
  229. } else if (correct_result_type.compare("%mat4x3") == 0) {
  230. EXPECT_THAT(getDiagnosticString(),
  231. HasSubstr("expected matrix type as Result Type"));
  232. }
  233. }
  234. }
  235. TEST_P(RayQueryCommon, BadIntersection) {
  236. std::string opcode = GetParam();
  237. std::string intersection = RayQueryIntersection(opcode, false);
  238. if (!intersection.empty()) {
  239. std::ostringstream ss;
  240. ss << RayQueryResult(opcode);
  241. ss << " " << opcode << " ";
  242. ss << RayQueryResultType(opcode, true);
  243. ss << " %ray_query ";
  244. ss << intersection;
  245. CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
  246. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  247. EXPECT_THAT(
  248. getDiagnosticString(),
  249. HasSubstr(
  250. "expected Intersection ID to be a constant 32-bit int scalar"));
  251. }
  252. }
  253. INSTANTIATE_TEST_SUITE_P(
  254. ValidateRayQueryCommon, RayQueryCommon,
  255. Values("OpRayQueryTerminateKHR", "OpRayQueryConfirmIntersectionKHR",
  256. "OpRayQueryProceedKHR", "OpRayQueryGetIntersectionTypeKHR",
  257. "OpRayQueryGetRayTMinKHR", "OpRayQueryGetRayFlagsKHR",
  258. "OpRayQueryGetWorldRayDirectionKHR",
  259. "OpRayQueryGetWorldRayOriginKHR", "OpRayQueryGetIntersectionTKHR",
  260. "OpRayQueryGetIntersectionInstanceCustomIndexKHR",
  261. "OpRayQueryGetIntersectionInstanceIdKHR",
  262. "OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR",
  263. "OpRayQueryGetIntersectionGeometryIndexKHR",
  264. "OpRayQueryGetIntersectionPrimitiveIndexKHR",
  265. "OpRayQueryGetIntersectionBarycentricsKHR",
  266. "OpRayQueryGetIntersectionFrontFaceKHR",
  267. "OpRayQueryGetIntersectionCandidateAABBOpaqueKHR",
  268. "OpRayQueryGetIntersectionObjectRayDirectionKHR",
  269. "OpRayQueryGetIntersectionObjectRayOriginKHR",
  270. "OpRayQueryGetIntersectionObjectToWorldKHR",
  271. "OpRayQueryGetIntersectionWorldToObjectKHR"));
  272. // tests various Intersection operand types
  273. TEST_F(ValidateRayQuery, IntersectionSuccess) {
  274. const std::string body = R"(
  275. %result_1 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %s32_0
  276. %result_2 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32_0
  277. )";
  278. CompileSuccessfully(GenerateShaderCode(body).c_str());
  279. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  280. }
  281. TEST_F(ValidateRayQuery, IntersectionVector) {
  282. const std::string body = R"(
  283. %result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32vec3_0
  284. )";
  285. CompileSuccessfully(GenerateShaderCode(body).c_str());
  286. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  287. EXPECT_THAT(
  288. getDiagnosticString(),
  289. HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
  290. }
  291. TEST_F(ValidateRayQuery, IntersectionNonConstantVariable) {
  292. const std::string body = R"(
  293. %var = OpVariable %ptr_function_u32 Function
  294. %result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %var
  295. )";
  296. CompileSuccessfully(GenerateShaderCode(body).c_str());
  297. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  298. EXPECT_THAT(
  299. getDiagnosticString(),
  300. HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
  301. }
  302. TEST_F(ValidateRayQuery, IntersectionNonConstantLoad) {
  303. const std::string body = R"(
  304. %var = OpVariable %ptr_function_u32 Function
  305. %load = OpLoad %u32 %var
  306. %result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %load
  307. )";
  308. CompileSuccessfully(GenerateShaderCode(body).c_str());
  309. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  310. EXPECT_THAT(
  311. getDiagnosticString(),
  312. HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
  313. }
  314. TEST_F(ValidateRayQuery, InitializeSuccess) {
  315. const std::string body = R"(
  316. %var_u32 = OpVariable %ptr_function_u32 Function
  317. %var_f32 = OpVariable %ptr_function_f32 Function
  318. %var_f32vec3 = OpVariable %ptr_function_f32vec3 Function
  319. %as = OpLoad %type_as %top_level_as
  320. OpRayQueryInitializeKHR %ray_query %as %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  321. %_u32 = OpLoad %u32 %var_u32
  322. %_f32 = OpLoad %f32 %var_f32
  323. %_f32vec3 = OpLoad %f32vec3 %var_f32vec3
  324. OpRayQueryInitializeKHR %ray_query %as %_u32 %_u32 %_f32vec3 %_f32 %_f32vec3 %_f32
  325. )";
  326. CompileSuccessfully(GenerateShaderCode(body).c_str());
  327. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  328. }
  329. TEST_F(ValidateRayQuery, InitializeFunctionSuccess) {
  330. const std::string declaration = R"(
  331. %rq_ptr = OpTypePointer Private %type_rq
  332. %rq_func_type = OpTypeFunction %void %rq_ptr
  333. %rq_var_1 = OpVariable %rq_ptr Private
  334. %rq_var_2 = OpVariable %rq_ptr Private
  335. )";
  336. const std::string body = R"(
  337. %fcall_1 = OpFunctionCall %void %rq_func %rq_var_1
  338. %as_1 = OpLoad %type_as %top_level_as
  339. OpRayQueryInitializeKHR %rq_var_1 %as_1 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  340. %fcall_2 = OpFunctionCall %void %rq_func %rq_var_2
  341. OpReturn
  342. OpFunctionEnd
  343. %rq_func = OpFunction %void None %rq_func_type
  344. %rq_param = OpFunctionParameter %rq_ptr
  345. %label = OpLabel
  346. %as_2 = OpLoad %type_as %top_level_as
  347. OpRayQueryInitializeKHR %rq_param %as_2 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  348. )";
  349. CompileSuccessfully(GenerateShaderCode(body, "", declaration).c_str());
  350. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  351. }
  352. TEST_F(ValidateRayQuery, InitializeBadRayQuery) {
  353. const std::string body = R"(
  354. %load = OpLoad %type_as %top_level_as
  355. OpRayQueryInitializeKHR %top_level_as %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  356. )";
  357. CompileSuccessfully(GenerateShaderCode(body).c_str());
  358. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  359. EXPECT_THAT(getDiagnosticString(),
  360. HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
  361. }
  362. TEST_F(ValidateRayQuery, InitializeBadAS) {
  363. const std::string body = R"(
  364. OpRayQueryInitializeKHR %ray_query %ray_query %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  365. )";
  366. CompileSuccessfully(GenerateShaderCode(body).c_str());
  367. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  368. EXPECT_THAT(getDiagnosticString(),
  369. HasSubstr("Expected Acceleration Structure to be of type "
  370. "OpTypeAccelerationStructureKHR"));
  371. }
  372. TEST_F(ValidateRayQuery, InitializeBadRayFlags64) {
  373. const std::string body = R"(
  374. %load = OpLoad %type_as %top_level_as
  375. OpRayQueryInitializeKHR %ray_query %load %u64_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  376. )";
  377. CompileSuccessfully(GenerateShaderCode(body).c_str());
  378. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  379. EXPECT_THAT(getDiagnosticString(),
  380. HasSubstr("Ray Flags must be a 32-bit int scalar"));
  381. }
  382. TEST_F(ValidateRayQuery, InitializeBadRayFlagsVector) {
  383. const std::string body = R"(
  384. %load = OpLoad %type_as %top_level_as
  385. OpRayQueryInitializeKHR %ray_query %load %u32vec2 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  386. )";
  387. CompileSuccessfully(GenerateShaderCode(body).c_str());
  388. EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  389. EXPECT_THAT(getDiagnosticString(),
  390. HasSubstr("Operand '15[%v2uint]' cannot be a type"));
  391. }
  392. TEST_F(ValidateRayQuery, InitializeBadCullMask) {
  393. const std::string body = R"(
  394. %load = OpLoad %type_as %top_level_as
  395. OpRayQueryInitializeKHR %ray_query %load %u32_0 %f32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  396. )";
  397. CompileSuccessfully(GenerateShaderCode(body).c_str());
  398. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  399. EXPECT_THAT(getDiagnosticString(),
  400. HasSubstr("Cull Mask must be a 32-bit int scalar"));
  401. }
  402. TEST_F(ValidateRayQuery, InitializeBadRayOriginVec4) {
  403. const std::string body = R"(
  404. %load = OpLoad %type_as %top_level_as
  405. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec4_0 %f32_0 %f32vec3_0 %f32_0
  406. )";
  407. CompileSuccessfully(GenerateShaderCode(body).c_str());
  408. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  409. EXPECT_THAT(
  410. getDiagnosticString(),
  411. HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
  412. }
  413. TEST_F(ValidateRayQuery, InitializeBadRayOriginFloat) {
  414. const std::string body = R"(
  415. %var_f32 = OpVariable %ptr_function_f32 Function
  416. %_f32 = OpLoad %f32 %var_f32
  417. %load = OpLoad %type_as %top_level_as
  418. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %_f32 %f32_0 %f32vec3_0 %f32_0
  419. )";
  420. CompileSuccessfully(GenerateShaderCode(body).c_str());
  421. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  422. EXPECT_THAT(
  423. getDiagnosticString(),
  424. HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
  425. }
  426. TEST_F(ValidateRayQuery, InitializeBadRayOriginInt) {
  427. const std::string body = R"(
  428. %load = OpLoad %type_as %top_level_as
  429. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %u32vec3_0 %f32_0 %f32vec3_0 %f32_0
  430. )";
  431. CompileSuccessfully(GenerateShaderCode(body).c_str());
  432. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  433. EXPECT_THAT(
  434. getDiagnosticString(),
  435. HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
  436. }
  437. TEST_F(ValidateRayQuery, InitializeBadRayTMin) {
  438. const std::string body = R"(
  439. %load = OpLoad %type_as %top_level_as
  440. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %u32_0 %f32vec3_0 %f32_0
  441. )";
  442. CompileSuccessfully(GenerateShaderCode(body).c_str());
  443. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  444. EXPECT_THAT(getDiagnosticString(),
  445. HasSubstr("Ray TMin must be a 32-bit float scalar"));
  446. }
  447. TEST_F(ValidateRayQuery, InitializeBadRayDirection) {
  448. const std::string body = R"(
  449. %load = OpLoad %type_as %top_level_as
  450. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec4_0 %f32_0
  451. )";
  452. CompileSuccessfully(GenerateShaderCode(body).c_str());
  453. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  454. EXPECT_THAT(
  455. getDiagnosticString(),
  456. HasSubstr("Ray Direction must be a 32-bit float 3-component vector"));
  457. }
  458. TEST_F(ValidateRayQuery, InitializeBadRayTMax) {
  459. const std::string body = R"(
  460. %load = OpLoad %type_as %top_level_as
  461. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f64_0
  462. )";
  463. CompileSuccessfully(GenerateShaderCode(body).c_str());
  464. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  465. EXPECT_THAT(getDiagnosticString(),
  466. HasSubstr("Ray TMax must be a 32-bit float scalar"));
  467. }
  468. TEST_F(ValidateRayQuery, GenerateIntersectionSuccess) {
  469. const std::string body = R"(
  470. %var = OpVariable %ptr_function_f32 Function
  471. %load = OpLoad %f32 %var
  472. OpRayQueryGenerateIntersectionKHR %ray_query %f32_0
  473. OpRayQueryGenerateIntersectionKHR %ray_query %load
  474. )";
  475. CompileSuccessfully(GenerateShaderCode(body).c_str());
  476. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  477. }
  478. TEST_F(ValidateRayQuery, GenerateIntersectionBadRayQuery) {
  479. const std::string body = R"(
  480. OpRayQueryGenerateIntersectionKHR %top_level_as %f32_0
  481. )";
  482. CompileSuccessfully(GenerateShaderCode(body).c_str());
  483. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  484. EXPECT_THAT(getDiagnosticString(),
  485. HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
  486. }
  487. TEST_F(ValidateRayQuery, GenerateIntersectionBadHitT) {
  488. const std::string body = R"(
  489. OpRayQueryGenerateIntersectionKHR %ray_query %u32_0
  490. )";
  491. CompileSuccessfully(GenerateShaderCode(body).c_str());
  492. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  493. EXPECT_THAT(getDiagnosticString(),
  494. HasSubstr("Hit T must be a 32-bit float scalar"));
  495. }
  496. TEST_F(ValidateRayQuery, RayQueryArraySuccess) {
  497. // This shader is slightly different to the ones above, so it doesn't reuse
  498. // the shader code generator.
  499. const std::string shader = R"(
  500. OpCapability Shader
  501. OpCapability RayQueryKHR
  502. OpExtension "SPV_KHR_ray_query"
  503. OpMemoryModel Logical GLSL450
  504. OpEntryPoint GLCompute %main "main"
  505. OpExecutionMode %main LocalSize 1 1 1
  506. OpSource GLSL 460
  507. OpDecorate %topLevelAS DescriptorSet 0
  508. OpDecorate %topLevelAS Binding 0
  509. OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
  510. %void = OpTypeVoid
  511. %func = OpTypeFunction %void
  512. %ray_query = OpTypeRayQueryKHR
  513. %uint = OpTypeInt 32 0
  514. %uint_2 = OpConstant %uint 2
  515. %ray_query_array = OpTypeArray %ray_query %uint_2
  516. %ptr_ray_query_array = OpTypePointer Private %ray_query_array
  517. %rayQueries = OpVariable %ptr_ray_query_array Private
  518. %int = OpTypeInt 32 1
  519. %int_0 = OpConstant %int 0
  520. %ptr_ray_query = OpTypePointer Private %ray_query
  521. %accel_struct = OpTypeAccelerationStructureKHR
  522. %ptr_accel_struct = OpTypePointer UniformConstant %accel_struct
  523. %topLevelAS = OpVariable %ptr_accel_struct UniformConstant
  524. %uint_0 = OpConstant %uint 0
  525. %uint_255 = OpConstant %uint 255
  526. %float = OpTypeFloat 32
  527. %v3float = OpTypeVector %float 3
  528. %float_0 = OpConstant %float 0
  529. %vec3_zero = OpConstantComposite %v3float %float_0 %float_0 %float_0
  530. %float_1 = OpConstant %float 1
  531. %vec3_xy_0_z_1 = OpConstantComposite %v3float %float_0 %float_0 %float_1
  532. %float_10 = OpConstant %float 10
  533. %v3uint = OpTypeVector %uint 3
  534. %uint_1 = OpConstant %uint 1
  535. %gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1
  536. %main = OpFunction %void None %func
  537. %main_label = OpLabel
  538. %first_ray_query = OpAccessChain %ptr_ray_query %rayQueries %int_0
  539. %topLevelAS_val = OpLoad %accel_struct %topLevelAS
  540. OpRayQueryInitializeKHR %first_ray_query %topLevelAS_val %uint_0 %uint_255 %vec3_zero %float_0 %vec3_xy_0_z_1 %float_10
  541. OpReturn
  542. OpFunctionEnd
  543. )";
  544. CompileSuccessfully(shader);
  545. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  546. }
  547. } // namespace
  548. } // namespace val
  549. } // namespace spvtools