MultilayerPerceptron.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Math/MatrixMxN.h>
  10. #include <AzNetworking/Serialization/ISerializer.h>
  11. #include <MachineLearning/INeuralNetwork.h>
  12. #include <Models/Layer.h>
  13. #include <Assets/ModelAsset.h>
  14. namespace MachineLearning
  15. {
  16. //! This is a basic multilayer perceptron neural network capable of basic training and feed forward operations.
  17. class MultilayerPerceptron
  18. : public INeuralNetwork
  19. {
  20. public:
  21. AZ_RTTI(MultilayerPerceptron, "{E12EF761-41A5-48C3-BF55-7179B280D45F}", INeuralNetwork);
  22. //! AzCore Reflection.
  23. //! @param context reflection context
  24. static void Reflect(AZ::ReflectContext* context);
  25. MultilayerPerceptron();
  26. MultilayerPerceptron(const MultilayerPerceptron&);
  27. MultilayerPerceptron(AZStd::size_t activationCount);
  28. virtual ~MultilayerPerceptron();
  29. MultilayerPerceptron& operator=(const MultilayerPerceptron&);
  30. //! INeuralNetwork interface
  31. //! @{
  32. AZStd::string GetName() const override;
  33. AZStd::string GetAssetFile(AssetTypes assetType) const override;
  34. AZStd::size_t GetInputDimensionality() const override;
  35. AZStd::size_t GetOutputDimensionality() const override;
  36. AZStd::size_t GetLayerCount() const override;
  37. AZ::MatrixMxN GetLayerWeights(AZStd::size_t layerIndex) const override;
  38. AZ::VectorN GetLayerBiases(AZStd::size_t layerIndex) const override;
  39. AZStd::size_t GetParameterCount() const override;
  40. IInferenceContextPtr CreateInferenceContext() override;
  41. ITrainingContextPtr CreateTrainingContext() override;
  42. const AZ::VectorN* Forward(IInferenceContextPtr context, const AZ::VectorN& activations) override;
  43. void Reverse(ITrainingContextPtr context, LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected) override;
  44. void GradientDescent(ITrainingContextPtr context, float learningRate) override;
  45. bool LoadModel() override;
  46. bool SaveModel() override;
  47. //! @}
  48. //! Adds a new layer to the model.
  49. void AddLayer(AZStd::size_t layerDimensionality, ActivationFunctions activationFunction = ActivationFunctions::ReLU);
  50. //! Retrieves a specific layer from the model, this is not thread safe and should only be used during unit testing to validate model parameters.
  51. Layer* GetLayer(AZStd::size_t layerIndex);
  52. //! Base serialize method for all serializable structures or classes to implement.
  53. //! @param serializer ISerializer instance to use for serialization
  54. //! @return boolean true for success, false for serialization failure
  55. bool Serialize(AzNetworking::ISerializer& serializer);
  56. //! Returns the estimated size required to serialize this model.
  57. AZStd::size_t EstimateSerializeSize() const;
  58. private:
  59. void OnActivationCountChanged();
  60. //! The model asset.
  61. AZ::Data::Asset<ModelAsset> m_asset;
  62. //! The model name.
  63. AZStd::string m_name;
  64. //! The model asset file.
  65. AZStd::string m_modelFile;
  66. //! Optional test and train asset data files.
  67. AZStd::string m_testDataFile;
  68. AZStd::string m_testLabelFile;
  69. AZStd::string m_trainDataFile;
  70. AZStd::string m_trainLabelFile;
  71. //! The number of neurons in the activation layer.
  72. AZStd::size_t m_activationCount = 0;
  73. //! The set of layers in the network.
  74. AZStd::vector<Layer> m_layers;
  75. };
  76. struct MlpInferenceContext
  77. : public IInferenceContext
  78. {
  79. AZStd::vector<LayerInferenceData> m_layerData;
  80. };
  81. struct MlpTrainingContext
  82. : public ITrainingContext
  83. {
  84. //! Used during the forward pass when calculating loss gradients.
  85. MlpInferenceContext m_forward;
  86. //! The number of accumulated training samples.
  87. AZStd::size_t m_trainingSampleSize = 0;
  88. //! The set of layer training data.
  89. AZStd::vector<LayerTrainingData> m_layerData;
  90. };
  91. }