Layer.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Models/MultilayerPerceptron.h>
  9. #include <Algorithms/Activations.h>
  10. #include <Algorithms/LossFunctions.h>
  11. #include <AzCore/RTTI/RTTI.h>
  12. #include <AzCore/RTTI/BehaviorContext.h>
  13. #include <AzCore/Serialization/EditContext.h>
  14. #include <AzCore/Serialization/SerializeContext.h>
  15. #include <AzCore/Console/IConsole.h>
  16. #include <AzCore/Console/ILogger.h>
  17. #include <random>
  18. namespace MachineLearning
  19. {
  20. AZ_CVAR(bool, ml_logGradients, false, nullptr, AZ::ConsoleFunctorFlags::Null, "Dumps some gradient metrics so they can be monitored during training");
  21. AZ_CVAR(bool, ml_logGradientsVerbose, false, nullptr, AZ::ConsoleFunctorFlags::Null, "Dumps complete gradient values to the console for examination, this can be a significant amount of data");
  22. void DumpVectorGradients(const AZ::VectorN& value, const char* label)
  23. {
  24. AZStd::string vectorString(label);
  25. for (AZStd::size_t iter = 0; iter < value.GetDimensionality(); ++iter)
  26. {
  27. vectorString += AZStd::string::format(" %.02f", value.GetElement(iter));
  28. }
  29. AZLOG_INFO(vectorString.c_str());
  30. }
  31. void DumpMatrixGradients(const AZ::MatrixMxN& value, const char* label)
  32. {
  33. for (AZStd::size_t i = 0; i < value.GetRowCount(); ++i)
  34. {
  35. AZStd::string rowString(label);
  36. rowString += AZStd::string::format(":%u", static_cast<uint32_t>(i));
  37. for (AZStd::size_t j = 0; j < value.GetColumnCount(); ++j)
  38. {
  39. rowString += AZStd::string::format(" %.02f", value.GetElement(i, j));
  40. }
  41. AZLOG_INFO(rowString.c_str());
  42. }
  43. }
  44. void AccumulateBiasGradients(AZ::VectorN& biasGradients, const AZ::VectorN& activationGradients, AZStd::size_t currentSamples)
  45. {
  46. AZ::Vector4 divisor(static_cast<float>(currentSamples));
  47. AZStd::vector<AZ::Vector4>& biasValues = biasGradients.GetVectorValues();
  48. const AZStd::vector<AZ::Vector4>& activationValues = activationGradients.GetVectorValues();
  49. for (AZStd::size_t iter = 0; iter < biasValues.size(); ++iter)
  50. {
  51. // average += (next - average) / samples
  52. biasValues[iter] += (activationValues[iter] - biasValues[iter]) / divisor;
  53. }
  54. }
  55. void AccumulateWeightGradients(const AZ::VectorN& activationGradients, const AZ::VectorN& lastInput, AZ::MatrixMxN& weightGradients, AZStd::size_t currentSamples)
  56. {
  57. // The following performs an outer product between activationGradients and lastInput
  58. // The reason we're not simply iteratively invoking OuterProduct is so that we can compute a more numerically stable average and preserve our gradients better over large batch sizes
  59. const AZ::Simd::Vec4::FloatType divisor = AZ::Simd::Vec4::Splat(static_cast<float>(currentSamples));
  60. for (AZStd::size_t colIter = 0; colIter < weightGradients.GetColumnGroups(); ++colIter)
  61. {
  62. AZ::Simd::Vec4::FloatType rhsElement = lastInput.GetVectorValues()[colIter].GetSimdValue();
  63. AZ::Simd::Vec4::FloatType splat0 = AZ::Simd::Vec4::SplatIndex0(rhsElement);
  64. AZ::Simd::Vec4::FloatType splat1 = AZ::Simd::Vec4::SplatIndex1(rhsElement);
  65. AZ::Simd::Vec4::FloatType splat2 = AZ::Simd::Vec4::SplatIndex2(rhsElement);
  66. AZ::Simd::Vec4::FloatType splat3 = AZ::Simd::Vec4::SplatIndex3(rhsElement);
  67. for (AZStd::size_t rowIter = 0; rowIter < weightGradients.GetRowGroups(); ++rowIter)
  68. {
  69. AZ::Simd::Vec4::FloatType lhsElement = activationGradients.GetVectorValues()[rowIter].GetSimdValue();
  70. AZ::Matrix4x4& outputElement = weightGradients.GetSubmatrix(rowIter, colIter);
  71. AZ::Simd::Vec4::FloatType next0 = AZ::Simd::Vec4::Sub(AZ::Simd::Vec4::Mul(lhsElement, splat0), outputElement.GetSimdValues()[0]);
  72. AZ::Simd::Vec4::FloatType next1 = AZ::Simd::Vec4::Sub(AZ::Simd::Vec4::Mul(lhsElement, splat1), outputElement.GetSimdValues()[1]);
  73. AZ::Simd::Vec4::FloatType next2 = AZ::Simd::Vec4::Sub(AZ::Simd::Vec4::Mul(lhsElement, splat2), outputElement.GetSimdValues()[2]);
  74. AZ::Simd::Vec4::FloatType next3 = AZ::Simd::Vec4::Sub(AZ::Simd::Vec4::Mul(lhsElement, splat3), outputElement.GetSimdValues()[3]);
  75. // average += (next - average) / samples
  76. outputElement.GetSimdValues()[0] = AZ::Simd::Vec4::Add(outputElement.GetSimdValues()[0], AZ::Simd::Vec4::Div(next0, divisor));
  77. outputElement.GetSimdValues()[1] = AZ::Simd::Vec4::Add(outputElement.GetSimdValues()[1], AZ::Simd::Vec4::Div(next1, divisor));
  78. outputElement.GetSimdValues()[2] = AZ::Simd::Vec4::Add(outputElement.GetSimdValues()[2], AZ::Simd::Vec4::Div(next2, divisor));
  79. outputElement.GetSimdValues()[3] = AZ::Simd::Vec4::Add(outputElement.GetSimdValues()[3], AZ::Simd::Vec4::Div(next3, divisor));
  80. }
  81. }
  82. weightGradients.FixUnusedElements();
  83. }
  84. void GetMinMaxElements(const AZ::VectorN& source, float& min, float& max)
  85. {
  86. const AZStd::vector<AZ::Vector4>& elements = source.GetVectorValues();
  87. if (!elements.empty())
  88. {
  89. AZ::Vector4 minimum = elements[0];
  90. AZ::Vector4 maximum = elements[0];
  91. for (AZStd::size_t i = 1; i < elements.size(); ++i)
  92. {
  93. minimum.GetMin(elements[i]);
  94. maximum.GetMax(elements[i]);
  95. }
  96. min = AZ::GetMin(AZ::GetMin(minimum.GetX(), minimum.GetY()), AZ::GetMin(minimum.GetZ(), minimum.GetW()));
  97. max = AZ::GetMax(AZ::GetMax(minimum.GetX(), minimum.GetY()), AZ::GetMax(minimum.GetZ(), minimum.GetW()));
  98. }
  99. }
  100. void GetMinMaxElements(AZ::MatrixMxN& source, float& min, float& max)
  101. {
  102. AZStd::vector<AZ::Matrix4x4>& elements = source.GetMatrixElements();
  103. if (!elements.empty())
  104. {
  105. AZ::Vector4 minimum = elements[0].GetRow(0);
  106. AZ::Vector4 maximum = elements[0].GetRow(0);
  107. for (AZStd::size_t i = 1; i < elements.size(); ++i)
  108. {
  109. for (int32_t j = 0; j < 4; ++j)
  110. {
  111. minimum = minimum.GetMin(elements[i].GetRow(j));
  112. maximum = maximum.GetMax(elements[i].GetRow(j));
  113. }
  114. }
  115. min = AZ::GetMin(AZ::GetMin(minimum.GetX(), minimum.GetY()), AZ::GetMin(minimum.GetZ(), minimum.GetW()));
  116. max = AZ::GetMax(AZ::GetMax(minimum.GetX(), minimum.GetY()), AZ::GetMax(minimum.GetZ(), minimum.GetW()));
  117. }
  118. }
  119. void Layer::Reflect(AZ::ReflectContext* context)
  120. {
  121. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  122. {
  123. serializeContext->Class<Layer>()
  124. ->Version(1)
  125. ->Field("InputSize", &Layer::m_inputSize)
  126. ->Field("OutputSize", &Layer::m_outputSize)
  127. ->Field("Weights", &Layer::m_weights)
  128. ->Field("Biases", &Layer::m_biases)
  129. ->Field("ActivationFunction", &Layer::m_activationFunction)
  130. ;
  131. if (AZ::EditContext* editContext = serializeContext->GetEditContext())
  132. {
  133. editContext->Class<Layer>("A single layer of a neural network", "")
  134. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  135. ->DataElement(AZ::Edit::UIHandlers::Default, &Layer::m_outputSize, "Layer Size", "The number of neurons the layer should have")
  136. ->Attribute(AZ::Edit::Attributes::ChangeNotify, &Layer::OnSizesChanged)
  137. ->DataElement(AZ::Edit::UIHandlers::ComboBox, &Layer::m_activationFunction, "Activation Function", "The activation function applied to this layer")
  138. ->Attribute(AZ::Edit::Attributes::EnumValues, &GetActivationEnumValues)
  139. ;
  140. }
  141. }
  142. auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
  143. if (behaviorContext)
  144. {
  145. behaviorContext->Class<Layer>()->
  146. Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
  147. Attribute(AZ::Script::Attributes::Module, "machineLearning")->
  148. Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
  149. Constructor<ActivationFunctions, AZStd::size_t, AZStd::size_t>()->
  150. Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
  151. Property("InputSize", BehaviorValueProperty(&Layer::m_inputSize))->
  152. Property("OutputSize", BehaviorValueProperty(&Layer::m_outputSize))->
  153. Property("ActivationFunction", BehaviorValueProperty(&Layer::m_activationFunction))
  154. ;
  155. }
  156. }
  157. Layer::Layer(ActivationFunctions activationFunction, AZStd::size_t activationDimensionality, AZStd::size_t layerDimensionality)
  158. : m_activationFunction(activationFunction)
  159. , m_inputSize(activationDimensionality)
  160. , m_outputSize(layerDimensionality)
  161. {
  162. OnSizesChanged();
  163. }
  164. const AZ::VectorN& Layer::Forward(LayerInferenceData& inferenceData, const AZ::VectorN& activations)
  165. {
  166. inferenceData.m_output = m_biases;
  167. AZ::VectorMatrixMultiply(m_weights, activations, inferenceData.m_output);
  168. Activate(m_activationFunction, inferenceData.m_output, inferenceData.m_output);
  169. return inferenceData.m_output;
  170. }
  171. void Layer::AccumulateGradients(AZStd::size_t samples, LayerTrainingData& trainingData, LayerInferenceData& inferenceData, const AZ::VectorN& previousLayerGradients)
  172. {
  173. // Ensure our bias gradient vector is appropriately sized
  174. if (trainingData.m_biasGradients.GetDimensionality() != m_outputSize)
  175. {
  176. trainingData.m_biasGradients = AZ::VectorN::CreateZero(m_outputSize);
  177. }
  178. // Ensure our weight gradient matrix is appropriately sized
  179. if ((trainingData.m_weightGradients.GetRowCount() != m_outputSize) || (trainingData.m_weightGradients.GetColumnCount() != m_inputSize))
  180. {
  181. trainingData.m_weightGradients = AZ::MatrixMxN::CreateZero(m_outputSize, m_inputSize);
  182. }
  183. // Ensure our backpropagation gradient vector is appropriately sized
  184. if (trainingData.m_backpropagationGradients.GetDimensionality() != m_inputSize)
  185. {
  186. trainingData.m_backpropagationGradients = AZ::VectorN::CreateZero(m_inputSize);
  187. }
  188. // Compute the partial derivatives of the output with respect to the activation function
  189. Activate_Derivative(m_activationFunction, inferenceData.m_output, previousLayerGradients, trainingData.m_activationGradients);
  190. // Accumulate the partial derivatives of the weight matrix with respect to the loss function
  191. AccumulateWeightGradients(trainingData.m_activationGradients, *trainingData.m_lastInput, trainingData.m_weightGradients, samples);
  192. // Accumulate the partial derivatives of the bias vector with respect to the loss function
  193. AccumulateBiasGradients(trainingData.m_biasGradients, trainingData.m_activationGradients, samples);
  194. // Accumulate the gradients to pass to the preceding layer for back-propagation
  195. AZ::VectorMatrixMultiplyLeft(trainingData.m_activationGradients, m_weights, trainingData.m_backpropagationGradients);
  196. if (ml_logGradients)
  197. {
  198. float min = 0.f;
  199. float max = 0.f;
  200. GetMinMaxElements(trainingData.m_weightGradients, min, max);
  201. AZLOG_INFO("Weight gradients: min value %f, max value %f", min, max);
  202. GetMinMaxElements(trainingData.m_biasGradients, min, max);
  203. AZLOG_INFO("Bias gradients: min value %f, max value %f", min, max);
  204. GetMinMaxElements(trainingData.m_backpropagationGradients, min, max);
  205. AZLOG_INFO("Back-propagation gradients: min value %f, max value %f", min, max);
  206. }
  207. if (ml_logGradientsVerbose)
  208. {
  209. DumpMatrixGradients(trainingData.m_weightGradients, "WeightGradients");
  210. DumpVectorGradients(trainingData.m_biasGradients, "BiasGradients");
  211. }
  212. }
  213. void Layer::ApplyGradients(LayerTrainingData& trainingData, float learningRate)
  214. {
  215. m_weights -= trainingData.m_weightGradients * learningRate;
  216. m_biases -= trainingData.m_biasGradients * learningRate;
  217. trainingData.m_biasGradients.SetZero();
  218. trainingData.m_weightGradients.SetZero();
  219. trainingData.m_backpropagationGradients.SetZero();
  220. }
  221. bool Layer::Serialize(AzNetworking::ISerializer& serializer)
  222. {
  223. return serializer.Serialize(m_inputSize, "inputSize")
  224. && serializer.Serialize(m_outputSize, "outputSize")
  225. && serializer.Serialize(m_weights, "weights")
  226. && serializer.Serialize(m_biases, "biases")
  227. && serializer.Serialize(m_activationFunction, "activationFunction");
  228. }
  229. AZStd::size_t Layer::EstimateSerializeSize() const
  230. {
  231. const AZStd::size_t padding = 64; // 64 bytes of extra padding just in case
  232. return padding
  233. + sizeof(m_inputSize)
  234. + sizeof(m_outputSize)
  235. + sizeof(AZStd::size_t) // for m_weights row count
  236. + sizeof(AZStd::size_t) // for m_weights column count
  237. + sizeof(AZStd::size_t) // for m_weights vector size
  238. + sizeof(float) * m_outputSize * m_inputSize // m_weights buffer
  239. + sizeof(AZStd::size_t) // for m_biases dimensionality
  240. + sizeof(AZStd::size_t) // for m_biases vector size
  241. + sizeof(float) * m_outputSize // m_biases buffer
  242. + sizeof(m_activationFunction);
  243. }
  244. void Layer::OnSizesChanged()
  245. {
  246. // Specifically for ReLU, we use Kaiming He initialization as this is optimal for convergence
  247. // For other activation functions we just use a standard normal distribution
  248. float standardDeviation = (m_activationFunction == ActivationFunctions::ReLU) ? 2.0f / m_inputSize
  249. : 1.0f / m_inputSize;
  250. std::random_device rd{};
  251. std::mt19937 gen{ rd() };
  252. auto dist = std::normal_distribution<float>{ 0.0f, standardDeviation };
  253. m_weights.Resize(m_outputSize, m_inputSize);
  254. for (AZStd::size_t row = 0; row < m_weights.GetRowCount(); ++row)
  255. {
  256. for (AZStd::size_t col = 0; col < m_weights.GetRowCount(); ++col)
  257. {
  258. m_weights.SetElement(row, col, dist(gen));
  259. }
  260. }
  261. m_biases = AZ::VectorN(m_outputSize, 0.01f);
  262. }
  263. }