Layer.h 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Math/MatrixMxN.h>
  10. #include <AzNetworking/Serialization/ISerializer.h>
  11. #include <MachineLearning/INeuralNetwork.h>
  12. namespace MachineLearning
  13. {
  14. // We separate out inference and training data to make multithreading models easier and more efficient.
  15. struct LayerInferenceData;
  16. struct LayerTrainingData;
  17. //! A class representing a single layer within a neural network.
  18. class Layer
  19. {
  20. public:
  21. AZ_TYPE_INFO(Layer, "{FB91E0A7-86C0-4431-83A8-04F8D8E1C9E2}");
  22. //! AzCore Reflection.
  23. //! @param context reflection context
  24. static void Reflect(AZ::ReflectContext* context);
  25. Layer() = default;
  26. Layer(Layer&&) = default;
  27. Layer(const Layer&) = default;
  28. Layer(ActivationFunctions activationFunction, AZStd::size_t activationDimensionality, AZStd::size_t layerDimensionality);
  29. ~Layer() = default;
  30. Layer& operator=(Layer&&) = default;
  31. Layer& operator=(const Layer&) = default;
  32. //! Performs a basic forward pass on this layer, outputs are stored in m_output.
  33. const AZ::VectorN& Forward(LayerInferenceData& inferenceData, const AZ::VectorN& activations);
  34. //! Performs a gradient computation against the provided expected output using the provided gradients from the previous layer.
  35. //! This method presumes that we've completed a forward pass immediately prior to fill all the relevant vectors
  36. void AccumulateGradients(AZStd::size_t samples, LayerTrainingData& trainingData, LayerInferenceData& inferenceData, const AZ::VectorN& expected);
  37. //! Applies the current gradient values to the layers weights and biases and resets the gradient values for a new accumulation pass.
  38. void ApplyGradients(LayerTrainingData& trainingData, float learningRate);
  39. //! Base serialize method for all serializable structures or classes to implement.
  40. //! @param serializer ISerializer instance to use for serialization
  41. //! @return boolean true for success, false for serialization failure
  42. bool Serialize(AzNetworking::ISerializer& serializer);
  43. //! Returns the estimated size required to serialize this layer.
  44. AZStd::size_t EstimateSerializeSize() const;
  45. //! Updates layer internals for it's requested dimensionalities.
  46. void OnSizesChanged();
  47. // These are intentionally left public so that unit testing can exhaustively examine all layer state
  48. AZStd::size_t m_inputSize = 0;
  49. AZStd::size_t m_outputSize = 0;
  50. AZ::MatrixMxN m_weights;
  51. AZ::VectorN m_biases;
  52. ActivationFunctions m_activationFunction = ActivationFunctions::ReLU;
  53. };
  54. //! These values are written to during inference.
  55. struct LayerInferenceData
  56. {
  57. AZ::VectorN m_output;
  58. };
  59. //! These values are read and written during training.
  60. struct LayerTrainingData
  61. {
  62. // These values will only be populated if backward propagation is performed
  63. const AZ::VectorN* m_lastInput;
  64. AZ::VectorN m_activationGradients;
  65. AZ::VectorN m_biasGradients;
  66. AZ::MatrixMxN m_weightGradients;
  67. AZ::VectorN m_backpropagationGradients;
  68. };
  69. }