MultilayerPerceptronEditorComponent.cpp 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <Tools/MultilayerPerceptronEditorComponent.h>
  10. #include <Components/MultilayerPerceptronComponent.h>
  11. #include <MachineLearning/IMachineLearning.h>
  12. #include <AzCore/RTTI/RTTI.h>
  13. #include <AzCore/RTTI/BehaviorContext.h>
  14. #include <AzCore/Serialization/EditContext.h>
  15. #include <AzCore/Serialization/SerializeContext.h>
  16. #include <AzCore/Settings/SettingsRegistryMergeUtils.h>
  17. #include <AzCore/Console/ILogger.h>
  18. #include <AzToolsFramework/API/ToolsApplicationAPI.h>
  19. #include <AzToolsFramework/API/EditorAssetSystemAPI.h>
  20. #include <AzToolsFramework/UI/UICore/WidgetHelpers.h>
  21. #include <AzQtComponents/Components/Widgets/FileDialog.h>
  22. #include <QMessageBox>
  23. namespace MachineLearning
  24. {
  25. void MultilayerPerceptronEditorComponent::Reflect(AZ::ReflectContext* context)
  26. {
  27. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  28. {
  29. serializeContext->Class<MultilayerPerceptronEditorComponent>()
  30. ->Version(0)
  31. ->Field("Asset", &MultilayerPerceptronEditorComponent::m_asset)
  32. ->Field("Model", &MultilayerPerceptronEditorComponent::m_model)
  33. ;
  34. if (AZ::EditContext* editContext = serializeContext->GetEditContext())
  35. {
  36. editContext
  37. ->Class<MultilayerPerceptronEditorComponent>("Multilayer Perceptron", "")
  38. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  39. ->Attribute(AZ::Edit::Attributes::Category, "MachineLearning")
  40. ->Attribute(AZ::Edit::Attributes::Icon, "Editor/Icons/Components/NeuralNetwork.svg")
  41. ->Attribute(AZ::Edit::Attributes::ViewportIcon, "Editor/Icons/Components/Viewport/NeuralNetwork.svg")
  42. ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC_CE("Game"))
  43. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptronEditorComponent::m_asset, "Asset", "This is the asset file the model is persisted to")
  44. ->Attribute(AZ::Edit::Attributes::ChangeNotify, &MultilayerPerceptronEditorComponent::AssetChanged)
  45. ->Attribute(AZ::Edit::Attributes::ClearNotify, &MultilayerPerceptronEditorComponent::AssetCleared)
  46. ->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptronEditorComponent::m_model, "Model", "This is the machine-learning model provided by this component");
  47. }
  48. }
  49. }
  50. void MultilayerPerceptronEditorComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  51. {
  52. provided.push_back(AZ_CRC("MultilayerPerceptronService"));
  53. }
  54. MultilayerPerceptronEditorComponent::MultilayerPerceptronEditorComponent()
  55. {
  56. m_model.m_proxy = this;
  57. m_handle.reset(&m_model);
  58. MachineLearningInterface::Get()->RegisterModel(m_handle);
  59. }
  60. MultilayerPerceptronEditorComponent::~MultilayerPerceptronEditorComponent()
  61. {
  62. MachineLearningInterface::Get()->UnregisterModel(m_handle);
  63. }
  64. void MultilayerPerceptronEditorComponent::Activate()
  65. {
  66. AssetChanged();
  67. }
  68. void MultilayerPerceptronEditorComponent::Deactivate()
  69. {
  70. AZ::Data::AssetBus::Handler::BusDisconnect();
  71. }
  72. void MultilayerPerceptronEditorComponent::BuildGameEntity(AZ::Entity* gameEntity)
  73. {
  74. MultilayerPerceptronComponent* component = gameEntity->CreateComponent<MultilayerPerceptronComponent>();
  75. component->m_asset = m_asset;
  76. }
  77. bool MultilayerPerceptronEditorComponent::SaveAsset()
  78. {
  79. return SaveAsAsset();
  80. }
  81. bool MultilayerPerceptronEditorComponent::LoadAsset()
  82. {
  83. m_asset.QueueLoad();
  84. return true;
  85. }
  86. void MultilayerPerceptronEditorComponent::AssetChanged()
  87. {
  88. AZ::Data::AssetBus::Handler::BusDisconnect();
  89. if (m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::Error ||
  90. m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::NotLoaded)
  91. {
  92. m_asset.QueueLoad();
  93. }
  94. AZ::Data::AssetBus::Handler::BusConnect(m_asset.GetId());
  95. }
  96. void MultilayerPerceptronEditorComponent::AssetCleared()
  97. {
  98. ;
  99. }
  100. void MultilayerPerceptronEditorComponent::OnAssetReady(AZ::Data::Asset<AZ::Data::AssetData> asset)
  101. {
  102. ModelAsset* modelAsset = asset.GetAs<ModelAsset>();
  103. if ((asset == m_asset) && (modelAsset != nullptr))
  104. {
  105. m_model = *modelAsset;
  106. AzToolsFramework::ToolsApplicationNotificationBus::Broadcast
  107. (
  108. &AzToolsFramework::ToolsApplicationNotificationBus::Events::InvalidatePropertyDisplay,
  109. AzToolsFramework::Refresh_EntireTree
  110. );
  111. }
  112. }
  113. void MultilayerPerceptronEditorComponent::OnAssetReloaded(AZ::Data::Asset<AZ::Data::AssetData> asset)
  114. {
  115. OnAssetReady(asset);
  116. }
  117. void MultilayerPerceptronEditorComponent::OnAssetError(AZ::Data::Asset<AZ::Data::AssetData> asset)
  118. {
  119. if (asset == m_asset)
  120. {
  121. AZLOG_WARN("OnAssetError: %s", asset.GetHint().c_str());
  122. }
  123. }
  124. void MultilayerPerceptronEditorComponent::OnAssetReloadError(AZ::Data::Asset<AZ::Data::AssetData> asset)
  125. {
  126. if (asset == m_asset)
  127. {
  128. AZLOG_WARN("OnAssetReloadError: %s", asset.GetHint().c_str());
  129. }
  130. }
  131. static AZStd::string PathAtProjectRoot(const AZStd::string_view name, const AZStd::string_view extension)
  132. {
  133. AZ::IO::Path projectPath;
  134. if (auto settingsRegistry = AZ::SettingsRegistry::Get(); settingsRegistry != nullptr)
  135. {
  136. settingsRegistry->Get(projectPath.Native(), AZ::SettingsRegistryMergeUtils::FilePathKey_ProjectPath);
  137. }
  138. projectPath /= AZ::IO::FixedMaxPathString::format("%.*s.%.*s", AZ_STRING_ARG(name), AZ_STRING_ARG(extension));
  139. return projectPath.Native();
  140. }
  141. template <typename T>
  142. AZ::Data::Asset<T> CreateOrFindAsset(const AZStd::string& assetPath, AZ::Data::AssetLoadBehavior loadBehavior)
  143. {
  144. AZ::Data::AssetId generatedAssetId;
  145. AZ::Data::AssetCatalogRequestBus::BroadcastResult
  146. (
  147. generatedAssetId,
  148. &AZ::Data::AssetCatalogRequests::GenerateAssetIdTEMP,
  149. assetPath.c_str()
  150. );
  151. return AZ::Data::AssetManager::Instance().FindOrCreateAsset(generatedAssetId, azrtti_typeid<T>(), loadBehavior);
  152. }
  153. bool MultilayerPerceptronEditorComponent::SaveAsAsset()
  154. {
  155. if (m_asset.Get() != nullptr)
  156. {
  157. m_asset->m_name = m_model.m_name;
  158. m_asset->m_activationCount = m_model.m_activationCount;
  159. m_asset->m_layers = m_model.m_layers;
  160. return m_asset.Save();
  161. }
  162. const AZStd::string initialAbsolutePathToSave = PathAtProjectRoot(m_model.GetName().c_str(), ModelAsset::Extension);
  163. const QString fileFilter = AZStd::string::format("Model (*.%s)", ModelAsset::Extension).c_str();
  164. const QString absolutePathQt = AzQtComponents::FileDialog::GetSaveFileName(nullptr, "Save As Asset...", QString(initialAbsolutePathToSave.c_str()), fileFilter);
  165. const AZStd::string absolutePath = AZStd::string(absolutePathQt.toUtf8());
  166. // User cancelled
  167. if (absolutePathQt.isEmpty())
  168. {
  169. return false;
  170. }
  171. // Copy m_model to m_asset so we can save latest data
  172. m_asset = CreateOrFindAsset<ModelAsset>(absolutePath, m_asset.GetAutoLoadBehavior());
  173. m_asset->m_name = m_model.m_name;
  174. m_asset->m_activationCount = m_model.m_activationCount;
  175. m_asset->m_layers = m_model.m_layers;
  176. AZ::Data::AssetBus::Handler::BusDisconnect();
  177. AZ::Data::AssetBus::Handler::BusConnect(m_asset.GetId());
  178. bool result = false;
  179. const auto assetType = AZ::AzTypeInfo<ModelAsset>::Uuid();
  180. if (auto assetHandler = AZ::Data::AssetManager::Instance().GetHandler(assetType))
  181. {
  182. if (AZ::IO::FileIOStream fileStream(absolutePath.c_str(), AZ::IO::OpenMode::ModeWrite); fileStream.IsOpen())
  183. {
  184. result = assetHandler->SaveAssetData(m_asset, &fileStream);
  185. AZLOG_INFO("Save %s. Location: %s", result ? "succeeded" : "failed", absolutePath.c_str());
  186. }
  187. }
  188. AzToolsFramework::ToolsApplicationNotificationBus::Broadcast
  189. (
  190. &AzToolsFramework::ToolsApplicationNotificationBus::Events::InvalidatePropertyDisplay,
  191. AzToolsFramework::Refresh_EntireTree
  192. );
  193. return result;
  194. }
  195. }