ActivationTests.cpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzTest/AzTest.h>
  9. #include <AzCore/UnitTest/TestTypes.h>
  10. #include <Algorithms/Activations.h>
  11. namespace UnitTest
  12. {
  13. class MachineLearning_Activations
  14. : public UnitTest::LeakDetectionFixture
  15. {
  16. };
  17. TEST_F(MachineLearning_Activations, OneHotArgMax)
  18. {
  19. AZStd::size_t testValue = 1;
  20. AZ::VectorN testVector;
  21. MachineLearning::OneHotEncode(testValue, 10, testVector);
  22. EXPECT_FLOAT_EQ(testVector.GetElement(0), 0.0f);
  23. EXPECT_FLOAT_EQ(testVector.GetElement(1), 1.0f);
  24. EXPECT_EQ(MachineLearning::ArgMaxDecode(testVector), testValue);
  25. testValue = 3;
  26. MachineLearning::OneHotEncode(testValue, 10, testVector);
  27. EXPECT_EQ(MachineLearning::ArgMaxDecode(testVector), testValue);
  28. testValue = 7;
  29. MachineLearning::OneHotEncode(testValue, 10, testVector);
  30. EXPECT_EQ(MachineLearning::ArgMaxDecode(testVector), testValue);
  31. testValue = 8;
  32. MachineLearning::OneHotEncode(testValue, 10, testVector);
  33. EXPECT_EQ(MachineLearning::ArgMaxDecode(testVector), testValue);
  34. }
  35. TEST_F(MachineLearning_Activations, TestRelu)
  36. {
  37. AZ::VectorN output = AZ::VectorN::CreateZero(1024);
  38. AZ::VectorN sourceVector = AZ::VectorN::CreateRandom(1024);
  39. sourceVector *= 100.0f;
  40. sourceVector -= 50.0f;
  41. MachineLearning::ReLU(sourceVector, output);
  42. for (AZStd::size_t iter = 0; iter < output.GetDimensionality(); ++iter)
  43. {
  44. ASSERT_GE(output.GetElement(iter), 0.0f);
  45. }
  46. }
  47. TEST_F(MachineLearning_Activations, TestSigmoid)
  48. {
  49. AZ::VectorN output = AZ::VectorN::CreateZero(1024);
  50. AZ::VectorN sourceVector = AZ::VectorN::CreateRandom(1024);
  51. sourceVector *= 100.0f;
  52. sourceVector -= 50.0f;
  53. MachineLearning::Sigmoid(sourceVector, output);
  54. // Sigmoid guarantees all outputs get squished between 0 and 1
  55. for (AZStd::size_t iter = 0; iter < output.GetDimensionality(); ++iter)
  56. {
  57. ASSERT_GE(output.GetElement(iter), 0.0f);
  58. ASSERT_LE(output.GetElement(iter), 1.0f);
  59. }
  60. }
  61. TEST_F(MachineLearning_Activations, TestSoftmax)
  62. {
  63. AZ::VectorN output = AZ::VectorN::CreateZero(1024);
  64. AZ::VectorN sourceVector = AZ::VectorN::CreateRandom(1024);
  65. sourceVector *= 100.0f;
  66. sourceVector -= 50.0f;
  67. MachineLearning::Softmax(sourceVector, output);
  68. // Sigmoid guarantees all outputs get squished between 0 and 1
  69. for (AZStd::size_t iter = 0; iter < output.GetDimensionality(); ++iter)
  70. {
  71. ASSERT_GE(output.GetElement(iter), 0.0f);
  72. ASSERT_LE(output.GetElement(iter), 1.0f);
  73. }
  74. // Additionally, the sum of all the elements should be <= 1, as softmax returns a probability distribution
  75. const float totalSum = output.L1Norm();
  76. // Between floating point precision and the estimates we use for exp(x), the total sum probability can be slightly greater than one
  77. // We add a small epsilon to account for this error
  78. ASSERT_GE(totalSum, 1.0f - AZ::Constants::Tolerance);
  79. ASSERT_LE(totalSum, 1.0f + AZ::Constants::Tolerance);
  80. }
  81. TEST_F(MachineLearning_Activations, TestLinear)
  82. {
  83. AZ::VectorN output = AZ::VectorN::CreateZero(1024);
  84. AZ::VectorN sourceVector = AZ::VectorN::CreateRandom(1024);
  85. sourceVector *= 100.0f;
  86. sourceVector -= 50.0f;
  87. MachineLearning::Linear(sourceVector, output);
  88. // Linear just returns the input provided
  89. // This makes it not suitable for anything with more than one layer
  90. for (AZStd::size_t iter = 0; iter < output.GetDimensionality(); ++iter)
  91. {
  92. ASSERT_EQ(output.GetElement(iter), sourceVector.GetElement(iter));
  93. }
  94. }
  95. }