MachineLearningDebugTrainingWindow.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Source/Debug/MachineLearningDebugTrainingWindow.h>
  9. #include <Source/Assets/MnistDataLoader.h>
  10. #include <Source/Algorithms/Activations.h>
  11. #include <ImGuiContextScope.h>
  12. #include <ImGui/ImGuiPass.h>
  13. #include <imgui/imgui.h>
  14. #include <imgui/imgui_internal.h>
  15. namespace MachineLearning
  16. {
  17. #ifdef IMGUI_ENABLED
  18. int32_t AZStdStringResizeCallback(ImGuiInputTextCallbackData* data)
  19. {
  20. if (data->EventFlag == ImGuiInputTextFlags_CallbackResize)
  21. {
  22. AZStd::string* azString = (AZStd::string*)data->UserData;
  23. AZ_Assert(azString->begin() == data->Buf, "Invalid type");
  24. azString->resize(data->BufSize);
  25. data->Buf = azString->begin();
  26. }
  27. return 0;
  28. }
  29. void TextInputHelper(const char* label, AZStd::string& data)
  30. {
  31. ImGui::InputText(label, data.begin(), data.size(), ImGuiInputTextFlags_CallbackResize, AZStdStringResizeCallback, (void*)(&data));
  32. }
  33. MachineLearningDebugTrainingWindow::~MachineLearningDebugTrainingWindow()
  34. {
  35. for (auto iter : m_trainingInstances)
  36. {
  37. delete iter.second;
  38. }
  39. }
  40. TrainingInstance* MachineLearningDebugTrainingWindow::RetrieveTrainingInstance(INeuralNetworkPtr modelPtr)
  41. {
  42. TrainingInstance* trainingInstance = m_trainingInstances[m_selectedModel];
  43. if (trainingInstance == nullptr)
  44. {
  45. m_trainingInstances[m_selectedModel] = new TrainingInstance();
  46. trainingInstance = m_trainingInstances[m_selectedModel];
  47. trainingInstance->m_trainingCycle.m_model = m_selectedModel;
  48. trainingInstance->m_testHistogram.Init("Test Cost", 250, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 0.2f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
  49. trainingInstance->m_testHistogram.SetMoveDirection(ImGui::LYImGuiUtils::HistogramContainer::PushRightMoveLeft);
  50. trainingInstance->m_trainHistogram.Init("Train Cost", 250, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 0.2f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
  51. trainingInstance->m_trainHistogram.SetMoveDirection(ImGui::LYImGuiUtils::HistogramContainer::PushRightMoveLeft);
  52. trainingInstance->m_testDataName = m_selectedModel->GetAssetFile(AssetTypes::TestData);
  53. trainingInstance->m_testLabelName = m_selectedModel->GetAssetFile(AssetTypes::TestLabels);
  54. trainingInstance->m_trainDataName = m_selectedModel->GetAssetFile(AssetTypes::TrainingData);
  55. trainingInstance->m_trainLabelName = m_selectedModel->GetAssetFile(AssetTypes::TrainingLabels);
  56. }
  57. return trainingInstance;
  58. }
  59. void MachineLearningDebugTrainingWindow::LoadTestTrainData(TrainingInstance* trainingInstance)
  60. {
  61. if (trainingInstance->m_trainingCycle.m_trainingData == nullptr)
  62. {
  63. trainingInstance->m_trainingCycle.m_trainingData = AZStd::make_shared<MnistDataLoader>();
  64. trainingInstance->m_trainingCycle.m_trainingData->LoadArchive(trainingInstance->m_trainDataName, trainingInstance->m_trainLabelName);
  65. }
  66. if (trainingInstance->m_trainingCycle.m_testData == nullptr)
  67. {
  68. trainingInstance->m_trainingCycle.m_testData = AZStd::make_shared<MnistDataLoader>();
  69. trainingInstance->m_trainingCycle.m_testData->LoadArchive(trainingInstance->m_testDataName, trainingInstance->m_testLabelName);
  70. }
  71. }
  72. void MachineLearningDebugTrainingWindow::RecalculateAccuracy(TrainingInstance* trainingInstance, ILabeledTrainingDataPtr data)
  73. {
  74. trainingInstance->m_trainingCycle.InitializeContexts();
  75. trainingInstance->m_totalSamples = static_cast<int32_t>(data->GetSampleCount());
  76. trainingInstance->m_correctPredictions = 0;
  77. trainingInstance->m_incorrectPredictions = 0;
  78. for (int32_t iter = 0; iter < trainingInstance->m_totalSamples; ++iter)
  79. {
  80. const AZ::VectorN& activations = data->GetDataByIndex(iter);
  81. const AZStd::size_t label = data->GetLabelAsValueByIndex(iter);
  82. const AZ::VectorN* output = m_selectedModel->Forward(trainingInstance->m_trainingCycle.m_inferenceContext.get(), activations);
  83. AZStd::size_t prediction = ArgMaxDecode(*output);
  84. if (label == prediction)
  85. {
  86. ++trainingInstance->m_correctPredictions;
  87. }
  88. else
  89. {
  90. ++trainingInstance->m_incorrectPredictions;
  91. }
  92. }
  93. }
  94. void MachineLearningDebugTrainingWindow::OnImGuiUpdate()
  95. {
  96. const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x;
  97. const ImGuiTableFlags flags = ImGuiTableFlags_BordersV
  98. | ImGuiTableFlags_BordersOuterH
  99. | ImGuiTableFlags_Resizable
  100. | ImGuiTableFlags_RowBg
  101. | ImGuiTableFlags_NoBordersInBody;
  102. const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
  103. IMachineLearning* machineLearning = MachineLearningInterface::Get();
  104. const ModelSet& modelSet = machineLearning->GetModelSet();
  105. ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(0.0f, 0.0f));
  106. ImGui::BeginChild("LeftPanel", ImVec2(m_trainingSplitWidth, -1.0f), true);
  107. if (ImGui::BeginTable("Models", 1, flags))
  108. {
  109. ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthStretch);
  110. ImGui::TableHeadersRow();
  111. AZStd::size_t index = 0;
  112. for (auto& neuralNetwork : modelSet)
  113. {
  114. ImGui::TableNextRow();
  115. ImGui::TableNextColumn();
  116. const bool isSelected = (m_selectedModelIndex == index);
  117. if (ImGui::Selectable(neuralNetwork->GetName().c_str(), isSelected))
  118. {
  119. m_selectedModel = neuralNetwork;
  120. m_selectedModelIndex = index;
  121. }
  122. ++index;
  123. }
  124. ImGui::EndTable();
  125. ImGui::NewLine();
  126. }
  127. ImGui::EndChild();
  128. ImGui::SameLine();
  129. ImGui::InvisibleButton("vsplitter", ImVec2(8.0f, -1.0f));
  130. if (ImGui::IsItemActive())
  131. {
  132. m_trainingSplitWidth += ImGui::GetIO().MouseDelta.x;
  133. }
  134. ImGui::SameLine();
  135. ImGui::BeginChild("RightPanel", ImVec2(0.0f, -1.0f), true);
  136. if (m_selectedModel != nullptr)
  137. {
  138. ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(8.0f, 4.0f));
  139. TrainingInstance* trainingInstance = RetrieveTrainingInstance(m_selectedModel);
  140. float currentTestCost = 0.0f;
  141. float currentTrainCost = 0.0f;
  142. if (trainingInstance->m_trainingCycle.m_testData != nullptr && trainingInstance->m_trainingCycle.m_trainingData != nullptr)
  143. {
  144. currentTestCost = trainingInstance->m_trainingCycle.ComputeCurrentCost(trainingInstance->m_trainingCycle.m_testData, trainingInstance->m_trainingCycle.m_costFunction, m_costSampleSize);
  145. currentTrainCost = trainingInstance->m_trainingCycle.ComputeCurrentCost(trainingInstance->m_trainingCycle.m_trainingData, trainingInstance->m_trainingCycle.m_costFunction, m_costSampleSize);
  146. }
  147. if (!trainingInstance->m_trainingCycle.m_trainingComplete)
  148. {
  149. trainingInstance->m_testHistogram.PushValue(currentTestCost);
  150. trainingInstance->m_trainHistogram.PushValue(currentTrainCost);
  151. if (ImGui::Button("Stop training"))
  152. {
  153. trainingInstance->m_trainingCycle.StopTraining();
  154. }
  155. ImGui::SameLine();
  156. int32_t epoch = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_currentEpoch);
  157. ImGui::Text("Epoch: %d", epoch);
  158. }
  159. else
  160. {
  161. if (ImGui::Button("Start training"))
  162. {
  163. LoadTestTrainData(trainingInstance);
  164. trainingInstance->m_trainingCycle.StartTraining();
  165. }
  166. ImGui::SameLine();
  167. if (ImGui::Button("Save"))
  168. {
  169. m_selectedModel->SaveModel();
  170. }
  171. ImGui::SameLine();
  172. if (ImGui::Button("Load"))
  173. {
  174. m_selectedModel->LoadModel();
  175. }
  176. ImGui::SameLine();
  177. if (ImGui::Button("Recalculate accuracy on test data"))
  178. {
  179. LoadTestTrainData(trainingInstance);
  180. RecalculateAccuracy(trainingInstance, trainingInstance->m_trainingCycle.m_testData);
  181. }
  182. ImGui::SameLine();
  183. if (ImGui::Button("Recalculate accuracy on training data"))
  184. {
  185. LoadTestTrainData(trainingInstance);
  186. RecalculateAccuracy(trainingInstance, trainingInstance->m_trainingCycle.m_trainingData);
  187. }
  188. }
  189. ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
  190. ImGui::NewLine();
  191. ImGui::Text("Asset location: %s", m_selectedModel->GetAssetFile(AssetTypes::Model).c_str());
  192. ImGui::NewLine();
  193. ImGui::Text("Total samples: %d", trainingInstance->m_totalSamples);
  194. ImGui::Text("Correct predictions: %d", trainingInstance->m_correctPredictions);
  195. ImGui::Text("Incorrect predictions: %d", trainingInstance->m_incorrectPredictions);
  196. const float accuracy = (static_cast<float>(trainingInstance->m_correctPredictions) * 100.0f) / static_cast<float>(trainingInstance->m_totalSamples);
  197. ImGui::Text("Accuracy: %f", accuracy);
  198. ImGui::Text("Test score: %f", currentTestCost);
  199. trainingInstance->m_testHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
  200. ImGui::Text("Train score: %f", currentTrainCost);
  201. trainingInstance->m_trainHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
  202. ImGui::SliderInt("Cost evaluation sample size", &m_costSampleSize, 10, 10000);
  203. ImGui::NewLine();
  204. TextInputHelper("Test data asset file", trainingInstance->m_testDataName);
  205. TextInputHelper("Test data label file", trainingInstance->m_testLabelName);
  206. ImGui::NewLine();
  207. TextInputHelper("Train data asset file", trainingInstance->m_trainDataName);
  208. TextInputHelper("Train data label file", trainingInstance->m_trainLabelName);
  209. ImGui::NewLine();
  210. //AZStd::fixed_string<64> valueString;
  211. //valueString = AZStd::fixed_string<64>::format("%0.3f", trainingInstance->m_trainingCycle.m_learningRate);
  212. //float logValue = log(trainingInstance->m_trainingCycle.m_learningRate);
  213. //ImGui::SliderFloat("LearningRate", &logValue, log(0.0001f), log(1.0f), valueString.c_str());
  214. //trainingInstance->m_trainingCycle.m_learningRate = exp(logValue);
  215. ImGui::SliderFloat("LearningRate", &trainingInstance->m_trainingCycle.m_learningRate, 0.0f, 0.1f);
  216. ImGui::SliderFloat("LearningRateDecay", &trainingInstance->m_trainingCycle.m_learningRateDecay, 0.0f, 1.0f);
  217. ImGui::NewLine();
  218. int32_t batchSize = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_batchSize);
  219. ImGui::SliderInt("Batch size", &batchSize, 1, 1000);
  220. trainingInstance->m_trainingCycle.m_batchSize = batchSize;
  221. int32_t totalIterations = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_totalIterations);
  222. ImGui::SliderInt("Number of iterations", &totalIterations, 1, 1000);
  223. trainingInstance->m_trainingCycle.m_totalIterations = totalIterations;
  224. int32_t costMetric = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_costFunction);
  225. ImGui::Combo("Cost metric", &costMetric, "MeanSquaredError\0");
  226. trainingInstance->m_trainingCycle.m_costFunction = static_cast<LossFunctions>(costMetric);
  227. ImGui::PopStyleVar();
  228. }
  229. ImGui::EndChild();
  230. ImGui::PopStyleVar();
  231. }
  232. #endif
  233. }