2
0

SupervisedLearning.cpp 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Nodes/SupervisedLearning.h>
  9. #include <Models/MultilayerPerceptron.h>
  10. #include <MachineLearning/ILabeledTrainingData.h>
  11. #include <Algorithms/Training.h>
  12. #include <AzCore/Console/ILogger.h>
  13. #include <AzCore/std/chrono/chrono.h>
  14. namespace MachineLearning
  15. {
  16. INeuralNetworkPtr SupervisedLearning::In
  17. (
  18. INeuralNetworkPtr Model,
  19. ILabeledTrainingDataPtr TrainingData,
  20. ILabeledTrainingDataPtr TestData,
  21. //LossFunctions CostFunction,
  22. AZStd::size_t CostFunction,
  23. AZStd::size_t TotalIterations,
  24. AZStd::size_t BatchSize,
  25. float LearningRate,
  26. float LearningRateDecay,
  27. float EarlyStopCost
  28. )
  29. {
  30. SupervisedLearningCycle trainingInstance(Model, TrainingData, TestData, static_cast<LossFunctions>(CostFunction), TotalIterations, BatchSize, LearningRate, LearningRateDecay, EarlyStopCost);
  31. trainingInstance.StartTraining();
  32. while (!trainingInstance.m_trainingComplete)
  33. {
  34. AZStd::this_thread::sleep_for(AZStd::chrono::milliseconds(1));
  35. }
  36. trainingInstance.StopTraining();
  37. return Model;
  38. }
  39. }