/* * Copyright (c) Contributors to the Open 3D Engine Project. * For complete copyright and license terms please see the LICENSE at the root of this distribution. * * SPDX-License-Identifier: Apache-2.0 OR MIT * */ #include "MachineLearningSystemComponent.h" #include #include #include #include #include #include #include static ScriptCanvas::MachineLearningPrivateObjectNodeableRegistry s_MachineLearningPrivateObjectNodeableRegistry; namespace AZ { AZ_TYPE_INFO_SPECIALIZE(MachineLearning::ActivationFunctions, "{2ABF758E-CA69-41AC-BC95-B47AD7DEA31B}"); AZ_TYPE_INFO_SPECIALIZE(MachineLearning::LossFunctions, "{18098C74-9AD0-4F1D-8093-545344620AD1}"); } namespace MachineLearning { AZ_COMPONENT_IMPL(MachineLearningSystemComponent, "MachineLearningSystemComponent", MachineLearningSystemComponentTypeId); void LayerParams::Reflect(AZ::ReflectContext* context) { if (auto serializeContext = azrtti_cast(context)) { serializeContext->Class() ->Version(1) ->Field("Size", &LayerParams::m_layerSize) ->Field("ActivationFunction", &LayerParams::m_activationFunction) ; if (AZ::EditContext* editContext = serializeContext->GetEditContext()) { editContext->Class("Parameters defining a single layer of a neural network", "") ->ClassElement(AZ::Edit::ClassElements::EditorData, "") ->DataElement(AZ::Edit::UIHandlers::Default, &LayerParams::m_layerSize, "Layer Size", "The number of neurons this layer should have") ->DataElement(AZ::Edit::UIHandlers::ComboBox, &LayerParams::m_activationFunction, "Activation Function", "The activation function applied to this layer") ; } } auto behaviorContext = azrtti_cast(context); if (behaviorContext) { behaviorContext->Class()-> Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)-> Attribute(AZ::Script::Attributes::Module, "machineLearning")-> Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)-> Constructor()-> Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value) ; } } void MachineLearningSystemComponent::Reflect(AZ::ReflectContext* context) { if (auto serializeContext = azrtti_cast(context)) { serializeContext->Class()->Version(0); serializeContext->Class()->Version(0); serializeContext->Class()->Version(0); serializeContext->Class()->Version(0); serializeContext->Class(); serializeContext->RegisterGenericType(); serializeContext->RegisterGenericType(); } if (auto behaviorContext = azrtti_cast(context)) { behaviorContext->Class(); behaviorContext->Class(); behaviorContext->Class(); behaviorContext->Class(); behaviorContext ->Enum(ActivationFunctions::Linear)>("Linear activation function") ->Enum(ActivationFunctions::ReLU)>("ReLU activation function"); behaviorContext ->Enum(LossFunctions::MeanSquaredError)>("Quadratic cost function") ->Enum(LossFunctions::CrossEntropyLoss)>("Cross entropy loss function"); } } void MachineLearningSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided) { provided.push_back(AZ_CRC_CE("MachineLearningService")); } void MachineLearningSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible) { incompatible.push_back(AZ_CRC_CE("MachineLearningService")); } void MachineLearningSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required) { } void MachineLearningSystemComponent::GetDependentServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& dependent) { } MachineLearningSystemComponent::MachineLearningSystemComponent() { if (MachineLearningInterface::Get() == nullptr) { MachineLearningInterface::Register(this); } } MachineLearningSystemComponent::~MachineLearningSystemComponent() { if (MachineLearningInterface::Get() == this) { MachineLearningInterface::Unregister(this); } } void MachineLearningSystemComponent::Init() { } void MachineLearningSystemComponent::Activate() { MachineLearningRequestBus::Handler::BusConnect(); } void MachineLearningSystemComponent::Deactivate() { MachineLearningRequestBus::Handler::BusDisconnect(); } } // namespace MachineLearning