MachineLearningSystemComponent.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "MachineLearningSystemComponent.h"
  9. #include <MachineLearning/MachineLearningTypeIds.h>
  10. #include <MachineLearning/Types.h>
  11. #include <AzCore/Serialization/SerializeContext.h>
  12. #include <AzCore/Serialization/EditContext.h>
  13. #include <AzCore/RTTI/BehaviorContext.h>
  14. #include <AzCore/Preprocessor/EnumReflectUtils.h>
  15. #include <Algorithms/Activations.h>
  16. #include <Assets/MnistDataLoader.h>
  17. #include <Models/Layer.h>
  18. #include <Models/MultilayerPerceptron.h>
  19. #include <AutoGenNodeableRegistry.generated.h>
  20. static ScriptCanvas::MachineLearningPrivateObjectNodeableRegistry s_MachineLearningPrivateObjectNodeableRegistry;
  21. namespace AZ
  22. {
  23. AZ_TYPE_INFO_SPECIALIZE(MachineLearning::ActivationFunctions, "{2ABF758E-CA69-41AC-BC95-B47AD7DEA31B}");
  24. AZ_TYPE_INFO_SPECIALIZE(MachineLearning::LossFunctions, "{18098C74-9AD0-4F1D-8093-545344620AD1}");
  25. }
  26. namespace MachineLearning
  27. {
  28. AZ_ENUM_DEFINE_REFLECT_UTILITIES(ActivationFunctions);
  29. AZ_ENUM_DEFINE_REFLECT_UTILITIES(LossFunctions);
  30. AZ_COMPONENT_IMPL(MachineLearningSystemComponent, "MachineLearningSystemComponent", MachineLearningSystemComponentTypeId);
  31. void MachineLearningSystemComponent::Reflect(AZ::ReflectContext* context)
  32. {
  33. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  34. {
  35. serializeContext->Class<MachineLearningSystemComponent, AZ::Component>()->Version(0);
  36. serializeContext->Class<ILabeledTrainingData>()->Version(0);
  37. serializeContext->Class<INeuralNetwork>()->Version(0);
  38. serializeContext->RegisterGenericType<INeuralNetworkPtr>();
  39. serializeContext->RegisterGenericType<ILabeledTrainingDataPtr>();
  40. }
  41. if (auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context))
  42. {
  43. behaviorContext->Class<MachineLearningSystemComponent>();
  44. behaviorContext->Class<INeuralNetwork>("INeuralNetwork")->
  45. Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
  46. Attribute(AZ::Script::Attributes::Module, "machineLearning")->
  47. Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
  48. Constructor<>()->
  49. Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
  50. Method("GetName", &INeuralNetwork::GetName)
  51. ;
  52. }
  53. Layer::Reflect(context);
  54. ModelAsset::Reflect(context);
  55. MnistDataLoader::Reflect(context);
  56. MultilayerPerceptron::Reflect(context);
  57. }
  58. void MachineLearningSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  59. {
  60. provided.push_back(AZ_CRC_CE("MachineLearningService"));
  61. }
  62. void MachineLearningSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
  63. {
  64. incompatible.push_back(AZ_CRC_CE("MachineLearningService"));
  65. }
  66. void MachineLearningSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required)
  67. {
  68. }
  69. void MachineLearningSystemComponent::GetDependentServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& dependent)
  70. {
  71. }
  72. MachineLearningSystemComponent::MachineLearningSystemComponent()
  73. {
  74. if (MachineLearningInterface::Get() == nullptr)
  75. {
  76. MachineLearningInterface::Register(this);
  77. }
  78. }
  79. MachineLearningSystemComponent::~MachineLearningSystemComponent()
  80. {
  81. if (MachineLearningInterface::Get() == this)
  82. {
  83. MachineLearningInterface::Unregister(this);
  84. }
  85. }
  86. void MachineLearningSystemComponent::Init()
  87. {
  88. }
  89. void MachineLearningSystemComponent::Activate()
  90. {
  91. MachineLearningRequestBus::Handler::BusConnect();
  92. m_assetHandler = AZStd::make_unique<ModelAssetHandler>();
  93. m_assetHandler->Register();
  94. }
  95. void MachineLearningSystemComponent::Deactivate()
  96. {
  97. m_assetHandler->Unregister();
  98. MachineLearningRequestBus::Handler::BusDisconnect();
  99. }
  100. void MachineLearningSystemComponent::RegisterModel(INeuralNetworkPtr model)
  101. {
  102. m_registeredModels.emplace(model);
  103. }
  104. void MachineLearningSystemComponent::UnregisterModel(INeuralNetworkPtr model)
  105. {
  106. m_registeredModels.erase(model);
  107. }
  108. ModelSet& MachineLearningSystemComponent::GetModelSet()
  109. {
  110. return m_registeredModels;
  111. }
  112. }