validate_ray_query.cpp 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. // Copyright (c) 2022 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates ray query instructions from SPV_KHR_ray_query
  15. #include "source/opcode.h"
  16. #include "source/val/instruction.h"
  17. #include "source/val/validate.h"
  18. #include "source/val/validation_state.h"
  19. namespace spvtools {
  20. namespace val {
  21. namespace {
  22. spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
  23. const Instruction* inst,
  24. uint32_t ray_query_index) {
  25. const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
  26. auto variable = _.FindDef(ray_query_id);
  27. if (!variable || (variable->opcode() != SpvOpVariable &&
  28. variable->opcode() != SpvOpFunctionParameter)) {
  29. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  30. << "Ray Query must be a memory object declaration";
  31. }
  32. auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
  33. if (!pointer || pointer->opcode() != SpvOpTypePointer) {
  34. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  35. << "Ray Query must be a pointer";
  36. }
  37. auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
  38. if (!type || type->opcode() != SpvOpTypeRayQueryKHR) {
  39. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  40. << "Ray Query must be a pointer to OpTypeRayQueryKHR";
  41. }
  42. return SPV_SUCCESS;
  43. }
  44. spv_result_t ValidateIntersectionId(ValidationState_t& _,
  45. const Instruction* inst,
  46. uint32_t intersection_index) {
  47. const uint32_t intersection_id =
  48. inst->GetOperandAs<uint32_t>(intersection_index);
  49. const uint32_t intersection_type = _.GetTypeId(intersection_id);
  50. const SpvOp intersection_opcode = _.GetIdOpcode(intersection_id);
  51. if (!_.IsIntScalarType(intersection_type) ||
  52. _.GetBitWidth(intersection_type) != 32 ||
  53. !spvOpcodeIsConstant(intersection_opcode)) {
  54. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  55. << "expected Intersection ID to be a constant 32-bit int scalar";
  56. }
  57. return SPV_SUCCESS;
  58. }
  59. } // namespace
  60. spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
  61. const SpvOp opcode = inst->opcode();
  62. const uint32_t result_type = inst->type_id();
  63. switch (opcode) {
  64. case SpvOpRayQueryInitializeKHR: {
  65. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  66. if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
  67. SpvOpTypeAccelerationStructureKHR) {
  68. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  69. << "Expected Acceleration Structure to be of type "
  70. "OpTypeAccelerationStructureKHR";
  71. }
  72. const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
  73. if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
  74. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  75. << "Ray Flags must be a 32-bit int scalar";
  76. }
  77. const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
  78. if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
  79. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  80. << "Cull Mask must be a 32-bit int scalar";
  81. }
  82. const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
  83. if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
  84. _.GetBitWidth(ray_origin) != 32) {
  85. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  86. << "Ray Origin must be a 32-bit float 3-component vector";
  87. }
  88. const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
  89. if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
  90. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  91. << "Ray TMin must be a 32-bit float scalar";
  92. }
  93. const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
  94. if (!_.IsFloatVectorType(ray_direction) ||
  95. _.GetDimension(ray_direction) != 3 ||
  96. _.GetBitWidth(ray_direction) != 32) {
  97. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  98. << "Ray Direction must be a 32-bit float 3-component vector";
  99. }
  100. const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
  101. if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
  102. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  103. << "Ray TMax must be a 32-bit float scalar";
  104. }
  105. break;
  106. }
  107. case SpvOpRayQueryTerminateKHR:
  108. case SpvOpRayQueryConfirmIntersectionKHR: {
  109. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  110. break;
  111. }
  112. case SpvOpRayQueryGenerateIntersectionKHR: {
  113. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  114. const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
  115. if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
  116. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  117. << "Hit T must be a 32-bit float scalar";
  118. }
  119. break;
  120. }
  121. case SpvOpRayQueryGetIntersectionFrontFaceKHR:
  122. case SpvOpRayQueryProceedKHR:
  123. case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
  124. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  125. if (!_.IsBoolScalarType(result_type)) {
  126. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  127. << "expected Result Type to be bool scalar type";
  128. }
  129. if (opcode == SpvOpRayQueryGetIntersectionFrontFaceKHR) {
  130. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  131. }
  132. break;
  133. }
  134. case SpvOpRayQueryGetIntersectionTKHR:
  135. case SpvOpRayQueryGetRayTMinKHR: {
  136. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  137. if (!_.IsFloatScalarType(result_type) ||
  138. _.GetBitWidth(result_type) != 32) {
  139. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  140. << "expected Result Type to be 32-bit float scalar type";
  141. }
  142. if (opcode == SpvOpRayQueryGetIntersectionTKHR) {
  143. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  144. }
  145. break;
  146. }
  147. case SpvOpRayQueryGetIntersectionTypeKHR:
  148. case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR:
  149. case SpvOpRayQueryGetIntersectionInstanceIdKHR:
  150. case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
  151. case SpvOpRayQueryGetIntersectionGeometryIndexKHR:
  152. case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR:
  153. case SpvOpRayQueryGetRayFlagsKHR: {
  154. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  155. if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
  156. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  157. << "expected Result Type to be 32-bit int scalar type";
  158. }
  159. if (opcode != SpvOpRayQueryGetRayFlagsKHR) {
  160. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  161. }
  162. break;
  163. }
  164. case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR:
  165. case SpvOpRayQueryGetIntersectionObjectRayOriginKHR:
  166. case SpvOpRayQueryGetWorldRayDirectionKHR:
  167. case SpvOpRayQueryGetWorldRayOriginKHR: {
  168. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  169. if (!_.IsFloatVectorType(result_type) ||
  170. _.GetDimension(result_type) != 3 ||
  171. _.GetBitWidth(result_type) != 32) {
  172. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  173. << "expected Result Type to be 32-bit float 3-component "
  174. "vector type";
  175. }
  176. if (opcode == SpvOpRayQueryGetIntersectionObjectRayDirectionKHR ||
  177. opcode == SpvOpRayQueryGetIntersectionObjectRayOriginKHR) {
  178. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  179. }
  180. break;
  181. }
  182. case SpvOpRayQueryGetIntersectionBarycentricsKHR: {
  183. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  184. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  185. if (!_.IsFloatVectorType(result_type) ||
  186. _.GetDimension(result_type) != 2 ||
  187. _.GetBitWidth(result_type) != 32) {
  188. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  189. << "expected Result Type to be 32-bit float 2-component "
  190. "vector type";
  191. }
  192. break;
  193. }
  194. case SpvOpRayQueryGetIntersectionObjectToWorldKHR:
  195. case SpvOpRayQueryGetIntersectionWorldToObjectKHR: {
  196. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  197. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  198. uint32_t num_rows = 0;
  199. uint32_t num_cols = 0;
  200. uint32_t col_type = 0;
  201. uint32_t component_type = 0;
  202. if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
  203. &component_type)) {
  204. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  205. << "expected matrix type as Result Type";
  206. }
  207. if (num_cols != 4) {
  208. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  209. << "expected Result Type matrix to have a Column Count of 4";
  210. }
  211. if (!_.IsFloatScalarType(component_type) ||
  212. _.GetBitWidth(result_type) != 32 || num_rows != 3) {
  213. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  214. << "expected Result Type matrix to have a Column Type of "
  215. "3-component 32-bit float vectors";
  216. }
  217. break;
  218. }
  219. default:
  220. break;
  221. }
  222. return SPV_SUCCESS;
  223. }
  224. } // namespace val
  225. } // namespace spvtools