Pārlūkot izejas kodu

Initial attempt at multithreaded training, still requires cleanup, plus numerous other changes and fixes. ImGui driven backprop is not converging correctly, requires more debugging.

Signed-off-by: kberg-amzn <[email protected]>
kberg-amzn 2 gadi atpakaļ
vecāks
revīzija
2ce44c0843
50 mainītis faili ar 1605 papildinājumiem un 320 dzēšanām
  1. 25 4
      Gems/MachineLearning/Code/CMakeLists.txt
  2. 23 0
      Gems/MachineLearning/Code/Include/MachineLearning/IInferenceContext.h
  3. 4 0
      Gems/MachineLearning/Code/Include/MachineLearning/ILabeledTrainingData.h
  4. 51 0
      Gems/MachineLearning/Code/Include/MachineLearning/IMachineLearning.h
  5. 29 8
      Gems/MachineLearning/Code/Include/MachineLearning/INeuralNetwork.h
  6. 22 0
      Gems/MachineLearning/Code/Include/MachineLearning/ITrainingContext.h
  7. 0 41
      Gems/MachineLearning/Code/Include/MachineLearning/MachineLearningBus.h
  8. 1 1
      Gems/MachineLearning/Code/Include/MachineLearning/MachineLearningTypeIds.h
  9. 8 0
      Gems/MachineLearning/Code/Include/MachineLearning/Types.h
  10. 20 2
      Gems/MachineLearning/Code/Source/Algorithms/Activations.cpp
  11. 4 0
      Gems/MachineLearning/Code/Source/Algorithms/Activations.h
  12. 103 38
      Gems/MachineLearning/Code/Source/Algorithms/Training.cpp
  13. 57 15
      Gems/MachineLearning/Code/Source/Algorithms/Training.h
  14. 7 4
      Gems/MachineLearning/Code/Source/Assets/MnistDataLoader.cpp
  15. 1 0
      Gems/MachineLearning/Code/Source/Assets/MnistDataLoader.h
  16. 12 1
      Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.cpp
  17. 3 0
      Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.h
  18. 31 0
      Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugModule.cpp
  19. 27 0
      Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugModule.h
  20. 148 0
      Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugSystemComponent.cpp
  21. 62 0
      Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugSystemComponent.h
  22. 262 0
      Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.cpp
  23. 65 0
      Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.h
  24. 17 6
      Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.cpp
  25. 17 10
      Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.h
  26. 63 26
      Gems/MachineLearning/Code/Source/Models/Layer.cpp
  27. 27 5
      Gems/MachineLearning/Code/Source/Models/Layer.h
  28. 229 30
      Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.cpp
  29. 60 13
      Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.h
  30. 0 18
      Gems/MachineLearning/Code/Source/Nodes/AccumulateTrainingGradients.ScriptCanvasNodeable.xml
  31. 0 19
      Gems/MachineLearning/Code/Source/Nodes/AccumulateTrainingGradients.cpp
  32. 15 0
      Gems/MachineLearning/Code/Source/Nodes/ArgMax.ScriptCanvasNodeable.xml
  33. 18 0
      Gems/MachineLearning/Code/Source/Nodes/ArgMax.cpp
  34. 3 3
      Gems/MachineLearning/Code/Source/Nodes/ArgMax.h
  35. 3 1
      Gems/MachineLearning/Code/Source/Nodes/ComputeCost.cpp
  36. 3 1
      Gems/MachineLearning/Code/Source/Nodes/FeedForward.cpp
  37. 0 16
      Gems/MachineLearning/Code/Source/Nodes/GradientDescent.ScriptCanvasNodeable.xml
  38. 15 0
      Gems/MachineLearning/Code/Source/Nodes/LoadModel.ScriptCanvasNodeable.xml
  39. 3 4
      Gems/MachineLearning/Code/Source/Nodes/LoadModel.cpp
  40. 3 3
      Gems/MachineLearning/Code/Source/Nodes/LoadModel.h
  41. 15 0
      Gems/MachineLearning/Code/Source/Nodes/SaveModel.ScriptCanvasNodeable.xml
  42. 18 0
      Gems/MachineLearning/Code/Source/Nodes/SaveModel.cpp
  43. 24 0
      Gems/MachineLearning/Code/Source/Nodes/SaveModel.h
  44. 11 1
      Gems/MachineLearning/Code/Source/Nodes/SupervisedLearning.cpp
  45. 23 0
      Gems/MachineLearning/Code/Tests/Algorithms/ActivationTests.cpp
  46. 7 7
      Gems/MachineLearning/Code/Tests/Models/LayerTests.cpp
  47. 38 36
      Gems/MachineLearning/Code/Tests/Models/MultilayerPerceptronTests.cpp
  48. 3 1
      Gems/MachineLearning/Code/machinelearning_api_files.cmake
  49. 16 0
      Gems/MachineLearning/Code/machinelearning_debug_files.cmake
  50. 9 6
      Gems/MachineLearning/Code/machinelearning_private_files.cmake

+ 25 - 4
Gems/MachineLearning/Code/CMakeLists.txt

@@ -85,12 +85,33 @@ ly_add_target(
             Gem::${gem_name}.Private.Object
             Gem::${gem_name}.Private.Object
 )
 )
 
 
+ly_add_target(
+    NAME ${gem_name}.Debug ${PAL_TRAIT_MONOLITHIC_DRIVEN_MODULE_TYPE}
+    NAMESPACE Gem
+    FILES_CMAKE
+        machinelearning_debug_files.cmake
+    INCLUDE_DIRECTORIES
+        PRIVATE
+            Source
+            .
+        PUBLIC
+            Include
+    BUILD_DEPENDENCIES
+        PUBLIC
+            Gem::${gem_name}.API
+        PRIVATE
+            Gem::${gem_name}.Private.Object
+            AZ::AtomCore
+            Gem::Atom_Feature_Common.Static
+            Gem::ImGui.Static
+)
+
 # By default, we will specify that the above target ${gem_name} would be used by
 # By default, we will specify that the above target ${gem_name} would be used by
 # Client and Server type targets when this gem is enabled.  If you don't want it
 # Client and Server type targets when this gem is enabled.  If you don't want it
 # active in Clients or Servers by default, delete one of both of the following lines:
 # active in Clients or Servers by default, delete one of both of the following lines:
-ly_create_alias(NAME ${gem_name}.Unified NAMESPACE Gem TARGETS Gem::${gem_name} Gem::ScriptCanvas)
-ly_create_alias(NAME ${gem_name}.Clients NAMESPACE Gem TARGETS Gem::${gem_name} Gem::ScriptCanvas)
-ly_create_alias(NAME ${gem_name}.Servers NAMESPACE Gem TARGETS Gem::${gem_name} Gem::ScriptCanvas)
+ly_create_alias(NAME ${gem_name}.Unified NAMESPACE Gem TARGETS Gem::${gem_name} Gem::${gem_name}.Debug Gem::ScriptCanvas)
+ly_create_alias(NAME ${gem_name}.Clients NAMESPACE Gem TARGETS Gem::${gem_name} Gem::${gem_name}.Debug Gem::ScriptCanvas)
+ly_create_alias(NAME ${gem_name}.Servers NAMESPACE Gem TARGETS Gem::${gem_name} Gem::${gem_name}.Debug Gem::ScriptCanvas)
 
 
 # For the Client and Server variants of ${gem_name} Gem, an alias to the ${gem_name}.API target will be made
 # For the Client and Server variants of ${gem_name} Gem, an alias to the ${gem_name}.API target will be made
 ly_create_alias(NAME ${gem_name}.Clients.API NAMESPACE Gem TARGETS Gem::${gem_name}.API)
 ly_create_alias(NAME ${gem_name}.Clients.API NAMESPACE Gem TARGETS Gem::${gem_name}.API)
@@ -159,7 +180,7 @@ if(PAL_TRAIT_BUILD_HOST_TOOLS)
     # By default, we will specify that the above target ${gem_name} would be used by
     # By default, we will specify that the above target ${gem_name} would be used by
     # Tool and Builder type targets when this gem is enabled.  If you don't want it
     # Tool and Builder type targets when this gem is enabled.  If you don't want it
     # active in Tools or Builders by default, delete one of both of the following lines:
     # active in Tools or Builders by default, delete one of both of the following lines:
-    ly_create_alias(NAME ${gem_name}.Tools NAMESPACE Gem TARGETS Gem::${gem_name} Gem::ScriptCanvas.Editor)
+    ly_create_alias(NAME ${gem_name}.Tools NAMESPACE Gem TARGETS Gem::${gem_name} Gem::${gem_name}.Debug Gem::ScriptCanvas.Editor)
     ly_create_alias(NAME ${gem_name}.Builders NAMESPACE Gem TARGETS Gem::${gem_name} Gem::ScriptCanvas.Editor)
     ly_create_alias(NAME ${gem_name}.Builders NAMESPACE Gem TARGETS Gem::${gem_name} Gem::ScriptCanvas.Editor)
 
 
     # For the Tools and Builders variants of ${gem_name} Gem, an alias to the ${gem_name}.Editor API target will be made
     # For the Tools and Builders variants of ${gem_name} Gem, an alias to the ${gem_name}.Editor API target will be made

+ 23 - 0
Gems/MachineLearning/Code/Include/MachineLearning/IInferenceContext.h

@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <MachineLearning/Types.h>
+
+namespace MachineLearning
+{
+    //! This is a light-weight context suitable only for making inferences.
+    //! It will not work for training purposes
+    struct IInferenceContext
+    {
+        virtual ~IInferenceContext() = default;
+    };
+
+    using IInferenceContextPtr = IInferenceContext*;
+}

+ 4 - 0
Gems/MachineLearning/Code/Include/MachineLearning/ILabeledTrainingData.h

@@ -28,6 +28,10 @@ namespace MachineLearning
         //! Returns the index-th label in the training data set.
         //! Returns the index-th label in the training data set.
         virtual const AZ::VectorN& GetLabelByIndex(AZStd::size_t index) = 0;
         virtual const AZ::VectorN& GetLabelByIndex(AZStd::size_t index) = 0;
 
 
+        //! Returns the index-th label in the training data set as a value rather than a vector.
+        //! Note that not all data-sets may support this operation
+        virtual AZStd::size_t GetLabelAsValueByIndex(AZStd::size_t index) = 0;
+
         //! Returns the index-th set of activations in the training data set.
         //! Returns the index-th set of activations in the training data set.
         virtual const AZ::VectorN& GetDataByIndex(AZStd::size_t index) = 0;
         virtual const AZ::VectorN& GetDataByIndex(AZStd::size_t index) = 0;
     };
     };

+ 51 - 0
Gems/MachineLearning/Code/Include/MachineLearning/IMachineLearning.h

@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <MachineLearning/MachineLearningTypeIds.h>
+#include <MachineLearning/INeuralNetwork.h>
+#include <AzCore/Interface/Interface.h>
+#include <AzCore/EBus/EBus.h>
+
+namespace MachineLearning
+{
+    using ModelSet = AZStd::set<INeuralNetworkPtr>;
+
+    class IMachineLearning
+    {
+    public:
+
+        AZ_RTTI(IMachineLearning, IMachineLearningTypeId);
+
+        virtual ~IMachineLearning() = default;
+
+        //! Registers a model with the machine learning interface.
+        virtual void RegisterModel(INeuralNetworkPtr model) = 0;
+
+        //! Removes a model from the machine learning interface.
+        virtual void UnregisterModel(INeuralNetworkPtr model) = 0;
+
+        //! Retrieves the full set of registered models from the machine learning interface.
+        virtual ModelSet& GetModelSet() = 0;
+    };
+
+    class IMachineLearningBusTraits
+        : public AZ::EBusTraits
+    {
+    public:
+        //! EBusTraits overrides
+        //! @{
+        static constexpr AZ::EBusHandlerPolicy HandlerPolicy = AZ::EBusHandlerPolicy::Single;
+        static constexpr AZ::EBusAddressPolicy AddressPolicy = AZ::EBusAddressPolicy::Single;
+        //! @}
+    };
+
+    using MachineLearningRequestBus = AZ::EBus<IMachineLearning, IMachineLearningBusTraits>;
+    using MachineLearningInterface = AZ::Interface<IMachineLearning>;
+}

+ 29 - 8
Gems/MachineLearning/Code/Include/MachineLearning/INeuralNetwork.h

@@ -10,7 +10,10 @@
 
 
 #include <AzCore/Math/VectorN.h>
 #include <AzCore/Math/VectorN.h>
 #include <MachineLearning/Types.h>
 #include <MachineLearning/Types.h>
+#include <MachineLearning/IInferenceContext.h>
+#include <MachineLearning/ITrainingContext.h>
 #include <AzCore/EBus/EBus.h>
 #include <AzCore/EBus/EBus.h>
+#include <AzCore/std/string/string.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
@@ -30,26 +33,44 @@ namespace MachineLearning
         INeuralNetwork& operator=(INeuralNetwork&&) = default;
         INeuralNetwork& operator=(INeuralNetwork&&) = default;
         INeuralNetwork& operator=(const INeuralNetwork&) = default;
         INeuralNetwork& operator=(const INeuralNetwork&) = default;
 
 
-        //! Adds a new layer to the network.
-        virtual void AddLayer([[maybe_unused]] AZStd::size_t layerDimensionality, [[maybe_unused]] ActivationFunctions activationFunction = ActivationFunctions::ReLU) {}
+        //! Returns a name for the model.
+        virtual AZStd::string GetName() const { return ""; }
+
+        //! Returns the file where model parameters are stored.
+        virtual AZStd::string GetAssetFile([[maybe_unused]] AssetTypes assetType) const { return ""; }
+
+        //! Returns the number of input neurons the model supports.
+        virtual AZStd::size_t GetInputDimensionality() const { return 0; }
+
+        //! Returns the number of output neurons the model supports.
+        virtual AZStd::size_t GetOutputDimensionality() const { return 0; }
 
 
         //! Returns the total number of layers in the network.
         //! Returns the total number of layers in the network.
         virtual AZStd::size_t GetLayerCount() const { return 0; }
         virtual AZStd::size_t GetLayerCount() const { return 0; }
 
 
-        //! Retrieves a specific layer from the network indexed by the layerIndex.
-        virtual Layer* GetLayer([[maybe_unused]] AZStd::size_t layerIndex) { return nullptr; }
-
         //! Returns the total number of parameters in the neural network.
         //! Returns the total number of parameters in the neural network.
         virtual AZStd::size_t GetParameterCount() const { return 0; }
         virtual AZStd::size_t GetParameterCount() const { return 0; }
 
 
+        // Returns a new inference context suitable for forward propagation operations.
+        virtual IInferenceContextPtr CreateInferenceContext() { return nullptr; }
+
+        // Returns a new training context suitable for back-propagation and gradient descent.
+        virtual ITrainingContextPtr CreateTrainingContext() { return nullptr; }
+
         //! Performs a basic feed-forward operation to compute the output from a set of activation values.
         //! Performs a basic feed-forward operation to compute the output from a set of activation values.
-        virtual const AZ::VectorN* Forward([[maybe_unused]] const AZ::VectorN& activations) { return nullptr; }
+        virtual const AZ::VectorN* Forward([[maybe_unused]] IInferenceContextPtr context, [[maybe_unused]] const AZ::VectorN& activations) { return nullptr; }
 
 
         //! Accumulates the loss gradients given a loss function, an activation vector and a corresponding label vector.
         //! Accumulates the loss gradients given a loss function, an activation vector and a corresponding label vector.
-        virtual void Reverse([[maybe_unused]] LossFunctions lossFunction, [[maybe_unused]] const AZ::VectorN& activations, [[maybe_unused]] const AZ::VectorN& expected) {}
+        virtual void Reverse([[maybe_unused]] ITrainingContextPtr context, [[maybe_unused]] LossFunctions lossFunction, [[maybe_unused]] const AZ::VectorN& activations, [[maybe_unused]] const AZ::VectorN& expected) {}
 
 
         //! Performs a gradient descent step and resets all gradient accumulators to zero.
         //! Performs a gradient descent step and resets all gradient accumulators to zero.
-        virtual void GradientDescent([[maybe_unused]] float learningRate) {}
+        virtual void GradientDescent([[maybe_unused]] ITrainingContextPtr context, [[maybe_unused]] float learningRate) {}
+
+        //! Loads the current model parameters from the associated asset file.
+        virtual bool LoadModel() { return false; }
+
+        //! Saves the current model parameters to the associated asset file.
+        virtual bool SaveModel() { return false; }
 
 
         //! For intrusive_ptr support
         //! For intrusive_ptr support
         //! @{
         //! @{

+ 22 - 0
Gems/MachineLearning/Code/Include/MachineLearning/ITrainingContext.h

@@ -0,0 +1,22 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <MachineLearning/Types.h>
+
+namespace MachineLearning
+{
+    //! This is a heavier weight context suitable for backpropagation and training of models.
+    struct ITrainingContext
+    {
+        virtual ~ITrainingContext() = default;
+    };
+
+    using ITrainingContextPtr = ITrainingContext*;
+}

+ 0 - 41
Gems/MachineLearning/Code/Include/MachineLearning/MachineLearningBus.h

@@ -1,41 +0,0 @@
-/*
- * Copyright (c) Contributors to the Open 3D Engine Project.
- * For complete copyright and license terms please see the LICENSE at the root of this distribution.
- *
- * SPDX-License-Identifier: Apache-2.0 OR MIT
- *
- */
-
-#pragma once
-
-#include <MachineLearning/MachineLearningTypeIds.h>
-#include <MachineLearning/INeuralNetwork.h>
-
-#include <AzCore/EBus/EBus.h>
-#include <AzCore/Interface/Interface.h>
-
-namespace MachineLearning
-{
-    class MachineLearningRequests
-    {
-    public:
-        AZ_RTTI(MachineLearningRequests, MachineLearningRequestsTypeId);
-        virtual ~MachineLearningRequests() = default;
-        // Put your public methods here
-    };
-
-    class MachineLearningBusTraits
-        : public AZ::EBusTraits
-    {
-    public:
-        //////////////////////////////////////////////////////////////////////////
-        // EBusTraits overrides
-        static constexpr AZ::EBusHandlerPolicy HandlerPolicy = AZ::EBusHandlerPolicy::Single;
-        static constexpr AZ::EBusAddressPolicy AddressPolicy = AZ::EBusAddressPolicy::Single;
-        //////////////////////////////////////////////////////////////////////////
-    };
-
-    using MachineLearningRequestBus = AZ::EBus<MachineLearningRequests, MachineLearningBusTraits>;
-    using MachineLearningInterface = AZ::Interface<MachineLearningRequests>;
-
-} // namespace MachineLearning

+ 1 - 1
Gems/MachineLearning/Code/Include/MachineLearning/MachineLearningTypeIds.h

@@ -22,5 +22,5 @@ namespace MachineLearning
     inline constexpr const char* MachineLearningEditorModuleTypeId = MachineLearningModuleTypeId;
     inline constexpr const char* MachineLearningEditorModuleTypeId = MachineLearningModuleTypeId;
 
 
     // Interface TypeIds
     // Interface TypeIds
-    inline constexpr const char* MachineLearningRequestsTypeId = "{B65151FE-3588-432A-A0EE-1DB5BF5147CA}";
+    inline constexpr const char* IMachineLearningTypeId = "{B65151FE-3588-432A-A0EE-1DB5BF5147CA}";
 } // namespace MachineLearning
 } // namespace MachineLearning

+ 8 - 0
Gems/MachineLearning/Code/Include/MachineLearning/Types.h

@@ -23,4 +23,12 @@ namespace MachineLearning
         Softmax,
         Softmax,
         Linear
         Linear
     );
     );
+
+    AZ_ENUM_CLASS(AssetTypes,
+        Model,
+        TestData,
+        TestLabels,
+        TrainingData, 
+        TrainingLabels
+    );
 }
 }

+ 20 - 2
Gems/MachineLearning/Code/Source/Algorithms/Activations.cpp

@@ -30,6 +30,22 @@ namespace MachineLearning
         output.SetElement(value, 1.0f);
         output.SetElement(value, 1.0f);
     }
     }
 
 
+    AZStd::size_t ArgMaxDecode(const AZ::VectorN& vector)
+    {
+        const AZStd::size_t numElements = vector.GetDimensionality();
+        float maxValue = 0.0f;
+        AZStd::size_t maxIndex = 0;
+        for (AZStd::size_t iter = 0; iter < numElements; ++iter)
+        {
+            if (vector.GetElement(iter) > maxValue)
+            {
+                maxValue = vector.GetElement(iter);
+                maxIndex = iter;
+            }
+        }
+        return maxIndex;
+    }
+
     void Activate(ActivationFunctions activationFunction, const AZ::VectorN& sourceVector, AZ::VectorN& output)
     void Activate(ActivationFunctions activationFunction, const AZ::VectorN& sourceVector, AZ::VectorN& output)
     {
     {
         output.Resize(sourceVector.GetDimensionality());
         output.Resize(sourceVector.GetDimensionality());
@@ -69,13 +85,15 @@ namespace MachineLearning
     {
     {
         const AZ::Vector4 vecZero = AZ::Vector4::CreateZero();
         const AZ::Vector4 vecZero = AZ::Vector4::CreateZero();
         const AZ::Vector4 vecOne = AZ::Vector4::CreateOne();
         const AZ::Vector4 vecOne = AZ::Vector4::CreateOne();
+        const AZ::Vector4 epsilon = AZ::Vector4(AZ::Constants::Tolerance);
         const AZStd::size_t numElements = sourceVector.GetVectorValues().size();
         const AZStd::size_t numElements = sourceVector.GetVectorValues().size();
         output.Resize(sourceVector.GetDimensionality());
         output.Resize(sourceVector.GetDimensionality());
         for (AZStd::size_t iter = 0; iter < numElements; ++iter)
         for (AZStd::size_t iter = 0; iter < numElements; ++iter)
         {
         {
             const AZ::Vector4& sourceElement = sourceVector.GetVectorValues()[iter];
             const AZ::Vector4& sourceElement = sourceVector.GetVectorValues()[iter];
             AZ::Vector4& outputElement = output.GetVectorValues()[iter];
             AZ::Vector4& outputElement = output.GetVectorValues()[iter];
-            outputElement = vecOne / (vecOne + (-sourceElement).GetExpEstimate());
+            const AZ::Vector4 divisor = (vecOne + (-sourceElement).GetExpEstimate()).GetMax(epsilon);
+            outputElement = vecOne / divisor;
             outputElement = outputElement.GetClamp(vecZero, vecOne);
             outputElement = outputElement.GetClamp(vecZero, vecOne);
         }
         }
         output.FixLastVectorElement();
         output.FixLastVectorElement();
@@ -110,7 +128,7 @@ namespace MachineLearning
             partialSum += outputElement;
             partialSum += outputElement;
         }
         }
 
 
-        const float divisor = 1.0f / partialSum.Dot(vecOne);
+        const float divisor = AZ::GetMax(1.0f / partialSum.Dot(vecOne), AZ::Constants::Tolerance);
         for (AZ::Vector4& element : output.GetVectorValues())
         for (AZ::Vector4& element : output.GetVectorValues())
         {
         {
             element = element * divisor;
             element = element * divisor;

+ 4 - 0
Gems/MachineLearning/Code/Source/Algorithms/Activations.h

@@ -9,6 +9,7 @@
 #pragma once
 #pragma once
 
 
 #include <AzCore/Math/VectorN.h>
 #include <AzCore/Math/VectorN.h>
+#include <AzCore/Serialization/EditContext.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <MachineLearning/INeuralNetwork.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
@@ -19,6 +20,9 @@ namespace MachineLearning
     //! One-hot encodes the provided value into the resulting vector output, which will have dimensionality maxValue.
     //! One-hot encodes the provided value into the resulting vector output, which will have dimensionality maxValue.
     void OneHotEncode(AZStd::size_t value, AZStd::size_t maxValue, AZ::VectorN& output);
     void OneHotEncode(AZStd::size_t value, AZStd::size_t maxValue, AZ::VectorN& output);
 
 
+    //! Reverses one-hot encoding, returns the index of the element with the largest value.
+    AZStd::size_t ArgMaxDecode(const AZ::VectorN& vector);
+
     //! Computes the requested activation function applied to all elements of the source vector.
     //! Computes the requested activation function applied to all elements of the source vector.
     void Activate(ActivationFunctions activationFunction, const AZ::VectorN& sourceVector, AZ::VectorN& output);
     void Activate(ActivationFunctions activationFunction, const AZ::VectorN& sourceVector, AZ::VectorN& output);
 
 

+ 103 - 38
Gems/MachineLearning/Code/Source/Algorithms/Training.cpp

@@ -10,72 +10,137 @@
 #include <Algorithms/LossFunctions.h>
 #include <Algorithms/LossFunctions.h>
 #include <AzCore/Math/SimdMath.h>
 #include <AzCore/Math/SimdMath.h>
 #include <AzCore/Console/ILogger.h>
 #include <AzCore/Console/ILogger.h>
+#include <AzCore/Jobs/JobCompletion.h>
+#include <AzCore/Jobs/JobFunction.h>
 #include <numeric>
 #include <numeric>
 #include <random>
 #include <random>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
-    float ComputeCurrentCost(INeuralNetworkPtr Model, ILabeledTrainingDataPtr TestData, LossFunctions CostFunction)
+    SupervisedLearningCycle::SupervisedLearningCycle()
     {
     {
-        const AZStd::size_t totalTestSize = TestData->GetSampleCount();
-
-        double result = 0.0;
-        for (uint32_t iter = 0; iter < totalTestSize; ++iter)
-        {
-            const AZ::VectorN& activations = TestData->GetDataByIndex(iter);
-            const AZ::VectorN& label = TestData->GetLabelByIndex(iter);
-            const AZ::VectorN* output = Model->Forward(activations);
-            result += static_cast<double>(ComputeTotalCost(CostFunction, label, *output));
-        }
-        result /= static_cast<double>(totalTestSize);
-        return static_cast<float>(result);
+        AZ::JobManagerDesc jobDesc;
+        jobDesc.m_jobManagerName = "MachineLearning Training";
+        jobDesc.m_workerThreads.push_back(AZ::JobManagerThreadDesc()); // Just one thread
+        m_trainingJobManager = AZStd::make_unique<AZ::JobManager>(jobDesc);
+        m_trainingjobContext = AZStd::make_unique<AZ::JobContext>(*m_trainingJobManager);
     }
     }
 
 
-    void SupervisedLearningCycle
+    SupervisedLearningCycle::SupervisedLearningCycle
     (
     (
         INeuralNetworkPtr model,
         INeuralNetworkPtr model,
         ILabeledTrainingDataPtr trainingData,
         ILabeledTrainingDataPtr trainingData,
         ILabeledTrainingDataPtr testData,
         ILabeledTrainingDataPtr testData,
-        LossFunctions costFunction, 
+        LossFunctions costFunction,
         AZStd::size_t totalIterations,
         AZStd::size_t totalIterations,
         AZStd::size_t batchSize,
         AZStd::size_t batchSize,
         float learningRate,
         float learningRate,
         float learningRateDecay,
         float learningRateDecay,
         float earlyStopCost
         float earlyStopCost
-    )
+    ) : SupervisedLearningCycle()
+    {
+        m_model = model;
+        m_trainingData = trainingData;
+        m_testData = testData;
+        m_costFunction = costFunction;
+        m_totalIterations = totalIterations;
+        m_batchSize = batchSize;
+        m_learningRate = learningRate;
+        m_learningRateDecay = learningRateDecay;
+        m_earlyStopCost = earlyStopCost;
+    }
+
+    void SupervisedLearningCycle::InitializeContexts()
+    {
+        if (m_inferenceContext == nullptr)
+        {
+            m_inferenceContext.reset(m_model->CreateInferenceContext());
+            m_trainingContext.reset(m_model->CreateTrainingContext());
+        }
+    }
+
+    void SupervisedLearningCycle::StartTraining()
     {
     {
-        const AZStd::size_t totalTrainingSize = trainingData->GetSampleCount();
-        const float initialCost = ComputeCurrentCost(model, testData, costFunction);
-        AZLOG_INFO("Initial model cost prior to training: %f", initialCost);
+        InitializeContexts();
+
+        const AZStd::size_t totalTrainingSize = m_trainingData->GetSampleCount();
 
 
         // Generate a set of training indices that we can later shuffle
         // Generate a set of training indices that we can later shuffle
-        AZStd::vector<AZStd::size_t> indices;
-        indices.resize(totalTrainingSize);
-        std::iota(indices.begin(), indices.end(), 0);
+        m_indices.resize(totalTrainingSize);
+        std::iota(m_indices.begin(), m_indices.end(), 0);
+        std::shuffle(m_indices.begin(), m_indices.end(), std::mt19937(std::random_device{}()));
+
+        // Start training
+        m_currentEpoch = 0;
+        m_trainingComplete = false;
+        m_currentIndex = 0;
+
+        auto job = [this]()
+        {
+            ExecTraining();
+        };
+        AZ::Job* trainingJob = AZ::CreateJobFunction(job, true, m_trainingjobContext.get());
+        trainingJob->Start();
+    }
 
 
-        for (uint32_t epoch = 0; epoch < totalIterations; ++epoch)
+    void SupervisedLearningCycle::ExecTraining()
+    {
+        const AZStd::size_t totalTrainingSize = m_trainingData->GetSampleCount();
+        while (!m_trainingComplete)
         {
         {
-            // We reshuffle the training data indices each epoch to avoid patterns in the training data
-            std::shuffle(indices.begin(), indices.end(), std::mt19937(std::random_device{}()));
-            AZStd::size_t sampleCount = 0;
-            for (uint32_t batch = 0; (batch < batchSize) && (sampleCount < totalTrainingSize); ++batch, ++sampleCount)
+            if (m_currentIndex >= totalTrainingSize)
             {
             {
-                const AZ::VectorN& activations = trainingData->GetDataByIndex(indices[sampleCount]);
-                const AZ::VectorN& label = trainingData->GetLabelByIndex(indices[sampleCount]);
-                model->Reverse(costFunction, activations, label);
+                // If we run out of training samples, we increment our epoch and reset for a new pass of the training data
+                m_currentIndex = 0;
+                m_learningRate *= m_learningRateDecay;
+
+                // We reshuffle the training data indices each epoch to avoid patterns in the training data
+                std::shuffle(m_indices.begin(), m_indices.end(), std::mt19937(std::random_device{}()));
+                ++m_currentEpoch;
+
+                // Generally we want to keep monitoring the models performence on both test and training data
+                // This allows us to detect if we're overfitting the model to the training data
+                float currentTestCost = ComputeCurrentCost(m_testData, m_costFunction);
+                if ((currentTestCost < m_earlyStopCost) || (m_currentEpoch >= m_totalIterations))
+                {
+                    m_trainingComplete = true;
+                    return;
+                }
             }
             }
-            model->GradientDescent(learningRate);
 
 
-            const float currentTestCost = ComputeCurrentCost(model, testData, costFunction);
-            const float currentTrainCost = ComputeCurrentCost(model, trainingData, costFunction);
-            AZLOG_INFO("Epoch %u, Test cost: %f, Train cost: %f, Learning rate: %f", epoch, currentTestCost, currentTrainCost, learningRate);
-            if (currentTestCost < earlyStopCost)
+            for (uint32_t batch = 0; (batch < m_batchSize) && (m_currentIndex < totalTrainingSize); ++batch, ++m_currentIndex)
             {
             {
-                AZLOG_INFO("Early stop threshold reached, exiting training loop: %f, %f", currentTestCost, earlyStopCost);
-                break;
+                const AZ::VectorN& activations = m_trainingData->GetDataByIndex(m_indices[m_currentIndex]);
+                const AZ::VectorN& label = m_trainingData->GetLabelByIndex(m_indices[m_currentIndex]);
+                m_model->Reverse(m_trainingContext.get(), m_costFunction, activations, label);
             }
             }
+            AZStd::lock_guard lock(m_mutex);
+            m_model->GradientDescent(m_trainingContext.get(), m_learningRate);
+        }
+    }
+
+    void SupervisedLearningCycle::StopTraining()
+    {
+        m_trainingComplete = true;
+    }
 
 
-            learningRate *= learningRateDecay;
+    float SupervisedLearningCycle::ComputeCurrentCost(ILabeledTrainingDataPtr TestData, LossFunctions CostFunction, AZStd::size_t maxSamples)
+    {
+        InitializeContexts();
+
+        const AZStd::size_t totalTestSize = TestData->GetSampleCount();
+        maxSamples = (maxSamples == 0) ? totalTestSize : AZStd::min(maxSamples, totalTestSize);
+
+        AZStd::lock_guard lock(m_mutex);
+        double result = 0.0;
+        for (uint32_t iter = 0; iter < maxSamples; ++iter)
+        {
+            const AZ::VectorN& activations = TestData->GetDataByIndex(iter);
+            const AZ::VectorN& label = TestData->GetLabelByIndex(iter);
+            const AZ::VectorN* output = m_model->Forward(m_inferenceContext.get(), activations);
+            result += static_cast<double>(ComputeTotalCost(CostFunction, label, *output));
         }
         }
+        result /= static_cast<double>(maxSamples);
+        return static_cast<float>(result);
     }
     }
 }
 }

+ 57 - 15
Gems/MachineLearning/Code/Source/Algorithms/Training.h

@@ -9,29 +9,71 @@
 #pragma once
 #pragma once
 
 
 #include <AzCore/Math/VectorN.h>
 #include <AzCore/Math/VectorN.h>
+#include <AzCore/Jobs/JobManager.h>
+#include <AzCore/Jobs/JobContext.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <MachineLearning/ILabeledTrainingData.h>
 #include <MachineLearning/ILabeledTrainingData.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
-    //! Calculates the average cost of the provided model on the set of labeled test data using the requested loss function.
-    float ComputeCurrentCost(INeuralNetworkPtr Model, ILabeledTrainingDataPtr TestData, LossFunctions CostFunction);
-
     //! Performs a supervised learning training cycle.
     //! Performs a supervised learning training cycle.
     //! Supervised learning is a form of machine learning where a model is provided a set of training data with expected output
     //! Supervised learning is a form of machine learning where a model is provided a set of training data with expected output
     //! Training then takes place in an iterative loop where the total error (cost, loss) of the model is minimized
     //! Training then takes place in an iterative loop where the total error (cost, loss) of the model is minimized
     //! This differs from unsupervised learning, where the training data lacks any form of labeling (expected correct output), and
     //! This differs from unsupervised learning, where the training data lacks any form of labeling (expected correct output), and
     //! the model is expected to learn the underlying structures of data on its own.
     //! the model is expected to learn the underlying structures of data on its own.
-    void SupervisedLearningCycle
-    (
-        INeuralNetworkPtr model,
-        ILabeledTrainingDataPtr trainingData,
-        ILabeledTrainingDataPtr testData,
-        LossFunctions costFunction,
-        AZStd::size_t totalIterations,
-        AZStd::size_t batchSize,
-        float learningRate,
-        float learningRateDecay,
-        float earlyStopCost
-    );
+    class SupervisedLearningCycle
+    {
+    public:
+
+        SupervisedLearningCycle();
+
+        SupervisedLearningCycle
+        (
+            INeuralNetworkPtr model,
+            ILabeledTrainingDataPtr trainingData,
+            ILabeledTrainingDataPtr testData,
+            LossFunctions costFunction,
+            AZStd::size_t totalIterations,
+            AZStd::size_t batchSize,
+            float learningRate,
+            float learningRateDecay,
+            float earlyStopCost
+        );
+
+        void InitializeContexts();
+
+        void StartTraining();
+        void StopTraining();
+
+        //! Calculates the average cost of the provided model on the set of labeled test data using the requested loss function.
+        float ComputeCurrentCost(ILabeledTrainingDataPtr TestData, LossFunctions CostFunction, AZStd::size_t maxSamples = 0);
+
+        AZStd::atomic<AZStd::size_t> m_currentEpoch = 0;
+        std::atomic<bool> m_trainingComplete = true;
+
+    //private:
+        void ExecTraining();
+
+        INeuralNetworkPtr m_model;
+        ILabeledTrainingDataPtr m_trainingData;
+        ILabeledTrainingDataPtr m_testData;
+        LossFunctions m_costFunction = LossFunctions::MeanSquaredError;
+        AZStd::size_t m_totalIterations = 0;
+        AZStd::size_t m_batchSize = 0;
+        float m_learningRate = 0.0f;
+        float m_learningRateDecay = 0.0f;
+        float m_earlyStopCost = 0.0f;
+
+        AZStd::vector<AZStd::size_t> m_indices;
+        AZStd::size_t m_currentIndex = 0;
+
+        AZStd::unique_ptr<IInferenceContext> m_inferenceContext;
+        AZStd::unique_ptr<ITrainingContext> m_trainingContext;
+
+        AZStd::unique_ptr<AZ::JobManager> m_trainingJobManager;
+        AZStd::unique_ptr<AZ::JobContext> m_trainingjobContext;
+
+        //! Guards model state.
+        mutable AZStd::recursive_mutex m_mutex;
+    };
 }
 }

+ 7 - 4
Gems/MachineLearning/Code/Source/Assets/MnistDataLoader.cpp

@@ -18,7 +18,7 @@
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/SerializeContext.h>
 #include <AzCore/Serialization/SerializeContext.h>
-#pragma optimize("", off)
+
 namespace MachineLearning
 namespace MachineLearning
 {
 {
     void MnistDataLoader::Reflect(AZ::ReflectContext* context)
     void MnistDataLoader::Reflect(AZ::ReflectContext* context)
@@ -66,6 +66,11 @@ namespace MachineLearning
         return m_labelVector;
         return m_labelVector;
     }
     }
 
 
+    AZStd::size_t MnistDataLoader::GetLabelAsValueByIndex(AZStd::size_t index)
+    {
+        return static_cast<AZStd::size_t>(m_labelBuffer[index]);
+    }
+
     const AZ::VectorN& MnistDataLoader::GetDataByIndex(AZStd::size_t index)
     const AZ::VectorN& MnistDataLoader::GetDataByIndex(AZStd::size_t index)
     {
     {
         const AZStd::size_t imageDataStride = m_dataHeader.m_height * m_dataHeader.m_width;
         const AZStd::size_t imageDataStride = m_dataHeader.m_height * m_dataHeader.m_width;
@@ -174,7 +179,6 @@ namespace MachineLearning
         labelHeader.m_labelCount = ntohl(labelHeader.m_labelCount);
         labelHeader.m_labelCount = ntohl(labelHeader.m_labelCount);
 
 
         constexpr uint32_t MnistLabelHeaderValue = 2049;
         constexpr uint32_t MnistLabelHeaderValue = 2049;
-
         if (labelHeader.m_labelHeader != MnistLabelHeaderValue)
         if (labelHeader.m_labelHeader != MnistLabelHeaderValue)
         {
         {
             // Invalid format
             // Invalid format
@@ -191,9 +195,8 @@ namespace MachineLearning
         }
         }
 
 
         m_labelBuffer.resize(labelHeader.m_labelCount);
         m_labelBuffer.resize(labelHeader.m_labelCount);
-        m_imageFile.Read(labelHeader.m_labelCount, m_labelBuffer.data());
+        m_labelFile.Read(labelHeader.m_labelCount, m_labelBuffer.data());
         AZLOG_INFO("Loaded MNIST archive %s containing %u samples", filePathFixed.c_str(), m_dataHeader.m_imageCount);
         AZLOG_INFO("Loaded MNIST archive %s containing %u samples", filePathFixed.c_str(), m_dataHeader.m_imageCount);
         return true;
         return true;
     }
     }
 }
 }
-#pragma optimize("", on)

+ 1 - 0
Gems/MachineLearning/Code/Source/Assets/MnistDataLoader.h

@@ -35,6 +35,7 @@ namespace MachineLearning
         bool LoadArchive(const AZStd::string& imageFilename, const AZStd::string& labelFilename) override;
         bool LoadArchive(const AZStd::string& imageFilename, const AZStd::string& labelFilename) override;
         AZStd::size_t GetSampleCount() const override;
         AZStd::size_t GetSampleCount() const override;
         const AZ::VectorN& GetLabelByIndex(AZStd::size_t index) override;
         const AZ::VectorN& GetLabelByIndex(AZStd::size_t index) override;
+        AZStd::size_t GetLabelAsValueByIndex(AZStd::size_t index) override;
         const AZ::VectorN& GetDataByIndex(AZStd::size_t index) override;
         const AZ::VectorN& GetDataByIndex(AZStd::size_t index) override;
         //! @}
         //! @}
 
 

+ 12 - 1
Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.cpp

@@ -9,6 +9,7 @@
 #pragma once
 #pragma once
 
 
 #include <Components/MultilayerPerceptronComponent.h>
 #include <Components/MultilayerPerceptronComponent.h>
+#include <MachineLearning/IMachineLearning.h>
 #include <AzCore/RTTI/RTTI.h>
 #include <AzCore/RTTI/RTTI.h>
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/EditContext.h>
@@ -64,9 +65,19 @@ namespace MachineLearning
         provided.push_back(AZ_CRC("MultilayerPerceptronService"));
         provided.push_back(AZ_CRC("MultilayerPerceptronService"));
     }
     }
 
 
-    void MultilayerPerceptronComponent::Activate()
+    MultilayerPerceptronComponent::MultilayerPerceptronComponent()
     {
     {
         m_handle.reset(&m_model);
         m_handle.reset(&m_model);
+        MachineLearningInterface::Get()->RegisterModel(m_handle);
+    }
+
+    MultilayerPerceptronComponent::~MultilayerPerceptronComponent()
+    {
+        MachineLearningInterface::Get()->UnregisterModel(m_handle);
+    }
+
+    void MultilayerPerceptronComponent::Activate()
+    {
         MultilayerPerceptronComponentRequestBus::Handler::BusConnect(GetEntityId());
         MultilayerPerceptronComponentRequestBus::Handler::BusConnect(GetEntityId());
     }
     }
 
 

+ 3 - 0
Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.h

@@ -36,6 +36,9 @@ namespace MachineLearning
         static void Reflect(AZ::ReflectContext* context);
         static void Reflect(AZ::ReflectContext* context);
         static void GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided);
         static void GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided);
 
 
+        MultilayerPerceptronComponent();
+        ~MultilayerPerceptronComponent();
+
         //! AZ::Component overrides
         //! AZ::Component overrides
         //! @{
         //! @{
         void Activate() override;
         void Activate() override;

+ 31 - 0
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugModule.cpp

@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Source/Debug/MachineLearningDebugModule.h>
+#include <Source/Debug/MachineLearningDebugSystemComponent.h>
+
+namespace MachineLearning
+{
+    MachineLearningDebugModule::MachineLearningDebugModule()
+        : AZ::Module()
+    {
+        m_descriptors.insert(m_descriptors.end(), {
+            MachineLearningDebugSystemComponent::CreateDescriptor()
+        });
+    }
+
+    AZ::ComponentTypeList MachineLearningDebugModule::GetRequiredSystemComponents() const
+    {
+        return AZ::ComponentTypeList
+        {
+            azrtti_typeid<MachineLearningDebugSystemComponent>()
+        };
+    }
+}
+
+AZ_DECLARE_MODULE_CLASS(Gem_MachineLearning_Debug, MachineLearning::MachineLearningDebugModule);

+ 27 - 0
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugModule.h

@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <AzCore/Module/Module.h>
+
+namespace MachineLearning
+{
+    class MachineLearningDebugModule
+        : public AZ::Module
+    {
+    public:
+        AZ_RTTI(MachineLearningDebugModule, "{79F87F90-31FC-4A5D-B0B3-9CB8F0CA93E2}", AZ::Module);
+        AZ_CLASS_ALLOCATOR(MachineLearningDebugModule, AZ::SystemAllocator);
+
+        MachineLearningDebugModule();
+        ~MachineLearningDebugModule() override = default;
+
+        AZ::ComponentTypeList GetRequiredSystemComponents() const override;
+    };
+}

+ 148 - 0
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugSystemComponent.cpp

@@ -0,0 +1,148 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Source/Debug/MachineLearningDebugSystemComponent.h>
+#include <AzCore/Component/ComponentApplicationBus.h>
+#include <AzCore/Interface/Interface.h>
+#include <Atom/Feature/ImGui/SystemBus.h>
+#include <ImGuiContextScope.h>
+#include <ImGui/ImGuiPass.h>
+#include <imgui/imgui.h>
+#include <imgui/imgui_internal.h>
+
+namespace MachineLearning
+{
+    void MachineLearningDebugSystemComponent::Reflect(AZ::ReflectContext* context)
+    {
+        if (AZ::SerializeContext* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
+        {
+            serializeContext->Class<MachineLearningDebugSystemComponent, AZ::Component>()
+                ->Version(1);
+        }
+    }
+
+    void MachineLearningDebugSystemComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
+    {
+        provided.push_back(AZ_CRC_CE("MachineLearningDebugSystemComponent"));
+    }
+
+    void MachineLearningDebugSystemComponent::GetRequiredServices([[maybe_unused]] AZ::ComponentDescriptor::DependencyArrayType& required)
+    {
+        ;
+    }
+
+    void MachineLearningDebugSystemComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatbile)
+    {
+        incompatbile.push_back(AZ_CRC_CE("MachineLearningDebugSystemComponent"));
+    }
+
+    void MachineLearningDebugSystemComponent::Activate()
+    {
+#ifdef IMGUI_ENABLED
+        ImGui::ImGuiUpdateListenerBus::Handler::BusConnect();
+#endif
+    }
+
+    void MachineLearningDebugSystemComponent::Deactivate()
+    {
+#ifdef IMGUI_ENABLED
+        ImGui::ImGuiUpdateListenerBus::Handler::BusDisconnect();
+#endif
+    }
+
+#ifdef IMGUI_ENABLED
+    void MachineLearningDebugSystemComponent::OnModelRegistryDisplay()
+    {
+        const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x;
+
+        const ImGuiTableFlags flags = ImGuiTableFlags_BordersV
+            | ImGuiTableFlags_BordersOuterH
+            | ImGuiTableFlags_Resizable
+            | ImGuiTableFlags_RowBg
+            | ImGuiTableFlags_NoBordersInBody;
+
+        const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
+
+        IMachineLearning* machineLearning = MachineLearningInterface::Get();
+        const ModelSet& modelSet = machineLearning->GetModelSet();
+
+        ImGui::Text("Total registered models: %u", static_cast<uint32_t>(modelSet.size()));
+        ImGui::NewLine();
+
+        if (ImGui::BeginTable("Model Details", 6, flags))
+        {
+            ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f);
+            ImGui::TableSetupColumn("File", ImGuiTableColumnFlags_WidthStretch);
+            ImGui::TableSetupColumn("Input Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
+            ImGui::TableSetupColumn("Output Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
+            ImGui::TableSetupColumn("Layers", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
+            ImGui::TableSetupColumn("Parameters", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
+            ImGui::TableHeadersRow();
+
+            AZStd::size_t index = 0;
+            for (auto& neuralNetwork : modelSet)
+            {
+                ImGui::TableNextRow();
+                ImGui::TableNextColumn();
+                ImGui::Text(neuralNetwork->GetName().c_str());
+                ImGui::TableNextColumn();
+                ImGui::Text(neuralNetwork->GetAssetFile(AssetTypes::Model).c_str());
+                ImGui::TableNextColumn();
+                ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetInputDimensionality()));
+                ImGui::TableNextColumn();
+                ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetOutputDimensionality()));
+                ImGui::TableNextColumn();
+                ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetLayerCount()));
+                ImGui::TableNextColumn();
+                ImGui::Text("%llu", neuralNetwork->GetParameterCount());
+                ++index;
+            }
+            ImGui::EndTable();
+            ImGui::NewLine();
+        }
+
+        ImGui::End();
+    }
+
+    void MachineLearningDebugSystemComponent::OnModelTrainingDisplay()
+    {
+        m_trainingWindow.OnImGuiUpdate();
+    }
+
+    void MachineLearningDebugSystemComponent::OnImGuiMainMenuUpdate()
+    {
+        if (ImGui::BeginMenu("MachineLearning"))
+        {
+            ImGui::Checkbox("Model Registry", &m_displayModelRegistry);
+            ImGui::Checkbox("Model Training", &m_displayTrainingWindow);
+            ImGui::EndMenu();
+        }
+    }
+
+    void MachineLearningDebugSystemComponent::OnImGuiUpdate()
+    {
+        if (m_displayModelRegistry)
+        {
+            if (ImGui::Begin("Model Registry", &m_displayModelRegistry, ImGuiWindowFlags_None))
+            {
+                OnModelRegistryDisplay();
+            }
+            ImGui::End();
+        }
+
+        if (m_displayTrainingWindow)
+        {
+            if (ImGui::Begin("Model Training", &m_displayTrainingWindow, ImGuiWindowFlags_None))
+            {
+                OnModelTrainingDisplay();
+            }
+            ImGui::End();
+        }
+    }
+#endif
+}

+ 62 - 0
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugSystemComponent.h

@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <AzCore/Component/Component.h>
+#include <AzCore/Interface/Interface.h>
+#include <MachineLearning/IMachineLearning.h>
+#include <Debug/MachineLearningDebugTrainingWindow.h>
+
+#ifdef IMGUI_ENABLED
+#   include <imgui/imgui.h>
+#   include <ImGuiBus.h>
+#endif
+
+namespace MachineLearning
+{
+    class MachineLearningDebugSystemComponent final
+        : public AZ::Component
+#ifdef IMGUI_ENABLED
+        , public ImGui::ImGuiUpdateListenerBus::Handler
+#endif
+    {
+    public:
+
+        AZ_COMPONENT(MachineLearningDebugSystemComponent, "{44A3FACE-9808-4BAD-BC9C-6DB6AE0A9707}");
+
+        static void Reflect(AZ::ReflectContext* context);
+        static void GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided);
+        static void GetRequiredServices(AZ::ComponentDescriptor::DependencyArrayType& required);
+        static void GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible);
+
+        ~MachineLearningDebugSystemComponent() override = default;
+
+        //! AZ::Component overrides
+        //! @{
+        void Activate() override;
+        void Deactivate() override;
+        //! @}
+
+#ifdef IMGUI_ENABLED
+        void OnModelRegistryDisplay();
+        void OnModelTrainingDisplay();
+
+        //! ImGui::ImGuiUpdateListenerBus overrides
+        //! @{
+        void OnImGuiMainMenuUpdate() override;
+        void OnImGuiUpdate() override;
+        //! @}
+    private:
+
+        bool m_displayModelRegistry = false;
+        bool m_displayTrainingWindow = false;
+        MachineLearningDebugTrainingWindow m_trainingWindow;
+#endif
+    };
+}

+ 262 - 0
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.cpp

@@ -0,0 +1,262 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Source/Debug/MachineLearningDebugTrainingWindow.h>
+#include <Source/Assets/MnistDataLoader.h>
+#include <Source/Algorithms/Activations.h>
+#include <ImGuiContextScope.h>
+#include <ImGui/ImGuiPass.h>
+#include <imgui/imgui.h>
+#include <imgui/imgui_internal.h>
+
+namespace MachineLearning
+{
+#ifdef IMGUI_ENABLED
+    int32_t AZStdStringResizeCallback(ImGuiInputTextCallbackData* data)
+    {
+        if (data->EventFlag == ImGuiInputTextFlags_CallbackResize)
+        {
+            AZStd::string* azString = (AZStd::string*)data->UserData;
+            AZ_Assert(azString->begin() == data->Buf, "Invalid type");
+            azString->resize(data->BufSize);
+            data->Buf = azString->begin();
+        }
+        return 0;
+    }
+
+    void TextInputHelper(const char* label, AZStd::string& data)
+    {
+        ImGui::InputText(label, data.begin(), data.size(), ImGuiInputTextFlags_CallbackResize, AZStdStringResizeCallback, (void*)(&data));
+    }
+
+    MachineLearningDebugTrainingWindow::~MachineLearningDebugTrainingWindow()
+    {
+        for (auto iter : m_trainingInstances)
+        {
+            delete iter.second;
+        }
+    }
+
+    TrainingInstance* MachineLearningDebugTrainingWindow::RetrieveTrainingInstance(INeuralNetworkPtr modelPtr)
+    {
+        TrainingInstance* trainingInstance = m_trainingInstances[m_selectedModel];
+        if (trainingInstance == nullptr)
+        {
+            m_trainingInstances[m_selectedModel] = new TrainingInstance();
+            trainingInstance = m_trainingInstances[m_selectedModel];
+            trainingInstance->m_trainingCycle.m_model = m_selectedModel;
+            trainingInstance->m_testHistogram.Init("Test Cost", 250, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 0.2f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
+            trainingInstance->m_testHistogram.SetMoveDirection(ImGui::LYImGuiUtils::HistogramContainer::PushRightMoveLeft);
+            trainingInstance->m_trainHistogram.Init("Train Cost", 250, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 0.2f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
+            trainingInstance->m_trainHistogram.SetMoveDirection(ImGui::LYImGuiUtils::HistogramContainer::PushRightMoveLeft);
+            trainingInstance->m_testDataName = m_selectedModel->GetAssetFile(AssetTypes::TestData);
+            trainingInstance->m_testLabelName = m_selectedModel->GetAssetFile(AssetTypes::TestLabels);
+            trainingInstance->m_trainDataName = m_selectedModel->GetAssetFile(AssetTypes::TrainingData);
+            trainingInstance->m_trainLabelName = m_selectedModel->GetAssetFile(AssetTypes::TrainingLabels);
+        }
+        return trainingInstance;
+    }
+
+    void MachineLearningDebugTrainingWindow::LoadTestTrainData(TrainingInstance* trainingInstance)
+    {
+        if (trainingInstance->m_trainingCycle.m_trainingData == nullptr)
+        {
+            trainingInstance->m_trainingCycle.m_trainingData = AZStd::make_shared<MnistDataLoader>();
+            trainingInstance->m_trainingCycle.m_trainingData->LoadArchive(trainingInstance->m_trainDataName, trainingInstance->m_trainLabelName);
+        }
+        if (trainingInstance->m_trainingCycle.m_testData == nullptr)
+        {
+            trainingInstance->m_trainingCycle.m_testData = AZStd::make_shared<MnistDataLoader>();
+            trainingInstance->m_trainingCycle.m_testData->LoadArchive(trainingInstance->m_testDataName, trainingInstance->m_testLabelName);
+        }
+    }
+
+    void MachineLearningDebugTrainingWindow::RecalculateAccuracy(TrainingInstance* trainingInstance, ILabeledTrainingDataPtr data)
+    {
+        trainingInstance->m_trainingCycle.InitializeContexts();
+        trainingInstance->m_totalSamples = static_cast<int32_t>(data->GetSampleCount());
+        trainingInstance->m_correctPredictions = 0;
+        trainingInstance->m_incorrectPredictions = 0;
+        for (int32_t iter = 0; iter < trainingInstance->m_totalSamples; ++iter)
+        {
+            const AZ::VectorN& activations = data->GetDataByIndex(iter);
+            const AZStd::size_t label = data->GetLabelAsValueByIndex(iter);
+            const AZ::VectorN* output = m_selectedModel->Forward(trainingInstance->m_trainingCycle.m_inferenceContext.get(), activations);
+            AZStd::size_t prediction = ArgMaxDecode(*output);
+            if (label == prediction)
+            {
+                ++trainingInstance->m_correctPredictions;
+            }
+            else
+            {
+                ++trainingInstance->m_incorrectPredictions;
+            }
+        }
+    }
+
+    void MachineLearningDebugTrainingWindow::OnImGuiUpdate()
+    {
+        const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x;
+
+        const ImGuiTableFlags flags = ImGuiTableFlags_BordersV
+            | ImGuiTableFlags_BordersOuterH
+            | ImGuiTableFlags_Resizable
+            | ImGuiTableFlags_RowBg
+            | ImGuiTableFlags_NoBordersInBody;
+
+        const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
+
+        IMachineLearning* machineLearning = MachineLearningInterface::Get();
+        const ModelSet& modelSet = machineLearning->GetModelSet();
+
+        ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(0.0f, 0.0f));
+        ImGui::BeginChild("LeftPanel", ImVec2(m_trainingSplitWidth, -1.0f), true);
+
+        if (ImGui::BeginTable("Models", 1, flags))
+        {
+            ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthStretch);
+            ImGui::TableHeadersRow();
+
+            AZStd::size_t index = 0;
+            for (auto& neuralNetwork : modelSet)
+            {
+                ImGui::TableNextRow();
+                ImGui::TableNextColumn();
+                const bool isSelected = (m_selectedModelIndex == index);
+                if (ImGui::Selectable(neuralNetwork->GetName().c_str(), isSelected))
+                {
+                    m_selectedModel = neuralNetwork;
+                    m_selectedModelIndex = index;
+                }
+                ++index;
+            }
+            ImGui::EndTable();
+            ImGui::NewLine();
+        }
+
+        ImGui::EndChild();
+        ImGui::SameLine();
+        ImGui::InvisibleButton("vsplitter", ImVec2(8.0f, -1.0f));
+        if (ImGui::IsItemActive())
+        {
+            m_trainingSplitWidth += ImGui::GetIO().MouseDelta.x;
+        }
+        ImGui::SameLine();
+        ImGui::BeginChild("RightPanel", ImVec2(0.0f, -1.0f), true);
+
+        if (m_selectedModel != nullptr)
+        {
+            ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(8.0f, 4.0f));
+            TrainingInstance* trainingInstance = RetrieveTrainingInstance(m_selectedModel);
+
+            float currentTestCost = 0.0f;
+            float currentTrainCost = 0.0f;
+            if (trainingInstance->m_trainingCycle.m_testData != nullptr && trainingInstance->m_trainingCycle.m_trainingData != nullptr)
+            {
+                currentTestCost = trainingInstance->m_trainingCycle.ComputeCurrentCost(trainingInstance->m_trainingCycle.m_testData, trainingInstance->m_trainingCycle.m_costFunction, m_costSampleSize);
+                currentTrainCost = trainingInstance->m_trainingCycle.ComputeCurrentCost(trainingInstance->m_trainingCycle.m_trainingData, trainingInstance->m_trainingCycle.m_costFunction, m_costSampleSize);
+            }
+
+            if (!trainingInstance->m_trainingCycle.m_trainingComplete)
+            {
+                trainingInstance->m_testHistogram.PushValue(currentTestCost);
+                trainingInstance->m_trainHistogram.PushValue(currentTrainCost);
+                if (ImGui::Button("Stop training"))
+                {
+                    trainingInstance->m_trainingCycle.StopTraining();
+                }
+                ImGui::SameLine();
+                int32_t epoch = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_currentEpoch);
+                ImGui::Text("Epoch: %d", epoch);
+            }
+            else
+            {
+                if (ImGui::Button("Start training"))
+                {
+                    LoadTestTrainData(trainingInstance);
+                    trainingInstance->m_trainingCycle.StartTraining();
+                }
+                ImGui::SameLine();
+                if (ImGui::Button("Save"))
+                {
+                    m_selectedModel->SaveModel();
+                }
+                ImGui::SameLine();
+                if (ImGui::Button("Load"))
+                {
+                    m_selectedModel->LoadModel();
+                }
+                ImGui::SameLine();
+                if (ImGui::Button("Recalculate accuracy on test data"))
+                {
+                    LoadTestTrainData(trainingInstance);
+                    RecalculateAccuracy(trainingInstance, trainingInstance->m_trainingCycle.m_testData);
+                }
+                ImGui::SameLine();
+                if (ImGui::Button("Recalculate accuracy on training data"))
+                {
+                    LoadTestTrainData(trainingInstance);
+                    RecalculateAccuracy(trainingInstance, trainingInstance->m_trainingCycle.m_trainingData);
+                }
+            }
+ 
+            ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
+            ImGui::NewLine();
+            ImGui::Text("Asset location: %s", m_selectedModel->GetAssetFile(AssetTypes::Model).c_str());
+            ImGui::NewLine();
+
+            ImGui::Text("Total samples: %d", trainingInstance->m_totalSamples);
+            ImGui::Text("Correct predictions: %d", trainingInstance->m_correctPredictions);
+            ImGui::Text("Incorrect predictions: %d", trainingInstance->m_incorrectPredictions);
+
+            const float accuracy = (static_cast<float>(trainingInstance->m_correctPredictions) * 100.0f) / static_cast<float>(trainingInstance->m_totalSamples);
+            ImGui::Text("Accuracy: %f", accuracy);
+
+            ImGui::Text("Test score: %f", currentTestCost);
+            trainingInstance->m_testHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
+            ImGui::Text("Train score: %f", currentTrainCost);
+            trainingInstance->m_trainHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
+            ImGui::SliderInt("Cost evaluation sample size", &m_costSampleSize, 10, 10000);
+            ImGui::NewLine();
+
+            TextInputHelper("Test data asset file", trainingInstance->m_testDataName);
+            TextInputHelper("Test data label file", trainingInstance->m_testLabelName);
+            ImGui::NewLine();
+            TextInputHelper("Train data asset file", trainingInstance->m_trainDataName);
+            TextInputHelper("Train data label file", trainingInstance->m_trainLabelName);
+            ImGui::NewLine();
+
+            //AZStd::fixed_string<64> valueString;
+            //valueString = AZStd::fixed_string<64>::format("%0.3f", trainingInstance->m_trainingCycle.m_learningRate);
+            //float logValue = log(trainingInstance->m_trainingCycle.m_learningRate);
+            //ImGui::SliderFloat("LearningRate", &logValue, log(0.0001f), log(1.0f), valueString.c_str());
+            //trainingInstance->m_trainingCycle.m_learningRate = exp(logValue);
+
+            ImGui::SliderFloat("LearningRate", &trainingInstance->m_trainingCycle.m_learningRate, 0.0f, 0.1f);
+            ImGui::SliderFloat("LearningRateDecay", &trainingInstance->m_trainingCycle.m_learningRateDecay, 0.0f, 1.0f);
+            ImGui::NewLine();
+
+            int32_t batchSize = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_batchSize);
+            ImGui::SliderInt("Batch size", &batchSize, 1, 1000);
+            trainingInstance->m_trainingCycle.m_batchSize = batchSize;
+            int32_t totalIterations = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_totalIterations);
+            ImGui::SliderInt("Number of iterations", &totalIterations, 1, 1000);
+            trainingInstance->m_trainingCycle.m_totalIterations = totalIterations;
+
+            int32_t costMetric = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_costFunction);
+            ImGui::Combo("Cost metric", &costMetric, "MeanSquaredError\0");
+            trainingInstance->m_trainingCycle.m_costFunction = static_cast<LossFunctions>(costMetric);
+
+            ImGui::PopStyleVar();
+        }
+
+        ImGui::EndChild();
+        ImGui::PopStyleVar();
+    }
+#endif
+}

+ 65 - 0
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.h

@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <AzCore/Component/Component.h>
+#include <AzCore/Interface/Interface.h>
+#include <AzCore/std/containers/map.h>
+#include <MachineLearning/IMachineLearning.h>
+#include <Algorithms/Training.h>
+
+#ifdef IMGUI_ENABLED
+#   include <imgui/imgui.h>
+#   include <ImGuiBus.h>
+#   include <LYImGuiUtils/HistogramContainer.h>
+#endif
+
+namespace MachineLearning
+{
+    struct TrainingInstance
+    {
+        SupervisedLearningCycle m_trainingCycle;
+
+        AZStd::string m_testDataName;
+        AZStd::string m_testLabelName;
+
+        AZStd::string m_trainDataName;
+        AZStd::string m_trainLabelName;
+
+        int32_t m_totalSamples = 0;
+        int32_t m_correctPredictions = 0;
+        int32_t m_incorrectPredictions = 0;
+
+#ifdef IMGUI_ENABLED
+        ImGui::LYImGuiUtils::HistogramContainer m_testHistogram;
+        ImGui::LYImGuiUtils::HistogramContainer m_trainHistogram;
+#endif
+    };
+
+    class MachineLearningDebugTrainingWindow
+    {
+    public:
+        ~MachineLearningDebugTrainingWindow();
+
+        TrainingInstance* RetrieveTrainingInstance(INeuralNetworkPtr modelPtr);
+        void LoadTestTrainData(TrainingInstance* trainingInstance);
+        void RecalculateAccuracy(TrainingInstance* trainingInstance, ILabeledTrainingDataPtr data);
+
+#ifdef IMGUI_ENABLED
+        void OnImGuiUpdate();
+#endif
+
+        AZStd::size_t m_selectedModelIndex = 0;
+        INeuralNetworkPtr m_selectedModel = nullptr;
+        float m_trainingSplitWidth = 400.0f;
+        int32_t m_costSampleSize = 1000;
+
+        AZStd::map<INeuralNetworkPtr, TrainingInstance*> m_trainingInstances;
+    };
+}

+ 17 - 6
Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.cpp

@@ -54,11 +54,7 @@ namespace MachineLearning
                 Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
                 Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
                 Constructor<>()->
                 Constructor<>()->
                 Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
                 Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
-                Method("AddLayer", &INeuralNetwork::AddLayer)->
-                Method("GetLayerCount", &INeuralNetwork::GetLayerCount)->
-                Method("GetLayer", &INeuralNetwork::GetLayer)->
-                Method("Forward", &INeuralNetwork::Forward)->
-                Method("Reverse", &INeuralNetwork::Reverse)
+                Method("GetName", &INeuralNetwork::GetName)
                 ;
                 ;
         }
         }
 
 
@@ -114,4 +110,19 @@ namespace MachineLearning
     {
     {
         MachineLearningRequestBus::Handler::BusDisconnect();
         MachineLearningRequestBus::Handler::BusDisconnect();
     }
     }
-} // namespace MachineLearning
+
+    void MachineLearningSystemComponent::RegisterModel(INeuralNetworkPtr model)
+    {
+        m_registeredModels.emplace(model);
+    }
+
+    void MachineLearningSystemComponent::UnregisterModel(INeuralNetworkPtr model)
+    {
+        m_registeredModels.erase(model);
+    }
+
+    ModelSet& MachineLearningSystemComponent::GetModelSet()
+    {
+        return m_registeredModels;
+    }
+}

+ 17 - 10
Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.h

@@ -9,13 +9,14 @@
 #pragma once
 #pragma once
 
 
 #include <AzCore/Component/Component.h>
 #include <AzCore/Component/Component.h>
-#include <MachineLearning/MachineLearningBus.h>
+#include <MachineLearning/IMachineLearning.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
     class MachineLearningSystemComponent
     class MachineLearningSystemComponent
         : public AZ::Component
         : public AZ::Component
         , protected MachineLearningRequestBus::Handler
         , protected MachineLearningRequestBus::Handler
+//        , public AZ::Interface<IMachineLearning>::Registrar
     {
     {
     public:
     public:
         AZ_COMPONENT_DECL(MachineLearningSystemComponent);
         AZ_COMPONENT_DECL(MachineLearningSystemComponent);
@@ -31,17 +32,23 @@ namespace MachineLearning
         ~MachineLearningSystemComponent();
         ~MachineLearningSystemComponent();
 
 
     protected:
     protected:
-        ////////////////////////////////////////////////////////////////////////
-        // MachineLearningRequestBus interface implementation
 
 
-        ////////////////////////////////////////////////////////////////////////
-
-        ////////////////////////////////////////////////////////////////////////
-        // AZ::Component interface implementation
+        //! AZ::Component interface
+        //! @{
         void Init() override;
         void Init() override;
         void Activate() override;
         void Activate() override;
         void Deactivate() override;
         void Deactivate() override;
-        ////////////////////////////////////////////////////////////////////////
-    };
+        //! @}
+
+        //! IMachineLearning interface
+        //! @{
+        void RegisterModel(INeuralNetworkPtr model) override;
+        void UnregisterModel(INeuralNetworkPtr model) override;
+        ModelSet& GetModelSet() override;
+        //! @}
 
 
-} // namespace MachineLearning
+    private:
+
+        ModelSet m_registeredModels;
+    };
+}

+ 63 - 26
Gems/MachineLearning/Code/Source/Models/Layer.cpp

@@ -13,6 +13,7 @@
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/SerializeContext.h>
 #include <AzCore/Serialization/SerializeContext.h>
+#include <random>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
@@ -65,64 +66,100 @@ namespace MachineLearning
         OnSizesChanged();
         OnSizesChanged();
     }
     }
 
 
-    const AZ::VectorN& Layer::Forward(const AZ::VectorN& activations)
+    const AZ::VectorN& Layer::Forward(LayerInferenceData& inferenceData, const AZ::VectorN& activations)
     {
     {
-        m_lastInput = activations;
-        m_output = m_biases;
-        AZ::VectorMatrixMultiply(m_weights, m_lastInput, m_output);
-        Activate(m_activationFunction, m_output, m_output);
-        return m_output;
+        inferenceData.m_output = m_biases;
+        AZ::VectorMatrixMultiply(m_weights, activations, inferenceData.m_output);
+        Activate(m_activationFunction, inferenceData.m_output, inferenceData.m_output);
+        return inferenceData.m_output;
     }
     }
 
 
-    void Layer::AccumulateGradients(const AZ::VectorN& previousLayerGradients)
+    void Layer::AccumulateGradients(LayerTrainingData& trainingData, LayerInferenceData& inferenceData, const AZ::VectorN& previousLayerGradients)
     {
     {
         // Ensure our bias gradient vector is appropriately sized
         // Ensure our bias gradient vector is appropriately sized
-        if (m_biasGradients.GetDimensionality() != m_outputSize)
+        if (trainingData.m_biasGradients.GetDimensionality() != m_outputSize)
         {
         {
-            m_biasGradients = AZ::VectorN::CreateZero(m_outputSize);
+            trainingData.m_biasGradients = AZ::VectorN::CreateZero(m_outputSize);
         }
         }
 
 
         // Ensure our weight gradient matrix is appropriately sized
         // Ensure our weight gradient matrix is appropriately sized
-        if ((m_weightGradients.GetRowCount() != m_outputSize) || (m_weightGradients.GetColumnCount() != m_inputSize))
+        if ((trainingData.m_weightGradients.GetRowCount() != m_outputSize) || (trainingData.m_weightGradients.GetColumnCount() != m_inputSize))
         {
         {
-            m_weightGradients = AZ::MatrixMxN::CreateZero(m_outputSize, m_inputSize);
+            trainingData.m_weightGradients = AZ::MatrixMxN::CreateZero(m_outputSize, m_inputSize);
         }
         }
 
 
         // Ensure our backpropagation gradient vector is appropriately sized
         // Ensure our backpropagation gradient vector is appropriately sized
-        if (m_backpropagationGradients.GetDimensionality() != m_inputSize)
+        if (trainingData.m_backpropagationGradients.GetDimensionality() != m_inputSize)
         {
         {
-            m_backpropagationGradients = AZ::VectorN::CreateZero(m_inputSize);
+            trainingData.m_backpropagationGradients = AZ::VectorN::CreateZero(m_inputSize);
         }
         }
 
 
         // Compute the partial derivatives of the output with respect to the activation function
         // Compute the partial derivatives of the output with respect to the activation function
-        Activate_Derivative(m_activationFunction, m_output, previousLayerGradients, m_activationGradients);
+        Activate_Derivative(m_activationFunction, inferenceData.m_output, previousLayerGradients, trainingData.m_activationGradients);
 
 
         // Accumulate the partial derivatives of the weight matrix with respect to the loss function
         // Accumulate the partial derivatives of the weight matrix with respect to the loss function
-        AZ::OuterProduct(m_activationGradients, m_lastInput, m_weightGradients);
+        AZ::OuterProduct(trainingData.m_activationGradients, *trainingData.m_lastInput, trainingData.m_weightGradients);
 
 
         // Accumulate the partial derivatives of the bias vector with respect to the loss function
         // Accumulate the partial derivatives of the bias vector with respect to the loss function
-        m_biasGradients += m_activationGradients;
+        trainingData.m_biasGradients += trainingData.m_activationGradients;
 
 
         // Accumulate the gradients to pass to the preceding layer for back-propagation
         // Accumulate the gradients to pass to the preceding layer for back-propagation
-        AZ::VectorMatrixMultiplyLeft(m_activationGradients, m_weights, m_backpropagationGradients);
+        AZ::VectorMatrixMultiplyLeft(trainingData.m_activationGradients, m_weights, trainingData.m_backpropagationGradients);
     }
     }
 
 
-    void Layer::ApplyGradients(float learningRate)
+    void Layer::ApplyGradients(LayerTrainingData& trainingData, float learningRate)
     {
     {
-        m_weights -= m_weightGradients * learningRate;
-        m_biases -= m_biasGradients * learningRate;
+        m_weights -= trainingData.m_weightGradients * learningRate;
+        m_biases -= trainingData.m_biasGradients * learningRate;
 
 
-        m_biasGradients.SetZero();
-        m_weightGradients.SetZero();
-        m_backpropagationGradients.SetZero();
+        trainingData.m_biasGradients.SetZero();
+        trainingData.m_weightGradients.SetZero();
+        trainingData.m_backpropagationGradients.SetZero();
+    }
+
+    bool Layer::Serialize(AzNetworking::ISerializer& serializer)
+    {
+        return serializer.Serialize(m_inputSize, "inputSize")
+            && serializer.Serialize(m_outputSize, "outputSize")
+            && serializer.Serialize(m_weights, "weights")
+            && serializer.Serialize(m_biases, "biases")
+            && serializer.Serialize(m_activationFunction, "activationFunction");
+    }
+
+    AZStd::size_t Layer::EstimateSerializeSize() const
+    {
+        const AZStd::size_t padding = 64; // 64 bytes of extra padding just in case
+        return padding
+             + sizeof(m_inputSize)
+             + sizeof(m_outputSize)
+             + sizeof(AZStd::size_t) // for m_weights row count
+             + sizeof(AZStd::size_t) // for m_weights column count
+             + sizeof(AZStd::size_t) // for m_weights vector size
+             + sizeof(float) * m_outputSize * m_inputSize // m_weights buffer
+             + sizeof(AZStd::size_t) // for m_biases dimensionality
+             + sizeof(AZStd::size_t) // for m_biases vector size
+             + sizeof(float) * m_outputSize // m_biases buffer
+             + sizeof(m_activationFunction);
     }
     }
 
 
     void Layer::OnSizesChanged()
     void Layer::OnSizesChanged()
     {
     {
-        m_weights = AZ::MatrixMxN::CreateRandom(m_outputSize, m_inputSize);
-        m_weights -= 0.5f; // It's preferable for efficient training to keep initial weights centered around zero
+        // Specifically for ReLU, we use Kaiming He initialization as this is proven optimal for convergence
+        // For other activation functions we just use a standard normal distribution
+        float standardDeviation = (m_activationFunction == ActivationFunctions::ReLU) ? 2.0f / m_inputSize 
+                                                                                      : 1.0f / m_inputSize;
+        std::random_device rd{};
+        std::mt19937 gen{ rd() };
+        auto dist = std::normal_distribution<float>{ 0.0f, standardDeviation };
+        m_weights.Resize(m_outputSize, m_inputSize);
+        for (AZStd::size_t row = 0; row < m_weights.GetRowCount(); ++row)
+        {
+            for (AZStd::size_t col = 0; col < m_weights.GetRowCount(); ++col)
+            {
+                m_weights.SetElement(row, col, dist(gen));
+            }
+        }
 
 
         m_biases = AZ::VectorN(m_outputSize, 0.01f);
         m_biases = AZ::VectorN(m_outputSize, 0.01f);
-        m_output = AZ::VectorN::CreateZero(m_outputSize);
     }
     }
 }
 }

+ 27 - 5
Gems/MachineLearning/Code/Source/Models/Layer.h

@@ -9,10 +9,15 @@
 #pragma once
 #pragma once
 
 
 #include <AzCore/Math/MatrixMxN.h>
 #include <AzCore/Math/MatrixMxN.h>
+#include <AzNetworking/Serialization/ISerializer.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <MachineLearning/INeuralNetwork.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
+    // We separate out inference and training data to make multithreading models easier and more efficient.
+    struct LayerInferenceData;
+    struct LayerTrainingData;
+
     //! A class representing a single layer within a neural network.
     //! A class representing a single layer within a neural network.
     class Layer
     class Layer
     {
     {
@@ -34,14 +39,22 @@ namespace MachineLearning
         Layer& operator=(const Layer&) = default;
         Layer& operator=(const Layer&) = default;
 
 
         //! Performs a basic forward pass on this layer, outputs are stored in m_output.
         //! Performs a basic forward pass on this layer, outputs are stored in m_output.
-        const AZ::VectorN& Forward(const AZ::VectorN& activations);
+        const AZ::VectorN& Forward(LayerInferenceData& inferenceData, const AZ::VectorN& activations);
 
 
         //! Performs a gradient computation against the provided expected output using the provided gradients from the previous layer.
         //! Performs a gradient computation against the provided expected output using the provided gradients from the previous layer.
         //! This method presumes that we've completed a forward pass immediately prior to fill all the relevant vectors
         //! This method presumes that we've completed a forward pass immediately prior to fill all the relevant vectors
-        void AccumulateGradients(const AZ::VectorN& expected);
+        void AccumulateGradients(LayerTrainingData& trainingData, LayerInferenceData& inferenceData, const AZ::VectorN& expected);
 
 
         //! Applies the current gradient values to the layers weights and biases and resets the gradient values for a new accumulation pass.
         //! Applies the current gradient values to the layers weights and biases and resets the gradient values for a new accumulation pass.
-        void ApplyGradients(float learningRate);
+        void ApplyGradients(LayerTrainingData& trainingData, float learningRate);
+
+        //! Base serialize method for all serializable structures or classes to implement.
+        //! @param serializer ISerializer instance to use for serialization
+        //! @return boolean true for success, false for serialization failure
+        bool Serialize(AzNetworking::ISerializer& serializer);
+
+        //! Returns the estimated size required to serialize this layer.
+        AZStd::size_t EstimateSerializeSize() const;
 
 
         //! Updates layer internals for it's requested dimensionalities.
         //! Updates layer internals for it's requested dimensionalities.
         void OnSizesChanged();
         void OnSizesChanged();
@@ -51,11 +64,20 @@ namespace MachineLearning
         AZStd::size_t m_outputSize = 0;
         AZStd::size_t m_outputSize = 0;
         AZ::MatrixMxN m_weights;
         AZ::MatrixMxN m_weights;
         AZ::VectorN m_biases;
         AZ::VectorN m_biases;
-        AZ::VectorN m_output;
         ActivationFunctions m_activationFunction = ActivationFunctions::ReLU;
         ActivationFunctions m_activationFunction = ActivationFunctions::ReLU;
+    };
 
 
+    //! These values are written to during inference.
+    struct LayerInferenceData
+    {
+        AZ::VectorN m_output;
+    };
+
+    //! These values are read and written during training.
+    struct LayerTrainingData
+    {
         // These values will only be populated if backward propagation is performed
         // These values will only be populated if backward propagation is performed
-        AZ::VectorN m_lastInput;
+        const AZ::VectorN* m_lastInput;
         AZ::VectorN m_activationGradients;
         AZ::VectorN m_activationGradients;
         AZ::VectorN m_biasGradients;
         AZ::VectorN m_biasGradients;
         AZ::MatrixMxN m_weightGradients;
         AZ::MatrixMxN m_weightGradients;

+ 229 - 30
Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.cpp

@@ -12,6 +12,12 @@
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/SerializeContext.h>
 #include <AzCore/Serialization/SerializeContext.h>
+#include <AzCore/IO/FileIO.h>
+#include <AzCore/IO/FileReader.h>
+#include <AzCore/IO/Path/Path.h>
+#include <AzCore/Console/ILogger.h>
+#include <AzNetworking/Serialization/NetworkInputSerializer.h>
+#include <AzNetworking/Serialization/NetworkOutputSerializer.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
@@ -21,6 +27,12 @@ namespace MachineLearning
         {
         {
             serializeContext->Class<MultilayerPerceptron>()
             serializeContext->Class<MultilayerPerceptron>()
                 ->Version(1)
                 ->Version(1)
+                ->Field("Name", &MultilayerPerceptron::m_name)
+                ->Field("ModelFile", &MultilayerPerceptron::m_modelFile)
+                ->Field("TestDataFile", &MultilayerPerceptron::m_testDataFile)
+                ->Field("TestLabelFile", &MultilayerPerceptron::m_testLabelFile)
+                ->Field("TrainDataFile", &MultilayerPerceptron::m_trainDataFile)
+                ->Field("TrainLabelFile", &MultilayerPerceptron::m_trainLabelFile)
                 ->Field("ActivationCount", &MultilayerPerceptron::m_activationCount)
                 ->Field("ActivationCount", &MultilayerPerceptron::m_activationCount)
                 ->Field("Layers", &MultilayerPerceptron::m_layers)
                 ->Field("Layers", &MultilayerPerceptron::m_layers)
                 ;
                 ;
@@ -29,6 +41,12 @@ namespace MachineLearning
             {
             {
                 editContext->Class<MultilayerPerceptron>("A basic multilayer perceptron class", "")
                 editContext->Class<MultilayerPerceptron>("A basic multilayer perceptron class", "")
                     ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
                     ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
+                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_name, "Name", "The name for this model")
+                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_modelFile, "ModelFile", "The file this model is saved to and loaded from")
+                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testDataFile, "TestDataFile", "The file test data should be loaded from")
+                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testLabelFile, "TestLabelFile", "The file test labels should be loaded from")
+                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_trainDataFile, "TrainDataFile", "The file training data should be loaded from")
+                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_trainLabelFile, "TrainLabelFile", "The file training labels should be loaded from")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_activationCount, "Activation Count", "The number of neurons in the activation layer")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_activationCount, "Activation Count", "The number of neurons in the activation layer")
                     ->Attribute(AZ::Edit::Attributes::ChangeNotify, &MultilayerPerceptron::OnActivationCountChanged)
                     ->Attribute(AZ::Edit::Attributes::ChangeNotify, &MultilayerPerceptron::OnActivationCountChanged)
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_layers, "Layers", "The layers of the neural network")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_layers, "Layers", "The layers of the neural network")
@@ -46,44 +64,99 @@ namespace MachineLearning
                 Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
                 Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
                 Constructor<AZStd::size_t>()->
                 Constructor<AZStd::size_t>()->
                 Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
                 Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
-                Method("AddLayer", &MultilayerPerceptron::AddLayer)->
+                Method("GetName", &MultilayerPerceptron::GetName)->
                 Method("GetLayerCount", &MultilayerPerceptron::GetLayerCount)->
                 Method("GetLayerCount", &MultilayerPerceptron::GetLayerCount)->
-                Method("GetLayer", &MultilayerPerceptron::GetLayer)->
-                Method("Forward", &MultilayerPerceptron::Forward)->
-                Method("Reverse", &MultilayerPerceptron::Reverse)->
                 Property("ActivationCount", BehaviorValueProperty(&MultilayerPerceptron::m_activationCount))->
                 Property("ActivationCount", BehaviorValueProperty(&MultilayerPerceptron::m_activationCount))->
                 Property("Layers", BehaviorValueProperty(&MultilayerPerceptron::m_layers))
                 Property("Layers", BehaviorValueProperty(&MultilayerPerceptron::m_layers))
                 ;
                 ;
         }
         }
     }
     }
 
 
+    MultilayerPerceptron::MultilayerPerceptron()
+    {
+    }
+
+    MultilayerPerceptron::MultilayerPerceptron(const MultilayerPerceptron& rhs)
+        : m_name(rhs.m_name)
+        , m_modelFile(rhs.m_modelFile)
+        , m_testDataFile(rhs.m_testDataFile)
+        , m_testLabelFile(rhs.m_testLabelFile)
+        , m_trainDataFile(rhs.m_trainDataFile)
+        , m_trainLabelFile(rhs.m_trainLabelFile)
+        , m_activationCount(rhs.m_activationCount)
+        , m_layers(rhs.m_layers)
+    {
+    }
+
     MultilayerPerceptron::MultilayerPerceptron(AZStd::size_t activationCount)
     MultilayerPerceptron::MultilayerPerceptron(AZStd::size_t activationCount)
         : m_activationCount(activationCount)
         : m_activationCount(activationCount)
     {
     {
     }
     }
 
 
-    void MultilayerPerceptron::AddLayer(AZStd::size_t layerDimensionality, ActivationFunctions activationFunction)
+    MultilayerPerceptron::~MultilayerPerceptron()
     {
     {
-        AZStd::size_t lastLayerDimensionality = m_activationCount;
-        if (!m_layers.empty())
+    }
+
+    MultilayerPerceptron& MultilayerPerceptron::operator=(const MultilayerPerceptron& rhs)
+    {
+        m_name = rhs.m_name;
+        m_modelFile = rhs.m_modelFile;
+        m_testDataFile = rhs.m_testDataFile;
+        m_testLabelFile = rhs.m_testLabelFile;
+        m_trainDataFile = rhs.m_trainDataFile;
+        m_trainLabelFile = rhs.m_trainLabelFile;
+        m_activationCount = rhs.m_activationCount;
+        m_layers = rhs.m_layers;
+        return *this;
+    }
+
+    AZStd::string MultilayerPerceptron::GetName() const
+    {
+        return m_name;
+    }
+
+    AZStd::string MultilayerPerceptron::GetAssetFile(AssetTypes assetType) const
+    {
+        switch (assetType)
         {
         {
-            lastLayerDimensionality = m_layers.back().m_biases.GetDimensionality();
+        case AssetTypes::Model:
+            return m_modelFile;
+        case AssetTypes::TestData:
+            return m_testDataFile;
+        case AssetTypes::TestLabels:
+            return m_testLabelFile;
+        case AssetTypes::TrainingData:
+            return m_trainDataFile;
+        case AssetTypes::TrainingLabels:
+            return m_trainLabelFile;
         }
         }
-        m_layers.push_back(AZStd::move(Layer(activationFunction, lastLayerDimensionality, layerDimensionality)));
+        return "";
     }
     }
 
 
-    AZStd::size_t MultilayerPerceptron::GetLayerCount() const
+    AZStd::size_t MultilayerPerceptron::GetInputDimensionality() const
     {
     {
-        return m_layers.size();
+        return m_activationCount;
     }
     }
 
 
-    Layer* MultilayerPerceptron::GetLayer(AZStd::size_t layerIndex)
+    AZStd::size_t MultilayerPerceptron::GetOutputDimensionality() const
     {
     {
-        return &m_layers[layerIndex];
+        //AZStd::lock_guard lock(m_mutex);
+        if (!m_layers.empty())
+        {
+            return m_layers.back().m_biases.GetDimensionality();
+        }
+        return m_activationCount;
+    }
+
+    AZStd::size_t MultilayerPerceptron::GetLayerCount() const
+    {
+        //AZStd::lock_guard lock(m_mutex);
+        return m_layers.size();
     }
     }
 
 
     AZStd::size_t MultilayerPerceptron::GetParameterCount() const
     AZStd::size_t MultilayerPerceptron::GetParameterCount() const
     {
     {
+        //AZStd::lock_guard lock(m_mutex);
         AZStd::size_t parameterCount = 0;
         AZStd::size_t parameterCount = 0;
         for (const Layer& layer : m_layers)
         for (const Layer& layer : m_layers)
         {
         {
@@ -92,50 +165,81 @@ namespace MachineLearning
         return parameterCount;
         return parameterCount;
     }
     }
 
 
-    const AZ::VectorN* MultilayerPerceptron::Forward(const AZ::VectorN& activations)
+    IInferenceContextPtr MultilayerPerceptron::CreateInferenceContext()
     {
     {
+        return new MlpInferenceContext();
+    }
+
+    ITrainingContextPtr MultilayerPerceptron::CreateTrainingContext()
+    {
+        return new MlpTrainingContext();
+    }
+
+    const AZ::VectorN* MultilayerPerceptron::Forward(IInferenceContextPtr context, const AZ::VectorN& activations)
+    {
+        //AZStd::lock_guard lock(m_mutex);
+        MlpInferenceContext* forwardContext = static_cast<MlpInferenceContext*>(context);
+        forwardContext->m_layerData.resize(m_layers.size());
+
         const AZ::VectorN* lastLayerOutput = &activations;
         const AZ::VectorN* lastLayerOutput = &activations;
-        for (Layer& layer : m_layers)
+        for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
         {
         {
-            layer.Forward(*lastLayerOutput);
-            lastLayerOutput = &layer.m_output;
+            m_layers[iter].Forward(forwardContext->m_layerData[iter], *lastLayerOutput);
+            lastLayerOutput = &forwardContext->m_layerData[iter].m_output;
         }
         }
         return lastLayerOutput;
         return lastLayerOutput;
     }
     }
 
 
-    void MultilayerPerceptron::Reverse(LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected)
+    void MultilayerPerceptron::Reverse(ITrainingContextPtr context, LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected)
     {
     {
-        ++m_trainingSampleSize;
+        //AZStd::lock_guard lock(m_mutex);
+        MlpTrainingContext* reverseContext = static_cast<MlpTrainingContext*>(context);
+        MlpInferenceContext* forwardContext = &reverseContext->m_forward;
+        reverseContext->m_layerData.resize(m_layers.size());
+        forwardContext->m_layerData.resize(m_layers.size());
+
+        ++reverseContext->m_trainingSampleSize;
 
 
         // First feed-forward the activations to get our current model predictions
         // First feed-forward the activations to get our current model predictions
-        const AZ::VectorN* output = Forward(activations);
+        // We do additional book-keeping over a standard forward pass to make gradient calculations easier
+        const AZ::VectorN* lastLayerOutput = &activations;
+        for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
+        {
+            reverseContext->m_layerData[iter].m_lastInput = lastLayerOutput;
+            m_layers[iter].Forward(forwardContext->m_layerData[iter], *lastLayerOutput);
+            lastLayerOutput = &forwardContext->m_layerData[iter].m_output;
+        }
 
 
         // Compute the partial derivatives of the loss function with respect to the final layer output
         // Compute the partial derivatives of the loss function with respect to the final layer output
         AZ::VectorN costGradients;
         AZ::VectorN costGradients;
-        ComputeLoss_Derivative(lossFunction, *output, expected, costGradients);
+        ComputeLoss_Derivative(lossFunction, *lastLayerOutput, expected, costGradients);
 
 
-        for (auto iter = m_layers.rbegin(); iter != m_layers.rend(); ++iter)
+        AZ::VectorN* lossGradient = &costGradients;
+        for (int64_t iter = static_cast<int64_t>(m_layers.size()) - 1; iter >= 0; --iter)
         {
         {
-            iter->AccumulateGradients(costGradients);
-            costGradients = iter->m_backpropagationGradients;
+            m_layers[iter].AccumulateGradients(reverseContext->m_layerData[iter], forwardContext->m_layerData[iter], *lossGradient);
+            lossGradient = &reverseContext->m_layerData[iter].m_backpropagationGradients;
         }
         }
     }
     }
 
 
-    void MultilayerPerceptron::GradientDescent(float learningRate)
+    void MultilayerPerceptron::GradientDescent(ITrainingContextPtr context, float learningRate)
     {
     {
-        if (m_trainingSampleSize > 0)
+        //AZStd::lock_guard lock(m_mutex);
+        MlpTrainingContext* reverseContext = static_cast<MlpTrainingContext*>(context);
+        if (reverseContext->m_trainingSampleSize > 0)
         {
         {
-            const float adjustedLearningRate = learningRate / static_cast<float>(m_trainingSampleSize);
-            for (auto iter = m_layers.rbegin(); iter != m_layers.rend(); ++iter)
+            const float adjustedLearningRate = learningRate / static_cast<float>(reverseContext->m_trainingSampleSize);
+            for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
             {
             {
-                iter->ApplyGradients(adjustedLearningRate);
+                m_layers[iter].ApplyGradients(reverseContext->m_layerData[iter], adjustedLearningRate);
             }
             }
         }
         }
-        m_trainingSampleSize = 0;
+        reverseContext->m_trainingSampleSize = 0;
     }
     }
 
 
     void MultilayerPerceptron::OnActivationCountChanged()
     void MultilayerPerceptron::OnActivationCountChanged()
     {
     {
+        //AZStd::lock_guard lock(m_mutex);
         AZStd::size_t lastLayerDimensionality = m_activationCount;
         AZStd::size_t lastLayerDimensionality = m_activationCount;
         for (Layer& layer : m_layers)
         for (Layer& layer : m_layers)
         {
         {
@@ -144,4 +248,99 @@ namespace MachineLearning
             lastLayerDimensionality = layer.m_outputSize;
             lastLayerDimensionality = layer.m_outputSize;
         }
         }
     }
     }
+
+    bool MultilayerPerceptron::LoadModel()
+    {
+        AZ::IO::SystemFile modelFile;
+        AZ::IO::FixedMaxPath filePathFixed = m_modelFile.c_str();
+        if (AZ::IO::FileIOBase* fileIOBase = AZ::IO::FileIOBase::GetInstance())
+        {
+            fileIOBase->ResolvePath(filePathFixed, m_modelFile.c_str());
+        }
+
+        if (!modelFile.Open(filePathFixed.c_str(), AZ::IO::SystemFile::SF_OPEN_READ_ONLY))
+        {
+            AZLOG_ERROR("Failed to load '%s'. File could not be opened.", filePathFixed.c_str());
+            return false;
+        }
+
+        const AZ::IO::SizeType length = modelFile.Length();
+        if (length == 0)
+        {
+            AZLOG_ERROR("Failed to load '%s'. File is empty.", filePathFixed.c_str());
+            return false;
+        }
+
+        AZStd::vector<uint8_t> serializeBuffer;
+        serializeBuffer.resize(length);
+        modelFile.Seek(0, AZ::IO::SystemFile::SF_SEEK_BEGIN);
+        modelFile.Read(serializeBuffer.size(), serializeBuffer.data());
+        AzNetworking::NetworkOutputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
+        return Serialize(serializer);
+    }
+
+    bool MultilayerPerceptron::SaveModel()
+    {
+        AZ::IO::SystemFile modelFile;
+        AZ::IO::FixedMaxPath filePathFixed = m_modelFile.c_str();
+        if (AZ::IO::FileIOBase* fileIOBase = AZ::IO::FileIOBase::GetInstance())
+        {
+            fileIOBase->ResolvePath(filePathFixed, m_modelFile.c_str());
+        }
+
+        if (!modelFile.Open(filePathFixed.c_str(), AZ::IO::SystemFile::SF_OPEN_READ_WRITE | AZ::IO::SystemFile::SF_OPEN_CREATE))
+        {
+            AZLOG_ERROR("Failed to save to '%s'. File could not be opened for writing.", filePathFixed.c_str());
+            return false;
+        }
+        modelFile.Seek(0, AZ::IO::SystemFile::SF_SEEK_BEGIN);
+
+        AZStd::vector<uint8_t> serializeBuffer;
+        serializeBuffer.resize(EstimateSerializeSize());
+        AzNetworking::NetworkInputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
+        if (Serialize(serializer))
+        {
+            modelFile.Write(serializeBuffer.data(), serializeBuffer.size());
+            return true;
+        }
+
+        return false;
+    }
+
+    void MultilayerPerceptron::AddLayer(AZStd::size_t layerDimensionality, ActivationFunctions activationFunction)
+    {
+        // This is not thread safe, this should only be used during model configuration
+        const AZStd::size_t lastLayerDimensionality = GetOutputDimensionality();
+        m_layers.push_back(AZStd::move(Layer(activationFunction, lastLayerDimensionality, layerDimensionality)));
+    }
+
+    Layer* MultilayerPerceptron::GetLayer(AZStd::size_t layerIndex)
+    {
+        // This is not thread safe, this method should only be used by unit testing to inspect layer weights and biases for correctness
+        return &m_layers[layerIndex];
+    }
+
+    bool MultilayerPerceptron::Serialize(AzNetworking::ISerializer& serializer)
+    {
+        //AZStd::lock_guard lock(m_mutex);
+        return serializer.Serialize(m_name, "Name")
+            && serializer.Serialize(m_activationCount, "activationCount")
+            && serializer.Serialize(m_layers, "layers");
+    }
+
+    AZStd::size_t MultilayerPerceptron::EstimateSerializeSize() const
+    {
+        const AZStd::size_t padding = 64; // 64 bytes of extra padding just in case
+        AZStd::size_t estimatedSize = padding 
+            + sizeof(AZStd::size_t)
+            + m_name.size()
+            + sizeof(m_activationCount)
+            + sizeof(AZStd::size_t);
+        //AZStd::lock_guard lock(m_mutex);
+        for (const Layer& layer : m_layers)
+        {
+            estimatedSize += layer.EstimateSerializeSize();
+        }
+        return estimatedSize;
+    }
 }
 }

+ 60 - 13
Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.h

@@ -9,6 +9,7 @@
 #pragma once
 #pragma once
 
 
 #include <AzCore/Math/MatrixMxN.h>
 #include <AzCore/Math/MatrixMxN.h>
+#include <AzNetworking/Serialization/ISerializer.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <Models/Layer.h>
 #include <Models/Layer.h>
 
 
@@ -26,37 +27,83 @@ namespace MachineLearning
         //! @param context reflection context
         //! @param context reflection context
         static void Reflect(AZ::ReflectContext* context);
         static void Reflect(AZ::ReflectContext* context);
 
 
-        MultilayerPerceptron() = default;
-        MultilayerPerceptron(MultilayerPerceptron&&) = default;
-        MultilayerPerceptron(const MultilayerPerceptron&) = default;
+        MultilayerPerceptron();
+        MultilayerPerceptron(const MultilayerPerceptron&);
         MultilayerPerceptron(AZStd::size_t activationCount);
         MultilayerPerceptron(AZStd::size_t activationCount);
-        virtual ~MultilayerPerceptron() = default;
+        virtual ~MultilayerPerceptron();
 
 
-        MultilayerPerceptron& operator=(MultilayerPerceptron&&) = default;
-        MultilayerPerceptron& operator=(const MultilayerPerceptron&) = default;
+        MultilayerPerceptron& operator=(const MultilayerPerceptron&);
 
 
         //! INeuralNetwork interface
         //! INeuralNetwork interface
         //! @{
         //! @{
-        void AddLayer(AZStd::size_t layerDimensionality, ActivationFunctions activationFunction = ActivationFunctions::ReLU) override;
+        AZStd::string GetName() const override;
+        AZStd::string GetAssetFile(AssetTypes assetType) const override;
+        AZStd::size_t GetInputDimensionality() const override;
+        AZStd::size_t GetOutputDimensionality() const override;
         AZStd::size_t GetLayerCount() const override;
         AZStd::size_t GetLayerCount() const override;
-        Layer* GetLayer(AZStd::size_t layerIndex) override;
         AZStd::size_t GetParameterCount() const override;
         AZStd::size_t GetParameterCount() const override;
-        const AZ::VectorN* Forward(const AZ::VectorN& activations) override;
-        void Reverse(LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected) override;
-        void GradientDescent(float learningRate) override;
+        IInferenceContextPtr CreateInferenceContext() override;
+        ITrainingContextPtr CreateTrainingContext() override;
+        const AZ::VectorN* Forward(IInferenceContextPtr context, const AZ::VectorN& activations) override;
+        void Reverse(ITrainingContextPtr context, LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected) override;
+        void GradientDescent(ITrainingContextPtr context, float learningRate) override;
+        bool LoadModel() override;
+        bool SaveModel() override;
         //! @}
         //! @}
 
 
+        //! Adds a new layer to the model.
+        void AddLayer(AZStd::size_t layerDimensionality, ActivationFunctions activationFunction = ActivationFunctions::ReLU);
+
+        //! Retrieves a specific layer from the model, this is not thread safe and should only be used during unit testing to validate model parameters.
+        Layer* GetLayer(AZStd::size_t layerIndex);
+
+        //! Base serialize method for all serializable structures or classes to implement.
+        //! @param serializer ISerializer instance to use for serialization
+        //! @return boolean true for success, false for serialization failure
+        bool Serialize(AzNetworking::ISerializer& serializer);
+
+        //! Returns the estimated size required to serialize this model.
+        AZStd::size_t EstimateSerializeSize() const;
+
     private:
     private:
 
 
         void OnActivationCountChanged();
         void OnActivationCountChanged();
 
 
+        //! The model name.
+        AZStd::string m_name;
+
+        //! The model asset file.
+        AZStd::string m_modelFile;
+
+        //! Optional test and train asset data files.
+        AZStd::string m_testDataFile;
+        AZStd::string m_testLabelFile;
+        AZStd::string m_trainDataFile;
+        AZStd::string m_trainLabelFile;
+
         //! The number of neurons in the activation layer.
         //! The number of neurons in the activation layer.
         AZStd::size_t m_activationCount = 0;
         AZStd::size_t m_activationCount = 0;
 
 
+        //! The set of layers in the network.
+        AZStd::vector<Layer> m_layers;
+    };
+
+    struct MlpInferenceContext
+        : public IInferenceContext
+    {
+        AZStd::vector<LayerInferenceData> m_layerData;
+    };
+
+    struct MlpTrainingContext
+        : public ITrainingContext
+    {
+        //! Used during the forward pass when calculating loss gradients.
+        MlpInferenceContext m_forward;
+
         //! The number of accumulated training samples.
         //! The number of accumulated training samples.
         AZStd::size_t m_trainingSampleSize = 0;
         AZStd::size_t m_trainingSampleSize = 0;
 
 
-        //! The set of layers in the network.
-        AZStd::vector<Layer> m_layers;
+        //! The set of layer training data.
+        AZStd::vector<LayerTrainingData> m_layerData;
     };
     };
 }
 }

+ 0 - 18
Gems/MachineLearning/Code/Source/Nodes/AccumulateTrainingGradients.ScriptCanvasNodeable.xml

@@ -1,18 +0,0 @@
-<?xml version="1.0" encoding="utf-8"?>
-
-<ScriptCanvas Include="Nodes/AccumulateTrainingGradients.h" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
-    <Class Name="AccumulateTrainingGradients"
-           QualifiedName="MachineLearning::AccumulateTrainingGradients"
-           PreferredClassName="Accumulate training gradient"
-           Category="MachineLearning"
-           Description="Accumulates the cost gradients for a machine learning model against a set of activations and a set of expected outputs.">
-
-        <Input Name="In" DisplayGroup="In" Description="Parameters controlling cost gradient calculation">
-            <Parameter Name="Model" Type="MachineLearning::INeuralNetworkPtr" Description="The model to accumulate a cost gradient for."/>
-            <Parameter Name="CostFunction" Type="MachineLearning::LossFunctions" Description="The loss function to use to compute the cost."/>
-            <Parameter Name="Activations" Type="AZ::VectorN" Description="The set of activation values to apply to the model (must match the models input count)."/>
-            <Parameter Name="ExpectedOutput" Type="AZ::VectorN" Description="The expected outputs given the provided inputs (must match the models output count)."/>
-            <Return Name="Model" Type="MachineLearning::INeuralNetworkPtr" Shared="true"/>
-        </Input>
-    </Class>
-</ScriptCanvas>

+ 0 - 19
Gems/MachineLearning/Code/Source/Nodes/AccumulateTrainingGradients.cpp

@@ -1,19 +0,0 @@
-/*
- * Copyright (c) Contributors to the Open 3D Engine Project.
- * For complete copyright and license terms please see the LICENSE at the root of this distribution.
- *
- * SPDX-License-Identifier: Apache-2.0 OR MIT
- *
- */
-
-#include <Nodes/AccumulateTrainingGradients.h>
-#include <Models/MultilayerPerceptron.h>
-
-namespace MachineLearning
-{
-    INeuralNetworkPtr AccumulateTrainingGradients::In(INeuralNetworkPtr Model, LossFunctions LossFunction, AZ::VectorN Activations, AZ::VectorN ExpectedOutput)
-    {
-        Model->Reverse(LossFunction, Activations, ExpectedOutput);
-        return Model;
-    }
-}

+ 15 - 0
Gems/MachineLearning/Code/Source/Nodes/ArgMax.ScriptCanvasNodeable.xml

@@ -0,0 +1,15 @@
+<?xml version="1.0" encoding="utf-8"?>
+
+<ScriptCanvas Include="Nodes/ArgMax.h" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
+    <Class Name="ArgMax"
+           QualifiedName="MachineLearning::ArgMax"
+           PreferredClassName="ArgMax"
+           Category="MachineLearning"
+           Description="Reverses one-hot encoding by returning the index of the element from a VectorN with the largest magnitude.">
+
+        <Input Name="In" DisplayGroup="In" Description="Parameters controlling the arg-max decoding">
+            <Parameter Name="oneHot" Type="AZ::VectorN" Description="The one-hot vector to decode."/>
+            <Return Name="argmax" Type="AZStd::size_t" Shared="true"/>
+        </Input>
+    </Class>
+</ScriptCanvas>

+ 18 - 0
Gems/MachineLearning/Code/Source/Nodes/ArgMax.cpp

@@ -0,0 +1,18 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Nodes/ArgMax.h>
+#include <Algorithms/Activations.h>
+
+namespace MachineLearning
+{
+    AZStd::size_t ArgMax::In(AZ::VectorN Value)
+    {
+        return ArgMaxDecode(Value);
+    }
+}

+ 3 - 3
Gems/MachineLearning/Code/Source/Nodes/GradientDescent.h → Gems/MachineLearning/Code/Source/Nodes/ArgMax.h

@@ -12,13 +12,13 @@
 #include <ScriptCanvas/Core/Nodeable.h>
 #include <ScriptCanvas/Core/Nodeable.h>
 #include <ScriptCanvas/Core/NodeableNode.h>
 #include <ScriptCanvas/Core/NodeableNode.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <MachineLearning/INeuralNetwork.h>
-#include <Source/Nodes/GradientDescent.generated.h>
+#include <Source/Nodes/ArgMax.generated.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
-    class GradientDescent
+    class ArgMax
         : public ScriptCanvas::Nodeable
         : public ScriptCanvas::Nodeable
     {
     {
-        SCRIPTCANVAS_NODE_GradientDescent;
+        SCRIPTCANVAS_NODE_ArgMax;
     };
     };
 }
 }

+ 3 - 1
Gems/MachineLearning/Code/Source/Nodes/ComputeCost.cpp

@@ -14,7 +14,9 @@ namespace MachineLearning
 {
 {
     float ComputeCost::In(INeuralNetworkPtr Model, LossFunctions LossFunction, AZ::VectorN Activations, AZ::VectorN ExpectedOutput)
     float ComputeCost::In(INeuralNetworkPtr Model, LossFunctions LossFunction, AZ::VectorN Activations, AZ::VectorN ExpectedOutput)
     {
     {
-        const AZ::VectorN* modelOutput = Model->Forward(Activations);
+        AZStd::unique_ptr<IInferenceContext> inferenceContext;
+        inferenceContext.reset(Model->CreateInferenceContext());
+        const AZ::VectorN* modelOutput = Model->Forward(inferenceContext.get(), Activations);
         return ComputeTotalCost(LossFunction, ExpectedOutput, *modelOutput);
         return ComputeTotalCost(LossFunction, ExpectedOutput, *modelOutput);
     }
     }
 }
 }

+ 3 - 1
Gems/MachineLearning/Code/Source/Nodes/FeedForward.cpp

@@ -13,7 +13,9 @@ namespace MachineLearning
 {
 {
     AZ::VectorN FeedForward::In(INeuralNetworkPtr Model, AZ::VectorN Activations)
     AZ::VectorN FeedForward::In(INeuralNetworkPtr Model, AZ::VectorN Activations)
     {
     {
-        AZ::VectorN results = *Model->Forward(Activations);
+        AZStd::unique_ptr<IInferenceContext> inferenceContext;
+        inferenceContext.reset(Model->CreateInferenceContext());
+        AZ::VectorN results = *Model->Forward(inferenceContext.get(), Activations);
         return results;
         return results;
     }
     }
 }
 }

+ 0 - 16
Gems/MachineLearning/Code/Source/Nodes/GradientDescent.ScriptCanvasNodeable.xml

@@ -1,16 +0,0 @@
-<?xml version="1.0" encoding="utf-8"?>
-
-<ScriptCanvas Include="Nodes/GradientDescent.h" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
-    <Class Name="GradientDescent"
-           QualifiedName="MachineLearning::GradientDescent"
-           PreferredClassName="Performs a gradient descent step"
-           Category="MachineLearning"
-           Description="Performs a gradient descent step on a model and zeroes all gradient vectors.">
-
-        <Input Name="In" DisplayGroup="In" Description="Parameters controlling gradient descent">
-            <Parameter Name="Model" Type="MachineLearning::INeuralNetworkPtr" Description="The model to perform a gradient descent step on."/>
-            <Parameter Name="LearningRate" Type="float" Description="The learning rate to use."/>
-            <Return Name="Model" Type="MachineLearning::INeuralNetworkPtr" Shared="true"/>
-        </Input>
-    </Class>
-</ScriptCanvas>

+ 15 - 0
Gems/MachineLearning/Code/Source/Nodes/LoadModel.ScriptCanvasNodeable.xml

@@ -0,0 +1,15 @@
+<?xml version="1.0" encoding="utf-8"?>
+
+<ScriptCanvas Include="Nodes/LoadModel.h" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
+    <Class Name="LoadModel"
+           QualifiedName="MachineLearning::LoadModel"
+           PreferredClassName="Load model"
+           Category="MachineLearning"
+           Description="Loads a model from an asset file.">
+
+        <Input Name="In" DisplayGroup="In" Description="Parameters controlling the model to load">
+            <Parameter Name="Model" Type="MachineLearning::INeuralNetworkPtr" Description="The model to load parameters to."/>
+            <Return Name="Model" Type="MachineLearning::INeuralNetworkPtr" Shared="true"/>
+        </Input>
+    </Class>
+</ScriptCanvas>

+ 3 - 4
Gems/MachineLearning/Code/Source/Nodes/GradientDescent.cpp → Gems/MachineLearning/Code/Source/Nodes/LoadModel.cpp

@@ -6,14 +6,13 @@
  *
  *
  */
  */
 
 
-#include <Nodes/GradientDescent.h>
-#include <Models/MultilayerPerceptron.h>
+#include <Nodes/LoadModel.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
-    INeuralNetworkPtr GradientDescent::In(INeuralNetworkPtr Model, float LearningRate)
+    INeuralNetworkPtr LoadModel::In(INeuralNetworkPtr Model)
     {
     {
-        Model->GradientDescent(LearningRate);
+        Model->LoadModel();
         return Model;
         return Model;
     }
     }
 }
 }

+ 3 - 3
Gems/MachineLearning/Code/Source/Nodes/AccumulateTrainingGradients.h → Gems/MachineLearning/Code/Source/Nodes/LoadModel.h

@@ -12,13 +12,13 @@
 #include <ScriptCanvas/Core/Nodeable.h>
 #include <ScriptCanvas/Core/Nodeable.h>
 #include <ScriptCanvas/Core/NodeableNode.h>
 #include <ScriptCanvas/Core/NodeableNode.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <MachineLearning/INeuralNetwork.h>
-#include <Source/Nodes/AccumulateTrainingGradients.generated.h>
+#include <Source/Nodes/LoadModel.generated.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
-    class AccumulateTrainingGradients
+    class LoadModel
         : public ScriptCanvas::Nodeable
         : public ScriptCanvas::Nodeable
     {
     {
-        SCRIPTCANVAS_NODE_AccumulateTrainingGradients;
+        SCRIPTCANVAS_NODE_LoadModel;
     };
     };
 }
 }

+ 15 - 0
Gems/MachineLearning/Code/Source/Nodes/SaveModel.ScriptCanvasNodeable.xml

@@ -0,0 +1,15 @@
+<?xml version="1.0" encoding="utf-8"?>
+
+<ScriptCanvas Include="Nodes/SaveModel.h" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
+    <Class Name="SaveModel"
+           QualifiedName="MachineLearning::SaveModel"
+           PreferredClassName="Save model"
+           Category="MachineLearning"
+           Description="Saves a model to an asset file.">
+
+        <Input Name="In" DisplayGroup="In" Description="Parameters controlling the model to save">
+            <Parameter Name="Model" Type="MachineLearning::INeuralNetworkPtr" Description="The model to save."/>
+            <Return Name="Model" Type="MachineLearning::INeuralNetworkPtr" Shared="true"/>
+        </Input>
+    </Class>
+</ScriptCanvas>

+ 18 - 0
Gems/MachineLearning/Code/Source/Nodes/SaveModel.cpp

@@ -0,0 +1,18 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Nodes/SaveModel.h>
+
+namespace MachineLearning
+{
+    INeuralNetworkPtr SaveModel::In(INeuralNetworkPtr Model)
+    {
+        Model->SaveModel();
+        return Model;
+    }
+}

+ 24 - 0
Gems/MachineLearning/Code/Source/Nodes/SaveModel.h

@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <ScriptCanvas/CodeGen/NodeableCodegen.h>
+#include <ScriptCanvas/Core/Nodeable.h>
+#include <ScriptCanvas/Core/NodeableNode.h>
+#include <MachineLearning/INeuralNetwork.h>
+#include <Source/Nodes/SaveModel.generated.h>
+
+namespace MachineLearning
+{
+    class SaveModel
+        : public ScriptCanvas::Nodeable
+    {
+        SCRIPTCANVAS_NODE_SaveModel;
+    };
+}

+ 11 - 1
Gems/MachineLearning/Code/Source/Nodes/SupervisedLearning.cpp

@@ -10,6 +10,8 @@
 #include <Models/MultilayerPerceptron.h>
 #include <Models/MultilayerPerceptron.h>
 #include <MachineLearning/ILabeledTrainingData.h>
 #include <MachineLearning/ILabeledTrainingData.h>
 #include <Algorithms/Training.h>
 #include <Algorithms/Training.h>
+#include <AzCore/Console/ILogger.h>
+#include <AzCore/std/chrono/chrono.h>
 
 
 namespace MachineLearning
 namespace MachineLearning
 {
 {
@@ -27,7 +29,15 @@ namespace MachineLearning
         float EarlyStopCost
         float EarlyStopCost
     )
     )
     {
     {
-        SupervisedLearningCycle(Model, TrainingData, TestData, static_cast<LossFunctions>(CostFunction), TotalIterations, BatchSize, LearningRate, LearningRateDecay, EarlyStopCost);
+        SupervisedLearningCycle trainingInstance(Model, TrainingData, TestData, static_cast<LossFunctions>(CostFunction), TotalIterations, BatchSize, LearningRate, LearningRateDecay, EarlyStopCost);
+
+        trainingInstance.StartTraining();
+        while (!trainingInstance.m_trainingComplete)
+        {
+            AZStd::this_thread::sleep_for(AZStd::chrono::milliseconds(1));
+        }
+        trainingInstance.StopTraining();
+
         return Model;
         return Model;
     }
     }
 }
 }

+ 23 - 0
Gems/MachineLearning/Code/Tests/Algorithms/ActivationTests.cpp

@@ -17,6 +17,29 @@ namespace UnitTest
     {
     {
     };
     };
 
 
+    TEST_F(MachineLearning_Activations, OneHotArgMax)
+    {
+        AZStd::size_t testValue = 1;
+        AZ::VectorN testVector;
+
+        MachineLearning::OneHotEncode(testValue, 10, testVector);
+        EXPECT_FLOAT_EQ(testVector.GetElement(0), 0.0f);
+        EXPECT_FLOAT_EQ(testVector.GetElement(1), 1.0f);
+        EXPECT_EQ(MachineLearning::ArgMaxDecode(testVector), testValue);
+
+        testValue = 3;
+        MachineLearning::OneHotEncode(testValue, 10, testVector);
+        EXPECT_EQ(MachineLearning::ArgMaxDecode(testVector), testValue);
+
+        testValue = 7;
+        MachineLearning::OneHotEncode(testValue, 10, testVector);
+        EXPECT_EQ(MachineLearning::ArgMaxDecode(testVector), testValue);
+
+        testValue = 8;
+        MachineLearning::OneHotEncode(testValue, 10, testVector);
+        EXPECT_EQ(MachineLearning::ArgMaxDecode(testVector), testValue);
+    }
+
     TEST_F(MachineLearning_Activations, TestRelu)
     TEST_F(MachineLearning_Activations, TestRelu)
     {
     {
         AZ::VectorN output = AZ::VectorN::CreateZero(1024);
         AZ::VectorN output = AZ::VectorN::CreateZero(1024);

+ 7 - 7
Gems/MachineLearning/Code/Tests/Models/LayerTests.cpp

@@ -26,29 +26,29 @@ namespace UnitTest
         EXPECT_EQ(testLayer.m_weights.GetColumnCount(), 8);
         EXPECT_EQ(testLayer.m_weights.GetColumnCount(), 8);
         EXPECT_EQ(testLayer.m_weights.GetRowCount(), 4);
         EXPECT_EQ(testLayer.m_weights.GetRowCount(), 4);
         EXPECT_EQ(testLayer.m_biases.GetDimensionality(), 4);
         EXPECT_EQ(testLayer.m_biases.GetDimensionality(), 4);
-        EXPECT_EQ(testLayer.m_output.GetDimensionality(), 4);
     }
     }
 
 
     TEST_F(MachineLearning_Layers, TestForward)
     TEST_F(MachineLearning_Layers, TestForward)
     {
     {
         // Construct a layer that takes 8 inputs and generates 4 outputs
         // Construct a layer that takes 8 inputs and generates 4 outputs
         MachineLearning::Layer testLayer(MachineLearning::ActivationFunctions::Linear, 8, 4);
         MachineLearning::Layer testLayer(MachineLearning::ActivationFunctions::Linear, 8, 4);
+        MachineLearning::LayerInferenceData inferenceData;
         testLayer.m_biases = AZ::VectorN::CreateOne(testLayer.m_biases.GetDimensionality());
         testLayer.m_biases = AZ::VectorN::CreateOne(testLayer.m_biases.GetDimensionality());
         testLayer.m_weights = AZ::MatrixMxN::CreateZero(testLayer.m_weights.GetRowCount(), testLayer.m_weights.GetColumnCount());
         testLayer.m_weights = AZ::MatrixMxN::CreateZero(testLayer.m_weights.GetRowCount(), testLayer.m_weights.GetColumnCount());
         testLayer.m_weights += 1.0f;
         testLayer.m_weights += 1.0f;
 
 
         const AZ::VectorN ones = AZ::VectorN::CreateOne(8); // Input of all ones
         const AZ::VectorN ones = AZ::VectorN::CreateOne(8); // Input of all ones
-        testLayer.Forward(ones);
-        for (AZStd::size_t iter = 0; iter < testLayer.m_output.GetDimensionality(); ++iter)
+        testLayer.Forward(inferenceData, ones);
+        for (AZStd::size_t iter = 0; iter < inferenceData.m_output.GetDimensionality(); ++iter)
         {
         {
-            ASSERT_FLOAT_EQ(testLayer.m_output.GetElement(iter), 9.0f); // 8 edges of 1's + 1 for the bias
+            ASSERT_FLOAT_EQ(inferenceData.m_output.GetElement(iter), 9.0f); // 8 edges of 1's + 1 for the bias
         }
         }
 
 
         const AZ::VectorN zeros = AZ::VectorN::CreateZero(8); // Input of all zeros
         const AZ::VectorN zeros = AZ::VectorN::CreateZero(8); // Input of all zeros
-        testLayer.Forward(zeros);
-        for (AZStd::size_t iter = 0; iter < testLayer.m_output.GetDimensionality(); ++iter)
+        testLayer.Forward(inferenceData, zeros);
+        for (AZStd::size_t iter = 0; iter < inferenceData.m_output.GetDimensionality(); ++iter)
         {
         {
-            ASSERT_FLOAT_EQ(testLayer.m_output.GetElement(iter), 1.0f); // Weights are all zero, leaving only the layer biases which are all set to 1
+            ASSERT_FLOAT_EQ(inferenceData.m_output.GetElement(iter), 1.0f); // Weights are all zero, leaving only the layer biases which are all set to 1
         }
         }
     }
     }
 }
 }

+ 38 - 36
Gems/MachineLearning/Code/Tests/Models/MultilayerPerceptronTests.cpp

@@ -36,16 +36,18 @@ namespace UnitTest
         const float layer1Biases[] = { 0.60f, 0.60f };
         const float layer1Biases[] = { 0.60f, 0.60f };
 
 
         MachineLearning::MultilayerPerceptron mlp(2);
         MachineLearning::MultilayerPerceptron mlp(2);
+        MachineLearning::MlpInferenceContext inferenceData;
+        MachineLearning::MlpTrainingContext trainingData;
         mlp.AddLayer(2, MachineLearning::ActivationFunctions::Sigmoid);
         mlp.AddLayer(2, MachineLearning::ActivationFunctions::Sigmoid);
         mlp.AddLayer(2, MachineLearning::ActivationFunctions::Sigmoid);
         mlp.AddLayer(2, MachineLearning::ActivationFunctions::Sigmoid);
 
 
-        MachineLearning::Layer& layer0 = mlp.GetLayer(0);
-        layer0.m_weights = AZ::MatrixMxN::CreateFromPackedFloats(2, 2, layer0Weights);
-        layer0.m_biases = AZ::VectorN::CreateFromFloats(2, layer0Biases);
+        MachineLearning::Layer* layer0 = mlp.GetLayer(0);
+        layer0->m_weights = AZ::MatrixMxN::CreateFromPackedFloats(2, 2, layer0Weights);
+        layer0->m_biases = AZ::VectorN::CreateFromFloats(2, layer0Biases);
 
 
-        MachineLearning::Layer& layer1 = mlp.GetLayer(1);
-        layer1.m_weights = AZ::MatrixMxN::CreateFromPackedFloats(2, 2, layer1Weights);
-        layer1.m_biases = AZ::VectorN::CreateFromFloats(2, layer1Biases);
+        MachineLearning::Layer* layer1 = mlp.GetLayer(1);
+        layer1->m_weights = AZ::MatrixMxN::CreateFromPackedFloats(2, 2, layer1Weights);
+        layer1->m_biases = AZ::VectorN::CreateFromFloats(2, layer1Biases);
 
 
         const float activations[] = { 0.05f, 0.10f };
         const float activations[] = { 0.05f, 0.10f };
         const float labels[] = { 0.01f, 0.99f };
         const float labels[] = { 0.01f, 0.99f };
@@ -53,59 +55,59 @@ namespace UnitTest
         const AZ::VectorN trainingInput = AZ::VectorN::CreateFromFloats(2, activations);
         const AZ::VectorN trainingInput = AZ::VectorN::CreateFromFloats(2, activations);
         const AZ::VectorN trainingOutput = AZ::VectorN::CreateFromFloats(2, labels);
         const AZ::VectorN trainingOutput = AZ::VectorN::CreateFromFloats(2, labels);
 
 
-        const AZ::VectorN& actualOutput = mlp.Forward(trainingInput);
+        const AZ::VectorN* actualOutput = mlp.Forward(&inferenceData, trainingInput);
 
 
         // Validate intermediate layer output given the initial weights and biases
         // Validate intermediate layer output given the initial weights and biases
-        EXPECT_TRUE(AZ::IsCloseMag(layer0.m_output.GetElement(0), 0.5933f, 0.01f));
-        EXPECT_TRUE(AZ::IsCloseMag(layer0.m_output.GetElement(1), 0.5969f, 0.01f));
+        EXPECT_TRUE(AZ::IsCloseMag(inferenceData.m_layerData[0].m_output.GetElement(0), 0.5933f, 0.01f));
+        EXPECT_TRUE(AZ::IsCloseMag(inferenceData.m_layerData[0].m_output.GetElement(1), 0.5969f, 0.01f));
 
 
         // Validate final model output given the initial weights and biases
         // Validate final model output given the initial weights and biases
-        EXPECT_TRUE(AZ::IsCloseMag(actualOutput.GetElement(0), 0.75f, 0.01f));
-        EXPECT_TRUE(AZ::IsCloseMag(actualOutput.GetElement(1), 0.77f, 0.01f));
+        EXPECT_TRUE(AZ::IsCloseMag(actualOutput->GetElement(0), 0.75f, 0.01f));
+        EXPECT_TRUE(AZ::IsCloseMag(actualOutput->GetElement(1), 0.77f, 0.01f));
 
 
-        float cost = MachineLearning::ComputeTotalCost(MachineLearning::LossFunctions::MeanSquaredError, trainingOutput, actualOutput);
+        float cost = MachineLearning::ComputeTotalCost(MachineLearning::LossFunctions::MeanSquaredError, trainingOutput, *actualOutput);
         EXPECT_TRUE(AZ::IsCloseMag(cost, 0.60f, 0.01f));
         EXPECT_TRUE(AZ::IsCloseMag(cost, 0.60f, 0.01f));
 
 
-        mlp.Reverse(MachineLearning::LossFunctions::MeanSquaredError, trainingInput, trainingOutput);
+        mlp.Reverse(&trainingData, MachineLearning::LossFunctions::MeanSquaredError, trainingInput, trainingOutput);
 
 
         // Check the activation gradients
         // Check the activation gradients
-        EXPECT_NEAR(layer1.m_activationGradients.GetElement(0),  0.1385f, 0.01f);
-        EXPECT_NEAR(layer1.m_activationGradients.GetElement(1), -0.0381f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[1].m_activationGradients.GetElement(0),  0.1385f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[1].m_activationGradients.GetElement(1), -0.0381f, 0.01f);
 
 
-        EXPECT_NEAR(layer1.m_weightGradients.GetElement(0, 0),  0.0822f, 0.01f);
-        EXPECT_NEAR(layer1.m_weightGradients.GetElement(0, 1),  0.0826f, 0.01f);
-        EXPECT_NEAR(layer1.m_weightGradients.GetElement(1, 0), -0.0226f, 0.01f);
-        EXPECT_NEAR(layer1.m_weightGradients.GetElement(1, 1), -0.0227f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[1].m_weightGradients.GetElement(0, 0),  0.0822f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[1].m_weightGradients.GetElement(0, 1),  0.0826f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[1].m_weightGradients.GetElement(1, 0), -0.0226f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[1].m_weightGradients.GetElement(1, 1), -0.0227f, 0.01f);
 
 
-        EXPECT_NEAR(layer1.m_backpropagationGradients.GetElement(0), 0.0364f, 0.01f);
-        EXPECT_NEAR(layer1.m_backpropagationGradients.GetElement(1), 0.0414f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[1].m_backpropagationGradients.GetElement(0), 0.0364f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[1].m_backpropagationGradients.GetElement(1), 0.0414f, 0.01f);
 
 
-        EXPECT_NEAR(layer0.m_weightGradients.GetElement(0, 0),  0.0004f, 0.01f);
-        EXPECT_NEAR(layer0.m_weightGradients.GetElement(0, 1),  0.0008f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[0].m_weightGradients.GetElement(0, 0),  0.0004f, 0.01f);
+        EXPECT_NEAR(trainingData.m_layerData[0].m_weightGradients.GetElement(0, 1),  0.0008f, 0.01f);
 
 
-        mlp.GradientDescent(0.5f);
+        mlp.GradientDescent(&trainingData, 0.5f);
 
 
-        EXPECT_NEAR(layer1.m_weights.GetElement(0, 0), 0.3590f, 0.01f);
-        EXPECT_NEAR(layer1.m_weights.GetElement(0, 1), 0.4087f, 0.01f);
-        EXPECT_NEAR(layer1.m_weights.GetElement(1, 0), 0.5113f, 0.01f);
-        EXPECT_NEAR(layer1.m_weights.GetElement(1, 1), 0.5614f, 0.01f);
+        EXPECT_NEAR(layer1->m_weights.GetElement(0, 0), 0.3590f, 0.01f);
+        EXPECT_NEAR(layer1->m_weights.GetElement(0, 1), 0.4087f, 0.01f);
+        EXPECT_NEAR(layer1->m_weights.GetElement(1, 0), 0.5113f, 0.01f);
+        EXPECT_NEAR(layer1->m_weights.GetElement(1, 1), 0.5614f, 0.01f);
 
 
-        EXPECT_NEAR(layer0.m_weights.GetElement(0, 0), 0.1498f, 0.01f);
-        EXPECT_NEAR(layer0.m_weights.GetElement(0, 1), 0.1996f, 0.01f);
-        EXPECT_NEAR(layer0.m_weights.GetElement(1, 0), 0.2495f, 0.01f);
-        EXPECT_NEAR(layer0.m_weights.GetElement(1, 1), 0.2995f, 0.01f);
+        EXPECT_NEAR(layer0->m_weights.GetElement(0, 0), 0.1498f, 0.01f);
+        EXPECT_NEAR(layer0->m_weights.GetElement(0, 1), 0.1996f, 0.01f);
+        EXPECT_NEAR(layer0->m_weights.GetElement(1, 0), 0.2495f, 0.01f);
+        EXPECT_NEAR(layer0->m_weights.GetElement(1, 1), 0.2995f, 0.01f);
 
 
         // Now lets evaluate a whole training cycle
         // Now lets evaluate a whole training cycle
         const AZStd::size_t numTrainingLoops = 10000;
         const AZStd::size_t numTrainingLoops = 10000;
         for (AZStd::size_t iter = 0; iter < numTrainingLoops; ++iter)
         for (AZStd::size_t iter = 0; iter < numTrainingLoops; ++iter)
         {
         {
-            mlp.Reverse(MachineLearning::LossFunctions::MeanSquaredError, trainingInput, trainingOutput);
-            mlp.GradientDescent(0.5f);
+            mlp.Reverse(&trainingData, MachineLearning::LossFunctions::MeanSquaredError, trainingInput, trainingOutput);
+            mlp.GradientDescent(&trainingData, 0.5f);
         }
         }
 
 
         // We expect the total cost of the network on the training sample to be much lower after training
         // We expect the total cost of the network on the training sample to be much lower after training
-        const AZ::VectorN& trainedOutput = mlp.Forward(trainingInput);
-        float trainedCost = MachineLearning::ComputeTotalCost(MachineLearning::LossFunctions::MeanSquaredError, trainingOutput, trainedOutput);
+        const AZ::VectorN* trainedOutput = mlp.Forward(&inferenceData, trainingInput);
+        float trainedCost = MachineLearning::ComputeTotalCost(MachineLearning::LossFunctions::MeanSquaredError, trainingOutput, *trainedOutput);
         EXPECT_LT(trainedCost, 5.0e-6f);
         EXPECT_LT(trainedCost, 5.0e-6f);
     }
     }
 }
 }

+ 3 - 1
Gems/MachineLearning/Code/machinelearning_api_files.cmake

@@ -7,9 +7,11 @@
 #
 #
 
 
 set(FILES
 set(FILES
+    Include/MachineLearning/IInferenceContext.h
     Include/MachineLearning/INeuralNetwork.h
     Include/MachineLearning/INeuralNetwork.h
     Include/MachineLearning/ILabeledTrainingData.h
     Include/MachineLearning/ILabeledTrainingData.h
-    Include/MachineLearning/MachineLearningBus.h
+    Include/MachineLearning/IMachineLearning.h
+    Include/MachineLearning/ITrainingContext.h
     Include/MachineLearning/MachineLearningTypeIds.h
     Include/MachineLearning/MachineLearningTypeIds.h
     Include/MachineLearning/Types.h
     Include/MachineLearning/Types.h
 )
 )

+ 16 - 0
Gems/MachineLearning/Code/machinelearning_debug_files.cmake

@@ -0,0 +1,16 @@
+#
+# Copyright (c) Contributors to the Open 3D Engine Project.
+# For complete copyright and license terms please see the LICENSE at the root of this distribution.
+#
+# SPDX-License-Identifier: Apache-2.0 OR MIT
+#
+#
+
+set(FILES
+    Source/Debug/MachineLearningDebugModule.cpp
+    Source/Debug/MachineLearningDebugModule.h
+    Source/Debug/MachineLearningDebugTrainingWindow.cpp
+    Source/Debug/MachineLearningDebugTrainingWindow.h
+    Source/Debug/MachineLearningDebugSystemComponent.cpp
+    Source/Debug/MachineLearningDebugSystemComponent.h
+)

+ 9 - 6
Gems/MachineLearning/Code/machinelearning_private_files.cmake

@@ -25,24 +25,27 @@ set(FILES
     Source/Models/Layer.h
     Source/Models/Layer.h
     Source/Models/MultilayerPerceptron.cpp
     Source/Models/MultilayerPerceptron.cpp
     Source/Models/MultilayerPerceptron.h
     Source/Models/MultilayerPerceptron.h
-    Source/Nodes/AccumulateTrainingGradients.ScriptCanvasNodeable.xml
-    Source/Nodes/AccumulateTrainingGradients.cpp
-    Source/Nodes/AccumulateTrainingGradients.h
+    Source/Nodes/ArgMax.ScriptCanvasNodeable.xml
+    Source/Nodes/ArgMax.cpp
+    Source/Nodes/ArgMax.h
     Source/Nodes/ComputeCost.ScriptCanvasNodeable.xml
     Source/Nodes/ComputeCost.ScriptCanvasNodeable.xml
     Source/Nodes/ComputeCost.cpp
     Source/Nodes/ComputeCost.cpp
     Source/Nodes/ComputeCost.h
     Source/Nodes/ComputeCost.h
     Source/Nodes/FeedForward.ScriptCanvasNodeable.xml
     Source/Nodes/FeedForward.ScriptCanvasNodeable.xml
     Source/Nodes/FeedForward.cpp
     Source/Nodes/FeedForward.cpp
     Source/Nodes/FeedForward.h
     Source/Nodes/FeedForward.h
-    Source/Nodes/GradientDescent.ScriptCanvasNodeable.xml
-    Source/Nodes/GradientDescent.cpp
-    Source/Nodes/GradientDescent.h
+    Source/Nodes/LoadModel.ScriptCanvasNodeable.xml
+    Source/Nodes/LoadModel.cpp
+    Source/Nodes/LoadModel.h
     Source/Nodes/LoadTrainingData.ScriptCanvasNodeable.xml
     Source/Nodes/LoadTrainingData.ScriptCanvasNodeable.xml
     Source/Nodes/LoadTrainingData.cpp
     Source/Nodes/LoadTrainingData.cpp
     Source/Nodes/LoadTrainingData.h
     Source/Nodes/LoadTrainingData.h
     Source/Nodes/OneHot.ScriptCanvasNodeable.xml
     Source/Nodes/OneHot.ScriptCanvasNodeable.xml
     Source/Nodes/OneHot.cpp
     Source/Nodes/OneHot.cpp
     Source/Nodes/OneHot.h
     Source/Nodes/OneHot.h
+    Source/Nodes/SaveModel.ScriptCanvasNodeable.xml
+    Source/Nodes/SaveModel.cpp
+    Source/Nodes/SaveModel.h
     Source/Nodes/SupervisedLearning.ScriptCanvasNodeable.xml
     Source/Nodes/SupervisedLearning.ScriptCanvasNodeable.xml
     Source/Nodes/SupervisedLearning.cpp
     Source/Nodes/SupervisedLearning.cpp
     Source/Nodes/SupervisedLearning.h
     Source/Nodes/SupervisedLearning.h