LayerTests.cpp 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzTest/AzTest.h>
  9. #include <AzCore/UnitTest/TestTypes.h>
  10. #include <Models/Layer.h>
  11. namespace UnitTest
  12. {
  13. class MachineLearning_Layers
  14. : public UnitTest::LeakDetectionFixture
  15. {
  16. };
  17. TEST_F(MachineLearning_Layers, TestConstructor)
  18. {
  19. // Construct a layer that takes 8 inputs and generates 4 outputs
  20. MachineLearning::Layer testLayer(MachineLearning::ActivationFunctions::Linear, 8, 4);
  21. EXPECT_EQ(testLayer.m_inputSize, 8);
  22. EXPECT_EQ(testLayer.m_outputSize, 4);
  23. EXPECT_EQ(testLayer.m_weights.GetColumnCount(), 8);
  24. EXPECT_EQ(testLayer.m_weights.GetRowCount(), 4);
  25. EXPECT_EQ(testLayer.m_biases.GetDimensionality(), 4);
  26. }
  27. TEST_F(MachineLearning_Layers, TestForward)
  28. {
  29. // Construct a layer that takes 8 inputs and generates 4 outputs
  30. MachineLearning::Layer testLayer(MachineLearning::ActivationFunctions::Linear, 8, 4);
  31. MachineLearning::LayerInferenceData inferenceData;
  32. testLayer.m_biases = AZ::VectorN::CreateOne(testLayer.m_biases.GetDimensionality());
  33. testLayer.m_weights = AZ::MatrixMxN::CreateZero(testLayer.m_weights.GetRowCount(), testLayer.m_weights.GetColumnCount());
  34. testLayer.m_weights += 1.0f;
  35. const AZ::VectorN ones = AZ::VectorN::CreateOne(8); // Input of all ones
  36. testLayer.Forward(inferenceData, ones);
  37. for (AZStd::size_t iter = 0; iter < inferenceData.m_output.GetDimensionality(); ++iter)
  38. {
  39. ASSERT_FLOAT_EQ(inferenceData.m_output.GetElement(iter), 9.0f); // 8 edges of 1's + 1 for the bias
  40. }
  41. const AZ::VectorN zeros = AZ::VectorN::CreateZero(8); // Input of all zeros
  42. testLayer.Forward(inferenceData, zeros);
  43. for (AZStd::size_t iter = 0; iter < inferenceData.m_output.GetDimensionality(); ++iter)
  44. {
  45. ASSERT_FLOAT_EQ(inferenceData.m_output.GetElement(iter), 1.0f); // Weights are all zero, leaving only the layer biases which are all set to 1
  46. }
  47. }
  48. }