validate_ray_tracing_reorder.cpp 57 KB


  1. // Copyright (c) 2022 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates ray tracing instructions from SPV_NV_shader_invocation_reorder and
  15. // SPV_EXT_shader_invocation_reorder
  16. #include "source/opcode.h"
  17. #include "source/val/instruction.h"
  18. #include "source/val/validate.h"
  19. #include "source/val/validation_state.h"
  20. #include <limits>
  21. namespace spvtools {
  22. namespace val {
  23. static const uint32_t KRayParamInvalidId = std::numeric_limits<uint32_t>::max();
  24. uint32_t GetArrayLength(ValidationState_t& _, const Instruction* array_type) {
  25. assert(array_type->opcode() == spv::Op::OpTypeArray);
  26. uint32_t const_int_id = array_type->GetOperandAs<uint32_t>(2U);
  27. Instruction* array_length_inst = _.FindDef(const_int_id);
  28. uint32_t array_length = 0;
  29. if (array_length_inst->opcode() == spv::Op::OpConstant) {
  30. array_length = array_length_inst->GetOperandAs<uint32_t>(2);
  31. }
  32. return array_length;
  33. }
  34. spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
  35. const Instruction* inst,
  36. uint32_t ray_query_index) {
  37. const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
  38. auto variable = _.FindDef(ray_query_id);
  39. auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
  40. if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
  41. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  42. << "Ray Query must be a pointer";
  43. }
  44. auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
  45. if (!type || type->opcode() != spv::Op::OpTypeRayQueryKHR) {
  46. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  47. << "Ray Query must be a pointer to OpTypeRayQueryKHR";
  48. }
  49. return SPV_SUCCESS;
  50. }
  51. spv_result_t ValidateHitObjectPointer(ValidationState_t& _,
  52. const Instruction* inst,
  53. uint32_t hit_object_index) {
  54. const uint32_t hit_object_id = inst->GetOperandAs<uint32_t>(hit_object_index);
  55. auto variable = _.FindDef(hit_object_id);
  56. auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
  57. if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
  58. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  59. << "Hit Object must be a pointer";
  60. }
  61. auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
  62. if (!type || type->opcode() != spv::Op::OpTypeHitObjectNV) {
  63. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  64. << "Type must be OpTypeHitObjectNV";
  65. }
  66. return SPV_SUCCESS;
  67. }
  68. spv_result_t ValidateHitObjectPointerEXT(ValidationState_t& _,
  69. const Instruction* inst,
  70. uint32_t hit_object_index) {
  71. const uint32_t hit_object_id = inst->GetOperandAs<uint32_t>(hit_object_index);
  72. auto variable = _.FindDef(hit_object_id);
  73. auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
  74. if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
  75. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  76. << "Hit Object must be a pointer";
  77. }
  78. auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
  79. if (!type || type->opcode() != spv::Op::OpTypeHitObjectEXT) {
  80. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  81. << "Type must be OpTypeHitObjectEXT";
  82. }
  83. return SPV_SUCCESS;
  84. }
  85. spv_result_t ValidateHitObjectInstructionCommonParameters(
  86. ValidationState_t& _, const Instruction* inst,
  87. uint32_t acceleration_struct_index, uint32_t instance_id_index,
  88. uint32_t primtive_id_index, uint32_t geometry_index,
  89. uint32_t ray_flags_index, uint32_t cull_mask_index, uint32_t hit_kind_index,
  90. uint32_t sbt_index, uint32_t sbt_offset_index, uint32_t sbt_stride_index,
  91. uint32_t sbt_record_offset_index, uint32_t sbt_record_stride_index,
  92. uint32_t miss_index, uint32_t ray_origin_index, uint32_t ray_tmin_index,
  93. uint32_t ray_direction_index, uint32_t ray_tmax_index,
  94. uint32_t payload_index, uint32_t hit_object_attr_index) {
  95. auto isValidId = [](uint32_t spvid) { return spvid < KRayParamInvalidId; };
  96. if (isValidId(acceleration_struct_index) &&
  97. _.GetIdOpcode(_.GetOperandTypeId(inst, acceleration_struct_index)) !=
  98. spv::Op::OpTypeAccelerationStructureKHR) {
  99. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  100. << "Expected Acceleration Structure to be of type "
  101. "OpTypeAccelerationStructureKHR";
  102. }
  103. if (isValidId(instance_id_index)) {
  104. const uint32_t instance_id = _.GetOperandTypeId(inst, instance_id_index);
  105. if (!_.IsIntScalarType(instance_id) || _.GetBitWidth(instance_id) != 32) {
  106. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  107. << "Instance Id must be a 32-bit int scalar";
  108. }
  109. }
  110. if (isValidId(primtive_id_index)) {
  111. const uint32_t primitive_id = _.GetOperandTypeId(inst, primtive_id_index);
  112. if (!_.IsIntScalarType(primitive_id) || _.GetBitWidth(primitive_id) != 32) {
  113. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  114. << "Primitive Id must be a 32-bit int scalar";
  115. }
  116. }
  117. if (isValidId(geometry_index)) {
  118. const uint32_t geometry_index_id = _.GetOperandTypeId(inst, geometry_index);
  119. if (!_.IsIntScalarType(geometry_index_id) ||
  120. _.GetBitWidth(geometry_index_id) != 32) {
  121. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  122. << "Geometry Index must be a 32-bit int scalar";
  123. }
  124. }
  125. if (isValidId(miss_index)) {
  126. const uint32_t miss_index_id = _.GetOperandTypeId(inst, miss_index);
  127. if (!_.IsUnsignedIntScalarType(miss_index_id) ||
  128. _.GetBitWidth(miss_index_id) != 32) {
  129. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  130. << "Miss Index must be a 32-bit int scalar";
  131. }
  132. }
  133. if (isValidId(cull_mask_index)) {
  134. const uint32_t cull_mask_id = _.GetOperandTypeId(inst, cull_mask_index);
  135. if (!_.IsUnsignedIntScalarType(cull_mask_id) ||
  136. _.GetBitWidth(cull_mask_id) != 32) {
  137. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  138. << "Cull mask must be a 32-bit int scalar";
  139. }
  140. }
  141. if (isValidId(sbt_index)) {
  142. const uint32_t sbt_index_id = _.GetOperandTypeId(inst, sbt_index);
  143. if (!_.IsUnsignedIntScalarType(sbt_index_id) ||
  144. _.GetBitWidth(sbt_index_id) != 32) {
  145. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  146. << "SBT Index must be a 32-bit unsigned int scalar";
  147. }
  148. }
  149. if (isValidId(sbt_offset_index)) {
  150. const uint32_t sbt_offset_id = _.GetOperandTypeId(inst, sbt_offset_index);
  151. if (!_.IsUnsignedIntScalarType(sbt_offset_id) ||
  152. _.GetBitWidth(sbt_offset_id) != 32) {
  153. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  154. << "SBT Offset must be a 32-bit unsigned int scalar";
  155. }
  156. }
  157. if (isValidId(sbt_stride_index)) {
  158. const uint32_t sbt_stride_index_id =
  159. _.GetOperandTypeId(inst, sbt_stride_index);
  160. if (!_.IsUnsignedIntScalarType(sbt_stride_index_id) ||
  161. _.GetBitWidth(sbt_stride_index_id) != 32) {
  162. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  163. << "SBT Stride must be a 32-bit unsigned int scalar";
  164. }
  165. }
  166. if (isValidId(sbt_record_offset_index)) {
  167. const uint32_t sbt_record_offset_index_id =
  168. _.GetOperandTypeId(inst, sbt_record_offset_index);
  169. if (!_.IsUnsignedIntScalarType(sbt_record_offset_index_id) ||
  170. _.GetBitWidth(sbt_record_offset_index_id) != 32) {
  171. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  172. << "SBT record offset must be a 32-bit unsigned int scalar";
  173. }
  174. }
  175. if (isValidId(sbt_record_stride_index)) {
  176. const uint32_t sbt_record_stride_index_id =
  177. _.GetOperandTypeId(inst, sbt_record_stride_index);
  178. if (!_.IsUnsignedIntScalarType(sbt_record_stride_index_id) ||
  179. _.GetBitWidth(sbt_record_stride_index_id) != 32) {
  180. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  181. << "SBT record stride must be a 32-bit unsigned int scalar";
  182. }
  183. }
  184. if (isValidId(ray_origin_index)) {
  185. const uint32_t ray_origin_id = _.GetOperandTypeId(inst, ray_origin_index);
  186. if (!_.IsFloatVectorType(ray_origin_id) ||
  187. _.GetDimension(ray_origin_id) != 3 ||
  188. _.GetBitWidth(ray_origin_id) != 32) {
  189. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  190. << "Ray Origin must be a 32-bit float 3-component vector";
  191. }
  192. }
  193. if (isValidId(ray_tmin_index)) {
  194. const uint32_t ray_tmin_id = _.GetOperandTypeId(inst, ray_tmin_index);
  195. if (!_.IsFloatScalarType(ray_tmin_id) || _.GetBitWidth(ray_tmin_id) != 32) {
  196. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  197. << "Ray TMin must be a 32-bit float scalar";
  198. }
  199. }
  200. if (isValidId(ray_direction_index)) {
  201. const uint32_t ray_direction_id =
  202. _.GetOperandTypeId(inst, ray_direction_index);
  203. if (!_.IsFloatVectorType(ray_direction_id) ||
  204. _.GetDimension(ray_direction_id) != 3 ||
  205. _.GetBitWidth(ray_direction_id) != 32) {
  206. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  207. << "Ray Direction must be a 32-bit float 3-component vector";
  208. }
  209. }
  210. if (isValidId(ray_tmax_index)) {
  211. const uint32_t ray_tmax_id = _.GetOperandTypeId(inst, ray_tmax_index);
  212. if (!_.IsFloatScalarType(ray_tmax_id) || _.GetBitWidth(ray_tmax_id) != 32) {
  213. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  214. << "Ray TMax must be a 32-bit float scalar";
  215. }
  216. }
  217. if (isValidId(ray_flags_index)) {
  218. const uint32_t ray_flags_id = _.GetOperandTypeId(inst, ray_flags_index);
  219. if (!_.IsIntScalarType(ray_flags_id) || _.GetBitWidth(ray_flags_id) != 32) {
  220. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  221. << "Ray Flags must be a 32-bit int scalar";
  222. }
  223. }
  224. if (isValidId(payload_index)) {
  225. const uint32_t payload_id = inst->GetOperandAs<uint32_t>(payload_index);
  226. auto variable = _.FindDef(payload_id);
  227. const auto var_opcode = variable->opcode();
  228. if (!variable || var_opcode != spv::Op::OpVariable ||
  229. (variable->GetOperandAs<spv::StorageClass>(2) !=
  230. spv::StorageClass::RayPayloadKHR &&
  231. variable->GetOperandAs<spv::StorageClass>(2) !=
  232. spv::StorageClass::IncomingRayPayloadKHR)) {
  233. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  234. << "payload must be a OpVariable of storage "
  235. "class RayPayloadKHR or IncomingRayPayloadKHR";
  236. }
  237. }
  238. if (isValidId(hit_kind_index)) {
  239. const uint32_t hit_kind_id = _.GetOperandTypeId(inst, hit_kind_index);
  240. if (!_.IsUnsignedIntScalarType(hit_kind_id) ||
  241. _.GetBitWidth(hit_kind_id) != 32) {
  242. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  243. << "Hit Kind must be a 32-bit unsigned int scalar";
  244. }
  245. }
  246. if (isValidId(hit_object_attr_index)) {
  247. const uint32_t hit_object_attr_id =
  248. inst->GetOperandAs<uint32_t>(hit_object_attr_index);
  249. auto variable = _.FindDef(hit_object_attr_id);
  250. const auto var_opcode = variable->opcode();
  251. if (!variable || var_opcode != spv::Op::OpVariable ||
  252. !((variable->GetOperandAs<spv::StorageClass>(2) ==
  253. spv::StorageClass::HitObjectAttributeNV) ||
  254. (variable->GetOperandAs<spv::StorageClass>(2) ==
  255. spv::StorageClass::HitObjectAttributeEXT))) {
  256. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  257. << "Hit Object Attributes id must be a OpVariable of storage "
  258. "class HitObjectAttributeNV";
  259. }
  260. }
  261. return SPV_SUCCESS;
  262. }
  263. spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
  264. const spv::Op opcode = inst->opcode();
  265. const uint32_t result_type = inst->type_id();
  266. auto RegisterOpcodeForValidModel = [](ValidationState_t& vs,
  267. const Instruction* rtinst) {
  268. std::string opcode_name = spvOpcodeString(rtinst->opcode());
  269. vs.function(rtinst->function()->id())
  270. ->RegisterExecutionModelLimitation(
  271. [opcode_name](spv::ExecutionModel model, std::string* message) {
  272. if (model != spv::ExecutionModel::RayGenerationKHR &&
  273. model != spv::ExecutionModel::ClosestHitKHR &&
  274. model != spv::ExecutionModel::MissKHR) {
  275. if (message) {
  276. *message = opcode_name +
  277. " requires RayGenerationKHR, ClosestHitKHR and "
  278. "MissKHR execution models";
  279. }
  280. return false;
  281. }
  282. return true;
  283. });
  284. return;
  285. };
  286. switch (opcode) {
  287. case spv::Op::OpHitObjectIsMissNV:
  288. case spv::Op::OpHitObjectIsHitNV:
  289. case spv::Op::OpHitObjectIsEmptyNV: {
  290. RegisterOpcodeForValidModel(_, inst);
  291. if (!_.IsBoolScalarType(result_type)) {
  292. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  293. << "expected Result Type to be bool scalar type";
  294. }
  295. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  296. break;
  297. }
  298. case spv::Op::OpHitObjectGetShaderRecordBufferHandleNV: {
  299. RegisterOpcodeForValidModel(_, inst);
  300. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  301. if (!_.IsIntVectorType(result_type) ||
  302. (_.GetDimension(result_type) != 2) ||
  303. (_.GetBitWidth(result_type) != 32))
  304. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  305. << "Expected 32-bit integer type 2-component vector as Result "
  306. "Type: "
  307. << spvOpcodeString(opcode);
  308. break;
  309. }
  310. case spv::Op::OpHitObjectGetHitKindNV:
  311. case spv::Op::OpHitObjectGetPrimitiveIndexNV:
  312. case spv::Op::OpHitObjectGetGeometryIndexNV:
  313. case spv::Op::OpHitObjectGetInstanceIdNV:
  314. case spv::Op::OpHitObjectGetInstanceCustomIndexNV:
  315. case spv::Op::OpHitObjectGetShaderBindingTableRecordIndexNV: {
  316. RegisterOpcodeForValidModel(_, inst);
  317. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  318. if (!_.IsIntScalarType(result_type) || !_.GetBitWidth(result_type))
  319. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  320. << "Expected 32-bit integer type scalar as Result Type: "
  321. << spvOpcodeString(opcode);
  322. break;
  323. }
  324. case spv::Op::OpHitObjectGetCurrentTimeNV:
  325. case spv::Op::OpHitObjectGetRayTMaxNV:
  326. case spv::Op::OpHitObjectGetRayTMinNV: {
  327. RegisterOpcodeForValidModel(_, inst);
  328. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  329. if (!_.IsFloatScalarType(result_type) || _.GetBitWidth(result_type) != 32)
  330. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  331. << "Expected 32-bit floating-point type scalar as Result Type: "
  332. << spvOpcodeString(opcode);
  333. break;
  334. }
  335. case spv::Op::OpHitObjectGetObjectToWorldNV:
  336. case spv::Op::OpHitObjectGetWorldToObjectNV: {
  337. RegisterOpcodeForValidModel(_, inst);
  338. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  339. uint32_t num_rows = 0;
  340. uint32_t num_cols = 0;
  341. uint32_t col_type = 0;
  342. uint32_t component_type = 0;
  343. if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
  344. &component_type)) {
  345. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  346. << "expected matrix type as Result Type: "
  347. << spvOpcodeString(opcode);
  348. }
  349. if (num_cols != 4) {
  350. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  351. << "expected Result Type matrix to have a Column Count of 4"
  352. << spvOpcodeString(opcode);
  353. }
  354. if (!_.IsFloatScalarType(component_type) ||
  355. _.GetBitWidth(result_type) != 32 || num_rows != 3) {
  356. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  357. << "expected Result Type matrix to have a Column Type of "
  358. "3-component 32-bit float vectors: "
  359. << spvOpcodeString(opcode);
  360. }
  361. break;
  362. }
  363. case spv::Op::OpHitObjectGetObjectRayOriginNV:
  364. case spv::Op::OpHitObjectGetObjectRayDirectionNV:
  365. case spv::Op::OpHitObjectGetWorldRayDirectionNV:
  366. case spv::Op::OpHitObjectGetWorldRayOriginNV: {
  367. RegisterOpcodeForValidModel(_, inst);
  368. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  369. if (!_.IsFloatVectorType(result_type) ||
  370. (_.GetDimension(result_type) != 3) ||
  371. (_.GetBitWidth(result_type) != 32))
  372. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  373. << "Expected 32-bit floating-point type 3-component vector as "
  374. "Result Type: "
  375. << spvOpcodeString(opcode);
  376. break;
  377. }
  378. case spv::Op::OpHitObjectGetAttributesNV: {
  379. RegisterOpcodeForValidModel(_, inst);
  380. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  381. const uint32_t hit_object_attr_id = inst->GetOperandAs<uint32_t>(1);
  382. auto variable = _.FindDef(hit_object_attr_id);
  383. const auto var_opcode = variable->opcode();
  384. if (!variable || var_opcode != spv::Op::OpVariable ||
  385. variable->GetOperandAs<spv::StorageClass>(2) !=
  386. spv::StorageClass::HitObjectAttributeNV) {
  387. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  388. << "Hit Object Attributes id must be a OpVariable of storage "
  389. "class HitObjectAttributeNV";
  390. }
  391. break;
  392. }
  393. case spv::Op::OpHitObjectExecuteShaderNV: {
  394. RegisterOpcodeForValidModel(_, inst);
  395. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  396. const uint32_t hit_object_attr_id = inst->GetOperandAs<uint32_t>(1);
  397. auto variable = _.FindDef(hit_object_attr_id);
  398. const auto var_opcode = variable->opcode();
  399. if (!variable || var_opcode != spv::Op::OpVariable ||
  400. (variable->GetOperandAs<spv::StorageClass>(2)) !=
  401. spv::StorageClass::RayPayloadKHR) {
  402. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  403. << "Hit Object Attributes id must be a OpVariable of storage "
  404. "class RayPayloadKHR";
  405. }
  406. break;
  407. }
  408. case spv::Op::OpHitObjectRecordEmptyNV: {
  409. RegisterOpcodeForValidModel(_, inst);
  410. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  411. break;
  412. }
  413. case spv::Op::OpHitObjectRecordMissNV: {
  414. RegisterOpcodeForValidModel(_, inst);
  415. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  416. const uint32_t miss_index = _.GetOperandTypeId(inst, 1);
  417. if (!_.IsUnsignedIntScalarType(miss_index) ||
  418. _.GetBitWidth(miss_index) != 32) {
  419. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  420. << "Miss Index must be a 32-bit int scalar";
  421. }
  422. const uint32_t ray_origin = _.GetOperandTypeId(inst, 2);
  423. if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
  424. _.GetBitWidth(ray_origin) != 32) {
  425. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  426. << "Ray Origin must be a 32-bit float 3-component vector";
  427. }
  428. const uint32_t ray_tmin = _.GetOperandTypeId(inst, 3);
  429. if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
  430. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  431. << "Ray TMin must be a 32-bit float scalar";
  432. }
  433. const uint32_t ray_direction = _.GetOperandTypeId(inst, 4);
  434. if (!_.IsFloatVectorType(ray_direction) ||
  435. _.GetDimension(ray_direction) != 3 ||
  436. _.GetBitWidth(ray_direction) != 32) {
  437. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  438. << "Ray Direction must be a 32-bit float 3-component vector";
  439. }
  440. const uint32_t ray_tmax = _.GetOperandTypeId(inst, 5);
  441. if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
  442. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  443. << "Ray TMax must be a 32-bit float scalar";
  444. }
  445. break;
  446. }
  447. case spv::Op::OpHitObjectRecordHitWithIndexNV: {
  448. RegisterOpcodeForValidModel(_, inst);
  449. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  450. if (auto error = ValidateHitObjectInstructionCommonParameters(
  451. _, inst, 1 /* Acceleration Struct */, 2 /* Instance Id */,
  452. 3 /* Primtive Id */, 4 /* Geometry Index */,
  453. KRayParamInvalidId /* Ray Flags */,
  454. KRayParamInvalidId /* Cull Mask */, 5 /* Hit Kind*/,
  455. 6 /* SBT index */, KRayParamInvalidId /* SBT Offset */,
  456. KRayParamInvalidId /* SBT Stride */,
  457. KRayParamInvalidId /* SBT Record Offset */,
  458. KRayParamInvalidId /* SBT Record Stride */,
  459. KRayParamInvalidId /* Miss Index */, 7 /* Ray Origin */,
  460. 8 /* Ray TMin */, 9 /* Ray Direction */, 10 /* Ray TMax */,
  461. KRayParamInvalidId /* Payload */, 11 /* Hit Object Attribute */))
  462. return error;
  463. break;
  464. }
  465. case spv::Op::OpHitObjectRecordHitNV: {
  466. RegisterOpcodeForValidModel(_, inst);
  467. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  468. if (auto error = ValidateHitObjectInstructionCommonParameters(
  469. _, inst, 1 /* Acceleration Struct */, 2 /* Instance Id */,
  470. 3 /* Primtive Id */, 4 /* Geometry Index */,
  471. KRayParamInvalidId /* Ray Flags */,
  472. KRayParamInvalidId /* Cull Mask */, 5 /* Hit Kind*/,
  473. KRayParamInvalidId /* SBT index */,
  474. KRayParamInvalidId /* SBT Offset */,
  475. KRayParamInvalidId /* SBT Stride */, 6 /* SBT Record Offset */,
  476. 7 /* SBT Record Stride */, KRayParamInvalidId /* Miss Index */,
  477. 8 /* Ray Origin */, 9 /* Ray TMin */, 10 /* Ray Direction */,
  478. 11 /* Ray TMax */, KRayParamInvalidId /* Payload */,
  479. 12 /* Hit Object Attribute */))
  480. return error;
  481. break;
  482. }
  483. case spv::Op::OpHitObjectTraceRayMotionNV: {
  484. RegisterOpcodeForValidModel(_, inst);
  485. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  486. if (auto error = ValidateHitObjectInstructionCommonParameters(
  487. _, inst, 1 /* Acceleration Struct */,
  488. KRayParamInvalidId /* Instance Id */,
  489. KRayParamInvalidId /* Primtive Id */,
  490. KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
  491. 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
  492. KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
  493. 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
  494. KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
  495. 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
  496. 10 /* Ray TMax */, 12 /* Payload */,
  497. KRayParamInvalidId /* Hit Object Attribute */))
  498. return error;
  499. // Current Time
  500. const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
  501. if (!_.IsFloatScalarType(current_time_id) ||
  502. _.GetBitWidth(current_time_id) != 32) {
  503. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  504. << "Current Times must be a 32-bit float scalar type";
  505. }
  506. break;
  507. }
  508. case spv::Op::OpHitObjectTraceRayNV: {
  509. RegisterOpcodeForValidModel(_, inst);
  510. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  511. if (auto error = ValidateHitObjectInstructionCommonParameters(
  512. _, inst, 1 /* Acceleration Struct */,
  513. KRayParamInvalidId /* Instance Id */,
  514. KRayParamInvalidId /* Primtive Id */,
  515. KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
  516. 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
  517. KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
  518. 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
  519. KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
  520. 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
  521. 10 /* Ray TMax */, 11 /* Payload */,
  522. KRayParamInvalidId /* Hit Object Attribute */))
  523. return error;
  524. break;
  525. }
  526. case spv::Op::OpReorderThreadWithHitObjectNV: {
  527. std::string opcode_name = spvOpcodeString(inst->opcode());
  528. _.function(inst->function()->id())
  529. ->RegisterExecutionModelLimitation(
  530. [opcode_name](spv::ExecutionModel model, std::string* message) {
  531. if (model != spv::ExecutionModel::RayGenerationKHR) {
  532. if (message) {
  533. *message = opcode_name +
  534. " requires RayGenerationKHR execution model";
  535. }
  536. return false;
  537. }
  538. return true;
  539. });
  540. if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
  541. if (inst->operands().size() > 1) {
  542. if (inst->operands().size() != 3) {
  543. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  544. << "Hint and Bits are optional together i.e "
  545. << " Either both Hint and Bits should be provided or neither.";
  546. }
  547. // Validate the optional opreands Hint and Bits
  548. const uint32_t hint_id = _.GetOperandTypeId(inst, 1);
  549. if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
  550. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  551. << "Hint must be a 32-bit int scalar";
  552. }
  553. const uint32_t bits_id = _.GetOperandTypeId(inst, 2);
  554. if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
  555. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  556. << "bits must be a 32-bit int scalar";
  557. }
  558. }
  559. break;
  560. }
  561. case spv::Op::OpReorderThreadWithHintNV: {
  562. std::string opcode_name = spvOpcodeString(inst->opcode());
  563. _.function(inst->function()->id())
  564. ->RegisterExecutionModelLimitation(
  565. [opcode_name](spv::ExecutionModel model, std::string* message) {
  566. if (model != spv::ExecutionModel::RayGenerationKHR) {
  567. if (message) {
  568. *message = opcode_name +
  569. " requires RayGenerationKHR execution model";
  570. }
  571. return false;
  572. }
  573. return true;
  574. });
  575. const uint32_t hint_id = _.GetOperandTypeId(inst, 0);
  576. if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
  577. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  578. << "Hint must be a 32-bit int scalar";
  579. }
  580. const uint32_t bits_id = _.GetOperandTypeId(inst, 1);
  581. if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
  582. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  583. << "bits must be a 32-bit int scalar";
  584. }
  585. break;
  586. }
  587. case spv::Op::OpHitObjectGetClusterIdNV: {
  588. RegisterOpcodeForValidModel(_, inst);
  589. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  590. if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32)
  591. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  592. << "Expected 32-bit integer type scalar as Result Type: "
  593. << spvOpcodeString(opcode);
  594. break;
  595. }
  596. case spv::Op::OpHitObjectGetSpherePositionNV: {
  597. RegisterOpcodeForValidModel(_, inst);
  598. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  599. if (!_.IsFloatVectorType(result_type) ||
  600. _.GetDimension(result_type) != 3 ||
  601. _.GetBitWidth(result_type) != 32) {
  602. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  603. << "Expected 32-bit floating point 2 component vector type as "
  604. "Result Type: "
  605. << spvOpcodeString(opcode);
  606. }
  607. break;
  608. }
  609. case spv::Op::OpHitObjectGetSphereRadiusNV: {
  610. RegisterOpcodeForValidModel(_, inst);
  611. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  612. if (!_.IsFloatScalarType(result_type) ||
  613. _.GetBitWidth(result_type) != 32) {
  614. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  615. << "Expected 32-bit floating point scalar as Result Type: "
  616. << spvOpcodeString(opcode);
  617. }
  618. break;
  619. }
  620. case spv::Op::OpHitObjectGetLSSPositionsNV: {
  621. RegisterOpcodeForValidModel(_, inst);
  622. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  623. auto result_id = _.FindDef(result_type);
  624. if ((result_id->opcode() != spv::Op::OpTypeArray) ||
  625. (GetArrayLength(_, result_id) != 2) ||
  626. !_.IsFloatVectorType(_.GetComponentType(result_type)) ||
  627. _.GetDimension(_.GetComponentType(result_type)) != 3) {
  628. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  629. << "Expected 2 element array of 32-bit 3 component float point "
  630. "vector as Result Type: "
  631. << spvOpcodeString(opcode);
  632. }
  633. break;
  634. }
  635. case spv::Op::OpHitObjectGetLSSRadiiNV: {
  636. RegisterOpcodeForValidModel(_, inst);
  637. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  638. if (!_.IsFloatArrayType(result_type) ||
  639. (GetArrayLength(_, _.FindDef(result_type)) != 2) ||
  640. !_.IsFloatScalarType(_.GetComponentType(result_type))) {
  641. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  642. << "Expected 2 element array of 32-bit floating point scalar as "
  643. "Result Type: "
  644. << spvOpcodeString(opcode);
  645. }
  646. break;
  647. }
  648. case spv::Op::OpHitObjectIsSphereHitNV: {
  649. RegisterOpcodeForValidModel(_, inst);
  650. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  651. if (!_.IsBoolScalarType(result_type)) {
  652. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  653. << "Expected Boolean scalar as Result Type: "
  654. << spvOpcodeString(opcode);
  655. }
  656. break;
  657. }
  658. case spv::Op::OpHitObjectIsLSSHitNV: {
  659. RegisterOpcodeForValidModel(_, inst);
  660. if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
  661. if (!_.IsBoolScalarType(result_type)) {
  662. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  663. << "Expected Boolean scalar as Result Type: "
  664. << spvOpcodeString(opcode);
  665. }
  666. break;
  667. }
  668. default:
  669. break;
  670. }
  671. return SPV_SUCCESS;
  672. }
  673. spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
  674. const spv::Op opcode = inst->opcode();
  675. const uint32_t result_type = inst->type_id();
  676. auto RegisterOpcodeForValidModel = [](ValidationState_t& vs,
  677. const Instruction* rtinst) {
  678. std::string opcode_name = spvOpcodeString(rtinst->opcode());
  679. vs.function(rtinst->function()->id())
  680. ->RegisterExecutionModelLimitation(
  681. [opcode_name](spv::ExecutionModel model, std::string* message) {
  682. if (model != spv::ExecutionModel::RayGenerationKHR &&
  683. model != spv::ExecutionModel::ClosestHitKHR &&
  684. model != spv::ExecutionModel::MissKHR) {
  685. if (message) {
  686. *message = opcode_name +
  687. " requires RayGenerationKHR, ClosestHitKHR and "
  688. "MissKHR execution models";
  689. }
  690. return false;
  691. }
  692. return true;
  693. });
  694. return;
  695. };
  696. switch (opcode) {
  697. case spv::Op::OpHitObjectIsMissEXT:
  698. case spv::Op::OpHitObjectIsHitEXT:
  699. case spv::Op::OpHitObjectIsEmptyEXT: {
  700. RegisterOpcodeForValidModel(_, inst);
  701. if (!_.IsBoolScalarType(result_type)) {
  702. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  703. << "expected Result Type to be bool scalar type";
  704. }
  705. if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
  706. break;
  707. }
  708. case spv::Op::OpHitObjectGetShaderRecordBufferHandleEXT: {
  709. RegisterOpcodeForValidModel(_, inst);
  710. if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
  711. if (!_.IsIntVectorType(result_type) ||
  712. (_.GetDimension(result_type) != 2) ||
  713. (_.GetBitWidth(result_type) != 32))
  714. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  715. << "Expected 32-bit integer type 2-component vector as Result "
  716. "Type: "
  717. << spvOpcodeString(opcode);
  718. break;
  719. }
  720. case spv::Op::OpHitObjectGetHitKindEXT:
  721. case spv::Op::OpHitObjectGetPrimitiveIndexEXT:
  722. case spv::Op::OpHitObjectGetGeometryIndexEXT:
  723. case spv::Op::OpHitObjectGetInstanceIdEXT:
  724. case spv::Op::OpHitObjectGetInstanceCustomIndexEXT:
  725. case spv::Op::OpHitObjectGetShaderBindingTableRecordIndexEXT:
  726. case spv::Op::OpHitObjectGetRayFlagsEXT: {
  727. RegisterOpcodeForValidModel(_, inst);
  728. if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
  729. if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32)
  730. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  731. << "Expected 32-bit integer type scalar as Result Type: "
  732. << spvOpcodeString(opcode);
  733. break;
  734. }
  735. case spv::Op::OpHitObjectGetCurrentTimeEXT:
  736. case spv::Op::OpHitObjectGetRayTMaxEXT:
  737. case spv::Op::OpHitObjectGetRayTMinEXT: {
  738. RegisterOpcodeForValidModel(_, inst);
  739. if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
  740. if (!_.IsFloatScalarType(result_type) || _.GetBitWidth(result_type) != 32)
  741. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  742. << "Expected 32-bit floating-point type scalar as Result Type: "
  743. << spvOpcodeString(opcode);
  744. break;
  745. }
  746. case spv::Op::OpHitObjectGetObjectToWorldEXT:
  747. case spv::Op::OpHitObjectGetWorldToObjectEXT: {
  748. RegisterOpcodeForValidModel(_, inst);
  749. if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
  750. uint32_t num_rows = 0;
  751. uint32_t num_cols = 0;
  752. uint32_t col_type = 0;
  753. uint32_t component_type = 0;
  754. if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
  755. &component_type)) {
  756. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  757. << "expected matrix type as Result Type: "
  758. << spvOpcodeString(opcode);
  759. }
  760. if (num_cols != 4) {
  761. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  762. << "expected Result Type matrix to have a Column Count of 4"
  763. << spvOpcodeString(opcode);
  764. }
  765. if (!_.IsFloatScalarType(component_type) ||
  766. _.GetBitWidth(result_type) != 32 || num_rows != 3) {
  767. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  768. << "expected Result Type matrix to have a Column Type of "
  769. "3-component 32-bit float vectors: "
  770. << spvOpcodeString(opcode);
  771. }
  772. break;
  773. }
  774. case spv::Op::OpHitObjectGetObjectRayOriginEXT:
  775. case spv::Op::OpHitObjectGetObjectRayDirectionEXT:
  776. case spv::Op::OpHitObjectGetWorldRayDirectionEXT:
  777. case spv::Op::OpHitObjectGetWorldRayOriginEXT: {
  778. RegisterOpcodeForValidModel(_, inst);
  779. if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
  780. if (!_.IsFloatVectorType(result_type) ||
  781. (_.GetDimension(result_type) != 3) ||
  782. (_.GetBitWidth(result_type) != 32))
  783. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  784. << "Expected 32-bit floating-point type 3-component vector as "
  785. "Result Type: "
  786. << spvOpcodeString(opcode);
  787. break;
  788. }
  789. case spv::Op::OpHitObjectGetIntersectionTriangleVertexPositionsEXT: {
  790. RegisterOpcodeForValidModel(_, inst);
  791. if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
  792. auto result_id = _.FindDef(result_type);
  793. if ((result_id->opcode() != spv::Op::OpTypeArray) ||
  794. (GetArrayLength(_, result_id) != 3) ||
  795. !_.IsFloatVectorType(_.GetComponentType(result_type)) ||
  796. _.GetDimension(_.GetComponentType(result_type)) != 3 ||
  797. _.GetBitWidth(_.GetComponentType(result_type)) != 32) {
  798. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  799. << "Expected 3 element array of 32-bit 3 component float "
  800. "vectors as Result Type: "
  801. << spvOpcodeString(opcode);
  802. }
  803. break;
  804. }
  805. case spv::Op::OpHitObjectGetAttributesEXT: {
  806. RegisterOpcodeForValidModel(_, inst);
  807. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  808. const uint32_t hit_object_attr_id = inst->GetOperandAs<uint32_t>(1);
  809. auto variable = _.FindDef(hit_object_attr_id);
  810. const auto var_opcode = variable->opcode();
  811. if (!variable || var_opcode != spv::Op::OpVariable ||
  812. variable->GetOperandAs<spv::StorageClass>(2) !=
  813. spv::StorageClass::HitObjectAttributeEXT) {
  814. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  815. << "Hit Object Attributes id must be a OpVariable of storage "
  816. "class HitObjectAttributeEXT";
  817. }
  818. break;
  819. }
  820. case spv::Op::OpHitObjectSetShaderBindingTableRecordIndexEXT: {
  821. RegisterOpcodeForValidModel(_, inst);
  822. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  823. const uint32_t sbt_index_id = _.GetOperandTypeId(inst, 1);
  824. if (!_.IsIntScalarType(sbt_index_id) ||
  825. _.GetBitWidth(sbt_index_id) != 32) {
  826. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  827. << "SBT Index must be a 32-bit integer scalar";
  828. }
  829. break;
  830. }
  831. case spv::Op::OpHitObjectExecuteShaderEXT: {
  832. RegisterOpcodeForValidModel(_, inst);
  833. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  834. const uint32_t payload_id = inst->GetOperandAs<uint32_t>(1);
  835. auto variable = _.FindDef(payload_id);
  836. const auto var_opcode = variable->opcode();
  837. if (!variable || var_opcode != spv::Op::OpVariable ||
  838. (variable->GetOperandAs<spv::StorageClass>(2) !=
  839. spv::StorageClass::RayPayloadKHR &&
  840. variable->GetOperandAs<spv::StorageClass>(2) !=
  841. spv::StorageClass::IncomingRayPayloadKHR)) {
  842. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  843. << "Payload must be a OpVariable of storage "
  844. "class RayPayloadKHR or IncomingRayPayloadKHR";
  845. }
  846. break;
  847. }
  848. case spv::Op::OpHitObjectRecordEmptyEXT: {
  849. RegisterOpcodeForValidModel(_, inst);
  850. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  851. break;
  852. }
  853. case spv::Op::OpHitObjectRecordFromQueryEXT: {
  854. RegisterOpcodeForValidModel(_, inst);
  855. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  856. if (auto error = ValidateRayQueryPointer(_, inst, 1)) return error;
  857. if (!_.HasCapability(spv::Capability::RayQueryKHR))
  858. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  859. << spvOpcodeString(opcode)
  860. << ": requires RayQueryKHR capability";
  861. // Validate SBT Record Index (operand 2)
  862. const uint32_t sbt_record_index_id = _.GetOperandTypeId(inst, 2);
  863. if (!_.IsIntScalarType(sbt_record_index_id) ||
  864. _.GetBitWidth(sbt_record_index_id) != 32) {
  865. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  866. << "SBT Record Index must be a 32-bit integer scalar";
  867. }
  868. // Validate Hit Object Attributes (operand 3)
  869. const uint32_t hit_object_attr_id = inst->GetOperandAs<uint32_t>(3);
  870. auto attr_variable = _.FindDef(hit_object_attr_id);
  871. const auto attr_var_opcode = attr_variable->opcode();
  872. if (!attr_variable || attr_var_opcode != spv::Op::OpVariable ||
  873. attr_variable->GetOperandAs<spv::StorageClass>(2) !=
  874. spv::StorageClass::HitObjectAttributeEXT) {
  875. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  876. << "Hit Object Attributes id must be a OpVariable of storage "
  877. "class HitObjectAttributeEXT";
  878. }
  879. break;
  880. }
  881. case spv::Op::OpHitObjectRecordMissEXT: {
  882. RegisterOpcodeForValidModel(_, inst);
  883. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  884. // Ray Flags (operand 1)
  885. const uint32_t ray_flags_id = _.GetOperandTypeId(inst, 1);
  886. if (!_.IsIntScalarType(ray_flags_id) ||
  887. _.GetBitWidth(ray_flags_id) != 32) {
  888. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  889. << "Ray Flags must be a 32-bit int scalar";
  890. }
  891. // Miss Index (operand 2)
  892. const uint32_t miss_index = _.GetOperandTypeId(inst, 2);
  893. if (!_.IsUnsignedIntScalarType(miss_index) ||
  894. _.GetBitWidth(miss_index) != 32) {
  895. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  896. << "Miss Index must be a 32-bit unsigned int scalar";
  897. }
  898. // Ray Origin (operand 3)
  899. const uint32_t ray_origin = _.GetOperandTypeId(inst, 3);
  900. if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
  901. _.GetBitWidth(ray_origin) != 32) {
  902. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  903. << "Ray Origin must be a 32-bit float 3-component vector";
  904. }
  905. // Ray TMin (operand 4)
  906. const uint32_t ray_tmin = _.GetOperandTypeId(inst, 4);
  907. if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
  908. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  909. << "Ray TMin must be a 32-bit float scalar";
  910. }
  911. // Ray Direction (operand 5)
  912. const uint32_t ray_direction = _.GetOperandTypeId(inst, 5);
  913. if (!_.IsFloatVectorType(ray_direction) ||
  914. _.GetDimension(ray_direction) != 3 ||
  915. _.GetBitWidth(ray_direction) != 32) {
  916. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  917. << "Ray Direction must be a 32-bit float 3-component vector";
  918. }
  919. // Ray TMax (operand 6)
  920. const uint32_t ray_tmax = _.GetOperandTypeId(inst, 6);
  921. if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
  922. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  923. << "Ray TMax must be a 32-bit float scalar";
  924. }
  925. break;
  926. }
  927. case spv::Op::OpHitObjectRecordMissMotionEXT: {
  928. RegisterOpcodeForValidModel(_, inst);
  929. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  930. // Ray Flags (operand 1)
  931. const uint32_t ray_flags_id = _.GetOperandTypeId(inst, 1);
  932. if (!_.IsIntScalarType(ray_flags_id) ||
  933. _.GetBitWidth(ray_flags_id) != 32) {
  934. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  935. << "Ray Flags must be a 32-bit int scalar";
  936. }
  937. // Miss Index (operand 2)
  938. const uint32_t miss_index = _.GetOperandTypeId(inst, 2);
  939. if (!_.IsUnsignedIntScalarType(miss_index) ||
  940. _.GetBitWidth(miss_index) != 32) {
  941. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  942. << "Miss Index must be a 32-bit unsigned int scalar";
  943. }
  944. // Ray Origin (operand 3)
  945. const uint32_t ray_origin = _.GetOperandTypeId(inst, 3);
  946. if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
  947. _.GetBitWidth(ray_origin) != 32) {
  948. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  949. << "Ray Origin must be a 32-bit float 3-component vector";
  950. }
  951. // Ray TMin (operand 4)
  952. const uint32_t ray_tmin = _.GetOperandTypeId(inst, 4);
  953. if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
  954. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  955. << "Ray TMin must be a 32-bit float scalar";
  956. }
  957. // Ray Direction (operand 5)
  958. const uint32_t ray_direction = _.GetOperandTypeId(inst, 5);
  959. if (!_.IsFloatVectorType(ray_direction) ||
  960. _.GetDimension(ray_direction) != 3 ||
  961. _.GetBitWidth(ray_direction) != 32) {
  962. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  963. << "Ray Direction must be a 32-bit float 3-component vector";
  964. }
  965. // Ray TMax (operand 6)
  966. const uint32_t ray_tmax = _.GetOperandTypeId(inst, 6);
  967. if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
  968. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  969. << "Ray TMax must be a 32-bit float scalar";
  970. }
  971. // Current Time (operand 7)
  972. const uint32_t current_time_id = _.GetOperandTypeId(inst, 7);
  973. if (!_.IsFloatScalarType(current_time_id) ||
  974. _.GetBitWidth(current_time_id) != 32) {
  975. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  976. << "Current Time must be a 32-bit float scalar";
  977. }
  978. break;
  979. }
  980. case spv::Op::OpReorderThreadWithHintEXT: {
  981. std::string opcode_name = spvOpcodeString(inst->opcode());
  982. _.function(inst->function()->id())
  983. ->RegisterExecutionModelLimitation(
  984. [opcode_name](spv::ExecutionModel model, std::string* message) {
  985. if (model != spv::ExecutionModel::RayGenerationKHR) {
  986. if (message) {
  987. *message = opcode_name +
  988. " requires RayGenerationKHR execution model";
  989. }
  990. return false;
  991. }
  992. return true;
  993. });
  994. const uint32_t hint_id = _.GetOperandTypeId(inst, 0);
  995. if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
  996. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  997. << "Hint must be a 32-bit int scalar";
  998. }
  999. const uint32_t bits_id = _.GetOperandTypeId(inst, 1);
  1000. if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
  1001. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1002. << "Bits must be a 32-bit int scalar";
  1003. }
  1004. break;
  1005. }
  1006. case spv::Op::OpReorderThreadWithHitObjectEXT: {
  1007. std::string opcode_name = spvOpcodeString(inst->opcode());
  1008. _.function(inst->function()->id())
  1009. ->RegisterExecutionModelLimitation(
  1010. [opcode_name](spv::ExecutionModel model, std::string* message) {
  1011. if (model != spv::ExecutionModel::RayGenerationKHR) {
  1012. if (message) {
  1013. *message = opcode_name +
  1014. " requires RayGenerationKHR execution model";
  1015. }
  1016. return false;
  1017. }
  1018. return true;
  1019. });
  1020. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  1021. if (inst->operands().size() > 1) {
  1022. if (inst->operands().size() != 3) {
  1023. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1024. << "Hint and Bits are optional together i.e "
  1025. << " Either both Hint and Bits should be provided or neither.";
  1026. }
  1027. // Validate the optional operands Hint and Bits
  1028. const uint32_t hint_id = _.GetOperandTypeId(inst, 1);
  1029. if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
  1030. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1031. << "Hint must be a 32-bit int scalar";
  1032. }
  1033. const uint32_t bits_id = _.GetOperandTypeId(inst, 2);
  1034. if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
  1035. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1036. << "Bits must be a 32-bit int scalar";
  1037. }
  1038. }
  1039. break;
  1040. }
  1041. case spv::Op::OpHitObjectTraceRayEXT: {
  1042. RegisterOpcodeForValidModel(_, inst);
  1043. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  1044. if (auto error = ValidateHitObjectInstructionCommonParameters(
  1045. _, inst, 1 /* Acceleration Struct */,
  1046. KRayParamInvalidId /* Instance Id */,
  1047. KRayParamInvalidId /* Primitive Id */,
  1048. KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
  1049. 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
  1050. KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
  1051. 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
  1052. KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
  1053. 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
  1054. 10 /* Ray TMax */, 11 /* Payload */,
  1055. KRayParamInvalidId /* Hit Object Attribute */))
  1056. return error;
  1057. break;
  1058. }
  1059. case spv::Op::OpHitObjectTraceRayMotionEXT: {
  1060. RegisterOpcodeForValidModel(_, inst);
  1061. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  1062. if (auto error = ValidateHitObjectInstructionCommonParameters(
  1063. _, inst, 1 /* Acceleration Struct */,
  1064. KRayParamInvalidId /* Instance Id */,
  1065. KRayParamInvalidId /* Primitive Id */,
  1066. KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
  1067. 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
  1068. KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
  1069. 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
  1070. KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
  1071. 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
  1072. 10 /* Ray TMax */, 12 /* Payload */,
  1073. KRayParamInvalidId /* Hit Object Attribute */))
  1074. return error;
  1075. // Current Time (operand 11)
  1076. const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
  1077. if (!_.IsFloatScalarType(current_time_id) ||
  1078. _.GetBitWidth(current_time_id) != 32) {
  1079. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1080. << "Current Time must be a 32-bit float scalar";
  1081. }
  1082. break;
  1083. }
  1084. case spv::Op::OpHitObjectReorderExecuteShaderEXT: {
  1085. std::string opcode_name = spvOpcodeString(inst->opcode());
  1086. _.function(inst->function()->id())
  1087. ->RegisterExecutionModelLimitation(
  1088. [opcode_name](spv::ExecutionModel model, std::string* message) {
  1089. if (model != spv::ExecutionModel::RayGenerationKHR) {
  1090. if (message) {
  1091. *message = opcode_name +
  1092. " requires RayGenerationKHR execution model";
  1093. }
  1094. return false;
  1095. }
  1096. return true;
  1097. });
  1098. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  1099. // Validate Payload (operand 1)
  1100. const uint32_t payload_id = inst->GetOperandAs<uint32_t>(1);
  1101. auto variable = _.FindDef(payload_id);
  1102. const auto var_opcode = variable->opcode();
  1103. if (!variable || var_opcode != spv::Op::OpVariable ||
  1104. (variable->GetOperandAs<spv::StorageClass>(2) !=
  1105. spv::StorageClass::RayPayloadKHR &&
  1106. variable->GetOperandAs<spv::StorageClass>(2) !=
  1107. spv::StorageClass::IncomingRayPayloadKHR)) {
  1108. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1109. << "Payload must be a OpVariable of storage "
  1110. "class RayPayloadKHR or IncomingRayPayloadKHR";
  1111. }
  1112. // Check for optional Hint and Bits (operands 2 and 3)
  1113. if (inst->operands().size() > 2) {
  1114. if (inst->operands().size() != 4) {
  1115. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1116. << "Hint and Bits are optional together i.e "
  1117. << " Either both Hint and Bits should be provided or neither.";
  1118. }
  1119. // Validate optional Hint and Bits
  1120. const uint32_t hint_id = _.GetOperandTypeId(inst, 2);
  1121. if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
  1122. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1123. << "Hint must be a 32-bit int scalar";
  1124. }
  1125. const uint32_t bits_id = _.GetOperandTypeId(inst, 3);
  1126. if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
  1127. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1128. << "Bits must be a 32-bit int scalar";
  1129. }
  1130. }
  1131. break;
  1132. }
  1133. case spv::Op::OpHitObjectTraceReorderExecuteEXT: {
  1134. std::string opcode_name = spvOpcodeString(inst->opcode());
  1135. _.function(inst->function()->id())
  1136. ->RegisterExecutionModelLimitation(
  1137. [opcode_name](spv::ExecutionModel model, std::string* message) {
  1138. if (model != spv::ExecutionModel::RayGenerationKHR) {
  1139. if (message) {
  1140. *message = opcode_name +
  1141. " requires RayGenerationKHR execution model";
  1142. }
  1143. return false;
  1144. }
  1145. return true;
  1146. });
  1147. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  1148. // Validate base trace ray parameters (operands 1-11)
  1149. if (auto error = ValidateHitObjectInstructionCommonParameters(
  1150. _, inst, 1 /* Acceleration Struct */,
  1151. KRayParamInvalidId /* Instance Id */,
  1152. KRayParamInvalidId /* Primitive Id */,
  1153. KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
  1154. 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
  1155. KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
  1156. 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
  1157. KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
  1158. 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
  1159. 10 /* Ray TMax */, 11 /* Payload */,
  1160. KRayParamInvalidId /* Hit Object Attribute */))
  1161. return error;
  1162. // Check for optional Hint and Bits (operands 12 and 13)
  1163. if (inst->operands().size() > 12) {
  1164. if (inst->operands().size() != 14) {
  1165. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1166. << "Hint and Bits are optional together i.e "
  1167. << " Either both Hint and Bits should be provided or neither.";
  1168. }
  1169. // Validate optional Hint and Bits
  1170. const uint32_t hint_id = _.GetOperandTypeId(inst, 12);
  1171. if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
  1172. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1173. << "Hint must be a 32-bit int scalar";
  1174. }
  1175. const uint32_t bits_id = _.GetOperandTypeId(inst, 13);
  1176. if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
  1177. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1178. << "Bits must be a 32-bit int scalar";
  1179. }
  1180. }
  1181. break;
  1182. }
  1183. case spv::Op::OpHitObjectTraceMotionReorderExecuteEXT: {
  1184. std::string opcode_name = spvOpcodeString(inst->opcode());
  1185. _.function(inst->function()->id())
  1186. ->RegisterExecutionModelLimitation(
  1187. [opcode_name](spv::ExecutionModel model, std::string* message) {
  1188. if (model != spv::ExecutionModel::RayGenerationKHR) {
  1189. if (message) {
  1190. *message = opcode_name +
  1191. " requires RayGenerationKHR execution model";
  1192. }
  1193. return false;
  1194. }
  1195. return true;
  1196. });
  1197. if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
  1198. // Validate base trace ray parameters (operands 1-12)
  1199. if (auto error = ValidateHitObjectInstructionCommonParameters(
  1200. _, inst, 1 /* Acceleration Struct */,
  1201. KRayParamInvalidId /* Instance Id */,
  1202. KRayParamInvalidId /* Primitive Id */,
  1203. KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
  1204. 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
  1205. KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
  1206. 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
  1207. KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
  1208. 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
  1209. 10 /* Ray TMax */, 12 /* Payload */,
  1210. KRayParamInvalidId /* Hit Object Attribute */))
  1211. return error;
  1212. // Current Time (operand 11)
  1213. const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
  1214. if (!_.IsFloatScalarType(current_time_id) ||
  1215. _.GetBitWidth(current_time_id) != 32) {
  1216. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1217. << "Current Time must be a 32-bit float scalar";
  1218. }
  1219. // Check for optional Hint and Bits (operands 13 and 14)
  1220. if (inst->operands().size() > 13) {
  1221. if (inst->operands().size() != 15) {
  1222. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1223. << "Hint and Bits are optional together i.e "
  1224. << " Either both Hint and Bits should be provided or neither.";
  1225. }
  1226. // Validate optional Hint and Bits
  1227. const uint32_t hint_id = _.GetOperandTypeId(inst, 13);
  1228. if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
  1229. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1230. << "Hint must be a 32-bit int scalar";
  1231. }
  1232. const uint32_t bits_id = _.GetOperandTypeId(inst, 14);
  1233. if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
  1234. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  1235. << "Bits must be a 32-bit int scalar";
  1236. }
  1237. }
  1238. break;
  1239. }
  1240. default:
  1241. break;
  1242. }
  1243. return SPV_SUCCESS;
  1244. }
  1245. } // namespace val
  1246. } // namespace spvtools