TrainingDataView.h 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <MachineLearning/INeuralNetwork.h>
  10. #include <MachineLearning/ILabeledTrainingData.h>
  11. #include <AzCore/std/string/string.h>
  12. #include <AzCore/IO/FileIO.h>
  13. namespace MachineLearning
  14. {
  15. //! This wraps any training data set to restrict the range of samples to a subset of the total.
  16. class TrainingDataView
  17. : public ILabeledTrainingData
  18. {
  19. public:
  20. AZ_TYPE_INFO(TrainingDataView, "{BF396C77-4348-46BA-9606-275A3454738E}", ILabeledTrainingData);
  21. //! AzCore Reflection.
  22. //! @param context reflection context
  23. static void Reflect(AZ::ReflectContext* context);
  24. TrainingDataView() = default;
  25. TrainingDataView(ILabeledTrainingDataPtr sourceData);
  26. bool IsValid() const;
  27. void SetSourceData(ILabeledTrainingDataPtr sourceData);
  28. void SetRange(AZStd::size_t first, AZStd::size_t last);
  29. AZStd::size_t GetOriginalSize() const;
  30. void ShuffleSamples();
  31. //! ILabeledTrainingData interface
  32. //! @{
  33. bool LoadArchive(const AZ::IO::Path& imageFilename, const AZ::IO::Path& labelFilename) override;
  34. AZStd::size_t GetSampleCount() const override;
  35. const AZ::VectorN& GetLabelByIndex(AZStd::size_t index) override;
  36. const AZ::VectorN& GetDataByIndex(AZStd::size_t index) override;
  37. //! @}
  38. AZStd::size_t m_first = 0;
  39. AZStd::size_t m_last = 0;
  40. private:
  41. void FillIndicies();
  42. AZStd::size_t m_firstCache = 0;
  43. AZStd::size_t m_lastCache = 0;
  44. AZStd::vector<AZStd::size_t> m_indices;
  45. ILabeledTrainingDataPtr m_sourceData;
  46. };
  47. }