MachineLearningDebugTrainingWindow.h 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Component/Component.h>
  10. #include <AzCore/Interface/Interface.h>
  11. #include <AzCore/std/containers/map.h>
  12. #include <MachineLearning/IMachineLearning.h>
  13. #include <Algorithms/Training.h>
  14. #ifdef IMGUI_ENABLED
  15. # include <imgui/imgui.h>
  16. # include <ImGuiBus.h>
  17. # include <LYImGuiUtils/HistogramContainer.h>
  18. #endif
  19. namespace MachineLearning
  20. {
  21. struct TrainingInstance
  22. {
  23. SupervisedLearningCycle m_trainingCycle;
  24. AZStd::string m_testDataName;
  25. AZStd::string m_testLabelName;
  26. AZStd::string m_trainDataName;
  27. AZStd::string m_trainLabelName;
  28. int32_t m_totalSamples = 0;
  29. int32_t m_correctPredictions = 0;
  30. int32_t m_incorrectPredictions = 0;
  31. #ifdef IMGUI_ENABLED
  32. ImGui::LYImGuiUtils::HistogramContainer m_testHistogram;
  33. ImGui::LYImGuiUtils::HistogramContainer m_trainHistogram;
  34. #endif
  35. };
  36. class MachineLearningDebugTrainingWindow
  37. {
  38. public:
  39. ~MachineLearningDebugTrainingWindow();
  40. TrainingInstance* RetrieveTrainingInstance(INeuralNetworkPtr modelPtr);
  41. void LoadTestTrainData(TrainingInstance* trainingInstance);
  42. void RecalculateAccuracy(TrainingInstance* trainingInstance, ILabeledTrainingData& data);
  43. #ifdef IMGUI_ENABLED
  44. void OnImGuiUpdate();
  45. #endif
  46. AZStd::size_t m_selectedModelIndex = 0;
  47. INeuralNetworkPtr m_selectedModel = nullptr;
  48. float m_trainingSplitWidth = 400.0f;
  49. AZStd::map<INeuralNetworkPtr, TrainingInstance*> m_trainingInstances;
  50. };
  51. }