validate_derivatives.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates correctness of derivative SPIR-V instructions.
  15. #include "source/val/validate.h"
  16. #include <string>
  17. #include "source/diagnostic.h"
  18. #include "source/opcode.h"
  19. #include "source/val/instruction.h"
  20. #include "source/val/validation_state.h"
  21. namespace spvtools {
  22. namespace val {
  23. // Validates correctness of derivative instructions.
  24. spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
  25. const spv::Op opcode = inst->opcode();
  26. const uint32_t result_type = inst->type_id();
  27. switch (opcode) {
  28. case spv::Op::OpDPdx:
  29. case spv::Op::OpDPdy:
  30. case spv::Op::OpFwidth:
  31. case spv::Op::OpDPdxFine:
  32. case spv::Op::OpDPdyFine:
  33. case spv::Op::OpFwidthFine:
  34. case spv::Op::OpDPdxCoarse:
  35. case spv::Op::OpDPdyCoarse:
  36. case spv::Op::OpFwidthCoarse: {
  37. if (!_.IsFloatScalarOrVectorType(result_type)) {
  38. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  39. << "Expected Result Type to be float scalar or vector type: "
  40. << spvOpcodeString(opcode);
  41. }
  42. if (!_.ContainsSizedIntOrFloatType(result_type, spv::Op::OpTypeFloat,
  43. 32)) {
  44. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  45. << "Result type component width must be 32 bits";
  46. }
  47. const uint32_t p_type = _.GetOperandTypeId(inst, 2);
  48. if (p_type != result_type) {
  49. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  50. << "Expected P type and Result Type to be the same: "
  51. << spvOpcodeString(opcode);
  52. }
  53. _.function(inst->function()->id())
  54. ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
  55. std::string* message) {
  56. if (model != spv::ExecutionModel::Fragment &&
  57. model != spv::ExecutionModel::GLCompute) {
  58. if (message) {
  59. *message =
  60. std::string(
  61. "Derivative instructions require Fragment or GLCompute "
  62. "execution model: ") +
  63. spvOpcodeString(opcode);
  64. }
  65. return false;
  66. }
  67. return true;
  68. });
  69. _.function(inst->function()->id())
  70. ->RegisterLimitation([opcode](const ValidationState_t& state,
  71. const Function* entry_point,
  72. std::string* message) {
  73. const auto* models = state.GetExecutionModels(entry_point->id());
  74. const auto* modes = state.GetExecutionModes(entry_point->id());
  75. if (models &&
  76. models->find(spv::ExecutionModel::GLCompute) != models->end() &&
  77. (!modes ||
  78. (modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
  79. modes->end() &&
  80. modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
  81. modes->end()))) {
  82. if (message) {
  83. *message = std::string(
  84. "Derivative instructions require "
  85. "DerivativeGroupQuadsNV "
  86. "or DerivativeGroupLinearNV execution mode for "
  87. "GLCompute execution model: ") +
  88. spvOpcodeString(opcode);
  89. }
  90. return false;
  91. }
  92. return true;
  93. });
  94. break;
  95. }
  96. default:
  97. break;
  98. }
  99. return SPV_SUCCESS;
  100. }
  101. } // namespace val
  102. } // namespace spvtools