MachineLearningDebugSystemComponent.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Source/Debug/MachineLearningDebugSystemComponent.h>
  9. #include <AzCore/Component/ComponentApplicationBus.h>
  10. #include <AzCore/Interface/Interface.h>
  11. #include <Atom/Feature/ImGui/SystemBus.h>
  12. #include <ImGuiContextScope.h>
  13. #include <ImGui/ImGuiPass.h>
  14. #include <imgui/imgui.h>
  15. #include <imgui/imgui_internal.h>
  16. namespace MachineLearning
  17. {
  18. void MachineLearningDebugSystemComponent::Reflect(AZ::ReflectContext* context)
  19. {
  20. if (AZ::SerializeContext* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  21. {
  22. serializeContext->Class<MachineLearningDebugSystemComponent, AZ::Component>()
  23. ->Version(1);
  24. }
  25. }
  26. void MachineLearningDebugSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  27. {
  28. provided.push_back(AZ_CRC_CE("MachineLearningDebugSystemComponent"));
  29. }
  30. void MachineLearningDebugSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required)
  31. {
  32. ;
  33. }
  34. void MachineLearningDebugSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatbile)
  35. {
  36. incompatbile.push_back(AZ_CRC_CE("MachineLearningDebugSystemComponent"));
  37. }
  38. void MachineLearningDebugSystemComponent::Activate()
  39. {
  40. #ifdef IMGUI_ENABLED
  41. ImGui::ImGuiUpdateListenerBus::Handler::BusConnect();
  42. #endif
  43. }
  44. void MachineLearningDebugSystemComponent::Deactivate()
  45. {
  46. #ifdef IMGUI_ENABLED
  47. ImGui::ImGuiUpdateListenerBus::Handler::BusDisconnect();
  48. #endif
  49. }
  50. #ifdef IMGUI_ENABLED
  51. void MachineLearningDebugSystemComponent::OnModelRegistryDisplay()
  52. {
  53. const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x;
  54. const ImGuiTableFlags flags = ImGuiTableFlags_BordersV
  55. | ImGuiTableFlags_BordersOuterH
  56. | ImGuiTableFlags_Resizable
  57. | ImGuiTableFlags_RowBg
  58. | ImGuiTableFlags_NoBordersInBody;
  59. const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
  60. IMachineLearning* machineLearning = MachineLearningInterface::Get();
  61. const ModelSet& modelSet = machineLearning->GetModelSet();
  62. ImGui::Text("Total registered models: %u", static_cast<uint32_t>(modelSet.size()));
  63. ImGui::NewLine();
  64. if (ImGui::BeginTable("Model Details", 6, flags))
  65. {
  66. ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f);
  67. ImGui::TableSetupColumn("File", ImGuiTableColumnFlags_WidthStretch);
  68. ImGui::TableSetupColumn("Input Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
  69. ImGui::TableSetupColumn("Output Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
  70. ImGui::TableSetupColumn("Layers", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
  71. ImGui::TableSetupColumn("Parameters", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
  72. ImGui::TableHeadersRow();
  73. AZStd::size_t index = 0;
  74. for (auto& neuralNetwork : modelSet)
  75. {
  76. ImGui::TableNextRow();
  77. ImGui::TableNextColumn();
  78. ImGui::Text(neuralNetwork->GetName().c_str());
  79. ImGui::TableNextColumn();
  80. ImGui::Text(neuralNetwork->GetAssetFile(AssetTypes::Model).c_str());
  81. ImGui::TableNextColumn();
  82. ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetInputDimensionality()));
  83. ImGui::TableNextColumn();
  84. ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetOutputDimensionality()));
  85. ImGui::TableNextColumn();
  86. ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetLayerCount()));
  87. ImGui::TableNextColumn();
  88. ImGui::Text("%llu", neuralNetwork->GetParameterCount());
  89. ++index;
  90. }
  91. ImGui::EndTable();
  92. ImGui::NewLine();
  93. }
  94. ImGui::End();
  95. }
  96. void MachineLearningDebugSystemComponent::OnModelTrainingDisplay()
  97. {
  98. m_trainingWindow.OnImGuiUpdate();
  99. }
  100. void MachineLearningDebugSystemComponent::OnImGuiMainMenuUpdate()
  101. {
  102. if (ImGui::BeginMenu("MachineLearning"))
  103. {
  104. ImGui::Checkbox("Model Registry", &m_displayModelRegistry);
  105. ImGui::Checkbox("Model Training", &m_displayTrainingWindow);
  106. ImGui::EndMenu();
  107. }
  108. }
  109. void MachineLearningDebugSystemComponent::OnImGuiUpdate()
  110. {
  111. if (m_displayModelRegistry)
  112. {
  113. if (ImGui::Begin("Model Registry", &m_displayModelRegistry, ImGuiWindowFlags_None))
  114. {
  115. OnModelRegistryDisplay();
  116. }
  117. ImGui::End();
  118. }
  119. if (m_displayTrainingWindow)
  120. {
  121. if (ImGui::Begin("Model Training", &m_displayTrainingWindow, ImGuiWindowFlags_None))
  122. {
  123. OnModelTrainingDisplay();
  124. }
  125. ImGui::End();
  126. }
  127. }
  128. #endif
  129. }