LossFunctionTests.cpp 826 B

12345678910111213141516171819202122232425262728
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzTest/AzTest.h>
  9. #include <AzCore/UnitTest/TestTypes.h>
  10. #include <Algorithms/LossFunctions.h>
  11. namespace UnitTest
  12. {
  13. class MachineLearning_LossFunctions
  14. : public UnitTest::LeakDetectionFixture
  15. {
  16. };
  17. TEST_F(MachineLearning_LossFunctions, TestMeanSquaredError)
  18. {
  19. AZ::VectorN expected = AZ::VectorN::CreateZero(1024);
  20. AZ::VectorN actual = AZ::VectorN::CreateOne(1024);
  21. const float totalLoss1 = MachineLearning::ComputeTotalCost(MachineLearning::LossFunctions::MeanSquaredError, expected, actual);
  22. EXPECT_EQ(totalLoss1, 1024.0f);
  23. }
  24. }