Training.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Algorithms/Training.h>
  9. #include <Algorithms/LossFunctions.h>
  10. #include <AzCore/Math/SimdMath.h>
  11. #include <AzCore/Console/ILogger.h>
  12. #include <AzCore/Jobs/JobCompletion.h>
  13. #include <AzCore/Jobs/JobFunction.h>
  14. #include <numeric>
  15. #include <random>
  16. namespace MachineLearning
  17. {
  18. SupervisedLearningCycle::SupervisedLearningCycle()
  19. {
  20. AZ::JobManagerDesc jobDesc;
  21. jobDesc.m_jobManagerName = "MachineLearning Training";
  22. jobDesc.m_workerThreads.push_back(AZ::JobManagerThreadDesc()); // Just one thread
  23. m_trainingJobManager = AZStd::make_unique<AZ::JobManager>(jobDesc);
  24. m_trainingjobContext = AZStd::make_unique<AZ::JobContext>(*m_trainingJobManager);
  25. }
  26. SupervisedLearningCycle::SupervisedLearningCycle
  27. (
  28. INeuralNetworkPtr model,
  29. ILabeledTrainingDataPtr trainingData,
  30. ILabeledTrainingDataPtr testData,
  31. LossFunctions costFunction,
  32. AZStd::size_t totalIterations,
  33. AZStd::size_t batchSize,
  34. float learningRate,
  35. float learningRateDecay,
  36. float earlyStopCost
  37. ) : SupervisedLearningCycle()
  38. {
  39. m_model = model;
  40. m_trainData = trainingData;
  41. m_testData = testData;
  42. m_costFunction = costFunction;
  43. m_totalIterations = totalIterations;
  44. m_batchSize = batchSize;
  45. m_learningRate = learningRate;
  46. m_learningRateDecay = learningRateDecay;
  47. m_earlyStopCost = earlyStopCost;
  48. }
  49. void SupervisedLearningCycle::InitializeContexts()
  50. {
  51. if (m_inferenceContext == nullptr)
  52. {
  53. m_inferenceContext.reset(m_model->CreateInferenceContext());
  54. m_trainingContext.reset(m_model->CreateTrainingContext());
  55. }
  56. }
  57. void SupervisedLearningCycle::StartTraining()
  58. {
  59. InitializeContexts();
  60. // Start training
  61. m_currentEpoch = 0;
  62. m_trainingComplete = false;
  63. m_currentIndex = 0;
  64. if (m_shuffleTrainingData)
  65. {
  66. m_trainData.ShuffleSamples();
  67. }
  68. auto job = [this]()
  69. {
  70. ExecTraining();
  71. };
  72. AZ::Job* trainingJob = AZ::CreateJobFunction(job, true, m_trainingjobContext.get());
  73. trainingJob->Start();
  74. }
  75. void SupervisedLearningCycle::StopTraining()
  76. {
  77. m_trainingComplete = true;
  78. }
  79. void SupervisedLearningCycle::ExecTraining()
  80. {
  81. const AZStd::size_t totalTrainingSize = m_trainData.GetSampleCount();
  82. while (!m_trainingComplete)
  83. {
  84. if (m_currentIndex >= totalTrainingSize)
  85. {
  86. // If we run out of training samples, we increment our epoch and reset for a new pass of the training data
  87. m_currentIndex = 0;
  88. m_learningRate *= m_learningRateDecay;
  89. // We reshuffle the training data indices each epoch to avoid patterns in the training data
  90. if (m_shuffleTrainingData)
  91. {
  92. AZStd::lock_guard lock(m_mutex);
  93. m_trainData.ShuffleSamples();
  94. }
  95. ++m_currentEpoch;
  96. // Generally we want to keep monitoring the model's performance on both test and training data
  97. // This allows us to detect if we're overfitting the model to the training data
  98. float currentTestCost = ComputeCurrentCost(m_testData, m_costFunction);
  99. float currentTrainCost = ComputeCurrentCost(m_trainData, m_costFunction);
  100. m_testCosts.PushBackItem(currentTestCost);
  101. m_trainCosts.PushBackItem(currentTrainCost);
  102. if ((currentTestCost < m_earlyStopCost) || (m_currentEpoch >= m_totalIterations))
  103. {
  104. m_trainingComplete = true;
  105. return;
  106. }
  107. }
  108. for (uint32_t batch = 0; (batch < m_batchSize) && (m_currentIndex < totalTrainingSize); ++batch, ++m_currentIndex)
  109. {
  110. const AZ::VectorN& activations = m_trainData.GetDataByIndex(m_currentIndex);
  111. const AZ::VectorN& label = m_trainData.GetLabelByIndex(m_currentIndex);
  112. m_model->Reverse(m_trainingContext.get(), m_costFunction, activations, label);
  113. }
  114. AZStd::lock_guard lock(m_mutex);
  115. m_model->GradientDescent(m_trainingContext.get(), m_learningRate);
  116. }
  117. }
  118. float SupervisedLearningCycle::ComputeCurrentCost(ILabeledTrainingData& testData, LossFunctions costFunction)
  119. {
  120. InitializeContexts();
  121. double result = 0.0;
  122. const AZStd::size_t totalTestSize = testData.GetSampleCount();
  123. for (uint32_t iter = 0; iter < totalTestSize; ++iter)
  124. {
  125. const AZ::VectorN& activations = testData.GetDataByIndex(iter);
  126. const AZ::VectorN& label = testData.GetLabelByIndex(iter);
  127. const AZ::VectorN* output = m_model->Forward(m_inferenceContext.get(), activations);
  128. result += static_cast<double>(ComputeTotalCost(costFunction, label, *output));
  129. }
  130. result /= static_cast<double>(totalTestSize);
  131. return static_cast<float>(result);
  132. }
  133. }