Преглед на файлове

Alter .mlfiles to be recognized azassets

Signed-off-by: kberg-amzn <[email protected]>
kberg-amzn преди 1 година
родител
ревизия
4ad254e90d
променени са 22 файла, в които са добавени 457 реда и са изтрити 137 реда
  1. 3 3
      Gems/MachineLearning/Code/CMakeLists.txt
  2. 8 1
      Gems/MachineLearning/Code/Include/MachineLearning/Types.h
  3. 0 2
      Gems/MachineLearning/Code/Source/Algorithms/Training.cpp
  4. 2 1
      Gems/MachineLearning/Code/Source/Assets/MnistDataLoader.cpp
  5. 15 0
      Gems/MachineLearning/Code/Source/Assets/ModelAsset.cpp
  6. 2 0
      Gems/MachineLearning/Code/Source/Assets/ModelAsset.h
  7. 57 19
      Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.cpp
  8. 18 0
      Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.h
  9. 1 6
      Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugSystemComponent.cpp
  10. 0 3
      Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.cpp
  11. 4 2
      Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.cpp
  12. 1 1
      Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.h
  13. 0 4
      Gems/MachineLearning/Code/Source/Models/Layer.cpp
  14. 15 75
      Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.cpp
  15. 4 15
      Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.h
  16. 9 0
      Gems/MachineLearning/Code/Source/Tools/MachineLearningEditorModule.cpp
  17. 7 2
      Gems/MachineLearning/Code/Source/Tools/MachineLearningEditorSystemComponent.cpp
  18. 7 1
      Gems/MachineLearning/Code/Source/Tools/MachineLearningEditorSystemComponent.h
  19. 223 0
      Gems/MachineLearning/Code/Source/Tools/MultilayerPerceptronEditorComponent.cpp
  20. 77 0
      Gems/MachineLearning/Code/Source/Tools/MultilayerPerceptronEditorComponent.h
  21. 2 0
      Gems/MachineLearning/Code/machinelearning_editor_private_files.cmake
  22. 2 2
      Gems/MachineLearning/Registry/assetprocessor_settings.setreg

+ 3 - 3
Gems/MachineLearning/Code/CMakeLists.txt

@@ -156,7 +156,7 @@ if(PAL_TRAIT_BUILD_HOST_TOOLS)
         BUILD_DEPENDENCIES
             PUBLIC
                 AZ::AzToolsFramework
-                $<TARGET_OBJECTS:Gem::${gem_name}.Private.Object>
+                Gem::${gem_name}.Private.Object
     )
 
     ly_add_target(
@@ -180,8 +180,8 @@ if(PAL_TRAIT_BUILD_HOST_TOOLS)
     # By default, we will specify that the above target ${gem_name} would be used by
     # Tool and Builder type targets when this gem is enabled.  If you don't want it
     # active in Tools or Builders by default, delete one of both of the following lines:
-    ly_create_alias(NAME ${gem_name}.Tools NAMESPACE Gem TARGETS Gem::${gem_name} Gem::${gem_name}.Debug Gem::ScriptCanvas.Editor)
-    ly_create_alias(NAME ${gem_name}.Builders NAMESPACE Gem TARGETS Gem::${gem_name} Gem::ScriptCanvas.Editor)
+    ly_create_alias(NAME ${gem_name}.Tools NAMESPACE Gem TARGETS Gem::${gem_name}.Editor Gem::${gem_name}.Debug Gem::ScriptCanvas.Editor)
+    ly_create_alias(NAME ${gem_name}.Builders NAMESPACE Gem TARGETS Gem::${gem_name}.Editor Gem::ScriptCanvas.Editor)
 
     # For the Tools and Builders variants of ${gem_name} Gem, an alias to the ${gem_name}.Editor API target will be made
     ly_create_alias(NAME ${gem_name}.Tools.API NAMESPACE Gem TARGETS Gem::${gem_name}.Editor.API)

+ 8 - 1
Gems/MachineLearning/Code/Include/MachineLearning/Types.h

@@ -25,10 +25,17 @@ namespace MachineLearning
     );
 
     AZ_ENUM_CLASS(AssetTypes,
-        Model,
         TestData,
         TestLabels,
         TrainingData, 
         TrainingLabels
     );
+
+    class IAssetPersistenceProxy
+    {
+    public:
+        virtual ~IAssetPersistenceProxy() = default;
+        virtual bool SaveAsset() = 0;
+        virtual bool LoadAsset() = 0;
+    };
 }

+ 0 - 2
Gems/MachineLearning/Code/Source/Algorithms/Training.cpp

@@ -63,8 +63,6 @@ namespace MachineLearning
     {
         InitializeContexts();
 
-        const AZStd::size_t totalTrainingSize = m_trainData.GetSampleCount();
-
         // Start training
         m_currentEpoch = 0;
         m_trainingComplete = false;

+ 2 - 1
Gems/MachineLearning/Code/Source/Assets/MnistDataLoader.cpp

@@ -18,7 +18,7 @@
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/SerializeContext.h>
-
+#pragma optimize("", off)
 namespace MachineLearning
 {
     void MnistDataLoader::Reflect(AZ::ReflectContext* context)
@@ -195,3 +195,4 @@ namespace MachineLearning
         return true;
     }
 }
+#pragma optimize("", on)

+ 15 - 0
Gems/MachineLearning/Code/Source/Assets/ModelAsset.cpp

@@ -12,6 +12,21 @@
 
 namespace MachineLearning
 {
+    void ModelAsset::Reflect(AZ::ReflectContext* context)
+    {
+        if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
+        {
+            serializeContext->Class<ModelAsset>()
+                ->Version(1);
+
+            if (AZ::EditContext* editContext = serializeContext->GetEditContext())
+            {
+                editContext->Class<ModelAsset>("ML Model Asset", "ML Model Asset")
+                    ->ClassElement(AZ::Edit::ClassElements::EditorData, "");
+            }
+        }
+    }
+
     bool ModelAsset::Serialize(AzNetworking::ISerializer& serializer)
     {
         return serializer.Serialize(m_name, "Name")

+ 2 - 0
Gems/MachineLearning/Code/Source/Assets/ModelAsset.h

@@ -27,6 +27,8 @@ namespace MachineLearning
         AZ_RTTI(ModelAsset, "{4D8D3782-DC3A-499A-A59D-542B85F5EDE9}", AZ::Data::AssetData);
         AZ_CLASS_ALLOCATOR(ModelAsset, AZ::SystemAllocator);
 
+        static void Reflect(AZ::ReflectContext* context);
+
         ~ModelAsset() = default;
 
         //! Base serialize method for all serializable structures or classes to implement.

+ 57 - 19
Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.cpp

@@ -14,6 +14,7 @@
 #include <AzCore/RTTI/BehaviorContext.h>
 #include <AzCore/Serialization/EditContext.h>
 #include <AzCore/Serialization/SerializeContext.h>
+#include <AzCore/Console/ILogger.h>
 
 namespace MachineLearning
 {
@@ -23,32 +24,21 @@ namespace MachineLearning
         {
             serializeContext->Class<MultilayerPerceptronComponent>()
                 ->Version(0)
+                ->Field("Asset", &MultilayerPerceptronComponent::m_asset)
                 ->Field("Model", &MultilayerPerceptronComponent::m_model)
                 ;
-
-            if (AZ::EditContext* editContext = serializeContext->GetEditContext())
-            {
-                editContext->Class<MultilayerPerceptronComponent>("Multilayer Perceptron", "")
-                    ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
-                    ->Attribute(AZ::Edit::Attributes::Category, "MachineLearning")
-                    ->Attribute(AZ::Edit::Attributes::Icon, "Editor/Icons/Components/NeuralNetwork.svg")
-                    ->Attribute(AZ::Edit::Attributes::ViewportIcon, "Editor/Icons/Components/Viewport/NeuralNetwork.svg")
-                    ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC_CE("Game"))
-                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptronComponent::m_model, "Model", "This is the machine-learning model provided by this component")
-                    ;
-            }
         }
 
         auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
         if (behaviorContext)
         {
-            behaviorContext->Class<MultilayerPerceptronComponent>("MultilayerPerceptron Component")->
-                Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
-                Attribute(AZ::Script::Attributes::Module, "machineLearning")->
-                Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
-                Constructor<>()->
-                Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
-                Property("Model", BehaviorValueProperty(&MultilayerPerceptronComponent::m_model))
+            behaviorContext->Class<MultilayerPerceptronComponent>("MultilayerPerceptron Component")
+                ->Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)
+                ->Attribute(AZ::Script::Attributes::Module, "machineLearning")
+                ->Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)
+                ->Constructor<>()
+                ->Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)
+                ->Property("Model", BehaviorValueProperty(&MultilayerPerceptronComponent::m_model))
                 ;
 
             behaviorContext->EBus<MultilayerPerceptronComponentRequestBus>("Multilayer perceptron requests")
@@ -79,10 +69,12 @@ namespace MachineLearning
     void MultilayerPerceptronComponent::Activate()
     {
         MultilayerPerceptronComponentRequestBus::Handler::BusConnect(GetEntityId());
+        AssetChanged();
     }
 
     void MultilayerPerceptronComponent::Deactivate()
     {
+        AZ::Data::AssetBus::Handler::BusDisconnect();
         MultilayerPerceptronComponentRequestBus::Handler::BusDisconnect();
     }
 
@@ -90,4 +82,50 @@ namespace MachineLearning
     {
         return m_handle;
     }
+
+    void MultilayerPerceptronComponent::AssetChanged()
+    {
+        AZ::Data::AssetBus::Handler::BusDisconnect();
+        if (m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::Error ||
+            m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::NotLoaded)
+        {
+            m_asset.QueueLoad();
+        }
+        AZ::Data::AssetBus::Handler::BusConnect(m_asset.GetId());
+    }
+
+    void MultilayerPerceptronComponent::AssetCleared()
+    {
+        ;
+    }
+
+    void MultilayerPerceptronComponent::OnAssetReady(AZ::Data::Asset<AZ::Data::AssetData> asset)
+    {
+        ModelAsset* modelAsset = asset.GetAs<ModelAsset>();
+        if ((asset == m_asset) && (modelAsset != nullptr))
+        {
+            m_model = *modelAsset;
+        }
+    }
+
+    void MultilayerPerceptronComponent::OnAssetReloaded(AZ::Data::Asset<AZ::Data::AssetData> asset)
+    {
+        OnAssetReady(asset);
+    }
+
+    void MultilayerPerceptronComponent::OnAssetError(AZ::Data::Asset<AZ::Data::AssetData> asset)
+    {
+        if (asset == m_asset)
+        {
+            AZLOG_WARN("OnAssetError: %s", asset.GetHint().c_str());
+        }
+    }
+
+    void MultilayerPerceptronComponent::OnAssetReloadError(AZ::Data::Asset<AZ::Data::AssetData> asset)
+    {
+        if (asset == m_asset)
+        {
+            AZLOG_WARN("OnAssetReloadError: %s", asset.GetHint().c_str());
+        }
+    }
 }

+ 18 - 0
Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.h

@@ -9,7 +9,9 @@
 #pragma once
 
 #include <AzCore/Component/Component.h>
+#include <AzCore/Asset/AssetCommon.h>
 #include <Models/MultilayerPerceptron.h>
+#include <Assets/ModelAsset.h>
 
 namespace MachineLearning
 {
@@ -25,6 +27,7 @@ namespace MachineLearning
 
     class MultilayerPerceptronComponent
         : public AZ::Component
+        , private AZ::Data::AssetBus::Handler
         , public MultilayerPerceptronComponentRequestBus::Handler
     {
     public:
@@ -52,7 +55,22 @@ namespace MachineLearning
 
     private:
 
+        // Edit context callbacks
+        void AssetChanged();
+        void AssetCleared();
+
+        // AZ::Data::AssetBus ...
+        void OnAssetReady(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
+        void OnAssetReloaded(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
+        void OnAssetError(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
+        void OnAssetReloadError(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
+
+        //! The model asset.
+        AZ::Data::Asset<ModelAsset> m_asset;
+
         MultilayerPerceptron m_model;
         INeuralNetworkPtr m_handle;
+
+        friend class MultilayerPerceptronEditorComponent;
     };
 }

+ 1 - 6
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugSystemComponent.cpp

@@ -66,18 +66,15 @@ namespace MachineLearning
             | ImGuiTableFlags_RowBg
             | ImGuiTableFlags_NoBordersInBody;
 
-        const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
-
         IMachineLearning* machineLearning = MachineLearningInterface::Get();
         const ModelSet& modelSet = machineLearning->GetModelSet();
 
         ImGui::Text("Total registered models: %u", static_cast<uint32_t>(modelSet.size()));
         ImGui::NewLine();
 
-        if (ImGui::BeginTable("Model Details", 6, flags))
+        if (ImGui::BeginTable("Model Details", 5, flags))
         {
             ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f);
-            ImGui::TableSetupColumn("File", ImGuiTableColumnFlags_WidthStretch);
             ImGui::TableSetupColumn("Input Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
             ImGui::TableSetupColumn("Output Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
             ImGui::TableSetupColumn("Layers", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
@@ -91,8 +88,6 @@ namespace MachineLearning
                 ImGui::TableNextColumn();
                 ImGui::Text(neuralNetwork->GetName().c_str());
                 ImGui::TableNextColumn();
-                ImGui::Text(neuralNetwork->GetAssetFile(AssetTypes::Model).c_str());
-                ImGui::TableNextColumn();
                 ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetInputDimensionality()));
                 ImGui::TableNextColumn();
                 ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetOutputDimensionality()));

+ 0 - 3
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.cpp

@@ -151,8 +151,6 @@ namespace MachineLearning
             | ImGuiTableFlags_RowBg
             | ImGuiTableFlags_NoBordersInBody;
 
-        const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
-
         IMachineLearning* machineLearning = MachineLearningInterface::Get();
         const ModelSet& modelSet = machineLearning->GetModelSet();
 
@@ -262,7 +260,6 @@ namespace MachineLearning
  
             ImGui::NewLine();
             ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
-            ImGui::Text("Asset location: %s", m_selectedModel->GetAssetFile(AssetTypes::Model).c_str());
 
             if (ImGui::BeginTable("Accuracy", 2, flags))
             {

+ 4 - 2
Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.cpp

@@ -59,6 +59,7 @@ namespace MachineLearning
         }
 
         Layer::Reflect(context);
+        ModelAsset::Reflect(context);
         MnistDataLoader::Reflect(context);
         MultilayerPerceptron::Reflect(context);
     }
@@ -104,12 +105,13 @@ namespace MachineLearning
     void MachineLearningSystemComponent::Activate()
     {
         MachineLearningRequestBus::Handler::BusConnect();
-        m_assetHandler.Register();
+        m_assetHandler = AZStd::make_unique<ModelAssetHandler>();
+        m_assetHandler->Register();
     }
 
     void MachineLearningSystemComponent::Deactivate()
     {
-        m_assetHandler.Unregister();
+        m_assetHandler->Unregister();
         MachineLearningRequestBus::Handler::BusDisconnect();
     }
 

+ 1 - 1
Gems/MachineLearning/Code/Source/MachineLearningSystemComponent.h

@@ -50,6 +50,6 @@ namespace MachineLearning
     private:
 
         ModelSet m_registeredModels;
-        ModelAssetHandler m_assetHandler;
+        AZStd::unique_ptr<ModelAssetHandler> m_assetHandler;
     };
 }

+ 0 - 4
Gems/MachineLearning/Code/Source/Models/Layer.cpp

@@ -133,10 +133,7 @@ namespace MachineLearning
         {
             serializeContext->Class<Layer>()
                 ->Version(1)
-                ->Field("InputSize", &Layer::m_inputSize)
                 ->Field("OutputSize", &Layer::m_outputSize)
-                ->Field("Weights", &Layer::m_weights)
-                ->Field("Biases", &Layer::m_biases)
                 ->Field("ActivationFunction", &Layer::m_activationFunction)
                 ;
 
@@ -161,7 +158,6 @@ namespace MachineLearning
                 Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
                 Constructor<ActivationFunctions, AZStd::size_t, AZStd::size_t>()->
                 Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
-                Property("InputSize", BehaviorValueProperty(&Layer::m_inputSize))->
                 Property("OutputSize", BehaviorValueProperty(&Layer::m_outputSize))->
                 Property("ActivationFunction", BehaviorValueProperty(&Layer::m_activationFunction))
                 ;

+ 15 - 75
Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.cpp

@@ -27,9 +27,7 @@ namespace MachineLearning
         {
             serializeContext->Class<MultilayerPerceptron>()
                 ->Version(1)
-                ->Field("ModelAsset", &MultilayerPerceptron::m_asset)
                 ->Field("Name", &MultilayerPerceptron::m_name)
-                ->Field("ModelFile", &MultilayerPerceptron::m_modelFile)
                 ->Field("TestDataFile", &MultilayerPerceptron::m_testDataFile)
                 ->Field("TestLabelFile", &MultilayerPerceptron::m_testLabelFile)
                 ->Field("TrainDataFile", &MultilayerPerceptron::m_trainDataFile)
@@ -42,9 +40,7 @@ namespace MachineLearning
             {
                 editContext->Class<MultilayerPerceptron>("A basic multilayer perceptron class", "")
                     ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
-                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_asset, "ModelAsset", "The model asset")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_name, "Name", "The name for this model")
-                    ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_modelFile, "ModelFile", "The file this model is saved to and loaded from")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testDataFile, "TestDataFile", "The file test data should be loaded from")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_testLabelFile, "TestLabelFile", "The file test labels should be loaded from")
                     ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptron::m_trainDataFile, "TrainDataFile", "The file training data should be loaded from")
@@ -80,7 +76,6 @@ namespace MachineLearning
 
     MultilayerPerceptron::MultilayerPerceptron(const MultilayerPerceptron& rhs)
         : m_name(rhs.m_name)
-        , m_modelFile(rhs.m_modelFile)
         , m_testDataFile(rhs.m_testDataFile)
         , m_testLabelFile(rhs.m_testLabelFile)
         , m_trainDataFile(rhs.m_trainDataFile)
@@ -102,13 +97,22 @@ namespace MachineLearning
     MultilayerPerceptron& MultilayerPerceptron::operator=(const MultilayerPerceptron& rhs)
     {
         m_name = rhs.m_name;
-        m_modelFile = rhs.m_modelFile;
         m_testDataFile = rhs.m_testDataFile;
         m_testLabelFile = rhs.m_testLabelFile;
         m_trainDataFile = rhs.m_trainDataFile;
         m_trainLabelFile = rhs.m_trainLabelFile;
         m_activationCount = rhs.m_activationCount;
         m_layers = rhs.m_layers;
+        OnActivationCountChanged();
+        return *this;
+    }
+
+    MultilayerPerceptron& MultilayerPerceptron::operator=(const ModelAsset& asset)
+    {
+        m_name = asset.m_name;
+        m_activationCount = asset.m_activationCount;
+        m_layers = asset.m_layers;
+        OnActivationCountChanged();
         return *this;
     }
 
@@ -121,8 +125,6 @@ namespace MachineLearning
     {
         switch (assetType)
         {
-        case AssetTypes::Model:
-            return m_modelFile;
         case AssetTypes::TestData:
             return m_testDataFile;
         case AssetTypes::TestLabels:
@@ -255,59 +257,19 @@ namespace MachineLearning
 
     bool MultilayerPerceptron::LoadModel()
     {
-        AZ::IO::SystemFile modelFile;
-        AZ::IO::FixedMaxPath filePathFixed = m_modelFile.c_str();
-        if (AZ::IO::FileIOBase* fileIOBase = AZ::IO::FileIOBase::GetInstance())
-        {
-            fileIOBase->ResolvePath(filePathFixed, m_modelFile.c_str());
-        }
-
-        if (!modelFile.Open(filePathFixed.c_str(), AZ::IO::SystemFile::SF_OPEN_READ_ONLY))
+        if (m_proxy)
         {
-            AZLOG_ERROR("Failed to load '%s'. File could not be opened.", filePathFixed.c_str());
-            return false;
+            return m_proxy->LoadAsset();
         }
-
-        const AZ::IO::SizeType length = modelFile.Length();
-        if (length == 0)
-        {
-            AZLOG_ERROR("Failed to load '%s'. File is empty.", filePathFixed.c_str());
-            return false;
-        }
-
-        AZStd::vector<uint8_t> serializeBuffer;
-        serializeBuffer.resize(length);
-        modelFile.Seek(0, AZ::IO::SystemFile::SF_SEEK_BEGIN);
-        modelFile.Read(serializeBuffer.size(), serializeBuffer.data());
-        AzNetworking::NetworkOutputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
-        return Serialize(serializer);
+        return false;
     }
 
     bool MultilayerPerceptron::SaveModel()
     {
-        AZ::IO::SystemFile modelFile;
-        AZ::IO::FixedMaxPath filePathFixed = m_modelFile.c_str();
-        if (AZ::IO::FileIOBase* fileIOBase = AZ::IO::FileIOBase::GetInstance())
-        {
-            fileIOBase->ResolvePath(filePathFixed, m_modelFile.c_str());
-        }
-
-        if (!modelFile.Open(filePathFixed.c_str(), AZ::IO::SystemFile::SF_OPEN_READ_WRITE | AZ::IO::SystemFile::SF_OPEN_CREATE))
-        {
-            AZLOG_ERROR("Failed to save to '%s'. File could not be opened for writing.", filePathFixed.c_str());
-            return false;
-        }
-        modelFile.Seek(0, AZ::IO::SystemFile::SF_SEEK_BEGIN);
-
-        AZStd::vector<uint8_t> serializeBuffer;
-        serializeBuffer.resize(EstimateSerializeSize());
-        AzNetworking::NetworkInputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
-        if (Serialize(serializer))
+        if (m_proxy)
         {
-            modelFile.Write(serializeBuffer.data(), serializeBuffer.size());
-            return true;
+            return m_proxy->SaveAsset();
         }
-
         return false;
     }
 
@@ -323,26 +285,4 @@ namespace MachineLearning
         // This is not thread safe, this method should only be used by unit testing to inspect layer weights and biases for correctness
         return &m_layers[layerIndex];
     }
-
-    bool MultilayerPerceptron::Serialize(AzNetworking::ISerializer& serializer)
-    {
-        return serializer.Serialize(m_name, "Name")
-            && serializer.Serialize(m_activationCount, "activationCount")
-            && serializer.Serialize(m_layers, "layers");
-    }
-
-    AZStd::size_t MultilayerPerceptron::EstimateSerializeSize() const
-    {
-        const AZStd::size_t padding = 64; // 64 bytes of extra padding just in case
-        AZStd::size_t estimatedSize = padding 
-            + sizeof(AZStd::size_t)
-            + m_name.size()
-            + sizeof(m_activationCount)
-            + sizeof(AZStd::size_t);
-        for (const Layer& layer : m_layers)
-        {
-            estimatedSize += layer.EstimateSerializeSize();
-        }
-        return estimatedSize;
-    }
 }

+ 4 - 15
Gems/MachineLearning/Code/Source/Models/MultilayerPerceptron.h

@@ -9,7 +9,6 @@
 #pragma once
 
 #include <AzCore/Math/MatrixMxN.h>
-#include <AzNetworking/Serialization/ISerializer.h>
 #include <MachineLearning/INeuralNetwork.h>
 #include <Models/Layer.h>
 #include <Assets/ModelAsset.h>
@@ -34,6 +33,7 @@ namespace MachineLearning
         virtual ~MultilayerPerceptron();
 
         MultilayerPerceptron& operator=(const MultilayerPerceptron&);
+        MultilayerPerceptron& operator=(const ModelAsset&);
 
         //! INeuralNetwork interface
         //! @{
@@ -60,27 +60,13 @@ namespace MachineLearning
         //! Retrieves a specific layer from the model, this is not thread safe and should only be used during unit testing to validate model parameters.
         Layer* GetLayer(AZStd::size_t layerIndex);
 
-        //! Base serialize method for all serializable structures or classes to implement.
-        //! @param serializer ISerializer instance to use for serialization
-        //! @return boolean true for success, false for serialization failure
-        bool Serialize(AzNetworking::ISerializer& serializer);
-
-        //! Returns the estimated size required to serialize this model.
-        AZStd::size_t EstimateSerializeSize() const;
-
     private:
 
         void OnActivationCountChanged();
 
-        //! The model asset.
-        AZ::Data::Asset<ModelAsset> m_asset;
-
         //! The model name.
         AZStd::string m_name;
 
-        //! The model asset file.
-        AZStd::string m_modelFile;
-
         //! Optional test and train asset data files.
         AZStd::string m_testDataFile;
         AZStd::string m_testLabelFile;
@@ -92,6 +78,9 @@ namespace MachineLearning
 
         //! The set of layers in the network.
         AZStd::vector<Layer> m_layers;
+
+        IAssetPersistenceProxy* m_proxy = nullptr;
+        friend class MultilayerPerceptronEditorComponent;
     };
 
     struct MlpInferenceContext

+ 9 - 0
Gems/MachineLearning/Code/Source/Tools/MachineLearningEditorModule.cpp

@@ -1,6 +1,14 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
 
 #include <MachineLearning/MachineLearningTypeIds.h>
 #include <MachineLearningModuleInterface.h>
+#include <Tools/MultilayerPerceptronEditorComponent.h>
 #include "MachineLearningEditorSystemComponent.h"
 
 namespace MachineLearning
@@ -20,6 +28,7 @@ namespace MachineLearning
             // This happens through the [MyComponent]::Reflect() function.
             m_descriptors.insert(m_descriptors.end(), {
                 MachineLearningEditorSystemComponent::CreateDescriptor(),
+                MultilayerPerceptronEditorComponent::CreateDescriptor()
             });
         }
 

+ 7 - 2
Gems/MachineLearning/Code/Source/Tools/MachineLearningEditorSystemComponent.cpp

@@ -1,7 +1,13 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
 
 #include <AzCore/Serialization/SerializeContext.h>
 #include "MachineLearningEditorSystemComponent.h"
-
 #include <MachineLearning/MachineLearningTypeIds.h>
 
 namespace MachineLearning
@@ -55,5 +61,4 @@ namespace MachineLearning
         AzToolsFramework::EditorEvents::Bus::Handler::BusDisconnect();
         MachineLearningSystemComponent::Deactivate();
     }
-
 } // namespace MachineLearning

+ 7 - 1
Gems/MachineLearning/Code/Source/Tools/MachineLearningEditorSystemComponent.h

@@ -1,8 +1,14 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
 
 #pragma once
 
 #include <AzToolsFramework/API/ToolsApplicationAPI.h>
-
 #include <MachineLearningSystemComponent.h>
 
 namespace MachineLearning

+ 223 - 0
Gems/MachineLearning/Code/Source/Tools/MultilayerPerceptronEditorComponent.cpp

@@ -0,0 +1,223 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <Tools/MultilayerPerceptronEditorComponent.h>
+#include <Components/MultilayerPerceptronComponent.h>
+#include <MachineLearning/IMachineLearning.h>
+#include <AzCore/RTTI/RTTI.h>
+#include <AzCore/RTTI/BehaviorContext.h>
+#include <AzCore/Serialization/EditContext.h>
+#include <AzCore/Serialization/SerializeContext.h>
+#include <AzCore/Settings/SettingsRegistryMergeUtils.h>
+#include <AzCore/Console/ILogger.h>
+#include <AzToolsFramework/API/ToolsApplicationAPI.h>
+#include <AzToolsFramework/API/EditorAssetSystemAPI.h>
+#include <AzToolsFramework/UI/UICore/WidgetHelpers.h>
+#include <AzQtComponents/Components/Widgets/FileDialog.h>
+#include <QMessageBox>
+
+namespace MachineLearning
+{
+    void MultilayerPerceptronEditorComponent::Reflect(AZ::ReflectContext* context)
+    {
+        if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
+        {
+            serializeContext->Class<MultilayerPerceptronEditorComponent>()
+                ->Version(0)
+                ->Field("Asset", &MultilayerPerceptronEditorComponent::m_asset)
+                ->Field("Model", &MultilayerPerceptronEditorComponent::m_model)
+                ;
+
+            if (AZ::EditContext* editContext = serializeContext->GetEditContext())
+            {
+                editContext
+                    ->Class<MultilayerPerceptronEditorComponent>("Multilayer Perceptron", "")
+                        ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
+                            ->Attribute(AZ::Edit::Attributes::Category, "MachineLearning")
+                            ->Attribute(AZ::Edit::Attributes::Icon, "Editor/Icons/Components/NeuralNetwork.svg")
+                            ->Attribute(AZ::Edit::Attributes::ViewportIcon, "Editor/Icons/Components/Viewport/NeuralNetwork.svg")
+                            ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC_CE("Game"))
+                        ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptronEditorComponent::m_asset, "Asset", "This is the asset file the model is persisted to")
+                            ->Attribute(AZ::Edit::Attributes::ChangeNotify, &MultilayerPerceptronEditorComponent::AssetChanged)
+                            ->Attribute(AZ::Edit::Attributes::ClearNotify, &MultilayerPerceptronEditorComponent::AssetCleared)
+                        ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptronEditorComponent::m_model, "Model", "This is the machine-learning model provided by this component");
+            }
+        }
+    }
+
+    void MultilayerPerceptronEditorComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
+    {
+        provided.push_back(AZ_CRC("MultilayerPerceptronService"));
+    }
+
+    MultilayerPerceptronEditorComponent::MultilayerPerceptronEditorComponent()
+    {
+        m_model.m_proxy = this;
+        m_handle.reset(&m_model);
+        MachineLearningInterface::Get()->RegisterModel(m_handle);
+    }
+
+    MultilayerPerceptronEditorComponent::~MultilayerPerceptronEditorComponent()
+    {
+        MachineLearningInterface::Get()->UnregisterModel(m_handle);
+    }
+
+    void MultilayerPerceptronEditorComponent::Activate()
+    {
+        AssetChanged();
+    }
+
+    void MultilayerPerceptronEditorComponent::Deactivate()
+    {
+        AZ::Data::AssetBus::Handler::BusDisconnect();
+    }
+
+    void MultilayerPerceptronEditorComponent::BuildGameEntity(AZ::Entity* gameEntity)
+    {
+        MultilayerPerceptronComponent* component = gameEntity->CreateComponent<MultilayerPerceptronComponent>();
+        component->m_asset = m_asset;
+    }
+
+    bool MultilayerPerceptronEditorComponent::SaveAsset()
+    {
+        return SaveAsAsset();
+    }
+
+    bool MultilayerPerceptronEditorComponent::LoadAsset()
+    {
+        m_asset.QueueLoad();
+        return true;
+    }
+
+    void MultilayerPerceptronEditorComponent::AssetChanged()
+    {
+        AZ::Data::AssetBus::Handler::BusDisconnect();
+        if (m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::Error ||
+            m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::NotLoaded)
+        {
+            m_asset.QueueLoad();
+        }
+        AZ::Data::AssetBus::Handler::BusConnect(m_asset.GetId());
+    }
+
+    void MultilayerPerceptronEditorComponent::AssetCleared()
+    {
+        ;
+    }
+
+    void MultilayerPerceptronEditorComponent::OnAssetReady(AZ::Data::Asset<AZ::Data::AssetData> asset)
+    {
+        ModelAsset* modelAsset = asset.GetAs<ModelAsset>();
+        if ((asset == m_asset) && (modelAsset != nullptr))
+        {
+            m_model = *modelAsset;
+            AzToolsFramework::ToolsApplicationNotificationBus::Broadcast
+            (
+                &AzToolsFramework::ToolsApplicationNotificationBus::Events::InvalidatePropertyDisplay, 
+                AzToolsFramework::Refresh_EntireTree
+            );
+        }
+    }
+
+    void MultilayerPerceptronEditorComponent::OnAssetReloaded(AZ::Data::Asset<AZ::Data::AssetData> asset)
+    {
+        OnAssetReady(asset);
+    }
+
+    void MultilayerPerceptronEditorComponent::OnAssetError(AZ::Data::Asset<AZ::Data::AssetData> asset)
+    {
+        if (asset == m_asset)
+        {
+            AZLOG_WARN("OnAssetError: %s", asset.GetHint().c_str());
+        }
+    }
+
+    void MultilayerPerceptronEditorComponent::OnAssetReloadError(AZ::Data::Asset<AZ::Data::AssetData> asset)
+    {
+        if (asset == m_asset)
+        {
+            AZLOG_WARN("OnAssetReloadError: %s", asset.GetHint().c_str());
+        }
+    }
+
+    static AZStd::string PathAtProjectRoot(const AZStd::string_view name, const AZStd::string_view extension)
+    {
+        AZ::IO::Path projectPath;
+        if (auto settingsRegistry = AZ::SettingsRegistry::Get(); settingsRegistry != nullptr)
+        {
+            settingsRegistry->Get(projectPath.Native(), AZ::SettingsRegistryMergeUtils::FilePathKey_ProjectPath);
+        }
+        projectPath /= AZ::IO::FixedMaxPathString::format("%.*s.%.*s", AZ_STRING_ARG(name), AZ_STRING_ARG(extension));
+        return projectPath.Native();
+    }
+
+    template <typename T>
+    AZ::Data::Asset<T> CreateOrFindAsset(const AZStd::string& assetPath, AZ::Data::AssetLoadBehavior loadBehavior)
+    {
+        AZ::Data::AssetId generatedAssetId;
+        AZ::Data::AssetCatalogRequestBus::BroadcastResult
+        (
+            generatedAssetId, 
+            &AZ::Data::AssetCatalogRequests::GenerateAssetIdTEMP, 
+            assetPath.c_str()
+        );
+        return AZ::Data::AssetManager::Instance().FindOrCreateAsset(generatedAssetId, azrtti_typeid<T>(), loadBehavior);
+    }
+
+    bool MultilayerPerceptronEditorComponent::SaveAsAsset()
+    {
+        if (m_asset.Get() != nullptr)
+        {
+            m_asset->m_name = m_model.m_name;
+            m_asset->m_activationCount = m_model.m_activationCount;
+            m_asset->m_layers = m_model.m_layers;
+            return m_asset.Save();
+        }
+
+        const AZStd::string initialAbsolutePathToSave = PathAtProjectRoot(m_model.GetName().c_str(), ModelAsset::Extension);
+        const QString fileFilter = AZStd::string::format("Model (*.%s)", ModelAsset::Extension).c_str();
+        const QString absolutePathQt = AzQtComponents::FileDialog::GetSaveFileName(nullptr, "Save As Asset...", QString(initialAbsolutePathToSave.c_str()), fileFilter);
+        const AZStd::string absolutePath = AZStd::string(absolutePathQt.toUtf8());
+
+        // User cancelled
+        if (absolutePathQt.isEmpty())
+        {
+            return false;
+        }
+
+        // Copy m_model to m_asset so we can save latest data
+        m_asset = CreateOrFindAsset<ModelAsset>(absolutePath, m_asset.GetAutoLoadBehavior());
+        m_asset->m_name = m_model.m_name;
+        m_asset->m_activationCount = m_model.m_activationCount;
+        m_asset->m_layers = m_model.m_layers;
+
+        AZ::Data::AssetBus::Handler::BusDisconnect();
+        AZ::Data::AssetBus::Handler::BusConnect(m_asset.GetId());
+
+        bool result = false;
+        const auto assetType = AZ::AzTypeInfo<ModelAsset>::Uuid();
+        if (auto assetHandler = AZ::Data::AssetManager::Instance().GetHandler(assetType))
+        {
+            if (AZ::IO::FileIOStream fileStream(absolutePath.c_str(), AZ::IO::OpenMode::ModeWrite); fileStream.IsOpen())
+            {
+                result = assetHandler->SaveAssetData(m_asset, &fileStream);
+                AZLOG_INFO("Save %s. Location: %s", result ? "succeeded" : "failed", absolutePath.c_str());
+            }
+        }
+
+        AzToolsFramework::ToolsApplicationNotificationBus::Broadcast
+        (
+            &AzToolsFramework::ToolsApplicationNotificationBus::Events::InvalidatePropertyDisplay,
+            AzToolsFramework::Refresh_EntireTree
+        );
+
+        return result;
+    }
+}

+ 77 - 0
Gems/MachineLearning/Code/Source/Tools/MultilayerPerceptronEditorComponent.h

@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <AzToolsFramework/ToolsComponents/EditorComponentBase.h>
+#include <AzCore/Asset/AssetCommon.h>
+#include <Models/MultilayerPerceptron.h>
+#include <Assets/ModelAsset.h>
+#include <MachineLearning/Types.h>
+
+namespace MachineLearning
+{
+    class MultilayerPerceptronEditorComponent
+        : public AzToolsFramework::Components::EditorComponentBase
+        , private AZ::Data::AssetBus::Handler
+        , public IAssetPersistenceProxy
+    {
+    public:
+
+        AZ_COMPONENT(MultilayerPerceptronEditorComponent, "{E33802A1-18E8-4CBE-A45B-7D7C979B1027}");
+
+        //! AzCore Reflection.
+        //! @param context reflection context
+        static void Reflect(AZ::ReflectContext* context);
+        static void GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided);
+
+        MultilayerPerceptronEditorComponent();
+        ~MultilayerPerceptronEditorComponent();
+
+        //! AZ::Component overrides
+        //! @{
+        void Activate() override;
+        void Deactivate() override;
+        //! @}
+
+        //! EditorComponentBase
+        //! @{
+        void BuildGameEntity(AZ::Entity* gameEntity) override;
+        //! @}
+
+        //! IAssetPersistenceProxy overrides
+        //! @{
+        bool SaveAsset() override;
+        bool LoadAsset() override;
+        //! @}
+
+    private:
+
+        //! Edit context callbacks
+        //! @{
+        void AssetChanged();
+        void AssetCleared();
+        //! @}
+
+        //! AZ::Data::AssetBus overrides
+        //! @{
+        void OnAssetReady(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
+        void OnAssetReloaded(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
+        void OnAssetError(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
+        void OnAssetReloadError(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
+        //! @}
+
+        bool SaveAsAsset();
+
+        //! The model asset.
+        AZ::Data::Asset<ModelAsset> m_asset;
+
+        MultilayerPerceptron m_model;
+        INeuralNetworkPtr m_handle;
+    };
+}

+ 2 - 0
Gems/MachineLearning/Code/machinelearning_editor_private_files.cmake

@@ -7,6 +7,8 @@
 #
 
 set(FILES
+    Source/Tools/MultilayerPerceptronEditorComponent.cpp
+    Source/Tools/MultilayerPerceptronEditorComponent.h
     Source/Tools/MachineLearningEditorSystemComponent.cpp
     Source/Tools/MachineLearningEditorSystemComponent.h
 )

+ 2 - 2
Gems/MachineLearning/Registry/assetprocessor_settings.setreg

@@ -12,8 +12,8 @@
                     "recursive": 1,
                     "order": 102
                 },
-                "RC mlasset": {
-                    "glob": "*.mlasset",
+                "RC MachineLearning Model": {
+                    "glob": "*.mlmodel",
                     "params": "copy",
                     "productAssetType": "{4D8D3782-DC3A-499A-A59D-542B85F5EDE9}"
                 }