Bläddra i källkod

Addresses code review feedback and starts conversion to the AZ::Asset system. Asset system work is incomplete at this time

Signed-off-by: kberg-amzn <[email protected]>
kberg-amzn 1 år sedan
förälder
incheckning
027a8d5c6b

+ 0 - 0
Gems/MachineLearning/Assets/Models/NumberClassifier → Gems/MachineLearning/Assets/Models/NumberClassifier.mlmodel


+ 6 - 7
Gems/MachineLearning/Code/Source/Algorithms/Training.h

@@ -52,12 +52,6 @@ namespace MachineLearning
         AZ::ThreadSafeDeque<float> m_testCosts;
         AZ::ThreadSafeDeque<float> m_trainCosts;
 
-    //private:
-        //! Calculates the average cost of the provided model on the set of labeled test data using the requested loss function.
-        float ComputeCurrentCost(ILabeledTrainingData& testData, LossFunctions costFunction);
-
-        void ExecTraining();
-
         INeuralNetworkPtr m_model;
         bool m_shuffleTrainingData = true;
         TrainingDataView m_trainData;
@@ -69,10 +63,15 @@ namespace MachineLearning
         float m_learningRateDecay = 0.0f;
         float m_earlyStopCost = 0.0f;
         AZStd::size_t m_currentIndex = 0;
-
         AZStd::unique_ptr<IInferenceContext> m_inferenceContext;
         AZStd::unique_ptr<ITrainingContext> m_trainingContext;
 
+    private:
+
+        //! Calculates the average cost of the provided model on the set of labeled test data using the requested loss function.
+        float ComputeCurrentCost(ILabeledTrainingData& testData, LossFunctions costFunction);
+        void ExecTraining();
+
         AZStd::unique_ptr<AZ::JobManager> m_trainingJobManager;
         AZStd::unique_ptr<AZ::JobContext> m_trainingjobContext;
 

+ 82 - 0
Gems/MachineLearning/Code/Source/Assets/ModelAsset.cpp

@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Assets/ModelAsset.h>
+#include <AzNetworking/Serialization/NetworkInputSerializer.h>
+#include <AzNetworking/Serialization/NetworkOutputSerializer.h>
+
+namespace MachineLearning
+{
+    bool ModelAsset::Serialize(AzNetworking::ISerializer& serializer)
+    {
+        return serializer.Serialize(m_name, "Name")
+            && serializer.Serialize(m_activationCount, "activationCount")
+            && serializer.Serialize(m_layers, "layers");
+    }
+
+    AZStd::size_t ModelAsset::EstimateSerializeSize() const
+    {
+        const AZStd::size_t padding = 64; // 64 bytes of extra padding just in case
+        AZStd::size_t estimatedSize = padding
+            + sizeof(AZStd::size_t)
+            + m_name.size()
+            + sizeof(m_activationCount)
+            + sizeof(AZStd::size_t);
+        for (const Layer& layer : m_layers)
+        {
+            estimatedSize += layer.EstimateSerializeSize();
+        }
+        return estimatedSize;
+    }
+
+    ModelAssetHandler::ModelAssetHandler()
+        : AzFramework::GenericAssetHandler<ModelAsset>(ModelAsset::DisplayName, ModelAsset::Group, ModelAsset::Extension)
+    {
+    }
+
+    AZ::Data::AssetHandler::LoadResult ModelAssetHandler::LoadAssetData
+    (
+        const AZ::Data::Asset<AZ::Data::AssetData>& asset, 
+        AZStd::shared_ptr<AZ::Data::AssetDataStream> stream,
+        [[maybe_unused]]const AZ::Data::AssetFilterCB& assetLoadFilterCB
+    )
+    {
+        ModelAsset* assetData = asset.GetAs<ModelAsset>();
+        AZ_Assert(assetData, "Asset is of the wrong type.");
+
+        const AZ::IO::SizeType length = stream->GetLength();
+
+        AZStd::vector<uint8_t> serializeBuffer;
+        serializeBuffer.resize(length);
+        stream->Read(length, serializeBuffer.data());
+        AzNetworking::NetworkOutputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
+        if (assetData->Serialize(serializer))
+        {
+            return AZ::Data::AssetHandler::LoadResult::LoadComplete;
+        }
+
+        return AZ::Data::AssetHandler::LoadResult::Error;
+    }
+
+    bool ModelAssetHandler::SaveAssetData(const AZ::Data::Asset<AZ::Data::AssetData>& asset, AZ::IO::GenericStream* stream)
+    {
+        ModelAsset* assetData = asset.GetAs<ModelAsset>();
+        AZ_Assert(assetData, "Asset is of the wrong type.");
+
+        AZStd::vector<uint8_t> serializeBuffer;
+        serializeBuffer.resize(assetData->EstimateSerializeSize());
+        AzNetworking::NetworkInputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
+        if (assetData->Serialize(serializer))
+        {
+            stream->Write(serializer.GetSize(), serializeBuffer.data());
+            return true;
+        }
+
+        return false;
+    }
+}

+ 66 - 0
Gems/MachineLearning/Code/Source/Assets/ModelAsset.h

@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <AzCore/Asset/AssetCommon.h>
+#include <AzCore/Asset/AssetSerializer.h>
+#include <AzNetworking/Serialization/ISerializer.h>
+#include <AzFramework/Asset/GenericAssetHandler.h>
+#include <Models/Layer.h>
+
+namespace MachineLearning
+{
+    class ModelAsset final
+        : public AZ::Data::AssetData
+    {
+    public:
+        static constexpr inline const char* DisplayName = "ModelAsset";
+        static constexpr inline const char* Extension = "mlmodel";
+        static constexpr inline const char* Group = "MachineLearning";
+
+        AZ_RTTI(ModelAsset, "{4D8D3782-DC3A-499A-A59D-542B85F5EDE9}", AZ::Data::AssetData);
+        AZ_CLASS_ALLOCATOR(ModelAsset, AZ::SystemAllocator);
+
+        ~ModelAsset() = default;
+
+        //! Base serialize method for all serializable structures or classes to implement.
+        //! @param serializer ISerializer instance to use for serialization
+        //! @return boolean true for success, false for serialization failure
+        bool Serialize(AzNetworking::ISerializer& serializer);
+
+        //! Returns the estimated size required to serialize this model.
+        AZStd::size_t EstimateSerializeSize() const;
+
+        //! The model name.
+        AZStd::string m_name;
+
+        //! The number of neurons in the activation layer.
+        AZStd::size_t m_activationCount = 0;
+
+        //! The set of layers in the network.
+        AZStd::vector<Layer> m_layers;
+    };
+
+    class ModelAssetHandler final
+        : public AzFramework::GenericAssetHandler<ModelAsset>
+    {
+    public:
+        ModelAssetHandler();
+
+    private:
+        AZ::Data::AssetHandler::LoadResult LoadAssetData
+        (
+            const AZ::Data::Asset<AZ::Data::AssetData>& asset, 
+            AZStd::shared_ptr<AZ::Data::AssetDataStream> stream,
+            const AZ::Data::AssetFilterCB& assetLoadFilterCB
+        ) override;
+
+        bool SaveAssetData(const AZ::Data::Asset<AZ::Data::AssetData>& asset, AZ::IO::GenericStream* stream) override;
+    };
+}

+ 0 - 30
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.cpp

@@ -126,26 +126,6 @@ namespace MachineLearning
         }
     }
 
-    void MachineLearningDebugTrainingWindow::DrawLayerParameters(TrainingInstance* trainingInstance, AZStd::size_t layerIndex)
-    {
-        if (trainingInstance->m_layerWeights.size() < layerIndex)
-        {
-            trainingInstance->m_layerWeights.resize(layerIndex + 1);
-            trainingInstance->m_layerWeights[layerIndex].Init("Weights", 250, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 1.0f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
-        }
-
-        if (trainingInstance->m_layerBiases.size() < layerIndex)
-        {
-            trainingInstance->m_layerBiases.resize(layerIndex + 1);
-            trainingInstance->m_layerBiases[layerIndex].Init("Biases", 250, ImGui::LYImGuiUtils::HistogramContainer::ViewType::Histogram, true, 0.0f, 1.0f, ImGui::LYImGuiUtils::HistogramContainer::AutoExpand);
-        }
-
-        //m_selectedModel->GetLayerWeights(layerIter), m_selectedModel->GetLayerBiases(layerIter)
-
-        //trainingInstance->m_layerWeights[layerIndex].Draw(ImGui::GetColumnWidth(), 200.0f);
-        //trainingInstance->m_layerBiases[layerIndex].Draw(ImGui::GetColumnWidth(), 200.0f);
-    }
-
     void DrawDataPanel(TrainingDataView& data, AZStd::string& dataName, AZStd::string& labelName)
     {
         ImGui::PushID(&data);
@@ -357,16 +337,6 @@ namespace MachineLearning
                 DrawDataPanel(trainingInstance->m_trainingCycle.m_trainData, trainingInstance->m_trainDataName, trainingInstance->m_trainLabelName);
             }
 
-            //for (AZStd::size_t layerIter = 0; layerIter < m_selectedModel->GetLayerCount(); ++layerIter)
-            //{
-            //    AZStd::fixed_string<64> name;
-            //    name = AZStd::string::format("Layer %u Parameters", static_cast<uint32_t>(layerIter));
-            //    if (ImGui::CollapsingHeader(name.c_str()))
-            //    {
-            //        DrawLayerParameters(trainingInstance, layerIter);
-            //    }
-            //}
-
             ImGui::PopItemWidth();
             ImGui::PopStyleVar();
         }

+ 0 - 4
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.h

@@ -39,9 +39,6 @@ namespace MachineLearning
 #ifdef IMGUI_ENABLED
         ImGui::LYImGuiUtils::HistogramContainer m_testHistogram;
         ImGui::LYImGuiUtils::HistogramContainer m_trainHistogram;
-
-        AZStd::vector<ImGui::LYImGuiUtils::HistogramContainer> m_layerWeights;
-        AZStd::vector<ImGui::LYImGuiUtils::HistogramContainer> m_layerBiases;
 #endif
     };
 
@@ -55,7 +52,6 @@ namespace MachineLearning
         void RecalculateAccuracy(TrainingInstance* trainingInstance, ILabeledTrainingData& data);
 
 #ifdef IMGUI_ENABLED
-        void DrawLayerParameters(TrainingInstance* trainingInstance, AZStd::size_t layerIndex);
         void OnImGuiUpdate();
 #endif
 

+ 2 - 0
Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.cpp

@@ -104,10 +104,12 @@ namespace MachineLearning
     void MachineLearningSystemComponent::Activate()
     {
         MachineLearningRequestBus::Handler::BusConnect();
+        m_assetHandler.Register();
     }
 
     void MachineLearningSystemComponent::Deactivate()
     {
+        m_assetHandler.Unregister();
         MachineLearningRequestBus::Handler::BusDisconnect();
     }
 

+ 2 - 1
Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.h

@@ -10,13 +10,13 @@
 
 #include <AzCore/Component/Component.h>
 #include <MachineLearning/IMachineLearning.h>
+#include <Assets/ModelAsset.h>
 
 namespace MachineLearning
 {
     class MachineLearningSystemComponent
         : public AZ::Component
         , protected MachineLearningRequestBus::Handler
-//        , public AZ::Interface<IMachineLearning>::Registrar
     {
     public:
         AZ_COMPONENT_DECL(MachineLearningSystemComponent);
@@ -50,5 +50,6 @@ namespace MachineLearning
     private:
 
         ModelSet m_registeredModels;
+        ModelAssetHandler m_assetHandler;
     };
 }

+ 30 - 11
Gems/MachineLearning/Code/Source/Models/Layer.cpp

@@ -20,6 +20,31 @@
 namespace MachineLearning
 {
     AZ_CVAR(bool, ml_logGradients, false, nullptr, AZ::ConsoleFunctorFlags::Null, "Dumps some gradient metrics so they can be monitored during training");
+    AZ_CVAR(bool, ml_logGradientsVerbose, false, nullptr, AZ::ConsoleFunctorFlags::Null, "Dumps complete gradient values to the console for examination, this can be a significant amount of data");
+
+    void DumpVectorGradients(const AZ::VectorN& value, const char* label)
+    {
+        AZStd::string vectorString(label);
+        for (AZStd::size_t iter = 0; iter < value.GetDimensionality(); ++iter)
+        {
+            vectorString += AZStd::string::format(" %.02f", value.GetElement(iter));
+        }
+        AZLOG_INFO(vectorString.c_str());
+    }
+
+    void DumpMatrixGradients(const AZ::MatrixMxN& value, const char* label)
+    {
+        for (AZStd::size_t i = 0; i < value.GetRowCount(); ++i)
+        {
+            AZStd::string rowString(label);
+            rowString += AZStd::string::format(":%u", static_cast<uint32_t>(i));
+            for (AZStd::size_t j = 0; j < value.GetColumnCount(); ++j)
+            {
+                rowString += AZStd::string::format(" %.02f", value.GetElement(i, j));
+            }
+            AZLOG_INFO(rowString.c_str());
+        }
+    }
 
     void AccumulateBiasGradients(AZ::VectorN& biasGradients, const AZ::VectorN& activationGradients, AZStd::size_t currentSamples)
     {
@@ -203,18 +228,12 @@ namespace MachineLearning
 
             GetMinMaxElements(trainingData.m_backpropagationGradients, min, max);
             AZLOG_INFO("Back-propagation gradients: min value %f, max value %f", min, max);
+        }
 
-            //for (AZStd::size_t i = 0; i < trainingData.m_weightGradients.GetRowCount(); ++i)
-            //{
-            //    for (AZStd::size_t j = 0; j < trainingData.m_weightGradients.GetColumnCount(); ++j)
-            //    {
-            //        AZLOG_INFO("Weight %ux%u : %f", i, j, trainingData.m_weightGradients.GetElement(i, j));
-            //    }
-            //}
-            //for (AZStd::size_t i = 0; i < trainingData.m_biasGradients.GetDimensionality(); ++i)
-            //{
-            //    AZLOG_INFO("Bias %u : %f", i, trainingData.m_biasGradients.GetElement(i));
-            //}
+        if (ml_logGradientsVerbose)
+        {
+            DumpMatrixGradients(trainingData.m_weightGradients, "WeightGradients");
+            DumpVectorGradients(trainingData.m_biasGradients, "BiasGradients");
         }
     }
 

+ 2 - 12
Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.cpp

@@ -27,6 +27,7 @@ namespace MachineLearning
         {
             serializeContext->Class<MultilayerPerceptron>()
                 ->Version(1)
+                ->Field("ModelAsset", &MultilayerPerceptron::m_asset)
                 ->Field("Name", &MultilayerPerceptron::m_name)
                 ->Field("ModelFile", &MultilayerPerceptron::m_modelFile)
                 ->Field("TestDataFile", &MultilayerPerceptron::m_testDataFile)
@@ -41,6 +42,7 @@ namespace MachineLearning
             {
                 editContext->Class<MultilayerPerceptron>("A basic multilayer perceptron class", "")
                     ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
+                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_asset, "ModelAsset", "The model asset")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_name, "Name", "The name for this model")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_modelFile, "ModelFile", "The file this model is saved to and loaded from")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testDataFile, "TestDataFile", "The file test data should be loaded from")
@@ -140,7 +142,6 @@ namespace MachineLearning
 
     AZStd::size_t MultilayerPerceptron::GetOutputDimensionality() const
     {
-        //AZStd::lock_guard lock(m_mutex);
         if (!m_layers.empty())
         {
             return m_layers.back().m_biases.GetDimensionality();
@@ -150,25 +151,21 @@ namespace MachineLearning
 
     AZStd::size_t MultilayerPerceptron::GetLayerCount() const
     {
-        //AZStd::lock_guard lock(m_mutex);
         return m_layers.size();
     }
 
     AZ::MatrixMxN MultilayerPerceptron::GetLayerWeights(AZStd::size_t layerIndex) const
     {
-        //AZStd::lock_guard lock(m_mutex);
         return m_layers[layerIndex].m_weights;
     }
 
     AZ::VectorN MultilayerPerceptron::GetLayerBiases(AZStd::size_t layerIndex) const
     {
-        //AZStd::lock_guard lock(m_mutex);
         return m_layers[layerIndex].m_biases;
     }
 
     AZStd::size_t MultilayerPerceptron::GetParameterCount() const
     {
-        //AZStd::lock_guard lock(m_mutex);
         AZStd::size_t parameterCount = 0;
         for (const Layer& layer : m_layers)
         {
@@ -189,7 +186,6 @@ namespace MachineLearning
 
     const AZ::VectorN* MultilayerPerceptron::Forward(IInferenceContextPtr context, const AZ::VectorN& activations)
     {
-        //AZStd::lock_guard lock(m_mutex);
         MlpInferenceContext* forwardContext = static_cast<MlpInferenceContext*>(context);
         forwardContext->m_layerData.resize(m_layers.size());
 
@@ -204,7 +200,6 @@ namespace MachineLearning
 
     void MultilayerPerceptron::Reverse(ITrainingContextPtr context, LossFunctions lossFunction, const AZ::VectorN& activations, const AZ::VectorN& expected)
     {
-        //AZStd::lock_guard lock(m_mutex);
         MlpTrainingContext* reverseContext = static_cast<MlpTrainingContext*>(context);
         MlpInferenceContext* forwardContext = &reverseContext->m_forward;
         reverseContext->m_layerData.resize(m_layers.size());
@@ -236,11 +231,9 @@ namespace MachineLearning
 
     void MultilayerPerceptron::GradientDescent(ITrainingContextPtr context, float learningRate)
     {
-        //AZStd::lock_guard lock(m_mutex);
         MlpTrainingContext* reverseContext = static_cast<MlpTrainingContext*>(context);
         if (reverseContext->m_trainingSampleSize > 0)
         {
-            //const float adjustedLearningRate = learningRate / static_cast<float>(reverseContext->m_trainingSampleSize);
             for (AZStd::size_t iter = 0; iter < m_layers.size(); ++iter)
             {
                 m_layers[iter].ApplyGradients(reverseContext->m_layerData[iter], learningRate);
@@ -251,7 +244,6 @@ namespace MachineLearning
 
     void MultilayerPerceptron::OnActivationCountChanged()
     {
-        //AZStd::lock_guard lock(m_mutex);
         AZStd::size_t lastLayerDimensionality = m_activationCount;
         for (Layer& layer : m_layers)
         {
@@ -334,7 +326,6 @@ namespace MachineLearning
 
     bool MultilayerPerceptron::Serialize(AzNetworking::ISerializer& serializer)
     {
-        //AZStd::lock_guard lock(m_mutex);
         return serializer.Serialize(m_name, "Name")
             && serializer.Serialize(m_activationCount, "activationCount")
             && serializer.Serialize(m_layers, "layers");
@@ -348,7 +339,6 @@ namespace MachineLearning
             + m_name.size()
             + sizeof(m_activationCount)
             + sizeof(AZStd::size_t);
-        //AZStd::lock_guard lock(m_mutex);
         for (const Layer& layer : m_layers)
         {
             estimatedSize += layer.EstimateSerializeSize();

+ 4 - 0
Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.h

@@ -12,6 +12,7 @@
 #include <AzNetworking/Serialization/ISerializer.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <Models/Layer.h>
+#include <Assets/ModelAsset.h>
 
 namespace MachineLearning
 {
@@ -71,6 +72,9 @@ namespace MachineLearning
 
         void OnActivationCountChanged();
 
+        //! The model asset.
+        AZ::Data::Asset<ModelAsset> m_asset;
+
         //! The model name.
         AZStd::string m_name;
 

+ 0 - 1
Gems/MachineLearning/Code/Source/Nodes/SupervisedLearning.cpp

@@ -20,7 +20,6 @@ namespace MachineLearning
         INeuralNetworkPtr Model,
         ILabeledTrainingDataPtr TrainingData, 
         ILabeledTrainingDataPtr TestData,
-        //LossFunctions CostFunction, 
         AZStd::size_t CostFunction,
         AZStd::size_t TotalIterations,
         AZStd::size_t BatchSize,

+ 2 - 0
Gems/MachineLearning/Code/Tests/Algorithms/ActivationTests.cpp

@@ -87,8 +87,10 @@ namespace UnitTest
 
         // Additionally, the sum of all the elements should be <= 1, as softmax returns a probability distribution
         const float totalSum = output.L1Norm();
+
         // Between floating point precision and the estimates we use for exp(x), the total sum probability can be slightly greater than one
         // We add a small epsilon to account for this error
+        ASSERT_GE(totalSum, 1.0f - AZ::Constants::Tolerance);
         ASSERT_LE(totalSum, 1.0f + AZ::Constants::Tolerance);
     }
 

+ 2 - 0
Gems/MachineLearning/Code/machinelearning_private_files.cmake

@@ -19,6 +19,8 @@ set(FILES
     Source/Algorithms/Training.h
     Source/Assets/MnistDataLoader.cpp
     Source/Assets/MnistDataLoader.h
+    Source/Assets/ModelAsset.cpp
+    Source/Assets/ModelAsset.h
     Source/Assets/TrainingDataView.cpp
     Source/Assets/TrainingDataView.h
     Source/Components/MultilayerPerceptronComponent.cpp

+ 5 - 0
Gems/MachineLearning/Registry/assetprocessor_settings.setreg

@@ -11,6 +11,11 @@
                     "watch": "@GEMROOT:MachineLearning@/Registry",
                     "recursive": 1,
                     "order": 102
+                },
+                "RC mlasset": {
+                    "glob": "*.mlasset",
+                    "params": "copy",
+                    "productAssetType": "{4D8D3782-DC3A-499A-A59D-542B85F5EDE9}"
                 }
             }
         }