MultilayerPerceptronComponent.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <Components/MultilayerPerceptronComponent.h>
  10. #include <MachineLearning/IMachineLearning.h>
  11. #include <AzCore/RTTI/RTTI.h>
  12. #include <AzCore/RTTI/BehaviorContext.h>
  13. #include <AzCore/Serialization/EditContext.h>
  14. #include <AzCore/Serialization/SerializeContext.h>
  15. #include <AzCore/Console/ILogger.h>
  16. namespace MachineLearning
  17. {
  18. void MultilayerPerceptronComponent::Reflect(AZ::ReflectContext* context)
  19. {
  20. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  21. {
  22. serializeContext->Class<MultilayerPerceptronComponent>()
  23. ->Version(0)
  24. ->Field("Asset", &MultilayerPerceptronComponent::m_asset)
  25. ->Field("Model", &MultilayerPerceptronComponent::m_model)
  26. ;
  27. }
  28. auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
  29. if (behaviorContext)
  30. {
  31. behaviorContext->Class<MultilayerPerceptronComponent>("MultilayerPerceptron Component")
  32. ->Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)
  33. ->Attribute(AZ::Script::Attributes::Module, "machineLearning")
  34. ->Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)
  35. ->Constructor<>()
  36. ->Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)
  37. ->Property("Model", BehaviorValueProperty(&MultilayerPerceptronComponent::m_model))
  38. ;
  39. behaviorContext->EBus<MultilayerPerceptronComponentRequestBus>("Multilayer perceptron requests")
  40. ->Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)
  41. ->Attribute(AZ::Script::Attributes::Module, "machinelearning")
  42. ->Attribute(AZ::Script::Attributes::Category, "MachineLearning")
  43. ->Event("Get model", &MachineLearning::MultilayerPerceptronComponentRequestBus::Events::GetModel)
  44. ;
  45. }
  46. }
  47. void MultilayerPerceptronComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  48. {
  49. provided.push_back(AZ_CRC("MultilayerPerceptronService"));
  50. }
  51. MultilayerPerceptronComponent::MultilayerPerceptronComponent()
  52. {
  53. m_handle.reset(&m_model);
  54. MachineLearningInterface::Get()->RegisterModel(m_handle);
  55. }
  56. MultilayerPerceptronComponent::~MultilayerPerceptronComponent()
  57. {
  58. MachineLearningInterface::Get()->UnregisterModel(m_handle);
  59. }
  60. void MultilayerPerceptronComponent::Activate()
  61. {
  62. MultilayerPerceptronComponentRequestBus::Handler::BusConnect(GetEntityId());
  63. AssetChanged();
  64. }
  65. void MultilayerPerceptronComponent::Deactivate()
  66. {
  67. AZ::Data::AssetBus::Handler::BusDisconnect();
  68. MultilayerPerceptronComponentRequestBus::Handler::BusDisconnect();
  69. }
  70. INeuralNetworkPtr MultilayerPerceptronComponent::GetModel()
  71. {
  72. return m_handle;
  73. }
  74. void MultilayerPerceptronComponent::AssetChanged()
  75. {
  76. AZ::Data::AssetBus::Handler::BusDisconnect();
  77. if (m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::Error ||
  78. m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::NotLoaded)
  79. {
  80. m_asset.QueueLoad();
  81. }
  82. AZ::Data::AssetBus::Handler::BusConnect(m_asset.GetId());
  83. }
  84. void MultilayerPerceptronComponent::AssetCleared()
  85. {
  86. ;
  87. }
  88. void MultilayerPerceptronComponent::OnAssetReady(AZ::Data::Asset<AZ::Data::AssetData> asset)
  89. {
  90. ModelAsset* modelAsset = asset.GetAs<ModelAsset>();
  91. if ((asset == m_asset) && (modelAsset != nullptr))
  92. {
  93. m_model = *modelAsset;
  94. }
  95. }
  96. void MultilayerPerceptronComponent::OnAssetReloaded(AZ::Data::Asset<AZ::Data::AssetData> asset)
  97. {
  98. OnAssetReady(asset);
  99. }
  100. void MultilayerPerceptronComponent::OnAssetError(AZ::Data::Asset<AZ::Data::AssetData> asset)
  101. {
  102. if (asset == m_asset)
  103. {
  104. AZLOG_WARN("OnAssetError: %s", asset.GetHint().c_str());
  105. }
  106. }
  107. void MultilayerPerceptronComponent::OnAssetReloadError(AZ::Data::Asset<AZ::Data::AssetData> asset)
  108. {
  109. if (asset == m_asset)
  110. {
  111. AZLOG_WARN("OnAssetReloadError: %s", asset.GetHint().c_str());
  112. }
  113. }
  114. }