MachineLearningDebugSystemComponent.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Source/Debug/MachineLearningDebugSystemComponent.h>
  9. #include <AzCore/Component/ComponentApplicationBus.h>
  10. #include <AzCore/Interface/Interface.h>
  11. #include <Atom/Feature/ImGui/SystemBus.h>
  12. #include <ImGuiContextScope.h>
  13. #include <ImGui/ImGuiPass.h>
  14. #include <imgui/imgui.h>
  15. #include <imgui/imgui_internal.h>
  16. namespace MachineLearning
  17. {
  18. void MachineLearningDebugSystemComponent::Reflect(AZ::ReflectContext* context)
  19. {
  20. if (AZ::SerializeContext* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  21. {
  22. serializeContext->Class<MachineLearningDebugSystemComponent, AZ::Component>()
  23. ->Version(1);
  24. }
  25. }
  26. void MachineLearningDebugSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  27. {
  28. provided.push_back(AZ_CRC_CE("MachineLearningDebugSystemComponent"));
  29. }
  30. void MachineLearningDebugSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required)
  31. {
  32. ;
  33. }
  34. void MachineLearningDebugSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatbile)
  35. {
  36. incompatbile.push_back(AZ_CRC_CE("MachineLearningDebugSystemComponent"));
  37. }
  38. void MachineLearningDebugSystemComponent::Activate()
  39. {
  40. #ifdef IMGUI_ENABLED
  41. ImGui::ImGuiUpdateListenerBus::Handler::BusConnect();
  42. #endif
  43. }
  44. void MachineLearningDebugSystemComponent::Deactivate()
  45. {
  46. #ifdef IMGUI_ENABLED
  47. ImGui::ImGuiUpdateListenerBus::Handler::BusDisconnect();
  48. #endif
  49. }
  50. #ifdef IMGUI_ENABLED
  51. void MachineLearningDebugSystemComponent::OnModelRegistryDisplay()
  52. {
  53. const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x;
  54. const ImGuiTableFlags flags = ImGuiTableFlags_BordersV
  55. | ImGuiTableFlags_BordersOuterH
  56. | ImGuiTableFlags_Resizable
  57. | ImGuiTableFlags_RowBg
  58. | ImGuiTableFlags_NoBordersInBody;
  59. IMachineLearning* machineLearning = MachineLearningInterface::Get();
  60. const ModelSet& modelSet = machineLearning->GetModelSet();
  61. ImGui::Text("Total registered models: %u", static_cast<uint32_t>(modelSet.size()));
  62. ImGui::NewLine();
  63. if (ImGui::BeginTable("Model Details", 5, flags))
  64. {
  65. ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f);
  66. ImGui::TableSetupColumn("Input Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
  67. ImGui::TableSetupColumn("Output Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
  68. ImGui::TableSetupColumn("Layers", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
  69. ImGui::TableSetupColumn("Parameters", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
  70. ImGui::TableHeadersRow();
  71. AZStd::size_t index = 0;
  72. for (auto& neuralNetwork : modelSet)
  73. {
  74. ImGui::TableNextRow();
  75. ImGui::TableNextColumn();
  76. ImGui::Text(neuralNetwork->GetName().c_str());
  77. ImGui::TableNextColumn();
  78. ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetInputDimensionality()));
  79. ImGui::TableNextColumn();
  80. ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetOutputDimensionality()));
  81. ImGui::TableNextColumn();
  82. ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetLayerCount()));
  83. ImGui::TableNextColumn();
  84. ImGui::Text("%llu", neuralNetwork->GetParameterCount());
  85. ++index;
  86. }
  87. ImGui::EndTable();
  88. ImGui::NewLine();
  89. }
  90. ImGui::End();
  91. }
  92. void MachineLearningDebugSystemComponent::OnModelTrainingDisplay()
  93. {
  94. m_trainingWindow.OnImGuiUpdate();
  95. }
  96. void MachineLearningDebugSystemComponent::OnImGuiMainMenuUpdate()
  97. {
  98. if (ImGui::BeginMenu("MachineLearning"))
  99. {
  100. ImGui::Checkbox("Model Registry", &m_displayModelRegistry);
  101. ImGui::Checkbox("Model Training", &m_displayTrainingWindow);
  102. ImGui::EndMenu();
  103. }
  104. }
  105. void MachineLearningDebugSystemComponent::OnImGuiUpdate()
  106. {
  107. if (m_displayModelRegistry)
  108. {
  109. if (ImGui::Begin("Model Registry", &m_displayModelRegistry, ImGuiWindowFlags_None))
  110. {
  111. OnModelRegistryDisplay();
  112. }
  113. ImGui::End();
  114. }
  115. if (m_displayTrainingWindow)
  116. {
  117. if (ImGui::Begin("Model Training", &m_displayTrainingWindow, ImGuiWindowFlags_None))
  118. {
  119. OnModelTrainingDisplay();
  120. }
  121. ImGui::End();
  122. }
  123. }
  124. #endif
  125. }