MultilayerPerceptron.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Models/MultilayerPerceptron.h>
  9. #include <Algorithms/LossFunctions.h>
  10. #include <AzCore/RTTI/RTTI.h>
  11. #include <AzCore/RTTI/BehaviorContext.h>
  12. #include <AzCore/Serialization/EditContext.h>
  13. #include <AzCore/Serialization/SerializeContext.h>
  14. #include <AzCore/IO/FileIO.h>
  15. #include <AzCore/IO/FileReader.h>
  16. #include <AzCore/IO/Path/Path.h>
  17. #include <AzCore/Console/ILogger.h>
  18. #include <AzNetworking/Serialization/NetworkInputSerializer.h>
  19. #include <AzNetworking/Serialization/NetworkOutputSerializer.h>
  20. namespace MachineLearning
  21. {
  22. void MultilayerPerceptron::Reflect(AZ::ReflectContext* context)
  23. {
  24. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  25. {
  26. serializeContext->Class<MultilayerPerceptron>()
  27. ->Version(1)
  28. ->Field("Name", &MultilayerPerceptron::m_name)
  29. ->Field("ModelFile", &MultilayerPerceptron::m_modelFile)
  30. ->Field("TestDataFile", &MultilayerPerceptron::m_testDataFile)
  31. ->Field("TestLabelFile", &MultilayerPerceptron::m_testLabelFile)
  32. ->Field("TrainDataFile", &MultilayerPerceptron::m_trainDataFile)
  33. ->Field("TrainLabelFile", &MultilayerPerceptron::m_trainLabelFile)
  34. ->Field("ActivationCount", &MultilayerPerceptron::m_activationCount)
  35. ->Field("Layers", &MultilayerPerceptron::m_layers)
  36. ;
  37. if (AZ::EditContext* editContext = serializeContext->GetEditContext())
  38. {
  39. editContext->Class<MultilayerPerceptron>("A basic multilayer perceptron class", "")
  40. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  41. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_name, "Name", "The name for this model")
  42. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_modelFile, "ModelFile", "The file this model is saved to and loaded from")
  43. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testDataFile, "TestDataFile", "The file test data should be loaded from")
  44. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testLabelFile, "TestLabelFile", "The file test labels should be loaded from")
  45. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_trainDataFile, "TrainDataFile", "The file training data should be loaded from")
  46. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_trainLabelFile, "TrainLabelFile", "The file training labels should be loaded from")
  47. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_activationCount, "Activation Count", "The number of neurons in the activation layer")
  48. ->Attribute(AZ::Edit::Attributes::ChangeNotify, &MultilayerPerceptron::OnActivationCountChanged)
  49. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_layers, "Layers", "The layers of the neural network")
  50. ->Attribute(AZ::Edit::Attributes::ChangeNotify, &MultilayerPerceptron::OnActivationCountChanged)
  51. ;
  52. }
  53. }
  54. auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
  55. if (behaviorContext)
  56. {
  57. behaviorContext->Class<MultilayerPerceptron>("Multilayer perceptron")->
  58. Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
  59. Attribute(AZ::Script::Attributes::Module, "machineLearning")->
  60. Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
  61. Constructor<AZStd::size_t>()->
  62. Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
  63. Method("GetName", &MultilayerPerceptron::GetName)->
  64. Method("GetLayerCount", &MultilayerPerceptron::GetLayerCount)->
  65. Property("ActivationCount", BehaviorValueProperty(&MultilayerPerceptron::m_activationCount))->
  66. Property("Layers", BehaviorValueProperty(&MultilayerPerceptron::m_layers))
  67. ;
  68. }
  69. }
  70. MultilayerPerceptron::MultilayerPerceptron()
  71. {
  72. }
  73. MultilayerPerceptron::MultilayerPerceptron(const MultilayerPerceptron& rhs)
  74. : m_name(rhs.m_name)
  75. , m_modelFile(rhs.m_modelFile)
  76. , m_testDataFile(rhs.m_testDataFile)
  77. , m_testLabelFile(rhs.m_testLabelFile)
  78. , m_trainDataFile(rhs.m_trainDataFile)
  79. , m_trainLabelFile(rhs.m_trainLabelFile)
  80. , m_activationCount(rhs.m_activationCount)
  81. , m_layers(rhs.m_layers)
  82. {
  83. }
  84. MultilayerPerceptron::MultilayerPerceptron(AZStd::size_t activationCount)
  85. : m_activationCount(activationCount)
  86. {
  87. }
  88. MultilayerPerceptron::~MultilayerPerceptron()
  89. {
  90. }
  91. MultilayerPerceptron& MultilayerPerceptron::operator=(const MultilayerPerceptron& rhs)
  92. {
  93. m_name = rhs.m_name;
  94. m_modelFile = rhs.m_modelFile;
  95. m_testDataFile = rhs.m_testDataFile;
  96. m_testLabelFile = rhs.m_testLabelFile;
  97. m_trainDataFile = rhs.m_trainDataFile;
  98. m_trainLabelFile = rhs.m_trainLabelFile;
  99. m_activationCount = rhs.m_activationCount;
  100. m_layers = rhs.m_layers;
  101. return *this;
  102. }
  103. AZStd::string MultilayerPerceptron::GetName() const
  104. {
  105. return m_name;
  106. }
  107. AZStd::string MultilayerPerceptron::GetAssetFile(AssetTypes assetType) const
  108. {
  109. switch (assetType)
  110. {
  111. case AssetTypes::Model:
  112. return m_modelFile;
  113. case AssetTypes::TestData:
  114. return m_testDataFile;
  115. case AssetTypes::TestLabels:
  116. return m_testLabelFile;
  117. case AssetTypes::TrainingData:
  118. return m_trainDataFile;
  119. case AssetTypes::TrainingLabels:
  120. return m_trainLabelFile;
  121. }
  122. return "";
  123. }
  124. AZStd::size_t MultilayerPerceptron::GetInputDimensionality() const
  125. {
  126. return m_activationCount;
  127. }
  128. AZStd::size_t MultilayerPerceptron::GetOutputDimensionality() const
  129. {
  130. //AZStd::lock_guard lock(m_mutex);
  131. if (!m_layers.empty())
  132. {
  133. return m_layers.back().m_biases.GetDimensionality();
  134. }
  135. return m_activationCount;
  136. }
  137. AZStd::size_t MultilayerPerceptron::GetLayerCount() const
  138. {
  139. //AZStd::lock_guard lock(m_mutex);
  140. return m_layers.size();
  141. }
  142. AZStd::size_t MultilayerPerceptron::GetParameterCount() const
  143. {
  144. //AZStd::lock_guard lock(m_mutex);
  145. AZStd::size_t parameterCount = 0;
  146. for (const Layer& layer : m_layers)
  147. {
  148. parameterCount += layer.m_inputSize * layer.m_outputSize + layer.m_outputSize;
  149. }
  150. return parameterCount;
  151. }
  152. IInferenceContextPtr MultilayerPerceptron::CreateInferenceContext()
  153. {
  154. return new MlpInferenceContext();
  155. }
  156. ITrainingContextPtr MultilayerPerceptron::CreateTrainingContext()
  157. {
  158. return new MlpTrainingContext();
  159. }
  160. const AZ::VectorN* MultilayerPerceptron::Forward(IInferenceContextPtr context, const AZ::VectorN& activations)
  161. {
  162. //AZStd::lock_guard lock(m_mutex);
  163. MlpInferenceContext* forwardContext = static_cast<MlpInferenceContext*>(context);
  164. forwardContext->m_layerData.resize(m_layers.size());
  165. const AZ::VectorN* lastLayerOutput = &activations;
  166. for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
  167. {
  168. m_layers[iter].Forward(forwardContext->m_layerData[iter], *lastLayerOutput);
  169. lastLayerOutput = &forwardContext->m_layerData[iter].m_output;
  170. }
  171. return lastLayerOutput;
  172. }
  173. void MultilayerPerceptron::Reverse(ITrainingContextPtr context, LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected)
  174. {
  175. //AZStd::lock_guard lock(m_mutex);
  176. MlpTrainingContext* reverseContext = static_cast<MlpTrainingContext*>(context);
  177. MlpInferenceContext* forwardContext = &reverseContext->m_forward;
  178. reverseContext->m_layerData.resize(m_layers.size());
  179. forwardContext->m_layerData.resize(m_layers.size());
  180. ++reverseContext->m_trainingSampleSize;
  181. // First feed-forward the activations to get our current model predictions
  182. // We do additional book-keeping over a standard forward pass to make gradient calculations easier
  183. const AZ::VectorN* lastLayerOutput = &activations;
  184. for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
  185. {
  186. reverseContext->m_layerData[iter].m_lastInput = lastLayerOutput;
  187. m_layers[iter].Forward(forwardContext->m_layerData[iter], *lastLayerOutput);
  188. lastLayerOutput = &forwardContext->m_layerData[iter].m_output;
  189. }
  190. // Compute the partial derivatives of the loss function with respect to the final layer output
  191. AZ::VectorN costGradients;
  192. ComputeLoss_Derivative(lossFunction, *lastLayerOutput, expected, costGradients);
  193. AZ::VectorN* lossGradient = &costGradients;
  194. for (int64_t iter = static_cast<int64_t>(m_layers.size()) - 1; iter >= 0; --iter)
  195. {
  196. m_layers[iter].AccumulateGradients(reverseContext->m_layerData[iter], forwardContext->m_layerData[iter], *lossGradient);
  197. lossGradient = &reverseContext->m_layerData[iter].m_backpropagationGradients;
  198. }
  199. }
  200. void MultilayerPerceptron::GradientDescent(ITrainingContextPtr context, float learningRate)
  201. {
  202. //AZStd::lock_guard lock(m_mutex);
  203. MlpTrainingContext* reverseContext = static_cast<MlpTrainingContext*>(context);
  204. if (reverseContext->m_trainingSampleSize > 0)
  205. {
  206. const float adjustedLearningRate = learningRate / static_cast<float>(reverseContext->m_trainingSampleSize);
  207. for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
  208. {
  209. m_layers[iter].ApplyGradients(reverseContext->m_layerData[iter], adjustedLearningRate);
  210. }
  211. }
  212. reverseContext->m_trainingSampleSize = 0;
  213. }
  214. void MultilayerPerceptron::OnActivationCountChanged()
  215. {
  216. //AZStd::lock_guard lock(m_mutex);
  217. AZStd::size_t lastLayerDimensionality = m_activationCount;
  218. for (Layer& layer : m_layers)
  219. {
  220. layer.m_inputSize = lastLayerDimensionality;
  221. layer.OnSizesChanged();
  222. lastLayerDimensionality = layer.m_outputSize;
  223. }
  224. }
  225. bool MultilayerPerceptron::LoadModel()
  226. {
  227. AZ::IO::SystemFile modelFile;
  228. AZ::IO::FixedMaxPath filePathFixed = m_modelFile.c_str();
  229. if (AZ::IO::FileIOBase* fileIOBase = AZ::IO::FileIOBase::GetInstance())
  230. {
  231. fileIOBase->ResolvePath(filePathFixed, m_modelFile.c_str());
  232. }
  233. if (!modelFile.Open(filePathFixed.c_str(), AZ::IO::SystemFile::SF_OPEN_READ_ONLY))
  234. {
  235. AZLOG_ERROR("Failed to load '%s'. File could not be opened.", filePathFixed.c_str());
  236. return false;
  237. }
  238. const AZ::IO::SizeType length = modelFile.Length();
  239. if (length == 0)
  240. {
  241. AZLOG_ERROR("Failed to load '%s'. File is empty.", filePathFixed.c_str());
  242. return false;
  243. }
  244. AZStd::vector<uint8_t> serializeBuffer;
  245. serializeBuffer.resize(length);
  246. modelFile.Seek(0, AZ::IO::SystemFile::SF_SEEK_BEGIN);
  247. modelFile.Read(serializeBuffer.size(), serializeBuffer.data());
  248. AzNetworking::NetworkOutputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
  249. return Serialize(serializer);
  250. }
  251. bool MultilayerPerceptron::SaveModel()
  252. {
  253. AZ::IO::SystemFile modelFile;
  254. AZ::IO::FixedMaxPath filePathFixed = m_modelFile.c_str();
  255. if (AZ::IO::FileIOBase* fileIOBase = AZ::IO::FileIOBase::GetInstance())
  256. {
  257. fileIOBase->ResolvePath(filePathFixed, m_modelFile.c_str());
  258. }
  259. if (!modelFile.Open(filePathFixed.c_str(), AZ::IO::SystemFile::SF_OPEN_READ_WRITE | AZ::IO::SystemFile::SF_OPEN_CREATE))
  260. {
  261. AZLOG_ERROR("Failed to save to '%s'. File could not be opened for writing.", filePathFixed.c_str());
  262. return false;
  263. }
  264. modelFile.Seek(0, AZ::IO::SystemFile::SF_SEEK_BEGIN);
  265. AZStd::vector<uint8_t> serializeBuffer;
  266. serializeBuffer.resize(EstimateSerializeSize());
  267. AzNetworking::NetworkInputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
  268. if (Serialize(serializer))
  269. {
  270. modelFile.Write(serializeBuffer.data(), serializeBuffer.size());
  271. return true;
  272. }
  273. return false;
  274. }
  275. void MultilayerPerceptron::AddLayer(AZStd::size_t layerDimensionality, ActivationFunctions activationFunction)
  276. {
  277. // This is not thread safe, this should only be used during model configuration
  278. const AZStd::size_t lastLayerDimensionality = GetOutputDimensionality();
  279. m_layers.push_back(AZStd::move(Layer(activationFunction, lastLayerDimensionality, layerDimensionality)));
  280. }
  281. Layer* MultilayerPerceptron::GetLayer(AZStd::size_t layerIndex)
  282. {
  283. // This is not thread safe, this method should only be used by unit testing to inspect layer weights and biases for correctness
  284. return &m_layers[layerIndex];
  285. }
  286. bool MultilayerPerceptron::Serialize(AzNetworking::ISerializer& serializer)
  287. {
  288. //AZStd::lock_guard lock(m_mutex);
  289. return serializer.Serialize(m_name, "Name")
  290. && serializer.Serialize(m_activationCount, "activationCount")
  291. && serializer.Serialize(m_layers, "layers");
  292. }
  293. AZStd::size_t MultilayerPerceptron::EstimateSerializeSize() const
  294. {
  295. const AZStd::size_t padding = 64; // 64 bytes of extra padding just in case
  296. AZStd::size_t estimatedSize = padding
  297. + sizeof(AZStd::size_t)
  298. + m_name.size()
  299. + sizeof(m_activationCount)
  300. + sizeof(AZStd::size_t);
  301. //AZStd::lock_guard lock(m_mutex);
  302. for (const Layer& layer : m_layers)
  303. {
  304. estimatedSize += layer.EstimateSerializeSize();
  305. }
  306. return estimatedSize;
  307. }
  308. }