/* * Copyright (c) Contributors to the Open 3D Engine Project. * For complete copyright and license terms please see the LICENSE at the root of this distribution. * * SPDX-License-Identifier: Apache-2.0 OR MIT * */ #pragma once #include #include #include #include namespace MachineLearning { //! This is a basic multilayer perceptron neural network capable of basic training and feed forward operations. class MultilayerPerceptron : public INeuralNetwork { public: AZ_RTTI(MultilayerPerceptron, "{E12EF761-41A5-48C3-BF55-7179B280D45F}", INeuralNetwork); //! AzCore Reflection. //! @param context reflection context static void Reflect(AZ::ReflectContext* context); MultilayerPerceptron(); MultilayerPerceptron(const MultilayerPerceptron&); MultilayerPerceptron(AZStd::size_t activationCount); virtual ~MultilayerPerceptron(); MultilayerPerceptron& operator=(const MultilayerPerceptron&); MultilayerPerceptron& operator=(const ModelAsset&); //! INeuralNetwork interface //! @{ AZStd::string GetName() const override; AZStd::string GetAssetFile(AssetTypes assetType) const override; AZStd::size_t GetInputDimensionality() const override; AZStd::size_t GetOutputDimensionality() const override; AZStd::size_t GetLayerCount() const override; AZ::MatrixMxN GetLayerWeights(AZStd::size_t layerIndex) const override; AZ::VectorN GetLayerBiases(AZStd::size_t layerIndex) const override; AZStd::size_t GetParameterCount() const override; IInferenceContextPtr CreateInferenceContext() override; ITrainingContextPtr CreateTrainingContext() override; const AZ::VectorN* Forward(IInferenceContextPtr context, const AZ::VectorN& activations) override; void Reverse(ITrainingContextPtr context, LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected) override; void GradientDescent(ITrainingContextPtr context, float learningRate) override; bool LoadModel() override; bool SaveModel() override; //! @} //! Adds a new layer to the model. void AddLayer(AZStd::size_t layerDimensionality, ActivationFunctions activationFunction = ActivationFunctions::ReLU); //! Retrieves a specific layer from the model, this is not thread safe and should only be used during unit testing to validate model parameters. Layer* GetLayer(AZStd::size_t layerIndex); private: void OnActivationCountChanged(); //! The model name. AZStd::string m_name; //! Optional test and train asset data files. AZStd::string m_testDataFile; AZStd::string m_testLabelFile; AZStd::string m_trainDataFile; AZStd::string m_trainLabelFile; //! The number of neurons in the activation layer. AZStd::size_t m_activationCount = 0; //! The set of layers in the network. AZStd::vector m_layers; IAssetPersistenceProxy* m_proxy = nullptr; friend class MultilayerPerceptronEditorComponent; }; struct MlpInferenceContext : public IInferenceContext { AZStd::vector m_layerData; }; struct MlpTrainingContext : public ITrainingContext { //! Used during the forward pass when calculating loss gradients. MlpInferenceContext m_forward; //! The number of accumulated training samples. AZStd::size_t m_trainingSampleSize = 0; //! The set of layer training data. AZStd::vector m_layerData; }; }