SpvPatternTest.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. // Copyright (C) 2025 NVIDIA Corporation
  2. //
  3. // Permission is hereby granted, free of charge, to any person obtaining a copy
  4. // of this software and associated documentation files (the "Software"), to deal
  5. // in the Software without restriction, including without limitation the rights
  6. // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  7. // copies of the Software, and to permit persons to whom the Software is
  8. // furnished to do so, subject to the following conditions:
  9. //
  10. // The above copyright notice and this permission notice shall be included in all
  11. // copies or substantial portions of the Software.
  12. //
  13. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  14. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  15. // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  16. // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  17. // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  18. // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  19. // SOFTWARE.
  20. #include "TestFixture.h"
  21. #include "glslang/Public/ResourceLimits.h"
  22. #include <gtest/gtest.h>
  23. #include <regex>
  24. #include <sstream>
  25. #include <string>
  26. namespace glslangtest {
  27. class SpvPatternTest : public ::testing::Test {
  28. protected:
  29. void SetUp() override
  30. {
  31. // Set up any common test state.
  32. }
  33. void TearDown() override
  34. {
  35. // Clean up any common test state.
  36. }
  37. // Helper function to compile shader and get SPIR-V disassembly.
  38. std::string compileShaderToSpirv(const std::string& shaderSource, EShLanguage stage)
  39. {
  40. glslang::TShader shader(stage);
  41. glslang::TProgram program;
  42. // Compile the shader
  43. const char* shaderStrings = shaderSource.c_str();
  44. shader.setStrings(&shaderStrings, 1);
  45. if (!shader.parse(GetDefaultResources(), 450, false, EShMsgDefault)) {
  46. return "COMPILATION_FAILED: " + std::string(shader.getInfoLog());
  47. }
  48. program.addShader(&shader);
  49. if (!program.link(EShMsgDefault)) {
  50. return "LINKING_FAILED: " + std::string(program.getInfoLog());
  51. }
  52. // Generate SPIR-V.
  53. std::vector<uint32_t> spirv;
  54. glslang::GlslangToSpv(*program.getIntermediate(stage), spirv);
  55. // Disassemble SPIR-V to text.
  56. std::ostringstream disassembly_stream;
  57. spv::Disassemble(disassembly_stream, spirv);
  58. return disassembly_stream.str();
  59. }
  60. // Helper function to check if the given SPIR-V string contains a specific pattern.
  61. bool containsPattern(const std::string& spirvText, const std::string& pattern)
  62. {
  63. return spirvText.find(pattern) != std::string::npos;
  64. }
  65. // Helper function to check if the given SPIR-V string contains a UConvert instruction.
  66. bool containsUConvert(const std::string& spirvText) { return containsPattern(spirvText, "UConvert"); }
  67. };
  68. // Test 1: Indexing an array with a regular int or uint should not generate a zero extension.
  69. TEST_F(SpvPatternTest, RegularIntUintArrayIndexNoConversion)
  70. {
  71. const std::string shaderSource = R"(
  72. #version 450 core
  73. layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
  74. void main() {
  75. uint u = 150u;
  76. int i = 100;
  77. float arr[200];
  78. float x = arr[u]; // Regular uint index
  79. float y = arr[i]; // Regular int index
  80. }
  81. )";
  82. std::string spirv = compileShaderToSpirv(shaderSource, EShLangCompute);
  83. // Check that the SPIR-V does NOT contain conversion instructions for regular int/uint indices.
  84. EXPECT_FALSE(containsUConvert(spirv))
  85. << "SPIR-V should not contain OpUConvert instruction for regular int/uint array indexing.\n"
  86. << "Generated SPIR-V:\n"
  87. << spirv;
  88. }
  89. // Test 2: Indexing an array with a variable index of type uint8_t should generate a zero extension.
  90. TEST_F(SpvPatternTest, Uint8VariableIndexGeneratesUConvert)
  91. {
  92. const std::string shaderSource = R"(
  93. #version 450 core
  94. #extension GL_EXT_shader_explicit_arithmetic_types : enable
  95. layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
  96. void main() {
  97. uint8_t u8 = uint8_t(150);
  98. float arr[200];
  99. float x = arr[u8]; // Variable uint8_t index
  100. }
  101. )";
  102. std::string spirv = compileShaderToSpirv(shaderSource, EShLangCompute);
  103. // Check that the SPIR-V contains OpUConvert instruction for variable uint8_t index.
  104. EXPECT_TRUE(containsUConvert(spirv))
  105. << "SPIR-V should contain OpUConvert instruction for variable uint8_t array indexing.\n"
  106. << "Generated SPIR-V:\n"
  107. << spirv;
  108. }
  109. // Test 2: Indexing an array with a variable index of type uint16_t should generate a zero extension.
  110. TEST_F(SpvPatternTest, Uint16VariableIndexGeneratesUConvert)
  111. {
  112. const std::string shaderSource = R"(
  113. #version 450 core
  114. #extension GL_EXT_shader_explicit_arithmetic_types : enable
  115. layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
  116. void main() {
  117. uint16_t u16 = uint16_t(150);
  118. float arr[200];
  119. float x = arr[u16]; // Variable uint16_t index
  120. }
  121. )";
  122. std::string spirv = compileShaderToSpirv(shaderSource, EShLangCompute);
  123. // Check that the SPIR-V contains OpUConvert instruction for variable uint16_t index.
  124. EXPECT_TRUE(containsUConvert(spirv))
  125. << "SPIR-V should contain OpUConvert instruction for variable uint16_t array indexing.\n"
  126. << "Generated SPIR-V:\n"
  127. << spirv;
  128. }
  129. // Test 3: Indexing an array with a constant index of type uint8_t should NOT generate a zero extension.
  130. // Glslang generates small constants as regular 32-bit integers.
  131. TEST_F(SpvPatternTest, Uint8ConstantIndexNoConversion)
  132. {
  133. const std::string shaderSource = R"(
  134. #version 450 core
  135. #extension GL_EXT_shader_explicit_arithmetic_types : enable
  136. layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
  137. void main() {
  138. float arr[200];
  139. float x = arr[uint8_t(150)]; // Constant uint8_t index
  140. }
  141. )";
  142. std::string spirv = compileShaderToSpirv(shaderSource, EShLangCompute);
  143. // Check that the SPIR-V does NOT contain OpUConvert instruction for constant uint8_t index.
  144. // Glslang generates small constants as regular 32-bit integers, so no conversion is needed.
  145. EXPECT_FALSE(containsUConvert(spirv))
  146. << "SPIR-V should not contain OpUConvert instruction for constant uint8_t array indexing.\n"
  147. << "Generated SPIR-V:\n"
  148. << spirv;
  149. }
  150. // Test 3: Indexing an array with a constant index of type uint16_t should NOT generate a zero extension.
  151. // (Glslang generates small constants as regular 32-bit integers.)
  152. TEST_F(SpvPatternTest, Uint16ConstantIndexNoConversion)
  153. {
  154. const std::string shaderSource = R"(
  155. #version 450 core
  156. #extension GL_EXT_shader_explicit_arithmetic_types : enable
  157. layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
  158. void main() {
  159. float arr[200];
  160. float x = arr[uint16_t(150)]; // Constant uint16_t index
  161. }
  162. )";
  163. std::string spirv = compileShaderToSpirv(shaderSource, EShLangCompute);
  164. // Check that the SPIR-V does NOT contain OpUConvert instruction for constant uint16_t index.
  165. // Glslang generates small constants as regular 32-bit integers, so no conversion is needed.
  166. EXPECT_FALSE(containsUConvert(spirv))
  167. << "SPIR-V should not contain OpUConvert instruction for constant uint16_t array indexing.\n"
  168. << "Generated SPIR-V:\n"
  169. << spirv;
  170. }
  171. } // namespace glslangtest