validate_ray_query.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. // Copyright (c) 2022 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates ray query instructions from SPV_KHR_ray_query
  15. #include "source/opcode.h"
  16. #include "source/val/instruction.h"
  17. #include "source/val/validate.h"
  18. #include "source/val/validation_state.h"
  19. namespace spvtools {
  20. namespace val {
  21. namespace {
  22. uint32_t GetArrayLength(ValidationState_t& _, const Instruction* array_type) {
  23. assert(array_type->opcode() == spv::Op::OpTypeArray);
  24. uint32_t const_int_id = array_type->GetOperandAs<uint32_t>(2U);
  25. Instruction* array_length_inst = _.FindDef(const_int_id);
  26. uint32_t array_length = 0;
  27. if (array_length_inst->opcode() == spv::Op::OpConstant) {
  28. array_length = array_length_inst->GetOperandAs<uint32_t>(2);
  29. }
  30. return array_length;
  31. }
  32. spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
  33. const Instruction* inst,
  34. uint32_t ray_query_index) {
  35. const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
  36. auto variable = _.FindDef(ray_query_id);
  37. const auto var_opcode = variable->opcode();
  38. if (!variable || (var_opcode != spv::Op::OpVariable &&
  39. var_opcode != spv::Op::OpFunctionParameter &&
  40. var_opcode != spv::Op::OpAccessChain)) {
  41. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  42. << "Ray Query must be a memory object declaration";
  43. }
  44. auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
  45. if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
  46. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  47. << "Ray Query must be a pointer";
  48. }
  49. auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
  50. if (!type || type->opcode() != spv::Op::OpTypeRayQueryKHR) {
  51. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  52. << "Ray Query must be a pointer to OpTypeRayQueryKHR";
  53. }
  54. return SPV_SUCCESS;
  55. }
  56. spv_result_t ValidateIntersectionId(ValidationState_t& _,
  57. const Instruction* inst,
  58. uint32_t intersection_index) {
  59. const uint32_t intersection_id =
  60. inst->GetOperandAs<uint32_t>(intersection_index);
  61. const uint32_t intersection_type = _.GetTypeId(intersection_id);
  62. const spv::Op intersection_opcode = _.GetIdOpcode(intersection_id);
  63. if (!_.IsIntScalarType(intersection_type) ||
  64. _.GetBitWidth(intersection_type) != 32 ||
  65. !spvOpcodeIsConstant(intersection_opcode)) {
  66. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  67. << "expected Intersection ID to be a constant 32-bit int scalar";
  68. }
  69. return SPV_SUCCESS;
  70. }
  71. } // namespace
  72. spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
  73. const spv::Op opcode = inst->opcode();
  74. const uint32_t result_type = inst->type_id();
  75. switch (opcode) {
  76. case spv::Op::OpRayQueryInitializeKHR: {
  77. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  78. if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
  79. spv::Op::OpTypeAccelerationStructureKHR) {
  80. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  81. << "Expected Acceleration Structure to be of type "
  82. "OpTypeAccelerationStructureKHR";
  83. }
  84. const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
  85. if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
  86. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  87. << "Ray Flags must be a 32-bit int scalar";
  88. }
  89. const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
  90. if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
  91. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  92. << "Cull Mask must be a 32-bit int scalar";
  93. }
  94. const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
  95. if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
  96. _.GetBitWidth(ray_origin) != 32) {
  97. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  98. << "Ray Origin must be a 32-bit float 3-component vector";
  99. }
  100. const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
  101. if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
  102. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  103. << "Ray TMin must be a 32-bit float scalar";
  104. }
  105. const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
  106. if (!_.IsFloatVectorType(ray_direction) ||
  107. _.GetDimension(ray_direction) != 3 ||
  108. _.GetBitWidth(ray_direction) != 32) {
  109. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  110. << "Ray Direction must be a 32-bit float 3-component vector";
  111. }
  112. const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
  113. if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
  114. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  115. << "Ray TMax must be a 32-bit float scalar";
  116. }
  117. break;
  118. }
  119. case spv::Op::OpRayQueryTerminateKHR:
  120. case spv::Op::OpRayQueryConfirmIntersectionKHR: {
  121. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  122. break;
  123. }
  124. case spv::Op::OpRayQueryGenerateIntersectionKHR: {
  125. if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
  126. const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
  127. if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
  128. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  129. << "Hit T must be a 32-bit float scalar";
  130. }
  131. break;
  132. }
  133. case spv::Op::OpRayQueryGetIntersectionFrontFaceKHR:
  134. case spv::Op::OpRayQueryProceedKHR:
  135. case spv::Op::OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
  136. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  137. if (!_.IsBoolScalarType(result_type)) {
  138. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  139. << "expected Result Type to be bool scalar type";
  140. }
  141. if (opcode == spv::Op::OpRayQueryGetIntersectionFrontFaceKHR) {
  142. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  143. }
  144. break;
  145. }
  146. case spv::Op::OpRayQueryGetIntersectionTKHR:
  147. case spv::Op::OpRayQueryGetRayTMinKHR: {
  148. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  149. if (!_.IsFloatScalarType(result_type) ||
  150. _.GetBitWidth(result_type) != 32) {
  151. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  152. << "expected Result Type to be 32-bit float scalar type";
  153. }
  154. if (opcode == spv::Op::OpRayQueryGetIntersectionTKHR) {
  155. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  156. }
  157. break;
  158. }
  159. case spv::Op::OpRayQueryGetIntersectionTypeKHR:
  160. case spv::Op::OpRayQueryGetIntersectionInstanceCustomIndexKHR:
  161. case spv::Op::OpRayQueryGetIntersectionInstanceIdKHR:
  162. case spv::Op::
  163. OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
  164. case spv::Op::OpRayQueryGetIntersectionGeometryIndexKHR:
  165. case spv::Op::OpRayQueryGetIntersectionPrimitiveIndexKHR:
  166. case spv::Op::OpRayQueryGetRayFlagsKHR: {
  167. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  168. if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
  169. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  170. << "expected Result Type to be 32-bit int scalar type";
  171. }
  172. if (opcode != spv::Op::OpRayQueryGetRayFlagsKHR) {
  173. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  174. }
  175. break;
  176. }
  177. case spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR:
  178. case spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR:
  179. case spv::Op::OpRayQueryGetWorldRayDirectionKHR:
  180. case spv::Op::OpRayQueryGetWorldRayOriginKHR: {
  181. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  182. if (!_.IsFloatVectorType(result_type) ||
  183. _.GetDimension(result_type) != 3 ||
  184. _.GetBitWidth(result_type) != 32) {
  185. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  186. << "expected Result Type to be 32-bit float 3-component "
  187. "vector type";
  188. }
  189. if (opcode == spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR ||
  190. opcode == spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR) {
  191. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  192. }
  193. break;
  194. }
  195. case spv::Op::OpRayQueryGetIntersectionBarycentricsKHR: {
  196. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  197. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  198. if (!_.IsFloatVectorType(result_type) ||
  199. _.GetDimension(result_type) != 2 ||
  200. _.GetBitWidth(result_type) != 32) {
  201. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  202. << "expected Result Type to be 32-bit float 2-component "
  203. "vector type";
  204. }
  205. break;
  206. }
  207. case spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR:
  208. case spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR: {
  209. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  210. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  211. uint32_t num_rows = 0;
  212. uint32_t num_cols = 0;
  213. uint32_t col_type = 0;
  214. uint32_t component_type = 0;
  215. if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
  216. &component_type)) {
  217. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  218. << "expected matrix type as Result Type";
  219. }
  220. if (num_cols != 4) {
  221. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  222. << "expected Result Type matrix to have a Column Count of 4";
  223. }
  224. if (!_.IsFloatScalarType(component_type) ||
  225. _.GetBitWidth(result_type) != 32 || num_rows != 3) {
  226. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  227. << "expected Result Type matrix to have a Column Type of "
  228. "3-component 32-bit float vectors";
  229. }
  230. break;
  231. }
  232. case spv::Op::OpRayQueryGetClusterIdNV: {
  233. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  234. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  235. if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
  236. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  237. << "expected Result Type to be 32-bit int scalar type";
  238. }
  239. break;
  240. }
  241. case spv::Op::OpRayQueryGetIntersectionSpherePositionNV: {
  242. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  243. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  244. if (!_.IsFloatVectorType(result_type) ||
  245. _.GetDimension(result_type) != 3 ||
  246. _.GetBitWidth(result_type) != 32) {
  247. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  248. << "expected Result Type to be 32-bit float 3-component "
  249. "vector type";
  250. }
  251. break;
  252. }
  253. case spv::Op::OpRayQueryGetIntersectionLSSPositionsNV: {
  254. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  255. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  256. auto result_id = _.FindDef(result_type);
  257. if ((result_id->opcode() != spv::Op::OpTypeArray) ||
  258. (GetArrayLength(_, result_id) != 2) ||
  259. !_.IsFloatVectorType(_.GetComponentType(result_type)) ||
  260. _.GetDimension(_.GetComponentType(result_type)) != 3) {
  261. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  262. << "Expected 2 element array of 32-bit 3 component float point "
  263. "vector as Result Type: "
  264. << spvOpcodeString(opcode);
  265. }
  266. break;
  267. }
  268. case spv::Op::OpRayQueryGetIntersectionLSSRadiiNV: {
  269. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  270. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  271. if (!_.IsFloatArrayType(result_type) ||
  272. (GetArrayLength(_, _.FindDef(result_type)) != 2) ||
  273. !_.IsFloatScalarType(_.GetComponentType(result_type))) {
  274. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  275. << "Expected 32-bit floating point scalar as Result Type: "
  276. << spvOpcodeString(opcode);
  277. }
  278. break;
  279. }
  280. case spv::Op::OpRayQueryGetIntersectionSphereRadiusNV:
  281. case spv::Op::OpRayQueryGetIntersectionLSSHitValueNV: {
  282. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  283. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  284. if (!_.IsFloatScalarType(result_type) ||
  285. _.GetBitWidth(result_type) != 32) {
  286. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  287. << "expected Result Type to be 32-bit floating point "
  288. "scalar type";
  289. }
  290. break;
  291. }
  292. case spv::Op::OpRayQueryIsSphereHitNV:
  293. case spv::Op::OpRayQueryIsLSSHitNV: {
  294. if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
  295. if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
  296. if (!_.IsBoolScalarType(result_type)) {
  297. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  298. << "expected Result Type to be Boolean "
  299. "scalar type";
  300. }
  301. break;
  302. }
  303. default:
  304. break;
  305. }
  306. return SPV_SUCCESS;
  307. }
  308. } // namespace val
  309. } // namespace spvtools