LossFunctions.h 1.5 KB

123456789101112131415161718192021222324252627282930
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Math/VectorN.h>
  10. #include <MachineLearning/INeuralNetwork.h>
  11. namespace MachineLearning
  12. {
  13. //! This is a useful helper that simply computes the total cost provided a loss function, and expected and actual outputs.
  14. float ComputeTotalCost(LossFunctions lossFunction, const AZ::VectorN& expected, const AZ::VectorN& actual);
  15. //! Computes the gradient of the loss using across all elements of the source vectors using the requested cost function.
  16. void ComputeLoss(LossFunctions lossFunction, const AZ::VectorN& expected, const AZ::VectorN& actual, AZ::VectorN& output);
  17. //! Computes the derivative of the rectified linear unit function (ReLU) applied to all elements of the source vector.
  18. void MeanSquaredError(const AZ::VectorN& expected, const AZ::VectorN& actual, AZ::VectorN& output);
  19. //! Computes the gradient of the loss using across all elements of the source vectors using the requested cost function.
  20. void ComputeLoss_Derivative(LossFunctions lossFunction, const AZ::VectorN& expected, const AZ::VectorN& actual, AZ::VectorN& output);
  21. //! Computes the derivative of the rectified linear unit function (ReLU) applied to all elements of the source vector.
  22. void MeanSquaredError_Derivative(const AZ::VectorN& expected, const AZ::VectorN& actual, AZ::VectorN& output);
  23. }