val_ray_query_test.cpp 26 KB


  1. // Copyright (c) 2022 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Tests ray query instructions from SPV_KHR_ray_query.
  15. #include <sstream>
  16. #include <string>
  17. #include "gmock/gmock.h"
  18. #include "spirv-tools/libspirv.h"
  19. #include "test/val/val_fixtures.h"
  20. namespace spvtools {
  21. namespace val {
  22. namespace {
  23. using ::testing::HasSubstr;
  24. using ::testing::Values;
  25. using ValidateRayQuery = spvtest::ValidateBase<bool>;
  26. std::string GenerateShaderCode(const std::string& body,
  27. const std::string& capabilities = "",
  28. const std::string& extensions = "",
  29. const std::string& declarations = "") {
  30. std::ostringstream ss;
  31. ss << R"(
  32. OpCapability Shader
  33. OpCapability Int64
  34. OpCapability Float64
  35. OpCapability RayQueryKHR
  36. )";
  37. ss << capabilities;
  38. ss << R"(
  39. OpExtension "SPV_KHR_ray_query"
  40. )";
  41. ss << extensions;
  42. ss << R"(
  43. OpMemoryModel Logical GLSL450
  44. OpEntryPoint GLCompute %main "main"
  45. OpExecutionMode %main LocalSize 1 1 1
  46. OpDecorate %top_level_as DescriptorSet 0
  47. OpDecorate %top_level_as Binding 0
  48. %void = OpTypeVoid
  49. %func = OpTypeFunction %void
  50. %bool = OpTypeBool
  51. %f32 = OpTypeFloat 32
  52. %f64 = OpTypeFloat 64
  53. %u32 = OpTypeInt 32 0
  54. %s32 = OpTypeInt 32 1
  55. %u64 = OpTypeInt 64 0
  56. %s64 = OpTypeInt 64 1
  57. %type_rq = OpTypeRayQueryKHR
  58. %type_as = OpTypeAccelerationStructureKHR
  59. %s32vec2 = OpTypeVector %s32 2
  60. %u32vec2 = OpTypeVector %u32 2
  61. %f32vec2 = OpTypeVector %f32 2
  62. %u32vec3 = OpTypeVector %u32 3
  63. %s32vec3 = OpTypeVector %s32 3
  64. %f32vec3 = OpTypeVector %f32 3
  65. %u32vec4 = OpTypeVector %u32 4
  66. %s32vec4 = OpTypeVector %s32 4
  67. %f32vec4 = OpTypeVector %f32 4
  68. %mat4x3 = OpTypeMatrix %f32vec3 4
  69. %f32_0 = OpConstant %f32 0
  70. %f64_0 = OpConstant %f64 0
  71. %s32_0 = OpConstant %s32 0
  72. %u32_0 = OpConstant %u32 0
  73. %u64_0 = OpConstant %u64 0
  74. %u32_2 = OpConstant %u32 2
  75. %arr2v3 = OpTypeArray %f32vec3 %u32_2
  76. %arr2f3 = OpTypeArray %f32 %u32_2
  77. %u32vec3_0 = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
  78. %f32vec3_0 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0
  79. %f32vec4_0 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0
  80. %ptr_rq = OpTypePointer Function %type_rq
  81. %ptr_as = OpTypePointer UniformConstant %type_as
  82. %top_level_as = OpVariable %ptr_as UniformConstant
  83. %ptr_function_u32 = OpTypePointer Function %u32
  84. %ptr_function_f32 = OpTypePointer Function %f32
  85. %ptr_function_f32vec3 = OpTypePointer Function %f32vec3
  86. )";
  87. ss << declarations;
  88. ss << R"(
  89. %main = OpFunction %void None %func
  90. %main_entry = OpLabel
  91. %ray_query = OpVariable %ptr_rq Function
  92. )";
  93. ss << body;
  94. ss << R"(
  95. OpReturn
  96. OpFunctionEnd)";
  97. return ss.str();
  98. }
  99. std::string RayQueryResult(std::string opcode) {
  100. if (opcode.compare("OpRayQueryProceedKHR") == 0 ||
  101. opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
  102. opcode.compare("OpRayQueryGetRayTMinKHR") == 0 ||
  103. opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 ||
  104. opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 ||
  105. opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
  106. opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
  107. opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
  108. "OffsetKHR") == 0 ||
  109. opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
  110. opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 ||
  111. opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 ||
  112. opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
  113. opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0 ||
  114. opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
  115. opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
  116. opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 ||
  117. opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0 ||
  118. opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
  119. opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
  120. return "%result =";
  121. }
  122. return "";
  123. }
  124. std::string RayQueryResultType(std::string opcode, bool valid) {
  125. if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
  126. opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 ||
  127. opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
  128. opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
  129. opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
  130. "OffsetKHR") == 0 ||
  131. opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
  132. opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0) {
  133. return valid ? "%u32" : "%f64";
  134. }
  135. if (opcode.compare("OpRayQueryGetRayTMinKHR") == 0 ||
  136. opcode.compare("OpRayQueryGetIntersectionTKHR") == 0) {
  137. return valid ? "%f32" : "%f64";
  138. }
  139. if (opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0) {
  140. return valid ? "%f32vec2" : "%f64";
  141. }
  142. if (opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
  143. opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
  144. opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 ||
  145. opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0) {
  146. return valid ? "%f32vec3" : "%f64";
  147. }
  148. if (opcode.compare("OpRayQueryProceedKHR") == 0 ||
  149. opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
  150. opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0) {
  151. return valid ? "%bool" : "%f64";
  152. }
  153. if (opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
  154. opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
  155. return valid ? "%mat4x3" : "%f64";
  156. }
  157. return "";
  158. }
  159. std::string RayQueryIntersection(std::string opcode, bool valid) {
  160. if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
  161. opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 ||
  162. opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
  163. opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
  164. opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
  165. "OffsetKHR") == 0 ||
  166. opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
  167. opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 ||
  168. opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 ||
  169. opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
  170. opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
  171. opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
  172. opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
  173. opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
  174. return valid ? "%s32_0" : "%f32_0";
  175. }
  176. return "";
  177. }
  178. using RayQueryCommon = spvtest::ValidateBase<std::string>;
  179. TEST_P(RayQueryCommon, Success) {
  180. std::string opcode = GetParam();
  181. std::ostringstream ss;
  182. ss << RayQueryResult(opcode);
  183. ss << " " << opcode << " ";
  184. ss << RayQueryResultType(opcode, true);
  185. ss << " %ray_query ";
  186. ss << RayQueryIntersection(opcode, true);
  187. CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
  188. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  189. }
  190. TEST_P(RayQueryCommon, BadQuery) {
  191. std::string opcode = GetParam();
  192. std::ostringstream ss;
  193. ss << RayQueryResult(opcode);
  194. ss << " " << opcode << " ";
  195. ss << RayQueryResultType(opcode, true);
  196. ss << " %top_level_as ";
  197. ss << RayQueryIntersection(opcode, true);
  198. CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
  199. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  200. EXPECT_THAT(getDiagnosticString(),
  201. HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
  202. }
  203. TEST_P(RayQueryCommon, BadResult) {
  204. std::string opcode = GetParam();
  205. std::string result_type = RayQueryResultType(opcode, false);
  206. if (!result_type.empty()) {
  207. std::ostringstream ss;
  208. ss << RayQueryResult(opcode);
  209. ss << " " << opcode << " ";
  210. ss << result_type;
  211. ss << " %ray_query ";
  212. ss << RayQueryIntersection(opcode, true);
  213. CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
  214. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  215. std::string correct_result_type = RayQueryResultType(opcode, true);
  216. if (correct_result_type.compare("%u32") == 0) {
  217. EXPECT_THAT(
  218. getDiagnosticString(),
  219. HasSubstr("expected Result Type to be 32-bit int scalar type"));
  220. } else if (correct_result_type.compare("%f32") == 0) {
  221. EXPECT_THAT(
  222. getDiagnosticString(),
  223. HasSubstr("expected Result Type to be 32-bit float scalar type"));
  224. } else if (correct_result_type.compare("%f32vec2") == 0) {
  225. EXPECT_THAT(getDiagnosticString(),
  226. HasSubstr("expected Result Type to be 32-bit float "
  227. "2-component vector type"));
  228. } else if (correct_result_type.compare("%f32vec3") == 0) {
  229. EXPECT_THAT(getDiagnosticString(),
  230. HasSubstr("expected Result Type to be 32-bit float "
  231. "3-component vector type"));
  232. } else if (correct_result_type.compare("%bool") == 0) {
  233. EXPECT_THAT(getDiagnosticString(),
  234. HasSubstr("expected Result Type to be bool scalar type"));
  235. } else if (correct_result_type.compare("%mat4x3") == 0) {
  236. EXPECT_THAT(getDiagnosticString(),
  237. HasSubstr("expected matrix type as Result Type"));
  238. }
  239. }
  240. }
  241. TEST_P(RayQueryCommon, BadIntersection) {
  242. std::string opcode = GetParam();
  243. std::string intersection = RayQueryIntersection(opcode, false);
  244. if (!intersection.empty()) {
  245. std::ostringstream ss;
  246. ss << RayQueryResult(opcode);
  247. ss << " " << opcode << " ";
  248. ss << RayQueryResultType(opcode, true);
  249. ss << " %ray_query ";
  250. ss << intersection;
  251. CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
  252. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  253. EXPECT_THAT(
  254. getDiagnosticString(),
  255. HasSubstr(
  256. "expected Intersection ID to be a constant 32-bit int scalar"));
  257. }
  258. }
  259. INSTANTIATE_TEST_SUITE_P(
  260. ValidateRayQueryCommon, RayQueryCommon,
  261. Values("OpRayQueryTerminateKHR", "OpRayQueryConfirmIntersectionKHR",
  262. "OpRayQueryProceedKHR", "OpRayQueryGetIntersectionTypeKHR",
  263. "OpRayQueryGetRayTMinKHR", "OpRayQueryGetRayFlagsKHR",
  264. "OpRayQueryGetWorldRayDirectionKHR",
  265. "OpRayQueryGetWorldRayOriginKHR", "OpRayQueryGetIntersectionTKHR",
  266. "OpRayQueryGetIntersectionInstanceCustomIndexKHR",
  267. "OpRayQueryGetIntersectionInstanceIdKHR",
  268. "OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR",
  269. "OpRayQueryGetIntersectionGeometryIndexKHR",
  270. "OpRayQueryGetIntersectionPrimitiveIndexKHR",
  271. "OpRayQueryGetIntersectionBarycentricsKHR",
  272. "OpRayQueryGetIntersectionFrontFaceKHR",
  273. "OpRayQueryGetIntersectionCandidateAABBOpaqueKHR",
  274. "OpRayQueryGetIntersectionObjectRayDirectionKHR",
  275. "OpRayQueryGetIntersectionObjectRayOriginKHR",
  276. "OpRayQueryGetIntersectionObjectToWorldKHR",
  277. "OpRayQueryGetIntersectionWorldToObjectKHR"));
  278. // tests various Intersection operand types
  279. TEST_F(ValidateRayQuery, IntersectionSuccess) {
  280. const std::string body = R"(
  281. %result_1 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %s32_0
  282. %result_2 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32_0
  283. )";
  284. CompileSuccessfully(GenerateShaderCode(body).c_str());
  285. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  286. }
  287. TEST_F(ValidateRayQuery, IntersectionVector) {
  288. const std::string body = R"(
  289. %result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32vec3_0
  290. )";
  291. CompileSuccessfully(GenerateShaderCode(body).c_str());
  292. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  293. EXPECT_THAT(
  294. getDiagnosticString(),
  295. HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
  296. }
  297. TEST_F(ValidateRayQuery, IntersectionNonConstantVariable) {
  298. const std::string body = R"(
  299. %var = OpVariable %ptr_function_u32 Function
  300. %result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %var
  301. )";
  302. CompileSuccessfully(GenerateShaderCode(body).c_str());
  303. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  304. EXPECT_THAT(
  305. getDiagnosticString(),
  306. HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
  307. }
  308. TEST_F(ValidateRayQuery, IntersectionNonConstantLoad) {
  309. const std::string body = R"(
  310. %var = OpVariable %ptr_function_u32 Function
  311. %load = OpLoad %u32 %var
  312. %result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %load
  313. )";
  314. CompileSuccessfully(GenerateShaderCode(body).c_str());
  315. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  316. EXPECT_THAT(
  317. getDiagnosticString(),
  318. HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
  319. }
  320. TEST_F(ValidateRayQuery, InitializeSuccess) {
  321. const std::string body = R"(
  322. %var_u32 = OpVariable %ptr_function_u32 Function
  323. %var_f32 = OpVariable %ptr_function_f32 Function
  324. %var_f32vec3 = OpVariable %ptr_function_f32vec3 Function
  325. %as = OpLoad %type_as %top_level_as
  326. OpRayQueryInitializeKHR %ray_query %as %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  327. %_u32 = OpLoad %u32 %var_u32
  328. %_f32 = OpLoad %f32 %var_f32
  329. %_f32vec3 = OpLoad %f32vec3 %var_f32vec3
  330. OpRayQueryInitializeKHR %ray_query %as %_u32 %_u32 %_f32vec3 %_f32 %_f32vec3 %_f32
  331. )";
  332. CompileSuccessfully(GenerateShaderCode(body).c_str());
  333. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  334. }
  335. TEST_F(ValidateRayQuery, InitializeFunctionSuccess) {
  336. const std::string declaration = R"(
  337. %rq_ptr = OpTypePointer Private %type_rq
  338. %rq_func_type = OpTypeFunction %void %rq_ptr
  339. %rq_var_1 = OpVariable %rq_ptr Private
  340. %rq_var_2 = OpVariable %rq_ptr Private
  341. )";
  342. const std::string body = R"(
  343. %fcall_1 = OpFunctionCall %void %rq_func %rq_var_1
  344. %as_1 = OpLoad %type_as %top_level_as
  345. OpRayQueryInitializeKHR %rq_var_1 %as_1 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  346. %fcall_2 = OpFunctionCall %void %rq_func %rq_var_2
  347. OpReturn
  348. OpFunctionEnd
  349. %rq_func = OpFunction %void None %rq_func_type
  350. %rq_param = OpFunctionParameter %rq_ptr
  351. %label = OpLabel
  352. %as_2 = OpLoad %type_as %top_level_as
  353. OpRayQueryInitializeKHR %rq_param %as_2 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  354. )";
  355. CompileSuccessfully(GenerateShaderCode(body, "", "", declaration).c_str());
  356. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  357. }
  358. TEST_F(ValidateRayQuery, InitializeBadRayQuery) {
  359. const std::string body = R"(
  360. %load = OpLoad %type_as %top_level_as
  361. OpRayQueryInitializeKHR %top_level_as %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  362. )";
  363. CompileSuccessfully(GenerateShaderCode(body).c_str());
  364. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  365. EXPECT_THAT(getDiagnosticString(),
  366. HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
  367. }
  368. TEST_F(ValidateRayQuery, InitializeBadAS) {
  369. const std::string body = R"(
  370. OpRayQueryInitializeKHR %ray_query %ray_query %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  371. )";
  372. CompileSuccessfully(GenerateShaderCode(body).c_str());
  373. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  374. EXPECT_THAT(getDiagnosticString(),
  375. HasSubstr("Expected Acceleration Structure to be of type "
  376. "OpTypeAccelerationStructureKHR"));
  377. }
  378. TEST_F(ValidateRayQuery, InitializeBadRayFlags64) {
  379. const std::string body = R"(
  380. %load = OpLoad %type_as %top_level_as
  381. OpRayQueryInitializeKHR %ray_query %load %u64_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  382. )";
  383. CompileSuccessfully(GenerateShaderCode(body).c_str());
  384. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  385. EXPECT_THAT(getDiagnosticString(),
  386. HasSubstr("Ray Flags must be a 32-bit int scalar"));
  387. }
  388. TEST_F(ValidateRayQuery, InitializeBadRayFlagsVector) {
  389. const std::string body = R"(
  390. %load = OpLoad %type_as %top_level_as
  391. OpRayQueryInitializeKHR %ray_query %load %u32vec2 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  392. )";
  393. CompileSuccessfully(GenerateShaderCode(body).c_str());
  394. EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  395. EXPECT_THAT(getDiagnosticString(),
  396. HasSubstr("Operand '15[%v2uint]' cannot be a type"));
  397. }
  398. TEST_F(ValidateRayQuery, InitializeBadCullMask) {
  399. const std::string body = R"(
  400. %load = OpLoad %type_as %top_level_as
  401. OpRayQueryInitializeKHR %ray_query %load %u32_0 %f32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
  402. )";
  403. CompileSuccessfully(GenerateShaderCode(body).c_str());
  404. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  405. EXPECT_THAT(getDiagnosticString(),
  406. HasSubstr("Cull Mask must be a 32-bit int scalar"));
  407. }
  408. TEST_F(ValidateRayQuery, InitializeBadRayOriginVec4) {
  409. const std::string body = R"(
  410. %load = OpLoad %type_as %top_level_as
  411. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec4_0 %f32_0 %f32vec3_0 %f32_0
  412. )";
  413. CompileSuccessfully(GenerateShaderCode(body).c_str());
  414. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  415. EXPECT_THAT(
  416. getDiagnosticString(),
  417. HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
  418. }
  419. TEST_F(ValidateRayQuery, InitializeBadRayOriginFloat) {
  420. const std::string body = R"(
  421. %var_f32 = OpVariable %ptr_function_f32 Function
  422. %_f32 = OpLoad %f32 %var_f32
  423. %load = OpLoad %type_as %top_level_as
  424. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %_f32 %f32_0 %f32vec3_0 %f32_0
  425. )";
  426. CompileSuccessfully(GenerateShaderCode(body).c_str());
  427. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  428. EXPECT_THAT(
  429. getDiagnosticString(),
  430. HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
  431. }
  432. TEST_F(ValidateRayQuery, InitializeBadRayOriginInt) {
  433. const std::string body = R"(
  434. %load = OpLoad %type_as %top_level_as
  435. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %u32vec3_0 %f32_0 %f32vec3_0 %f32_0
  436. )";
  437. CompileSuccessfully(GenerateShaderCode(body).c_str());
  438. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  439. EXPECT_THAT(
  440. getDiagnosticString(),
  441. HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
  442. }
  443. TEST_F(ValidateRayQuery, InitializeBadRayTMin) {
  444. const std::string body = R"(
  445. %load = OpLoad %type_as %top_level_as
  446. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %u32_0 %f32vec3_0 %f32_0
  447. )";
  448. CompileSuccessfully(GenerateShaderCode(body).c_str());
  449. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  450. EXPECT_THAT(getDiagnosticString(),
  451. HasSubstr("Ray TMin must be a 32-bit float scalar"));
  452. }
  453. TEST_F(ValidateRayQuery, InitializeBadRayDirection) {
  454. const std::string body = R"(
  455. %load = OpLoad %type_as %top_level_as
  456. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec4_0 %f32_0
  457. )";
  458. CompileSuccessfully(GenerateShaderCode(body).c_str());
  459. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  460. EXPECT_THAT(
  461. getDiagnosticString(),
  462. HasSubstr("Ray Direction must be a 32-bit float 3-component vector"));
  463. }
  464. TEST_F(ValidateRayQuery, InitializeBadRayTMax) {
  465. const std::string body = R"(
  466. %load = OpLoad %type_as %top_level_as
  467. OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f64_0
  468. )";
  469. CompileSuccessfully(GenerateShaderCode(body).c_str());
  470. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  471. EXPECT_THAT(getDiagnosticString(),
  472. HasSubstr("Ray TMax must be a 32-bit float scalar"));
  473. }
  474. TEST_F(ValidateRayQuery, GenerateIntersectionSuccess) {
  475. const std::string body = R"(
  476. %var = OpVariable %ptr_function_f32 Function
  477. %load = OpLoad %f32 %var
  478. OpRayQueryGenerateIntersectionKHR %ray_query %f32_0
  479. OpRayQueryGenerateIntersectionKHR %ray_query %load
  480. )";
  481. CompileSuccessfully(GenerateShaderCode(body).c_str());
  482. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  483. }
  484. TEST_F(ValidateRayQuery, GenerateIntersectionBadRayQuery) {
  485. const std::string body = R"(
  486. OpRayQueryGenerateIntersectionKHR %top_level_as %f32_0
  487. )";
  488. CompileSuccessfully(GenerateShaderCode(body).c_str());
  489. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  490. EXPECT_THAT(getDiagnosticString(),
  491. HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
  492. }
  493. TEST_F(ValidateRayQuery, GenerateIntersectionBadHitT) {
  494. const std::string body = R"(
  495. OpRayQueryGenerateIntersectionKHR %ray_query %u32_0
  496. )";
  497. CompileSuccessfully(GenerateShaderCode(body).c_str());
  498. EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  499. EXPECT_THAT(getDiagnosticString(),
  500. HasSubstr("Hit T must be a 32-bit float scalar"));
  501. }
  502. TEST_F(ValidateRayQuery, RayQueryArraySuccess) {
  503. // This shader is slightly different to the ones above, so it doesn't reuse
  504. // the shader code generator.
  505. const std::string shader = R"(
  506. OpCapability Shader
  507. OpCapability RayQueryKHR
  508. OpExtension "SPV_KHR_ray_query"
  509. OpMemoryModel Logical GLSL450
  510. OpEntryPoint GLCompute %main "main"
  511. OpExecutionMode %main LocalSize 1 1 1
  512. OpSource GLSL 460
  513. OpDecorate %topLevelAS DescriptorSet 0
  514. OpDecorate %topLevelAS Binding 0
  515. OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
  516. %void = OpTypeVoid
  517. %func = OpTypeFunction %void
  518. %ray_query = OpTypeRayQueryKHR
  519. %uint = OpTypeInt 32 0
  520. %uint_2 = OpConstant %uint 2
  521. %ray_query_array = OpTypeArray %ray_query %uint_2
  522. %ptr_ray_query_array = OpTypePointer Private %ray_query_array
  523. %rayQueries = OpVariable %ptr_ray_query_array Private
  524. %int = OpTypeInt 32 1
  525. %int_0 = OpConstant %int 0
  526. %ptr_ray_query = OpTypePointer Private %ray_query
  527. %accel_struct = OpTypeAccelerationStructureKHR
  528. %ptr_accel_struct = OpTypePointer UniformConstant %accel_struct
  529. %topLevelAS = OpVariable %ptr_accel_struct UniformConstant
  530. %uint_0 = OpConstant %uint 0
  531. %uint_255 = OpConstant %uint 255
  532. %float = OpTypeFloat 32
  533. %v3float = OpTypeVector %float 3
  534. %float_0 = OpConstant %float 0
  535. %vec3_zero = OpConstantComposite %v3float %float_0 %float_0 %float_0
  536. %float_1 = OpConstant %float 1
  537. %vec3_xy_0_z_1 = OpConstantComposite %v3float %float_0 %float_0 %float_1
  538. %float_10 = OpConstant %float 10
  539. %v3uint = OpTypeVector %uint 3
  540. %uint_1 = OpConstant %uint 1
  541. %gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1
  542. %main = OpFunction %void None %func
  543. %main_label = OpLabel
  544. %first_ray_query = OpAccessChain %ptr_ray_query %rayQueries %int_0
  545. %topLevelAS_val = OpLoad %accel_struct %topLevelAS
  546. OpRayQueryInitializeKHR %first_ray_query %topLevelAS_val %uint_0 %uint_255 %vec3_zero %float_0 %vec3_xy_0_z_1 %float_10
  547. OpReturn
  548. OpFunctionEnd
  549. )";
  550. CompileSuccessfully(shader);
  551. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
  552. }
  553. TEST_F(ValidateRayQuery, ClusterASNV) {
  554. const std::string cap = R"(
  555. OpCapability RayTracingClusterAccelerationStructureNV
  556. )";
  557. const std::string ext = R"(
  558. OpExtension "SPV_NV_cluster_acceleration_structure"
  559. )";
  560. const std::string body = R"(
  561. %clusterid = OpRayQueryGetClusterIdNV %s32 %ray_query %s32_0
  562. )";
  563. CompileSuccessfully(GenerateShaderCode(body, cap, ext).c_str(),
  564. SPV_ENV_VULKAN_1_2);
  565. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
  566. }
  567. using RayQueryLSSNVCommon = spvtest::ValidateBase<std::string>;
  568. std::string RayQueryLSSNVResultType(std::string opcode, bool valid) {
  569. if (opcode.compare("OpRayQueryGetIntersectionLSSPositionsNV") == 0)
  570. return valid ? "%arr2v3" : "%f64";
  571. if (opcode.compare("OpRayQueryGetIntersectionLSSRadiiNV") == 0)
  572. return valid ? "%arr2f3" : "%f64";
  573. if (opcode.compare("OpRayQueryGetIntersectionSphereRadiusNV") == 0 ||
  574. opcode.compare("OpRayQueryGetIntersectionLSSHitValueNV") == 0) {
  575. return valid ? "%f32" : "%f64";
  576. }
  577. if (opcode.compare("OpRayQueryGetIntersectionSpherePositionNV") == 0) {
  578. return valid ? "%f32vec3" : "%f64";
  579. }
  580. if (opcode.compare("OpRayQueryIsSphereHitNV") == 0 ||
  581. opcode.compare("OpRayQueryIsLSSHitNV") == 0) {
  582. return valid ? "%bool" : "%f64";
  583. }
  584. return "";
  585. }
  586. TEST_P(RayQueryLSSNVCommon, Success) {
  587. const std::string cap = R"(
  588. OpCapability RayTracingSpheresGeometryNV
  589. OpCapability RayTracingLinearSweptSpheresGeometryNV
  590. )";
  591. const std::string ext = R"(
  592. OpExtension "SPV_NV_linear_swept_spheres"
  593. )";
  594. std::string opcode = GetParam();
  595. std::ostringstream ss;
  596. ss << "%result = ";
  597. ss << " " << opcode << " ";
  598. ss << RayQueryLSSNVResultType(opcode, true);
  599. ss << " %ray_query ";
  600. ss << " %s32_0 ";
  601. CompileSuccessfully(GenerateShaderCode(ss.str(), cap, ext).c_str(),
  602. SPV_ENV_VULKAN_1_2);
  603. EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
  604. }
  605. INSTANTIATE_TEST_SUITE_P(ValidateRayQueryLSSNVCommon, RayQueryLSSNVCommon,
  606. Values("OpRayQueryGetIntersectionSpherePositionNV",
  607. "OpRayQueryGetIntersectionLSSPositionsNV",
  608. "OpRayQueryGetIntersectionSphereRadiusNV",
  609. "OpRayQueryGetIntersectionLSSRadiiNV",
  610. "OpRayQueryGetIntersectionLSSHitValueNV",
  611. "OpRayQueryIsSphereHitNV",
  612. "OpRayQueryIsLSSHitNV"));
  613. } // namespace
  614. } // namespace val
  615. } // namespace spvtools