LossFunctions.cpp 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Algorithms/LossFunctions.h>
  9. #include <AzCore/Math/SimdMath.h>
  10. namespace MachineLearning
  11. {
  12. float ComputeTotalCost(LossFunctions lossFunction, const AZ::VectorN& expected, const AZ::VectorN& actual)
  13. {
  14. AZ::VectorN costs;
  15. ComputeLoss(lossFunction, expected, actual, costs);
  16. AZ::Vector4 accumulator = AZ::Vector4::CreateZero();
  17. for (const AZ::Vector4& element : costs.GetVectorValues())
  18. {
  19. accumulator += element;
  20. }
  21. return accumulator.Dot(AZ::Vector4::CreateOne());
  22. }
  23. void ComputeLoss(LossFunctions costFunction, const AZ::VectorN& expected, const AZ::VectorN& actual, AZ::VectorN& output)
  24. {
  25. AZ_Assert(expected.GetDimensionality() == actual.GetDimensionality(), "The dimensionality of expected and actual must match");
  26. output.Resize(actual.GetDimensionality());
  27. switch (costFunction)
  28. {
  29. case LossFunctions::MeanSquaredError:
  30. MeanSquaredError(expected, actual, output);
  31. break;
  32. }
  33. }
  34. void MeanSquaredError(const AZ::VectorN& expected, const AZ::VectorN& actual, AZ::VectorN& output)
  35. {
  36. output = (actual - expected).GetSquare();
  37. }
  38. void ComputeLoss_Derivative(LossFunctions costFunction, const AZ::VectorN& expected, const AZ::VectorN& actual, AZ::VectorN& output)
  39. {
  40. AZ_Assert(expected.GetDimensionality() == actual.GetDimensionality(), "The dimensionality of expected and actual must match");
  41. output.Resize(actual.GetDimensionality());
  42. switch (costFunction)
  43. {
  44. case LossFunctions::MeanSquaredError:
  45. MeanSquaredError_Derivative(expected, actual, output);
  46. break;
  47. }
  48. }
  49. void MeanSquaredError_Derivative(const AZ::VectorN& expected, const AZ::VectorN& actual, AZ::VectorN& output)
  50. {
  51. output = (expected - actual);
  52. }
  53. }