ShaderResourceGroupConstantBufferTests.cpp 16 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzTest/AzTest.h>
  9. #include <Common/RPITestFixture.h>
  10. #include <Common/ShaderAssetTestUtils.h>
  11. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  12. namespace UnitTest
  13. {
  14. class ShaderResourceGroupConstantBufferTests
  15. : public RPITestFixture
  16. {
  17. protected:
  18. struct SimpleStruct
  19. {
  20. SimpleStruct() = default;
  21. SimpleStruct(float f, uint32_t u)
  22. : m_float{f}
  23. , m_uint{u}
  24. {}
  25. float m_float = 0;
  26. uint32_t m_uint = 0;
  27. };
  28. AZ::Data::Asset<AZ::RPI::ShaderAsset> m_shaderAsset;
  29. AZ::RHI::Ptr<AZ::RHI::ShaderResourceGroupLayout> m_srgLayout;
  30. AZ::Data::Instance<AZ::RPI::ShaderResourceGroup> m_srg;
  31. void SetUp() override
  32. {
  33. RPITestFixture::SetUp();
  34. // This provides the high-level metadata and low-level srg layout
  35. m_srgLayout = BuildSrgLayoutWithShaderConstants(m_shaderAsset);
  36. ASSERT_TRUE(m_srgLayout);
  37. ASSERT_TRUE(m_shaderAsset.IsReady());
  38. m_srg = AZ::RPI::ShaderResourceGroup::Create(m_shaderAsset, AZ::RPI::DefaultSupervariantIndex, m_srgLayout->GetName());
  39. ASSERT_TRUE(m_srg != nullptr);
  40. }
  41. void TearDown() override
  42. {
  43. m_srg.reset();
  44. m_srgLayout = nullptr;
  45. m_shaderAsset.Release();
  46. RPITestFixture::TearDown();
  47. }
  48. template<typename T>
  49. void ExpectEqual(AZStd::initializer_list<T> expectedValues, AZStd::span<const T> arrayView)
  50. {
  51. EXPECT_EQ(expectedValues.size(), arrayView.size());
  52. const T* expected = expectedValues.begin();
  53. for (int i = 0; i < expectedValues.size() && i < arrayView.size(); ++i)
  54. {
  55. EXPECT_EQ(expected[i], arrayView[i]);
  56. }
  57. }
  58. AZ::RHI::Ptr<AZ::RHI::ShaderResourceGroupLayout> BuildSrgLayoutWithShaderConstants(
  59. AZ::Data::Asset<AZ::RPI::ShaderAsset>& shaderAsset, [[maybe_unused]] bool includeMetadata = true)
  60. {
  61. using namespace AZ;
  62. AZ::RHI::Ptr<AZ::RHI::ShaderResourceGroupLayout> srgLayout = RHI::ShaderResourceGroupLayout::Create();
  63. srgLayout->SetName(Name{"TestSrg"});
  64. uint32_t offset = 0;
  65. uint32_t count;
  66. uint32_t size;
  67. uint32_t registerIndex = 0;
  68. uint32_t spaceIndex = 0;
  69. uint32_t sizeOfBool = 4;
  70. srgLayout->SetBindingSlot(0);
  71. // bool, binding index 0
  72. count = 1;
  73. size = count * sizeOfBool;
  74. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyBool"), offset, size, registerIndex, spaceIndex});
  75. offset += size;
  76. // bool2, binding index 1
  77. count = 2;
  78. size = count * sizeOfBool;
  79. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyBool2"), offset, size, registerIndex, spaceIndex});
  80. offset += size;
  81. // bool3, binding index 2
  82. count = 3;
  83. size = count * sizeOfBool;
  84. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyBool3"), offset, size, registerIndex, spaceIndex});
  85. offset += size;
  86. // bool4, binding index 3
  87. count = 4;
  88. size = count * sizeOfBool;
  89. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyBool4"), offset, size, registerIndex, spaceIndex});
  90. offset += size;
  91. // int, binding index 4
  92. count = 1;
  93. size = count * sizeof(int32_t);
  94. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyInt"), offset, size, registerIndex, spaceIndex});
  95. offset += size;
  96. // int2, binding index 5
  97. count = 2;
  98. size = count * sizeof(int32_t);
  99. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyInt2"), offset, size, registerIndex, spaceIndex});
  100. offset += size;
  101. // int3, binding index 6
  102. count = 3;
  103. size = count * sizeof(int32_t);
  104. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyInt3"), offset, size, registerIndex, spaceIndex});
  105. offset += size;
  106. // int4, binding index 7
  107. count = 4;
  108. size = count * sizeof(int32_t);
  109. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyInt4"), offset, size, registerIndex, spaceIndex});
  110. offset += size;
  111. // uint, binding index 8
  112. count = 1;
  113. size = count * sizeof(uint32_t);
  114. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyUint"), offset, size, registerIndex, spaceIndex});
  115. offset += size;
  116. // uint2, binding index 9
  117. count = 2;
  118. size = count * sizeof(uint32_t);
  119. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyUint2"), offset, size, registerIndex, spaceIndex});
  120. offset += size;
  121. // uint3, binding index 10
  122. count = 3;
  123. size = count * sizeof(uint32_t);
  124. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyUint3"), offset, size, registerIndex, spaceIndex});
  125. offset += size;
  126. // uint4, binding index 11
  127. count = 4;
  128. size = count * sizeof(uint32_t);
  129. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyUint4"), offset, size, registerIndex, spaceIndex});
  130. offset += size;
  131. // float, binding index 12
  132. count = 1;
  133. size = count * sizeof(float);
  134. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyFloat"), offset, size, registerIndex, spaceIndex});
  135. offset += size;
  136. // float2, binding index 13
  137. count = 2;
  138. size = count * sizeof(float);
  139. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyFloat2"), offset, size, registerIndex, spaceIndex});
  140. offset += size;
  141. // float3, binding index 14
  142. count = 3;
  143. size = count * sizeof(float);
  144. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyFloat3"), offset, size, registerIndex, spaceIndex});
  145. offset += size;
  146. // float4, binding index 15
  147. count = 4;
  148. size = count * sizeof(float);
  149. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MyFloat4"), offset, size, registerIndex, spaceIndex});
  150. offset += size;
  151. // simple struct, binding index 16
  152. // [GFX TODO][ATOM-111] This is not very fleshed out right now. We still need to do more to support structs, but at least I want to verify that SRG templatized setters and getters can work with structs
  153. count = 1;
  154. size = 8;
  155. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MySimpleStruct"), offset, size, registerIndex, spaceIndex});
  156. offset += size;
  157. // array of 2 simple structs, binding index 17
  158. // [GFX TODO][ATOM-111] This is not very fleshed out right now. We still need to do more to support structs, but at least I want to verify that SRG templatized setters and getters can work with structs
  159. count = 2;
  160. size = 16;
  161. srgLayout->AddShaderInput(RHI::ShaderInputConstantDescriptor{Name("MySimpleStructArray2"), offset, size, registerIndex, spaceIndex});
  162. offset += size;
  163. srgLayout->SetBindingSlot(0);
  164. EXPECT_TRUE(srgLayout->Finalize());
  165. shaderAsset = CreateTestShaderAsset(Uuid::CreateRandom(), srgLayout);
  166. return srgLayout;
  167. }
  168. };
  169. TEST_F(ShaderResourceGroupConstantBufferTests, SetConstant_GetConstant_ValidInput_Bool)
  170. {
  171. using namespace AZ;
  172. {
  173. const RHI::ShaderInputConstantIndex inputIndex(0);
  174. // Check using inputIndex
  175. EXPECT_TRUE(m_srg->SetConstant(inputIndex, true));
  176. EXPECT_EQ(true, m_srg->GetConstant<bool>(inputIndex));
  177. AZStd::span<const uint8_t> result = m_srg->GetConstantRaw(inputIndex);
  178. AZStd::span<const uint32_t> resultInUint = AZStd::span<const uint32_t>(reinterpret_cast<const uint32_t*>(result.data()), 1);
  179. ExpectEqual<uint32_t>({ 1 /*true*/ }, resultInUint);
  180. EXPECT_TRUE(m_srg->SetConstant(inputIndex, false));
  181. EXPECT_EQ(false, m_srg->GetConstant<bool>(inputIndex));
  182. result = m_srg->GetConstantRaw(inputIndex);
  183. resultInUint = AZStd::span<const uint32_t>(reinterpret_cast<const uint32_t*>(result.data()), 1);
  184. ExpectEqual<uint32_t>({ 0 /*false*/ }, resultInUint);
  185. }
  186. {
  187. const RHI::ShaderInputConstantIndex inputIndex(1);
  188. // Check using inputIndex
  189. EXPECT_TRUE(m_srg->SetConstantArray<bool>(inputIndex, AZStd::array<bool, 2>({ true, false })));
  190. AZStd::span<const uint8_t> result = m_srg->GetConstantRaw(inputIndex);
  191. AZStd::span<const uint32_t> resultInUint = AZStd::span<const uint32_t>(reinterpret_cast<const uint32_t*>(result.data()), 2);
  192. ExpectEqual<uint32_t>({ 1 /*true*/, 0 /*false*/ }, resultInUint);
  193. EXPECT_TRUE(m_srg->SetConstantArray<bool>(inputIndex, AZStd::array<bool, 2>({ false, true })));
  194. result = m_srg->GetConstantRaw(inputIndex);
  195. resultInUint = AZStd::span<const uint32_t>(reinterpret_cast<const uint32_t*>(result.data()), 2);
  196. ExpectEqual<uint32_t>({ 0 /*false*/, 1 /*true*/ }, resultInUint);
  197. }
  198. }
  199. TEST_F(ShaderResourceGroupConstantBufferTests, SetConstant_GetConstant_FalsePackedInGarbage_Bool)
  200. {
  201. using namespace AZ;
  202. uint32_t falsePackedInGarbage = 0xab00cdef;
  203. bool* asBools = reinterpret_cast<bool*>(&falsePackedInGarbage);
  204. {
  205. const RHI::ShaderInputConstantIndex inputIndex(0);
  206. // Check using inputIndex
  207. EXPECT_TRUE(m_srg->SetConstant<bool>(inputIndex, asBools[2]));
  208. EXPECT_EQ(false, m_srg->GetConstant<bool>(inputIndex));
  209. }
  210. {
  211. // Check using inputIndex
  212. const RHI::ShaderInputConstantIndex inputIndex(1);
  213. EXPECT_TRUE(m_srg->SetConstantArray<bool>(inputIndex, AZStd::array<bool, 2>({ asBools[1], asBools[2] })));
  214. AZStd::span<const uint8_t> result = m_srg->GetConstantRaw(inputIndex);
  215. AZStd::span<const uint32_t> resultInUint = AZStd::span<const uint32_t>(reinterpret_cast<const uint32_t*>(result.data()), 2);
  216. EXPECT_THAT(resultInUint, testing::ElementsAre(testing::IsTrue(), testing::IsFalse()));
  217. }
  218. }
  219. //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  220. // Test valid inputs for SetConstant and GetConstant
  221. TEST_F(ShaderResourceGroupConstantBufferTests, SetConstant_GetConstant_ValidInput_Int)
  222. {
  223. using namespace AZ;
  224. {
  225. const RHI::ShaderInputConstantIndex inputIndex(4);
  226. // Check using inputIndex
  227. EXPECT_TRUE(m_srg->SetConstant(inputIndex, 51));
  228. EXPECT_EQ(51, m_srg->GetConstant<int32_t>(inputIndex));
  229. ExpectEqual<int32_t>({ 51 }, m_srg->GetConstantArray<int32_t>(inputIndex));
  230. }
  231. {
  232. const RHI::ShaderInputConstantIndex inputIndex(5);
  233. // Check using inputIndex
  234. EXPECT_TRUE(m_srg->SetConstantArray<int32_t>(inputIndex, AZStd::array<int32_t, 2>({ 54, 55 })));
  235. ExpectEqual<int32_t>({ 54, 55 }, m_srg->GetConstantArray<int32_t>(inputIndex));
  236. }
  237. }
  238. TEST_F(ShaderResourceGroupConstantBufferTests, SetConstant_GetConstant_ValidInput_Float)
  239. {
  240. using namespace AZ;
  241. {
  242. const RHI::ShaderInputConstantIndex inputIndex(12);
  243. // Check using inputIndex
  244. EXPECT_TRUE(m_srg->SetConstant(inputIndex, 1.1f));
  245. EXPECT_EQ(1.1f, m_srg->GetConstant<float>(inputIndex));
  246. ExpectEqual<float>({ 1.1f }, m_srg->GetConstantArray<float>(inputIndex));
  247. }
  248. {
  249. const RHI::ShaderInputConstantIndex inputIndex(13);
  250. // Check using inputIndex
  251. EXPECT_TRUE(m_srg->SetConstantArray<float>(inputIndex, AZStd::array<float, 2>({ 1.4f, 1.5f })));
  252. ExpectEqual<float>({ 1.4f, 1.5f }, m_srg->GetConstantArray<float>(inputIndex));
  253. }
  254. }
  255. TEST_F(ShaderResourceGroupConstantBufferTests, SetConstant_GetConstant_ValidInput_Vector4)
  256. {
  257. using namespace AZ;
  258. AZ::Vector4 value;
  259. const RHI::ShaderInputConstantIndex inputIndex(15);
  260. // Check using inputIndex
  261. EXPECT_TRUE(m_srg->SetConstant(inputIndex, AZ::Vector4(2.6f, 2.7f, 2.8f, 2.9f)));
  262. value = m_srg->GetConstant<AZ::Vector4>(inputIndex);
  263. EXPECT_EQ(2.6f, static_cast<float>(value.GetX()));
  264. EXPECT_EQ(2.7f, static_cast<float>(value.GetY()));
  265. EXPECT_EQ(2.8f, static_cast<float>(value.GetZ()));
  266. EXPECT_EQ(2.9f, static_cast<float>(value.GetW()));
  267. }
  268. TEST_F(ShaderResourceGroupConstantBufferTests, SetConstant_GetConstant_ValidInput_SimpleStruct)
  269. {
  270. using namespace AZ;
  271. SimpleStruct value;
  272. const RHI::ShaderInputConstantIndex inputIndex(16);
  273. // Demonstrate the syntax of setting with a variable, and inputIndex
  274. {
  275. SimpleStruct inputValues = { 2.1f, 101 };
  276. EXPECT_TRUE(m_srg->SetConstant(inputIndex, inputValues));
  277. value = m_srg->GetConstant<SimpleStruct>(inputIndex);
  278. EXPECT_EQ(2.1f, value.m_float);
  279. EXPECT_EQ(101, value.m_uint);
  280. }
  281. }
  282. TEST_F(ShaderResourceGroupConstantBufferTests, SetConstant_GetConstant_ValidInput_SimpleStruct_Array)
  283. {
  284. using namespace AZ;
  285. AZStd::span<const SimpleStruct> values;
  286. const RHI::ShaderInputConstantIndex inputIndex(17);
  287. // Demonstrate the syntax of setting with a variable, and inputIndex...
  288. // Unfortunately, with arrays of custom types, you have to specify the element type explicitly
  289. {
  290. AZStd::vector<SimpleStruct> inputValues;
  291. inputValues.push_back({ 0.3f, 3 });
  292. inputValues.push_back({ 0.4f, 4 });
  293. EXPECT_TRUE(m_srg->SetConstantArray<SimpleStruct>(inputIndex, inputValues));
  294. values = m_srg->GetConstantArray<SimpleStruct>(inputIndex);
  295. EXPECT_EQ(2, values.size());
  296. EXPECT_EQ(0.3f, values[0].m_float);
  297. EXPECT_EQ(3, values[0].m_uint);
  298. EXPECT_EQ(0.4f, values[1].m_float);
  299. EXPECT_EQ(4, values[1].m_uint);
  300. }
  301. }
  302. TEST_F(ShaderResourceGroupConstantBufferTests, TestErrorReporting_SetConstant_WrongNumberOfElements_ArrayInput)
  303. {
  304. using namespace AZ;
  305. {
  306. AZ_TEST_START_ASSERTTEST;
  307. // MyFloat2
  308. EXPECT_FALSE(m_srg->SetConstantArray<float>(RHI::ShaderInputConstantIndex(13), AZStd::array<float, 3>({ 0.1f, 0.2f, 0.3f })));
  309. AZ_TEST_STOP_ASSERTTEST(1);
  310. }
  311. }
  312. TEST_F(ShaderResourceGroupConstantBufferTests, TestErrorReporting_GetConstants_WrongNumberOfElements_ArrayOutput)
  313. {
  314. using namespace AZ;
  315. {
  316. AZ_TEST_START_ASSERTTEST;
  317. // MyFloat2
  318. m_srg->GetConstantArray<AZ::Vector4>(RHI::ShaderInputConstantIndex(13));
  319. AZ_TEST_STOP_ASSERTTEST(1);
  320. }
  321. }
  322. TEST_F(ShaderResourceGroupConstantBufferTests, TestErrorReporting_SetConstant_WrongNumberOfElements_SingleInput)
  323. {
  324. using namespace AZ;
  325. {
  326. AZ_TEST_START_ASSERTTEST;
  327. // MyBool2
  328. EXPECT_FALSE(m_srg->SetConstant<bool>(RHI::ShaderInputConstantIndex(1), false));
  329. AZ_TEST_STOP_ASSERTTEST(1);
  330. }
  331. }
  332. TEST_F(ShaderResourceGroupConstantBufferTests, TestErrorReporting_GetConstant_WrongNumberOfElements_SingleOutput)
  333. {
  334. using namespace AZ;
  335. {
  336. AZ_TEST_START_ASSERTTEST;
  337. // MyBool3
  338. EXPECT_FALSE(m_srg->GetConstant<bool>(RHI::ShaderInputConstantIndex(2)));
  339. AZ_TEST_STOP_ASSERTTEST(1);
  340. }
  341. }
  342. }