ModelAsset.cpp 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Assets/ModelAsset.h>
  9. #include <AzNetworking/Serialization/NetworkInputSerializer.h>
  10. #include <AzNetworking/Serialization/NetworkOutputSerializer.h>
  11. namespace MachineLearning
  12. {
  13. void ModelAsset::Reflect(AZ::ReflectContext* context)
  14. {
  15. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  16. {
  17. serializeContext->Class<ModelAsset>()
  18. ->Version(1);
  19. if (AZ::EditContext* editContext = serializeContext->GetEditContext())
  20. {
  21. editContext->Class<ModelAsset>("ML Model Asset", "ML Model Asset")
  22. ->ClassElement(AZ::Edit::ClassElements::EditorData, "");
  23. }
  24. }
  25. }
  26. bool ModelAsset::Serialize(AzNetworking::ISerializer& serializer)
  27. {
  28. return serializer.Serialize(m_name, "Name")
  29. && serializer.Serialize(m_activationCount, "activationCount")
  30. && serializer.Serialize(m_layers, "layers");
  31. }
  32. AZStd::size_t ModelAsset::EstimateSerializeSize() const
  33. {
  34. const AZStd::size_t padding = 64; // 64 bytes of extra padding just in case
  35. AZStd::size_t estimatedSize = padding
  36. + sizeof(AZStd::size_t)
  37. + m_name.size()
  38. + sizeof(m_activationCount)
  39. + sizeof(AZStd::size_t);
  40. for (const Layer& layer : m_layers)
  41. {
  42. estimatedSize += layer.EstimateSerializeSize();
  43. }
  44. return estimatedSize;
  45. }
  46. ModelAssetHandler::ModelAssetHandler()
  47. : AzFramework::GenericAssetHandler<ModelAsset>(ModelAsset::DisplayName, ModelAsset::Group, ModelAsset::Extension)
  48. {
  49. }
  50. AZ::Data::AssetHandler::LoadResult ModelAssetHandler::LoadAssetData
  51. (
  52. const AZ::Data::Asset<AZ::Data::AssetData>& asset,
  53. AZStd::shared_ptr<AZ::Data::AssetDataStream> stream,
  54. [[maybe_unused]]const AZ::Data::AssetFilterCB& assetLoadFilterCB
  55. )
  56. {
  57. ModelAsset* assetData = asset.GetAs<ModelAsset>();
  58. AZ_Assert(assetData, "Asset is of the wrong type.");
  59. const AZ::IO::SizeType length = stream->GetLength();
  60. AZStd::vector<uint8_t> serializeBuffer;
  61. serializeBuffer.resize(length);
  62. stream->Read(length, serializeBuffer.data());
  63. AzNetworking::NetworkOutputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
  64. if (assetData->Serialize(serializer))
  65. {
  66. return AZ::Data::AssetHandler::LoadResult::LoadComplete;
  67. }
  68. return AZ::Data::AssetHandler::LoadResult::Error;
  69. }
  70. bool ModelAssetHandler::SaveAssetData(const AZ::Data::Asset<AZ::Data::AssetData>& asset, AZ::IO::GenericStream* stream)
  71. {
  72. ModelAsset* assetData = asset.GetAs<ModelAsset>();
  73. AZ_Assert(assetData, "Asset is of the wrong type.");
  74. AZStd::vector<uint8_t> serializeBuffer;
  75. serializeBuffer.resize(assetData->EstimateSerializeSize());
  76. AzNetworking::NetworkInputSerializer serializer(serializeBuffer.data(), static_cast<uint32_t>(serializeBuffer.size()));
  77. if (assetData->Serialize(serializer))
  78. {
  79. stream->Write(serializer.GetSize(), serializeBuffer.data());
  80. return true;
  81. }
  82. return false;
  83. }
  84. }