validate_derivatives.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates correctness of derivative SPIR-V instructions.
  15. #include "source/val/validate.h"
  16. #include <string>
  17. #include "source/diagnostic.h"
  18. #include "source/opcode.h"
  19. #include "source/val/instruction.h"
  20. #include "source/val/validation_state.h"
  21. namespace spvtools {
  22. namespace val {
  23. // Validates correctness of derivative instructions.
  24. spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
  25. const SpvOp opcode = inst->opcode();
  26. const uint32_t result_type = inst->type_id();
  27. switch (opcode) {
  28. case SpvOpDPdx:
  29. case SpvOpDPdy:
  30. case SpvOpFwidth:
  31. case SpvOpDPdxFine:
  32. case SpvOpDPdyFine:
  33. case SpvOpFwidthFine:
  34. case SpvOpDPdxCoarse:
  35. case SpvOpDPdyCoarse:
  36. case SpvOpFwidthCoarse: {
  37. if (!_.IsFloatScalarOrVectorType(result_type)) {
  38. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  39. << "Expected Result Type to be float scalar or vector type: "
  40. << spvOpcodeString(opcode);
  41. }
  42. const uint32_t p_type = _.GetOperandTypeId(inst, 2);
  43. if (p_type != result_type) {
  44. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  45. << "Expected P type and Result Type to be the same: "
  46. << spvOpcodeString(opcode);
  47. }
  48. _.function(inst->function()->id())
  49. ->RegisterExecutionModelLimitation([opcode](SpvExecutionModel model,
  50. std::string* message) {
  51. if (model != SpvExecutionModelFragment &&
  52. model != SpvExecutionModelGLCompute) {
  53. if (message) {
  54. *message =
  55. std::string(
  56. "Derivative instructions require Fragment or GLCompute "
  57. "execution model: ") +
  58. spvOpcodeString(opcode);
  59. }
  60. return false;
  61. }
  62. return true;
  63. });
  64. _.function(inst->function()->id())
  65. ->RegisterLimitation([opcode](const ValidationState_t& state,
  66. const Function* entry_point,
  67. std::string* message) {
  68. const auto* models = state.GetExecutionModels(entry_point->id());
  69. const auto* modes = state.GetExecutionModes(entry_point->id());
  70. if (models->find(SpvExecutionModelGLCompute) != models->end() &&
  71. modes->find(SpvExecutionModeDerivativeGroupLinearNV) ==
  72. modes->end() &&
  73. modes->find(SpvExecutionModeDerivativeGroupQuadsNV) ==
  74. modes->end()) {
  75. if (message) {
  76. *message = std::string(
  77. "Derivative instructions require "
  78. "DerivativeGroupQuadsNV "
  79. "or DerivativeGroupLinearNV execution mode for "
  80. "GLCompute execution model: ") +
  81. spvOpcodeString(opcode);
  82. }
  83. return false;
  84. }
  85. return true;
  86. });
  87. break;
  88. }
  89. default:
  90. break;
  91. }
  92. return SPV_SUCCESS;
  93. }
  94. } // namespace val
  95. } // namespace spvtools