SupervisedLearning.cpp 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Nodes/SupervisedLearning.h>
  9. #include <Models/MultilayerPerceptron.h>
  10. #include <MachineLearning/ILabeledTrainingData.h>
  11. #include <Algorithms/Training.h>
  12. #include <AzCore/Console/ILogger.h>
  13. #include <AzCore/std/chrono/chrono.h>
  14. namespace MachineLearning
  15. {
  16. INeuralNetworkPtr SupervisedLearning::In
  17. (
  18. INeuralNetworkPtr Model,
  19. ILabeledTrainingDataPtr TrainingData,
  20. ILabeledTrainingDataPtr TestData,
  21. AZStd::size_t CostFunction,
  22. AZStd::size_t TotalIterations,
  23. AZStd::size_t BatchSize,
  24. float LearningRate,
  25. float LearningRateDecay,
  26. float EarlyStopCost
  27. )
  28. {
  29. SupervisedLearningCycle trainingInstance(Model, TrainingData, TestData, static_cast<LossFunctions>(CostFunction), TotalIterations, BatchSize, LearningRate, LearningRateDecay, EarlyStopCost);
  30. trainingInstance.StartTraining();
  31. while (!trainingInstance.m_trainingComplete)
  32. {
  33. AZStd::this_thread::sleep_for(AZStd::chrono::milliseconds(1));
  34. }
  35. trainingInstance.StopTraining();
  36. return Model;
  37. }
  38. }