Training.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Algorithms/Training.h>
  9. #include <Algorithms/LossFunctions.h>
  10. #include <AzCore/Math/SimdMath.h>
  11. #include <AzCore/Console/ILogger.h>
  12. #include <AzCore/Jobs/JobCompletion.h>
  13. #include <AzCore/Jobs/JobFunction.h>
  14. #include <numeric>
  15. #include <random>
  16. namespace MachineLearning
  17. {
  18. SupervisedLearningCycle::SupervisedLearningCycle()
  19. {
  20. AZ::JobManagerDesc jobDesc;
  21. jobDesc.m_jobManagerName = "MachineLearning Training";
  22. jobDesc.m_workerThreads.push_back(AZ::JobManagerThreadDesc()); // Just one thread
  23. m_trainingJobManager = AZStd::make_unique<AZ::JobManager>(jobDesc);
  24. m_trainingjobContext = AZStd::make_unique<AZ::JobContext>(*m_trainingJobManager);
  25. }
  26. SupervisedLearningCycle::SupervisedLearningCycle
  27. (
  28. INeuralNetworkPtr model,
  29. ILabeledTrainingDataPtr trainingData,
  30. ILabeledTrainingDataPtr testData,
  31. LossFunctions costFunction,
  32. AZStd::size_t totalIterations,
  33. AZStd::size_t batchSize,
  34. float learningRate,
  35. float learningRateDecay,
  36. float earlyStopCost
  37. ) : SupervisedLearningCycle()
  38. {
  39. m_model = model;
  40. m_trainingData = trainingData;
  41. m_testData = testData;
  42. m_costFunction = costFunction;
  43. m_totalIterations = totalIterations;
  44. m_batchSize = batchSize;
  45. m_learningRate = learningRate;
  46. m_learningRateDecay = learningRateDecay;
  47. m_earlyStopCost = earlyStopCost;
  48. }
  49. void SupervisedLearningCycle::InitializeContexts()
  50. {
  51. if (m_inferenceContext == nullptr)
  52. {
  53. m_inferenceContext.reset(m_model->CreateInferenceContext());
  54. m_trainingContext.reset(m_model->CreateTrainingContext());
  55. }
  56. }
  57. void SupervisedLearningCycle::StartTraining()
  58. {
  59. InitializeContexts();
  60. const AZStd::size_t totalTrainingSize = m_trainingData->GetSampleCount();
  61. // Generate a set of training indices that we can later shuffle
  62. m_indices.resize(totalTrainingSize);
  63. std::iota(m_indices.begin(), m_indices.end(), 0);
  64. std::shuffle(m_indices.begin(), m_indices.end(), std::mt19937(std::random_device{}()));
  65. // Start training
  66. m_currentEpoch = 0;
  67. m_trainingComplete = false;
  68. m_currentIndex = 0;
  69. auto job = [this]()
  70. {
  71. ExecTraining();
  72. };
  73. AZ::Job* trainingJob = AZ::CreateJobFunction(job, true, m_trainingjobContext.get());
  74. trainingJob->Start();
  75. }
  76. void SupervisedLearningCycle::ExecTraining()
  77. {
  78. const AZStd::size_t totalTrainingSize = m_trainingData->GetSampleCount();
  79. while (!m_trainingComplete)
  80. {
  81. if (m_currentIndex >= totalTrainingSize)
  82. {
  83. // If we run out of training samples, we increment our epoch and reset for a new pass of the training data
  84. m_currentIndex = 0;
  85. m_learningRate *= m_learningRateDecay;
  86. // We reshuffle the training data indices each epoch to avoid patterns in the training data
  87. std::shuffle(m_indices.begin(), m_indices.end(), std::mt19937(std::random_device{}()));
  88. ++m_currentEpoch;
  89. // Generally we want to keep monitoring the models performence on both test and training data
  90. // This allows us to detect if we're overfitting the model to the training data
  91. float currentTestCost = ComputeCurrentCost(m_testData, m_costFunction);
  92. if ((currentTestCost < m_earlyStopCost) || (m_currentEpoch >= m_totalIterations))
  93. {
  94. m_trainingComplete = true;
  95. return;
  96. }
  97. }
  98. for (uint32_t batch = 0; (batch < m_batchSize) && (m_currentIndex < totalTrainingSize); ++batch, ++m_currentIndex)
  99. {
  100. const AZ::VectorN& activations = m_trainingData->GetDataByIndex(m_indices[m_currentIndex]);
  101. const AZ::VectorN& label = m_trainingData->GetLabelByIndex(m_indices[m_currentIndex]);
  102. m_model->Reverse(m_trainingContext.get(), m_costFunction, activations, label);
  103. }
  104. AZStd::lock_guard lock(m_mutex);
  105. m_model->GradientDescent(m_trainingContext.get(), m_learningRate);
  106. }
  107. }
  108. void SupervisedLearningCycle::StopTraining()
  109. {
  110. m_trainingComplete = true;
  111. }
  112. float SupervisedLearningCycle::ComputeCurrentCost(ILabeledTrainingDataPtr TestData, LossFunctions CostFunction, AZStd::size_t maxSamples)
  113. {
  114. InitializeContexts();
  115. const AZStd::size_t totalTestSize = TestData->GetSampleCount();
  116. maxSamples = (maxSamples == 0) ? totalTestSize : AZStd::min(maxSamples, totalTestSize);
  117. AZStd::lock_guard lock(m_mutex);
  118. double result = 0.0;
  119. for (uint32_t iter = 0; iter < maxSamples; ++iter)
  120. {
  121. const AZ::VectorN& activations = TestData->GetDataByIndex(iter);
  122. const AZ::VectorN& label = TestData->GetLabelByIndex(iter);
  123. const AZ::VectorN* output = m_model->Forward(m_inferenceContext.get(), activations);
  124. result += static_cast<double>(ComputeTotalCost(CostFunction, label, *output));
  125. }
  126. result /= static_cast<double>(maxSamples);
  127. return static_cast<float>(result);
  128. }
  129. }