123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- /*
- * Copyright (c) Contributors to the Open 3D Engine Project.
- * For complete copyright and license terms please see the LICENSE at the root of this distribution.
- *
- * SPDX-License-Identifier: Apache-2.0 OR MIT
- *
- */
- #include "ONNXSystemComponent.h"
- #include <ONNX/Model.h>
- #include <AzCore/Serialization/EditContext.h>
- #include <AzCore/Serialization/EditContextConstants.inl>
- #include <AzCore/Serialization/SerializeContext.h>
- namespace ONNX
- {
- void ONNXSystemComponent::AddTimingSample(const char* modelName, float inferenceTimeInMilliseconds, AZ::Color modelColor)
- {
- m_timingStats.PushHistogramValue(modelName, inferenceTimeInMilliseconds, modelColor);
- }
- void ONNXSystemComponent::OnImGuiUpdate()
- {
- if (!m_timingStats.m_show)
- {
- return;
- }
- if (ImGui::Begin("ONNX"))
- {
- m_timingStats.OnImGuiUpdate();
- }
- }
- void ONNXSystemComponent::OnImGuiMainMenuUpdate()
- {
- if (ImGui::BeginMenu("ONNX"))
- {
- ImGui::MenuItem(m_timingStats.GetName(), "", &m_timingStats.m_show);
- ImGui::EndMenu();
- }
- }
- void ONNXSystemComponent::Reflect(AZ::ReflectContext* context)
- {
- if (AZ::SerializeContext* serialize = azrtti_cast<AZ::SerializeContext*>(context))
- {
- serialize->Class<ONNXSystemComponent, AZ::Component>()->Version(0);
- if (AZ::EditContext* ec = serialize->GetEditContext())
- {
- ec->Class<ONNXSystemComponent>("ONNX", "Provides ONNX Runtime functionality in O3DE")
- ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
- ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC("System"))
- ->Attribute(AZ::Edit::Attributes::AutoExpand, true);
- }
- }
- }
- void ONNXSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
- {
- provided.push_back(AZ_CRC_CE("ONNXService"));
- }
- void ONNXSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
- {
- incompatible.push_back(AZ_CRC_CE("ONNXService"));
- }
- void ONNXSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required)
- {
- }
- void ONNXSystemComponent::GetDependentServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& dependent)
- {
- }
- ONNXSystemComponent::ONNXSystemComponent()
- {
- if (ONNXInterface::Get() == nullptr)
- {
- ONNXInterface::Register(this);
- }
- m_timingStats.SetName("ONNX Inference Timing Statistics");
- m_timingStats.SetHistogramBinCount(200);
- ImGui::ImGuiUpdateListenerBus::Handler::BusConnect();
- }
- ONNXSystemComponent::~ONNXSystemComponent()
- {
- ImGui::ImGuiUpdateListenerBus::Handler::BusDisconnect();
- if (ONNXInterface::Get() == this)
- {
- ONNXInterface::Unregister(this);
- }
- }
- Ort::Env* ONNXSystemComponent::GetEnv()
- {
- return m_env.get();
- }
- Ort::AllocatorWithDefaultOptions* ONNXSystemComponent::GetAllocator()
- {
- return m_allocator.get();
- }
- void OnnxLoggingFunction(void*, OrtLoggingLevel, const char* category, const char* logId, const char* codeLocation, const char* message)
- {
- AZ_Printf("ONNX", "%s %s %s %s\n", category, logId, codeLocation, message);
- }
- // The global environment and memory allocator are initialised with the system component, and are accessed via the EBus from within the
- // model. m_precomputedTimingData and m_precomputedTimingDataCuda are structs holding the test inference statistics run before the
- // editor starts up, and used by the ImGui dashboard.
- void ONNXSystemComponent::Init()
- {
- m_env = AZStd::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_VERBOSE, "test_log", OnnxLoggingFunction, nullptr);
- m_allocator = AZStd::make_unique<Ort::AllocatorWithDefaultOptions>();
- }
- void ONNXSystemComponent::Activate()
- {
- ONNXRequestBus::Handler::BusConnect();
- }
- void ONNXSystemComponent::Deactivate()
- {
- AZ::TickBus::Handler::BusDisconnect();
- }
- } // namespace ONNX
|