Training.h 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Math/VectorN.h>
  10. #include <AzCore/Jobs/JobManager.h>
  11. #include <AzCore/Jobs/JobContext.h>
  12. #include <AzCore/Threading/ThreadSafeDeque.h>
  13. #include <MachineLearning/INeuralNetwork.h>
  14. #include <Assets/TrainingDataView.h>
  15. namespace MachineLearning
  16. {
  17. //! Performs a supervised learning training cycle.
  18. //! Supervised learning is a form of machine learning where a model is provided a set of training data with expected output
  19. //! Training then takes place in an iterative loop where the total error (cost, loss) of the model is minimized
  20. //! This differs from unsupervised learning, where the training data lacks any form of labeling (expected correct output), and
  21. //! the model is expected to learn the underlying structures of data on its own.
  22. class SupervisedLearningCycle
  23. {
  24. public:
  25. SupervisedLearningCycle();
  26. SupervisedLearningCycle
  27. (
  28. INeuralNetworkPtr model,
  29. ILabeledTrainingDataPtr trainingData,
  30. ILabeledTrainingDataPtr testData,
  31. LossFunctions costFunction,
  32. AZStd::size_t totalIterations,
  33. AZStd::size_t batchSize,
  34. float learningRate,
  35. float learningRateDecay,
  36. float earlyStopCost
  37. );
  38. void InitializeContexts();
  39. void StartTraining();
  40. void StopTraining();
  41. AZStd::atomic<AZStd::size_t> m_currentEpoch = 0;
  42. std::atomic<bool> m_trainingComplete = true;
  43. AZ::ThreadSafeDeque<float> m_testCosts;
  44. AZ::ThreadSafeDeque<float> m_trainCosts;
  45. INeuralNetworkPtr m_model;
  46. bool m_shuffleTrainingData = true;
  47. TrainingDataView m_trainData;
  48. TrainingDataView m_testData;
  49. LossFunctions m_costFunction = LossFunctions::MeanSquaredError;
  50. AZStd::size_t m_totalIterations = 0;
  51. AZStd::size_t m_batchSize = 0;
  52. float m_learningRate = 0.0f;
  53. float m_learningRateDecay = 0.0f;
  54. float m_earlyStopCost = 0.0f;
  55. AZStd::size_t m_currentIndex = 0;
  56. AZStd::unique_ptr<IInferenceContext> m_inferenceContext;
  57. AZStd::unique_ptr<ITrainingContext> m_trainingContext;
  58. private:
  59. //! Calculates the average cost of the provided model on the set of labeled test data using the requested loss function.
  60. float ComputeCurrentCost(ILabeledTrainingData& testData, LossFunctions costFunction);
  61. void ExecTraining();
  62. AZStd::unique_ptr<AZ::JobManager> m_trainingJobManager;
  63. AZStd::unique_ptr<AZ::JobContext> m_trainingjobContext;
  64. //! Guards model state.
  65. mutable AZStd::recursive_mutex m_mutex;
  66. };
  67. }