validate_ray_query.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. // Copyright (c) 2022 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates ray query instructions from SPV_KHR_ray_query
  15. #include "source/opcode.h"
  16. #include "source/val/instruction.h"
  17. #include "source/val/validate.h"
  18. #include "source/val/validation_state.h"
  19. namespace spvtools {
  20. namespace val {
  21. namespace {
  22. spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
  23. const Instruction* inst,
  24. uint32_t ray_query_index) {
  25. const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
  26. auto variable = _.FindDef(ray_query_id);
  27. const auto var_opcode = variable->opcode();
  28. if (!variable || (var_opcode != spv::Op::OpVariable &&
  29. var_opcode != spv::Op::OpFunctionParameter &&
  30. var_opcode != spv::Op::OpAccessChain)) {
  31. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  32. << "Ray Query must be a memory object declaration";
  33. }
  34. auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
  35. if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
  36. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  37. << "Ray Query must be a pointer";
  38. }
  39. auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
  40. if (!type || type->opcode() != spv::Op::OpTypeRayQueryKHR) {
  41. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  42. << "Ray Query must be a pointer to OpTypeRayQueryKHR";
  43. }
  44. return SPV_SUCCESS;
  45. }
  46. spv_result_t ValidateIntersectionId(ValidationState_t& _,
  47. const Instruction* inst,
  48. uint32_t intersection_index) {
  49. const uint32_t intersection_id =
  50. inst->GetOperandAs<uint32_t>(intersection_index);
  51. const uint32_t intersection_type = _.GetTypeId(intersection_id);
  52. const spv::Op intersection_opcode = _.GetIdOpcode(intersection_id);
  53. if (!_.IsIntScalarType(intersection_type) ||
  54. _.GetBitWidth(intersection_type) != 32 ||
  55. !spvOpcodeIsConstant(intersection_opcode)) {
  56. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  57. << "expected Intersection ID to be a constant 32-bit int scalar";
  58. }
  59. return SPV_SUCCESS;
  60. }
  61. } // namespace
  62. spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
  63. const spv::Op opcode = inst->opcode();
  64. const uint32_t result_type = inst->type_id();
  65. switch (opcode) {
  66. case spv::Op::OpRayQueryInitializeKHR: {
  67. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  68. if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
  69. spv::Op::OpTypeAccelerationStructureKHR) {
  70. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  71. << "Expected Acceleration Structure to be of type "
  72. "OpTypeAccelerationStructureKHR";
  73. }
  74. const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
  75. if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
  76. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  77. << "Ray Flags must be a 32-bit int scalar";
  78. }
  79. const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
  80. if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
  81. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  82. << "Cull Mask must be a 32-bit int scalar";
  83. }
  84. const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
  85. if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
  86. _.GetBitWidth(ray_origin) != 32) {
  87. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  88. << "Ray Origin must be a 32-bit float 3-component vector";
  89. }
  90. const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
  91. if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
  92. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  93. << "Ray TMin must be a 32-bit float scalar";
  94. }
  95. const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
  96. if (!_.IsFloatVectorType(ray_direction) ||
  97. _.GetDimension(ray_direction) != 3 ||
  98. _.GetBitWidth(ray_direction) != 32) {
  99. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  100. << "Ray Direction must be a 32-bit float 3-component vector";
  101. }
  102. const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
  103. if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
  104. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  105. << "Ray TMax must be a 32-bit float scalar";
  106. }
  107. break;
  108. }
  109. case spv::Op::OpRayQueryTerminateKHR:
  110. case spv::Op::OpRayQueryConfirmIntersectionKHR: {
  111. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  112. break;
  113. }
  114. case spv::Op::OpRayQueryGenerateIntersectionKHR: {
  115. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  116. const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
  117. if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
  118. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  119. << "Hit T must be a 32-bit float scalar";
  120. }
  121. break;
  122. }
  123. case spv::Op::OpRayQueryGetIntersectionFrontFaceKHR:
  124. case spv::Op::OpRayQueryProceedKHR:
  125. case spv::Op::OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
  126. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  127. if (!_.IsBoolScalarType(result_type)) {
  128. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  129. << "expected Result Type to be bool scalar type";
  130. }
  131. if (opcode == spv::Op::OpRayQueryGetIntersectionFrontFaceKHR) {
  132. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  133. }
  134. break;
  135. }
  136. case spv::Op::OpRayQueryGetIntersectionTKHR:
  137. case spv::Op::OpRayQueryGetRayTMinKHR: {
  138. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  139. if (!_.IsFloatScalarType(result_type) ||
  140. _.GetBitWidth(result_type) != 32) {
  141. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  142. << "expected Result Type to be 32-bit float scalar type";
  143. }
  144. if (opcode == spv::Op::OpRayQueryGetIntersectionTKHR) {
  145. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  146. }
  147. break;
  148. }
  149. case spv::Op::OpRayQueryGetIntersectionTypeKHR:
  150. case spv::Op::OpRayQueryGetIntersectionInstanceCustomIndexKHR:
  151. case spv::Op::OpRayQueryGetIntersectionInstanceIdKHR:
  152. case spv::Op::
  153. OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
  154. case spv::Op::OpRayQueryGetIntersectionGeometryIndexKHR:
  155. case spv::Op::OpRayQueryGetIntersectionPrimitiveIndexKHR:
  156. case spv::Op::OpRayQueryGetRayFlagsKHR: {
  157. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  158. if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
  159. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  160. << "expected Result Type to be 32-bit int scalar type";
  161. }
  162. if (opcode != spv::Op::OpRayQueryGetRayFlagsKHR) {
  163. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  164. }
  165. break;
  166. }
  167. case spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR:
  168. case spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR:
  169. case spv::Op::OpRayQueryGetWorldRayDirectionKHR:
  170. case spv::Op::OpRayQueryGetWorldRayOriginKHR: {
  171. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  172. if (!_.IsFloatVectorType(result_type) ||
  173. _.GetDimension(result_type) != 3 ||
  174. _.GetBitWidth(result_type) != 32) {
  175. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  176. << "expected Result Type to be 32-bit float 3-component "
  177. "vector type";
  178. }
  179. if (opcode == spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR ||
  180. opcode == spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR) {
  181. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  182. }
  183. break;
  184. }
  185. case spv::Op::OpRayQueryGetIntersectionBarycentricsKHR: {
  186. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  187. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  188. if (!_.IsFloatVectorType(result_type) ||
  189. _.GetDimension(result_type) != 2 ||
  190. _.GetBitWidth(result_type) != 32) {
  191. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  192. << "expected Result Type to be 32-bit float 2-component "
  193. "vector type";
  194. }
  195. break;
  196. }
  197. case spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR:
  198. case spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR: {
  199. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  200. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  201. uint32_t num_rows = 0;
  202. uint32_t num_cols = 0;
  203. uint32_t col_type = 0;
  204. uint32_t component_type = 0;
  205. if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
  206. &component_type)) {
  207. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  208. << "expected matrix type as Result Type";
  209. }
  210. if (num_cols != 4) {
  211. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  212. << "expected Result Type matrix to have a Column Count of 4";
  213. }
  214. if (!_.IsFloatScalarType(component_type) ||
  215. _.GetBitWidth(result_type) != 32 || num_rows != 3) {
  216. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  217. << "expected Result Type matrix to have a Column Type of "
  218. "3-component 32-bit float vectors";
  219. }
  220. break;
  221. }
  222. default:
  223. break;
  224. }
  225. return SPV_SUCCESS;
  226. }
  227. } // namespace val
  228. } // namespace spvtools