ONNXSystemComponent.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "ONNXSystemComponent.h"
  9. #include <ONNX/Model.h>
  10. #include <AzCore/Serialization/EditContext.h>
  11. #include <AzCore/Serialization/EditContextConstants.inl>
  12. #include <AzCore/Serialization/SerializeContext.h>
  13. namespace ONNX
  14. {
  15. void ONNXSystemComponent::AddTimingSample(const char* modelName, float inferenceTimeInMilliseconds, AZ::Color modelColor)
  16. {
  17. m_timingStats.PushHistogramValue(modelName, inferenceTimeInMilliseconds, modelColor);
  18. }
  19. void ONNXSystemComponent::OnImGuiUpdate()
  20. {
  21. if (!m_timingStats.m_show)
  22. {
  23. return;
  24. }
  25. if (ImGui::Begin("ONNX"))
  26. {
  27. m_timingStats.OnImGuiUpdate();
  28. }
  29. }
  30. void ONNXSystemComponent::OnImGuiMainMenuUpdate()
  31. {
  32. if (ImGui::BeginMenu("ONNX"))
  33. {
  34. ImGui::MenuItem(m_timingStats.GetName(), "", &m_timingStats.m_show);
  35. ImGui::EndMenu();
  36. }
  37. }
  38. void ONNXSystemComponent::Reflect(AZ::ReflectContext* context)
  39. {
  40. if (AZ::SerializeContext* serialize = azrtti_cast<AZ::SerializeContext*>(context))
  41. {
  42. serialize->Class<ONNXSystemComponent, AZ::Component>()->Version(0);
  43. if (AZ::EditContext* ec = serialize->GetEditContext())
  44. {
  45. ec->Class<ONNXSystemComponent>("ONNX", "Provides ONNX Runtime functionality in O3DE")
  46. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  47. ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC("System"))
  48. ->Attribute(AZ::Edit::Attributes::AutoExpand, true);
  49. }
  50. }
  51. }
  52. void ONNXSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  53. {
  54. provided.push_back(AZ_CRC_CE("ONNXService"));
  55. }
  56. void ONNXSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
  57. {
  58. incompatible.push_back(AZ_CRC_CE("ONNXService"));
  59. }
  60. void ONNXSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required)
  61. {
  62. }
  63. void ONNXSystemComponent::GetDependentServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& dependent)
  64. {
  65. }
  66. ONNXSystemComponent::ONNXSystemComponent()
  67. {
  68. if (ONNXInterface::Get() == nullptr)
  69. {
  70. ONNXInterface::Register(this);
  71. }
  72. m_timingStats.SetName("ONNX Inference Timing Statistics");
  73. m_timingStats.SetHistogramBinCount(200);
  74. ImGui::ImGuiUpdateListenerBus::Handler::BusConnect();
  75. }
  76. ONNXSystemComponent::~ONNXSystemComponent()
  77. {
  78. ImGui::ImGuiUpdateListenerBus::Handler::BusDisconnect();
  79. if (ONNXInterface::Get() == this)
  80. {
  81. ONNXInterface::Unregister(this);
  82. }
  83. }
  84. Ort::Env* ONNXSystemComponent::GetEnv()
  85. {
  86. return m_env.get();
  87. }
  88. Ort::AllocatorWithDefaultOptions* ONNXSystemComponent::GetAllocator()
  89. {
  90. return m_allocator.get();
  91. }
  92. void OnnxLoggingFunction(void*, OrtLoggingLevel, const char* category, const char* logId, const char* codeLocation, const char* message)
  93. {
  94. AZ_Printf("ONNX", "%s %s %s %s\n", category, logId, codeLocation, message);
  95. }
  96. // The global environment and memory allocator are initialised with the system component, and are accessed via the EBus from within the
  97. // model. m_precomputedTimingData and m_precomputedTimingDataCuda are structs holding the test inference statistics run before the
  98. // editor starts up, and used by the ImGui dashboard.
  99. void ONNXSystemComponent::Init()
  100. {
  101. m_env = AZStd::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_VERBOSE, "test_log", OnnxLoggingFunction, nullptr);
  102. m_allocator = AZStd::make_unique<Ort::AllocatorWithDefaultOptions>();
  103. }
  104. void ONNXSystemComponent::Activate()
  105. {
  106. ONNXRequestBus::Handler::BusConnect();
  107. }
  108. void ONNXSystemComponent::Deactivate()
  109. {
  110. AZ::TickBus::Handler::BusDisconnect();
  111. }
  112. } // namespace ONNX