Layer.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Models/MultilayerPerceptron.h>
  9. #include <Algorithms/Activations.h>
  10. #include <Algorithms/LossFunctions.h>
  11. #include <AzCore/RTTI/RTTI.h>
  12. #include <AzCore/RTTI/BehaviorContext.h>
  13. #include <AzCore/Serialization/EditContext.h>
  14. #include <AzCore/Serialization/SerializeContext.h>
  15. #include <random>
  16. namespace MachineLearning
  17. {
  18. void Layer::Reflect(AZ::ReflectContext* context)
  19. {
  20. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  21. {
  22. serializeContext->Class<Layer>()
  23. ->Version(1)
  24. ->Field("InputSize", &Layer::m_inputSize)
  25. ->Field("OutputSize", &Layer::m_outputSize)
  26. ->Field("Weights", &Layer::m_weights)
  27. ->Field("Biases", &Layer::m_biases)
  28. ->Field("ActivationFunction", &Layer::m_activationFunction)
  29. ;
  30. if (AZ::EditContext* editContext = serializeContext->GetEditContext())
  31. {
  32. editContext->Class<Layer>("A single layer of a neural network", "")
  33. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  34. ->DataElement(AZ::Edit::UIHandlers::Default, &Layer::m_outputSize, "Layer Size", "The number of neurons the layer should have")
  35. ->Attribute(AZ::Edit::Attributes::ChangeNotify, &Layer::OnSizesChanged)
  36. ->DataElement(AZ::Edit::UIHandlers::ComboBox, &Layer::m_activationFunction, "Activation Function", "The activation function applied to this layer")
  37. ->Attribute(AZ::Edit::Attributes::EnumValues, &GetActivationEnumValues)
  38. ;
  39. }
  40. }
  41. auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
  42. if (behaviorContext)
  43. {
  44. behaviorContext->Class<Layer>()->
  45. Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
  46. Attribute(AZ::Script::Attributes::Module, "machineLearning")->
  47. Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
  48. Constructor<ActivationFunctions, AZStd::size_t, AZStd::size_t>()->
  49. Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
  50. Property("InputSize", BehaviorValueProperty(&Layer::m_inputSize))->
  51. Property("OutputSize", BehaviorValueProperty(&Layer::m_outputSize))->
  52. Property("ActivationFunction", BehaviorValueProperty(&Layer::m_activationFunction))
  53. ;
  54. }
  55. }
  56. Layer::Layer(ActivationFunctions activationFunction, AZStd::size_t activationDimensionality, AZStd::size_t layerDimensionality)
  57. : m_activationFunction(activationFunction)
  58. , m_inputSize(activationDimensionality)
  59. , m_outputSize(layerDimensionality)
  60. {
  61. OnSizesChanged();
  62. }
  63. const AZ::VectorN& Layer::Forward(LayerInferenceData& inferenceData, const AZ::VectorN& activations)
  64. {
  65. inferenceData.m_output = m_biases;
  66. AZ::VectorMatrixMultiply(m_weights, activations, inferenceData.m_output);
  67. Activate(m_activationFunction, inferenceData.m_output, inferenceData.m_output);
  68. return inferenceData.m_output;
  69. }
  70. void Layer::AccumulateGradients(LayerTrainingData& trainingData, LayerInferenceData& inferenceData, const AZ::VectorN& previousLayerGradients)
  71. {
  72. // Ensure our bias gradient vector is appropriately sized
  73. if (trainingData.m_biasGradients.GetDimensionality() != m_outputSize)
  74. {
  75. trainingData.m_biasGradients = AZ::VectorN::CreateZero(m_outputSize);
  76. }
  77. // Ensure our weight gradient matrix is appropriately sized
  78. if ((trainingData.m_weightGradients.GetRowCount() != m_outputSize) || (trainingData.m_weightGradients.GetColumnCount() != m_inputSize))
  79. {
  80. trainingData.m_weightGradients = AZ::MatrixMxN::CreateZero(m_outputSize, m_inputSize);
  81. }
  82. // Ensure our backpropagation gradient vector is appropriately sized
  83. if (trainingData.m_backpropagationGradients.GetDimensionality() != m_inputSize)
  84. {
  85. trainingData.m_backpropagationGradients = AZ::VectorN::CreateZero(m_inputSize);
  86. }
  87. // Compute the partial derivatives of the output with respect to the activation function
  88. Activate_Derivative(m_activationFunction, inferenceData.m_output, previousLayerGradients, trainingData.m_activationGradients);
  89. // Accumulate the partial derivatives of the weight matrix with respect to the loss function
  90. AZ::OuterProduct(trainingData.m_activationGradients, *trainingData.m_lastInput, trainingData.m_weightGradients);
  91. // Accumulate the partial derivatives of the bias vector with respect to the loss function
  92. trainingData.m_biasGradients += trainingData.m_activationGradients;
  93. // Accumulate the gradients to pass to the preceding layer for back-propagation
  94. AZ::VectorMatrixMultiplyLeft(trainingData.m_activationGradients, m_weights, trainingData.m_backpropagationGradients);
  95. }
  96. void Layer::ApplyGradients(LayerTrainingData& trainingData, float learningRate)
  97. {
  98. m_weights -= trainingData.m_weightGradients * learningRate;
  99. m_biases -= trainingData.m_biasGradients * learningRate;
  100. trainingData.m_biasGradients.SetZero();
  101. trainingData.m_weightGradients.SetZero();
  102. trainingData.m_backpropagationGradients.SetZero();
  103. }
  104. bool Layer::Serialize(AzNetworking::ISerializer& serializer)
  105. {
  106. return serializer.Serialize(m_inputSize, "inputSize")
  107. && serializer.Serialize(m_outputSize, "outputSize")
  108. && serializer.Serialize(m_weights, "weights")
  109. && serializer.Serialize(m_biases, "biases")
  110. && serializer.Serialize(m_activationFunction, "activationFunction");
  111. }
  112. AZStd::size_t Layer::EstimateSerializeSize() const
  113. {
  114. const AZStd::size_t padding = 64; // 64 bytes of extra padding just in case
  115. return padding
  116. + sizeof(m_inputSize)
  117. + sizeof(m_outputSize)
  118. + sizeof(AZStd::size_t) // for m_weights row count
  119. + sizeof(AZStd::size_t) // for m_weights column count
  120. + sizeof(AZStd::size_t) // for m_weights vector size
  121. + sizeof(float) * m_outputSize * m_inputSize // m_weights buffer
  122. + sizeof(AZStd::size_t) // for m_biases dimensionality
  123. + sizeof(AZStd::size_t) // for m_biases vector size
  124. + sizeof(float) * m_outputSize // m_biases buffer
  125. + sizeof(m_activationFunction);
  126. }
  127. void Layer::OnSizesChanged()
  128. {
  129. // Specifically for ReLU, we use Kaiming He initialization as this is proven optimal for convergence
  130. // For other activation functions we just use a standard normal distribution
  131. float standardDeviation = (m_activationFunction == ActivationFunctions::ReLU) ? 2.0f / m_inputSize
  132. : 1.0f / m_inputSize;
  133. std::random_device rd{};
  134. std::mt19937 gen{ rd() };
  135. auto dist = std::normal_distribution<float>{ 0.0f, standardDeviation };
  136. m_weights.Resize(m_outputSize, m_inputSize);
  137. for (AZStd::size_t row = 0; row < m_weights.GetRowCount(); ++row)
  138. {
  139. for (AZStd::size_t col = 0; col < m_weights.GetRowCount(); ++col)
  140. {
  141. m_weights.SetElement(row, col, dist(gen));
  142. }
  143. }
  144. m_biases = AZ::VectorN(m_outputSize, 0.01f);
  145. }
  146. }