MachineLearningDebugTrainingWindow.cpp 15 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Source/Debug/MachineLearningDebugTrainingWindow.h>
  9. #include <Source/Assets/MnistDataLoader.h>
  10. #include <Source/Algorithms/Activations.h>
  11. #include <AzCore/Console/IConsole.h>
  12. #include <AzCore/Console/ILogger.h>
  13. #ifdef IMGUI_ENABLED
  14. # include <ImGuiContextScope.h>
  15. # include <ImGui/ImGuiPass.h>
  16. # include <imgui/imgui.h>
  17. # include <imgui/imgui_internal.h>
  18. #endif
  19. namespace MachineLearning
  20. {
  21. AZ_CVAR(bool, ml_logAccuracyValues, false, nullptr, AZ::ConsoleFunctorFlags::Null, "Dumps the actual and expected labels during accuracy calculations");
  22. #ifdef IMGUI_ENABLED
  23. int32_t AZStdStringResizeCallback(ImGuiInputTextCallbackData* data)
  24. {
  25. if (data->EventFlag == ImGuiInputTextFlags_CallbackResize)
  26. {
  27. AZStd::string* azString = (AZStd::string*)data->UserData;
  28. AZ_Assert(azString->begin() == data->Buf, "Invalid type");
  29. azString->resize(data->BufSize);
  30. data->Buf = azString->begin();
  31. }
  32. return 0;
  33. }
  34. void TextInputHelper(const char* label, AZStd::string& data)
  35. {
  36. ImGui::InputText(label, data.begin(), data.size(), ImGuiInputTextFlags_CallbackResize, AZStdStringResizeCallback, (void*)(&data));
  37. }
  38. MachineLearningDebugTrainingWindow::~MachineLearningDebugTrainingWindow()
  39. {
  40. for (auto iter : m_trainingInstances)
  41. {
  42. delete iter.second;
  43. }
  44. }
  45. TrainingInstance* MachineLearningDebugTrainingWindow::RetrieveTrainingInstance(INeuralNetworkPtr modelPtr)
  46. {
  47. TrainingInstance* trainingInstance = m_trainingInstances[m_selectedModel];
  48. if (trainingInstance == nullptr)
  49. {
  50. m_trainingInstances[m_selectedModel] = new TrainingInstance();
  51. trainingInstance = m_trainingInstances[m_selectedModel];
  52. trainingInstance->m_trainingCycle.m_model = m_selectedModel;
  53. trainingInstance->m_testHistogram.Init("Test Cost", 250, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 0.2f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
  54. trainingInstance->m_testHistogram.SetMoveDirection(ImGui::LYImGuiUtils::HistogramContainer::PushRightMoveLeft);
  55. trainingInstance->m_trainHistogram.Init("Train Cost", 250, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 0.2f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
  56. trainingInstance->m_trainHistogram.SetMoveDirection(ImGui::LYImGuiUtils::HistogramContainer::PushRightMoveLeft);
  57. trainingInstance->m_testDataName = m_selectedModel->GetAssetFile(AssetTypes::TestData);
  58. trainingInstance->m_testLabelName = m_selectedModel->GetAssetFile(AssetTypes::TestLabels);
  59. trainingInstance->m_trainDataName = m_selectedModel->GetAssetFile(AssetTypes::TrainingData);
  60. trainingInstance->m_trainLabelName = m_selectedModel->GetAssetFile(AssetTypes::TrainingLabels);
  61. }
  62. return trainingInstance;
  63. }
  64. void MachineLearningDebugTrainingWindow::LoadTestTrainData(TrainingInstance* trainingInstance)
  65. {
  66. if (!trainingInstance->m_trainingCycle.m_trainData.IsValid())
  67. {
  68. ILabeledTrainingDataPtr dataPtr = AZStd::make_shared<MnistDataLoader>();
  69. trainingInstance->m_trainingCycle.m_trainData.SetSourceData(dataPtr);
  70. trainingInstance->m_trainingCycle.m_trainData.LoadArchive(trainingInstance->m_trainDataName, trainingInstance->m_trainLabelName);
  71. }
  72. if (!trainingInstance->m_trainingCycle.m_testData.IsValid())
  73. {
  74. ILabeledTrainingDataPtr dataPtr = AZStd::make_shared<MnistDataLoader>();
  75. trainingInstance->m_trainingCycle.m_testData.SetSourceData(dataPtr);
  76. trainingInstance->m_trainingCycle.m_testData.LoadArchive(trainingInstance->m_testDataName, trainingInstance->m_testLabelName);
  77. }
  78. }
  79. void LogVectorN(const AZ::VectorN& value, const char* label)
  80. {
  81. AZStd::string vectorString(label);
  82. for (AZStd::size_t iter = 0; iter < value.GetDimensionality(); ++iter)
  83. {
  84. vectorString += AZStd::string::format(" %.02f", value.GetElement(iter));
  85. }
  86. AZLOG_INFO(vectorString.c_str());
  87. }
  88. void MachineLearningDebugTrainingWindow::RecalculateAccuracy(TrainingInstance* trainingInstance, ILabeledTrainingData& data)
  89. {
  90. trainingInstance->m_trainingCycle.InitializeContexts();
  91. trainingInstance->m_totalSamples = static_cast<int32_t>(data.GetSampleCount());
  92. trainingInstance->m_correctPredictions = 0;
  93. trainingInstance->m_incorrectPredictions = 0;
  94. for (int32_t iter = 0; iter < trainingInstance->m_totalSamples; ++iter)
  95. {
  96. const AZ::VectorN& activations = data.GetDataByIndex(iter);
  97. const AZ::VectorN& label = data.GetLabelByIndex(iter);
  98. const AZ::VectorN* output = m_selectedModel->Forward(trainingInstance->m_trainingCycle.m_inferenceContext.get(), activations);
  99. AZStd::size_t prediction = ArgMaxDecode(*output);
  100. AZStd::size_t actual = ArgMaxDecode(label);
  101. if (prediction == actual)
  102. {
  103. ++trainingInstance->m_correctPredictions;
  104. }
  105. else
  106. {
  107. ++trainingInstance->m_incorrectPredictions;
  108. }
  109. if (ml_logAccuracyValues)
  110. {
  111. LogVectorN(label, "Actual");
  112. LogVectorN(*output, "Output");
  113. }
  114. }
  115. }
  116. void DrawDataPanel(TrainingDataView& data, AZStd::string& dataName, AZStd::string& labelName)
  117. {
  118. ImGui::PushID(&data);
  119. int32_t firstElement = static_cast<int32_t>(data.m_first);
  120. int32_t span = static_cast<int32_t>(data.m_last) - firstElement;
  121. TextInputHelper("Asset file", dataName);
  122. TextInputHelper("Label file", labelName);
  123. ImGui::SliderInt("First", &firstElement, 0, static_cast<int32_t>(data.GetOriginalSize()));
  124. ImGui::SameLine();
  125. ImGui::SliderInt("Count", &span, 0, static_cast<int32_t>(data.GetOriginalSize() - firstElement));
  126. data.m_first = firstElement;
  127. data.m_last = firstElement + span;
  128. ImGui::PopID();
  129. }
  130. void MachineLearningDebugTrainingWindow::OnImGuiUpdate()
  131. {
  132. const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x;
  133. const ImGuiTableFlags flags = ImGuiTableFlags_BordersV
  134. | ImGuiTableFlags_BordersOuterH
  135. | ImGuiTableFlags_Resizable
  136. | ImGuiTableFlags_RowBg
  137. | ImGuiTableFlags_NoBordersInBody;
  138. const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
  139. IMachineLearning* machineLearning = MachineLearningInterface::Get();
  140. const ModelSet& modelSet = machineLearning->GetModelSet();
  141. ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(0.0f, 0.0f));
  142. ImGui::BeginChild("LeftPanel", ImVec2(m_trainingSplitWidth, -1.0f), true);
  143. if (ImGui::BeginTable("Models", 1, flags))
  144. {
  145. ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthStretch);
  146. ImGui::TableHeadersRow();
  147. AZStd::size_t index = 0;
  148. for (auto& neuralNetwork : modelSet)
  149. {
  150. ImGui::TableNextRow();
  151. ImGui::TableNextColumn();
  152. const bool isSelected = (m_selectedModelIndex == index);
  153. if (ImGui::Selectable(neuralNetwork->GetName().c_str(), isSelected))
  154. {
  155. m_selectedModel = neuralNetwork;
  156. m_selectedModelIndex = index;
  157. }
  158. ++index;
  159. }
  160. ImGui::EndTable();
  161. ImGui::NewLine();
  162. }
  163. ImGui::EndChild();
  164. ImGui::SameLine();
  165. ImGui::InvisibleButton("vsplitter", ImVec2(8.0f, -1.0f));
  166. if (ImGui::IsItemActive())
  167. {
  168. m_trainingSplitWidth += ImGui::GetIO().MouseDelta.x;
  169. }
  170. ImGui::SameLine();
  171. ImGui::BeginChild("RightPanel", ImVec2(0.0f, -1.0f), true);
  172. if (m_selectedModel != nullptr)
  173. {
  174. ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(8.0f, 4.0f));
  175. TrainingInstance* trainingInstance = RetrieveTrainingInstance(m_selectedModel);
  176. {
  177. AZStd::deque<float> costs;
  178. trainingInstance->m_trainingCycle.m_testCosts.Swap(costs);
  179. for (float item : costs)
  180. {
  181. trainingInstance->m_testHistogram.PushValue(item);
  182. }
  183. }
  184. {
  185. AZStd::deque<float> costs;
  186. trainingInstance->m_trainingCycle.m_trainCosts.Swap(costs);
  187. for (float item : costs)
  188. {
  189. trainingInstance->m_trainHistogram.PushValue(item);
  190. }
  191. }
  192. int32_t batchSize = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_batchSize);
  193. int32_t totalIterations = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_totalIterations);
  194. if (!trainingInstance->m_trainingCycle.m_trainingComplete)
  195. {
  196. if (ImGui::Button("Stop training"))
  197. {
  198. trainingInstance->m_trainingCycle.StopTraining();
  199. }
  200. ImGui::SameLine();
  201. int32_t epoch = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_currentEpoch);
  202. ImGui::Text("Epoch: %d", epoch);
  203. }
  204. else
  205. {
  206. if (ImGui::Button("Start training"))
  207. {
  208. LoadTestTrainData(trainingInstance);
  209. trainingInstance->m_testHistogram.Init("Test Cost", totalIterations, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 0.2f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
  210. trainingInstance->m_trainHistogram.Init("Train Cost", totalIterations, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 0.2f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
  211. trainingInstance->m_trainingCycle.StartTraining();
  212. }
  213. ImGui::SameLine();
  214. if (ImGui::Button("Save"))
  215. {
  216. m_selectedModel->SaveModel();
  217. }
  218. ImGui::SameLine();
  219. if (ImGui::Button("Load"))
  220. {
  221. m_selectedModel->LoadModel();
  222. }
  223. ImGui::SameLine();
  224. if (ImGui::Button("Recalculate accuracy on test data"))
  225. {
  226. LoadTestTrainData(trainingInstance);
  227. RecalculateAccuracy(trainingInstance, trainingInstance->m_trainingCycle.m_testData);
  228. }
  229. ImGui::SameLine();
  230. if (ImGui::Button("Recalculate accuracy on training data"))
  231. {
  232. LoadTestTrainData(trainingInstance);
  233. RecalculateAccuracy(trainingInstance, trainingInstance->m_trainingCycle.m_trainData);
  234. }
  235. }
  236. ImGui::NewLine();
  237. ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
  238. ImGui::Text("Asset location: %s", m_selectedModel->GetAssetFile(AssetTypes::Model).c_str());
  239. if (ImGui::BeginTable("Accuracy", 2, flags))
  240. {
  241. ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f);
  242. ImGui::TableSetupColumn("Value", ImGuiTableColumnFlags_WidthStretch);
  243. ImGui::TableHeadersRow();
  244. ImGui::TableNextRow();
  245. ImGui::TableNextColumn();
  246. ImGui::Text("Total samples");
  247. ImGui::TableNextColumn();
  248. ImGui::Text("%d", trainingInstance->m_totalSamples);
  249. ImGui::TableNextRow();
  250. ImGui::TableNextColumn();
  251. ImGui::Text("Correct predictions");
  252. ImGui::TableNextColumn();
  253. ImGui::Text("%d", trainingInstance->m_correctPredictions);
  254. ImGui::TableNextRow();
  255. ImGui::TableNextColumn();
  256. ImGui::Text("Incorrect predictions");
  257. ImGui::TableNextColumn();
  258. ImGui::Text("%d", trainingInstance->m_incorrectPredictions);
  259. ImGui::TableNextRow();
  260. ImGui::TableNextColumn();
  261. ImGui::Text("Accuracy");
  262. ImGui::TableNextColumn();
  263. const float accuracy = (static_cast<float>(trainingInstance->m_correctPredictions) * 100.0f) / static_cast<float>(trainingInstance->m_totalSamples);
  264. ImGui::Text("%f", accuracy);
  265. ImGui::TableNextRow();
  266. ImGui::TableNextColumn();
  267. ImGui::Text("Test score");
  268. ImGui::TableNextColumn();
  269. ImGui::Text("%f", trainingInstance->m_testHistogram.GetLastValue());
  270. ImGui::TableNextRow();
  271. ImGui::TableNextColumn();
  272. ImGui::Text("Train score");
  273. ImGui::TableNextColumn();
  274. ImGui::Text("%f", trainingInstance->m_trainHistogram.GetLastValue());
  275. ImGui::EndTable();
  276. ImGui::NewLine();
  277. }
  278. trainingInstance->m_testHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
  279. trainingInstance->m_trainHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
  280. ImGui::NewLine();
  281. ImGui::SliderInt("Batch size", &batchSize, 1, 1000);
  282. trainingInstance->m_trainingCycle.m_batchSize = batchSize;
  283. ImGui::SliderInt("Number of iterations", &totalIterations, 1, 1000);
  284. trainingInstance->m_trainingCycle.m_totalIterations = totalIterations;
  285. int32_t costMetric = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_costFunction);
  286. ImGui::Combo("Cost metric", &costMetric, "MeanSquaredError\0");
  287. trainingInstance->m_trainingCycle.m_costFunction = static_cast<LossFunctions>(costMetric);
  288. ImGui::NewLine();
  289. ImGui::SliderFloat("LearningRate", &trainingInstance->m_trainingCycle.m_learningRate, 0.0f, 0.1f);
  290. ImGui::SliderFloat("LearningRateDecay", &trainingInstance->m_trainingCycle.m_learningRateDecay, 0.0f, 1.0f);
  291. ImGui::SliderFloat("EarlyStop", &trainingInstance->m_trainingCycle.m_earlyStopCost, 0.0f, 1.0f);
  292. ImGui::NewLine();
  293. ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.4f);
  294. ImGui::Checkbox("Shuffle data", &trainingInstance->m_trainingCycle.m_shuffleTrainingData);
  295. if (ImGui::CollapsingHeader("Test data", ImGuiTreeNodeFlags_Framed))
  296. {
  297. DrawDataPanel(trainingInstance->m_trainingCycle.m_testData, trainingInstance->m_testDataName, trainingInstance->m_testLabelName);
  298. }
  299. ImGui::NewLine();
  300. if (ImGui::CollapsingHeader("Train data", ImGuiTreeNodeFlags_Framed))
  301. {
  302. DrawDataPanel(trainingInstance->m_trainingCycle.m_trainData, trainingInstance->m_trainDataName, trainingInstance->m_trainLabelName);
  303. }
  304. ImGui::PopItemWidth();
  305. ImGui::PopStyleVar();
  306. }
  307. ImGui::EndChild();
  308. ImGui::PopStyleVar();
  309. }
  310. #endif
  311. }