MachineLearningSystemComponent.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "MachineLearningSystemComponent.h"
  9. #include <MachineLearning/MachineLearningTypeIds.h>
  10. #include <AzCore/Serialization/SerializeContext.h>
  11. #include <AzCore/Serialization/EditContext.h>
  12. #include <AzCore/RTTI/BehaviorContext.h>
  13. #include <Models/Layer.h>
  14. #include <Models/MultilayerPerceptron.h>
  15. #include <AutoGenNodeableRegistry.generated.h>
  16. static ScriptCanvas::MachineLearningPrivateObjectNodeableRegistry s_MachineLearningPrivateObjectNodeableRegistry;
  17. namespace AZ
  18. {
  19. AZ_TYPE_INFO_SPECIALIZE(MachineLearning::ActivationFunctions, "{2ABF758E-CA69-41AC-BC95-B47AD7DEA31B}");
  20. AZ_TYPE_INFO_SPECIALIZE(MachineLearning::LossFunctions, "{18098C74-9AD0-4F1D-8093-545344620AD1}");
  21. }
  22. namespace MachineLearning
  23. {
  24. AZ_COMPONENT_IMPL(MachineLearningSystemComponent, "MachineLearningSystemComponent", MachineLearningSystemComponentTypeId);
  25. void LayerParams::Reflect(AZ::ReflectContext* context)
  26. {
  27. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  28. {
  29. serializeContext->Class<LayerParams>()
  30. ->Version(1)
  31. ->Field("Size", &LayerParams::m_layerSize)
  32. ->Field("ActivationFunction", &LayerParams::m_activationFunction)
  33. ;
  34. if (AZ::EditContext* editContext = serializeContext->GetEditContext())
  35. {
  36. editContext->Class<LayerParams>("Parameters defining a single layer of a neural network", "")
  37. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  38. ->DataElement(AZ::Edit::UIHandlers::Default, &LayerParams::m_layerSize, "Layer Size", "The number of neurons this layer should have")
  39. ->DataElement(AZ::Edit::UIHandlers::ComboBox, &LayerParams::m_activationFunction, "Activation Function", "The activation function applied to this layer")
  40. ;
  41. }
  42. }
  43. auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
  44. if (behaviorContext)
  45. {
  46. behaviorContext->Class<LayerParams>()->
  47. Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
  48. Attribute(AZ::Script::Attributes::Module, "machineLearning")->
  49. Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
  50. Constructor<AZStd::size_t, ActivationFunctions>()->
  51. Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)
  52. ;
  53. }
  54. }
  55. void MachineLearningSystemComponent::Reflect(AZ::ReflectContext* context)
  56. {
  57. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  58. {
  59. serializeContext->Class<MachineLearningSystemComponent, AZ::Component>()->Version(0);
  60. serializeContext->Class<Layer>()->Version(0);
  61. serializeContext->Class<MultilayerPerceptron>()->Version(0);
  62. serializeContext->Class<INeuralNetwork>()->Version(0);
  63. serializeContext->Class<LayerParams>();
  64. serializeContext->RegisterGenericType<INeuralNetworkPtr>();
  65. serializeContext->RegisterGenericType<HiddenLayerParams>();
  66. }
  67. if (auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context))
  68. {
  69. behaviorContext->Class<MachineLearningSystemComponent>();
  70. behaviorContext->Class<LayerParams>();
  71. behaviorContext->Class<Layer>();
  72. behaviorContext->Class<MultilayerPerceptron>();
  73. behaviorContext
  74. ->Enum<static_cast<int>(ActivationFunctions::Linear)>("Linear activation function")
  75. ->Enum<static_cast<int>(ActivationFunctions::ReLU)>("ReLU activation function");
  76. behaviorContext
  77. ->Enum<static_cast<int>(LossFunctions::MeanSquaredError)>("Quadratic cost function")
  78. ->Enum<static_cast<int>(LossFunctions::CrossEntropyLoss)>("Cross entropy loss function");
  79. }
  80. }
  81. void MachineLearningSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  82. {
  83. provided.push_back(AZ_CRC_CE("MachineLearningService"));
  84. }
  85. void MachineLearningSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
  86. {
  87. incompatible.push_back(AZ_CRC_CE("MachineLearningService"));
  88. }
  89. void MachineLearningSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required)
  90. {
  91. }
  92. void MachineLearningSystemComponent::GetDependentServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& dependent)
  93. {
  94. }
  95. MachineLearningSystemComponent::MachineLearningSystemComponent()
  96. {
  97. if (MachineLearningInterface::Get() == nullptr)
  98. {
  99. MachineLearningInterface::Register(this);
  100. }
  101. }
  102. MachineLearningSystemComponent::~MachineLearningSystemComponent()
  103. {
  104. if (MachineLearningInterface::Get() == this)
  105. {
  106. MachineLearningInterface::Unregister(this);
  107. }
  108. }
  109. void MachineLearningSystemComponent::Init()
  110. {
  111. }
  112. void MachineLearningSystemComponent::Activate()
  113. {
  114. MachineLearningRequestBus::Handler::BusConnect();
  115. }
  116. void MachineLearningSystemComponent::Deactivate()
  117. {
  118. MachineLearningRequestBus::Handler::BusDisconnect();
  119. }
  120. } // namespace MachineLearning