MultilayerPerceptron.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Models/MultilayerPerceptron.h>
  9. #include <Algorithms/LossFunctions.h>
  10. #include <AzCore/RTTI/RTTI.h>
  11. #include <AzCore/RTTI/BehaviorContext.h>
  12. #include <AzCore/Serialization/EditContext.h>
  13. #include <AzCore/Serialization/SerializeContext.h>
  14. #include <AzCore/IO/FileIO.h>
  15. #include <AzCore/IO/FileReader.h>
  16. #include <AzCore/IO/Path/Path.h>
  17. #include <AzCore/Console/ILogger.h>
  18. #include <AzNetworking/Serialization/NetworkInputSerializer.h>
  19. #include <AzNetworking/Serialization/NetworkOutputSerializer.h>
  20. namespace MachineLearning
  21. {
  22. void MultilayerPerceptron::Reflect(AZ::ReflectContext* context)
  23. {
  24. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  25. {
  26. serializeContext->Class<MultilayerPerceptron>()
  27. ->Version(1)
  28. ->Field("Name", &MultilayerPerceptron::m_name)
  29. ->Field("TestDataFile", &MultilayerPerceptron::m_testDataFile)
  30. ->Field("TestLabelFile", &MultilayerPerceptron::m_testLabelFile)
  31. ->Field("TrainDataFile", &MultilayerPerceptron::m_trainDataFile)
  32. ->Field("TrainLabelFile", &MultilayerPerceptron::m_trainLabelFile)
  33. ->Field("ActivationCount", &MultilayerPerceptron::m_activationCount)
  34. ->Field("Layers", &MultilayerPerceptron::m_layers)
  35. ;
  36. if (AZ::EditContext* editContext = serializeContext->GetEditContext())
  37. {
  38. editContext->Class<MultilayerPerceptron>("A basic multilayer perceptron class", "")
  39. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  40. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_name, "Name", "The name for this model")
  41. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testDataFile, "TestDataFile", "The file test data should be loaded from")
  42. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testLabelFile, "TestLabelFile", "The file test labels should be loaded from")
  43. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_trainDataFile, "TrainDataFile", "The file training data should be loaded from")
  44. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_trainLabelFile, "TrainLabelFile", "The file training labels should be loaded from")
  45. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_activationCount, "Activation Count", "The number of neurons in the activation layer")
  46. ->Attribute(AZ::Edit::Attributes::ChangeNotify, &MultilayerPerceptron::OnActivationCountChanged)
  47. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_layers, "Layers", "The layers of the neural network")
  48. ->Attribute(AZ::Edit::Attributes::ChangeNotify, &MultilayerPerceptron::OnActivationCountChanged)
  49. ;
  50. }
  51. }
  52. auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
  53. if (behaviorContext)
  54. {
  55. behaviorContext->Class<MultilayerPerceptron>("Multilayer perceptron")->
  56. Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
  57. Attribute(AZ::Script::Attributes::Module, "machineLearning")->
  58. Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
  59. Constructor<AZStd::size_t>()->
  60. Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
  61. Method("GetName", &MultilayerPerceptron::GetName)->
  62. Method("GetLayerCount", &MultilayerPerceptron::GetLayerCount)->
  63. Property("ActivationCount", BehaviorValueProperty(&MultilayerPerceptron::m_activationCount))->
  64. Property("Layers", BehaviorValueProperty(&MultilayerPerceptron::m_layers))
  65. ;
  66. }
  67. }
  68. MultilayerPerceptron::MultilayerPerceptron()
  69. {
  70. }
  71. MultilayerPerceptron::MultilayerPerceptron(const MultilayerPerceptron& rhs)
  72. : m_name(rhs.m_name)
  73. , m_testDataFile(rhs.m_testDataFile)
  74. , m_testLabelFile(rhs.m_testLabelFile)
  75. , m_trainDataFile(rhs.m_trainDataFile)
  76. , m_trainLabelFile(rhs.m_trainLabelFile)
  77. , m_activationCount(rhs.m_activationCount)
  78. , m_layers(rhs.m_layers)
  79. {
  80. }
  81. MultilayerPerceptron::MultilayerPerceptron(AZStd::size_t activationCount)
  82. : m_activationCount(activationCount)
  83. {
  84. }
  85. MultilayerPerceptron::~MultilayerPerceptron()
  86. {
  87. }
  88. MultilayerPerceptron& MultilayerPerceptron::operator=(const MultilayerPerceptron& rhs)
  89. {
  90. m_name = rhs.m_name;
  91. m_testDataFile = rhs.m_testDataFile;
  92. m_testLabelFile = rhs.m_testLabelFile;
  93. m_trainDataFile = rhs.m_trainDataFile;
  94. m_trainLabelFile = rhs.m_trainLabelFile;
  95. m_activationCount = rhs.m_activationCount;
  96. m_layers = rhs.m_layers;
  97. OnActivationCountChanged();
  98. return *this;
  99. }
  100. MultilayerPerceptron& MultilayerPerceptron::operator=(const ModelAsset& asset)
  101. {
  102. m_name = asset.m_name;
  103. m_activationCount = asset.m_activationCount;
  104. m_layers = asset.m_layers;
  105. OnActivationCountChanged();
  106. return *this;
  107. }
  108. AZStd::string MultilayerPerceptron::GetName() const
  109. {
  110. return m_name;
  111. }
  112. AZStd::string MultilayerPerceptron::GetAssetFile(AssetTypes assetType) const
  113. {
  114. switch (assetType)
  115. {
  116. case AssetTypes::TestData:
  117. return m_testDataFile;
  118. case AssetTypes::TestLabels:
  119. return m_testLabelFile;
  120. case AssetTypes::TrainingData:
  121. return m_trainDataFile;
  122. case AssetTypes::TrainingLabels:
  123. return m_trainLabelFile;
  124. }
  125. return "";
  126. }
  127. AZStd::size_t MultilayerPerceptron::GetInputDimensionality() const
  128. {
  129. return m_activationCount;
  130. }
  131. AZStd::size_t MultilayerPerceptron::GetOutputDimensionality() const
  132. {
  133. if (!m_layers.empty())
  134. {
  135. return m_layers.back().m_biases.GetDimensionality();
  136. }
  137. return m_activationCount;
  138. }
  139. AZStd::size_t MultilayerPerceptron::GetLayerCount() const
  140. {
  141. return m_layers.size();
  142. }
  143. AZ::MatrixMxN MultilayerPerceptron::GetLayerWeights(AZStd::size_t layerIndex) const
  144. {
  145. return m_layers[layerIndex].m_weights;
  146. }
  147. AZ::VectorN MultilayerPerceptron::GetLayerBiases(AZStd::size_t layerIndex) const
  148. {
  149. return m_layers[layerIndex].m_biases;
  150. }
  151. AZStd::size_t MultilayerPerceptron::GetParameterCount() const
  152. {
  153. AZStd::size_t parameterCount = 0;
  154. for (const Layer& layer : m_layers)
  155. {
  156. parameterCount += layer.m_inputSize * layer.m_outputSize + layer.m_outputSize;
  157. }
  158. return parameterCount;
  159. }
  160. IInferenceContextPtr MultilayerPerceptron::CreateInferenceContext()
  161. {
  162. return new MlpInferenceContext();
  163. }
  164. ITrainingContextPtr MultilayerPerceptron::CreateTrainingContext()
  165. {
  166. return new MlpTrainingContext();
  167. }
  168. const AZ::VectorN* MultilayerPerceptron::Forward(IInferenceContextPtr context, const AZ::VectorN& activations)
  169. {
  170. MlpInferenceContext* forwardContext = static_cast<MlpInferenceContext*>(context);
  171. forwardContext->m_layerData.resize(m_layers.size());
  172. const AZ::VectorN* lastLayerOutput = &activations;
  173. for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
  174. {
  175. m_layers[iter].Forward(forwardContext->m_layerData[iter], *lastLayerOutput);
  176. lastLayerOutput = &forwardContext->m_layerData[iter].m_output;
  177. }
  178. return lastLayerOutput;
  179. }
  180. void MultilayerPerceptron::Reverse(ITrainingContextPtr context, LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected)
  181. {
  182. MlpTrainingContext* reverseContext = static_cast<MlpTrainingContext*>(context);
  183. MlpInferenceContext* forwardContext = &reverseContext->m_forward;
  184. reverseContext->m_layerData.resize(m_layers.size());
  185. forwardContext->m_layerData.resize(m_layers.size());
  186. ++reverseContext->m_trainingSampleSize;
  187. // First feed-forward the activations to get our current model predictions
  188. // We do additional book-keeping over a standard forward pass to make gradient calculations easier
  189. const AZ::VectorN* lastLayerOutput = &activations;
  190. for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
  191. {
  192. reverseContext->m_layerData[iter].m_lastInput = lastLayerOutput;
  193. m_layers[iter].Forward(forwardContext->m_layerData[iter], *lastLayerOutput);
  194. lastLayerOutput = &forwardContext->m_layerData[iter].m_output;
  195. }
  196. // Compute the partial derivatives of the loss function with respect to the final layer output
  197. AZ::VectorN costGradients;
  198. ComputeLoss_Derivative(lossFunction, *lastLayerOutput, expected, costGradients);
  199. AZ::VectorN* lossGradient = &costGradients;
  200. for (int64_t iter = static_cast<int64_t>(m_layers.size()) - 1; iter >= 0; --iter)
  201. {
  202. m_layers[iter].AccumulateGradients(reverseContext->m_trainingSampleSize, reverseContext->m_layerData[iter], forwardContext->m_layerData[iter], *lossGradient);
  203. lossGradient = &reverseContext->m_layerData[iter].m_backpropagationGradients;
  204. }
  205. }
  206. void MultilayerPerceptron::GradientDescent(ITrainingContextPtr context, float learningRate)
  207. {
  208. MlpTrainingContext* reverseContext = static_cast<MlpTrainingContext*>(context);
  209. if (reverseContext->m_trainingSampleSize > 0)
  210. {
  211. for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
  212. {
  213. m_layers[iter].ApplyGradients(reverseContext->m_layerData[iter], learningRate);
  214. }
  215. }
  216. reverseContext->m_trainingSampleSize = 0;
  217. }
  218. void MultilayerPerceptron::OnActivationCountChanged()
  219. {
  220. AZStd::size_t lastLayerDimensionality = m_activationCount;
  221. for (Layer& layer : m_layers)
  222. {
  223. layer.m_inputSize = lastLayerDimensionality;
  224. layer.OnSizesChanged();
  225. lastLayerDimensionality = layer.m_outputSize;
  226. }
  227. }
  228. bool MultilayerPerceptron::LoadModel()
  229. {
  230. if (m_proxy)
  231. {
  232. return m_proxy->LoadAsset();
  233. }
  234. return false;
  235. }
  236. bool MultilayerPerceptron::SaveModel()
  237. {
  238. if (m_proxy)
  239. {
  240. return m_proxy->SaveAsset();
  241. }
  242. return false;
  243. }
  244. void MultilayerPerceptron::AddLayer(AZStd::size_t layerDimensionality, ActivationFunctions activationFunction)
  245. {
  246. // This is not thread safe, this should only be used during model configuration
  247. const AZStd::size_t lastLayerDimensionality = GetOutputDimensionality();
  248. m_layers.push_back(AZStd::move(Layer(activationFunction, lastLayerDimensionality, layerDimensionality)));
  249. }
  250. Layer* MultilayerPerceptron::GetLayer(AZStd::size_t layerIndex)
  251. {
  252. // This is not thread safe, this method should only be used by unit testing to inspect layer weights and biases for correctness
  253. return &m_layers[layerIndex];
  254. }
  255. }