Training.h 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Math/VectorN.h>
  10. #include <AzCore/Jobs/JobManager.h>
  11. #include <AzCore/Jobs/JobContext.h>
  12. #include <MachineLearning/INeuralNetwork.h>
  13. #include <MachineLearning/ILabeledTrainingData.h>
  14. namespace MachineLearning
  15. {
  16. //! Performs a supervised learning training cycle.
  17. //! Supervised learning is a form of machine learning where a model is provided a set of training data with expected output
  18. //! Training then takes place in an iterative loop where the total error (cost, loss) of the model is minimized
  19. //! This differs from unsupervised learning, where the training data lacks any form of labeling (expected correct output), and
  20. //! the model is expected to learn the underlying structures of data on its own.
  21. class SupervisedLearningCycle
  22. {
  23. public:
  24. SupervisedLearningCycle();
  25. SupervisedLearningCycle
  26. (
  27. INeuralNetworkPtr model,
  28. ILabeledTrainingDataPtr trainingData,
  29. ILabeledTrainingDataPtr testData,
  30. LossFunctions costFunction,
  31. AZStd::size_t totalIterations,
  32. AZStd::size_t batchSize,
  33. float learningRate,
  34. float learningRateDecay,
  35. float earlyStopCost
  36. );
  37. void InitializeContexts();
  38. void StartTraining();
  39. void StopTraining();
  40. //! Calculates the average cost of the provided model on the set of labeled test data using the requested loss function.
  41. float ComputeCurrentCost(ILabeledTrainingDataPtr TestData, LossFunctions CostFunction, AZStd::size_t maxSamples = 0);
  42. AZStd::atomic<AZStd::size_t> m_currentEpoch = 0;
  43. std::atomic<bool> m_trainingComplete = true;
  44. //private:
  45. void ExecTraining();
  46. INeuralNetworkPtr m_model;
  47. ILabeledTrainingDataPtr m_trainingData;
  48. ILabeledTrainingDataPtr m_testData;
  49. LossFunctions m_costFunction = LossFunctions::MeanSquaredError;
  50. AZStd::size_t m_totalIterations = 0;
  51. AZStd::size_t m_batchSize = 0;
  52. float m_learningRate = 0.0f;
  53. float m_learningRateDecay = 0.0f;
  54. float m_earlyStopCost = 0.0f;
  55. AZStd::vector<AZStd::size_t> m_indices;
  56. AZStd::size_t m_currentIndex = 0;
  57. AZStd::unique_ptr<IInferenceContext> m_inferenceContext;
  58. AZStd::unique_ptr<ITrainingContext> m_trainingContext;
  59. AZStd::unique_ptr<AZ::JobManager> m_trainingJobManager;
  60. AZStd::unique_ptr<AZ::JobContext> m_trainingjobContext;
  61. //! Guards model state.
  62. mutable AZStd::recursive_mutex m_mutex;
  63. };
  64. }