/* * Copyright (c) Contributors to the Open 3D Engine Project. * For complete copyright and license terms please see the LICENSE at the root of this distribution. * * SPDX-License-Identifier: Apache-2.0 OR MIT * */ #include #include #include #include #include #include #include #include namespace MachineLearning { void MachineLearningDebugSystemComponent::Reflect(AZ::ReflectContext* context) { if (AZ::SerializeContext* serializeContext = azrtti_cast(context)) { serializeContext->Class() ->Version(1); } } void MachineLearningDebugSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided) { provided.push_back(AZ_CRC_CE("MachineLearningDebugSystemComponent")); } void MachineLearningDebugSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required) { ; } void MachineLearningDebugSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatbile) { incompatbile.push_back(AZ_CRC_CE("MachineLearningDebugSystemComponent")); } void MachineLearningDebugSystemComponent::Activate() { #ifdef IMGUI_ENABLED ImGui::ImGuiUpdateListenerBus::Handler::BusConnect(); #endif } void MachineLearningDebugSystemComponent::Deactivate() { #ifdef IMGUI_ENABLED ImGui::ImGuiUpdateListenerBus::Handler::BusDisconnect(); #endif } #ifdef IMGUI_ENABLED void MachineLearningDebugSystemComponent::OnModelRegistryDisplay() { const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x; const ImGuiTableFlags flags = ImGuiTableFlags_BordersV | ImGuiTableFlags_BordersOuterH | ImGuiTableFlags_Resizable | ImGuiTableFlags_RowBg | ImGuiTableFlags_NoBordersInBody; IMachineLearning* machineLearning = MachineLearningInterface::Get(); const ModelSet& modelSet = machineLearning->GetModelSet(); ImGui::Text("Total registered models: %u", static_cast(modelSet.size())); ImGui::NewLine(); if (ImGui::BeginTable("Model Details", 5, flags)) { ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f); ImGui::TableSetupColumn("Input Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f); ImGui::TableSetupColumn("Output Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f); ImGui::TableSetupColumn("Layers", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f); ImGui::TableSetupColumn("Parameters", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f); ImGui::TableHeadersRow(); AZStd::size_t index = 0; for (auto& neuralNetwork : modelSet) { ImGui::TableNextRow(); ImGui::TableNextColumn(); ImGui::Text(neuralNetwork->GetName().c_str()); ImGui::TableNextColumn(); ImGui::Text("%lld", aznumeric_cast(neuralNetwork->GetInputDimensionality())); ImGui::TableNextColumn(); ImGui::Text("%lld", aznumeric_cast(neuralNetwork->GetOutputDimensionality())); ImGui::TableNextColumn(); ImGui::Text("%lld", aznumeric_cast(neuralNetwork->GetLayerCount())); ImGui::TableNextColumn(); ImGui::Text("%llu", neuralNetwork->GetParameterCount()); ++index; } ImGui::EndTable(); ImGui::NewLine(); } ImGui::End(); } void MachineLearningDebugSystemComponent::OnModelTrainingDisplay() { m_trainingWindow.OnImGuiUpdate(); } void MachineLearningDebugSystemComponent::OnImGuiMainMenuUpdate() { if (ImGui::BeginMenu("MachineLearning")) { ImGui::Checkbox("Model Registry", &m_displayModelRegistry); ImGui::Checkbox("Model Training", &m_displayTrainingWindow); ImGui::EndMenu(); } } void MachineLearningDebugSystemComponent::OnImGuiUpdate() { if (m_displayModelRegistry) { if (ImGui::Begin("Model Registry", &m_displayModelRegistry, ImGuiWindowFlags_None)) { OnModelRegistryDisplay(); } ImGui::End(); } if (m_displayTrainingWindow) { if (ImGui::Begin("Model Training", &m_displayTrainingWindow, ImGuiWindowFlags_None)) { OnModelTrainingDisplay(); } ImGui::End(); } } #endif }