MnistDataLoader.h 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <MachineLearning/INeuralNetwork.h>
  10. #include <MachineLearning/ILabeledTrainingData.h>
  11. #include <AzCore/std/string/string.h>
  12. #include <AzCore/IO/FileIO.h>
  13. namespace MachineLearning
  14. {
  15. //! A class that can load the MNIST training data set.
  16. //! https://en.wikipedia.org/wiki/MNIST_database
  17. class MnistDataLoader
  18. : public ILabeledTrainingData
  19. {
  20. public:
  21. AZ_TYPE_INFO(MnistDataLoader, "{3F4C0F29-4E7E-4CAF-A331-EAC3D2D9409E}", ILabeledTrainingData);
  22. //! AzCore Reflection.
  23. //! @param context reflection context
  24. static void Reflect(AZ::ReflectContext* context);
  25. MnistDataLoader() = default;
  26. //! ILabeledTrainingData interface
  27. //! @{
  28. bool LoadArchive(const AZ::IO::Path& imageFilename, const AZ::IO::Path& labelFilename) override;
  29. AZStd::size_t GetSampleCount() const override;
  30. const AZ::VectorN& GetLabelByIndex(AZStd::size_t index) override;
  31. const AZ::VectorN& GetDataByIndex(AZStd::size_t index) override;
  32. //! @}
  33. private:
  34. bool LoadImageFile(const AZ::IO::Path& imageFilename);
  35. bool LoadLabelFile(const AZ::IO::Path& labelFilename);
  36. struct MnistDataHeader
  37. {
  38. uint32_t m_imageHeader = 0;
  39. uint32_t m_imageCount = 0;
  40. uint32_t m_height = 0;
  41. uint32_t m_width = 0;
  42. };
  43. MnistDataHeader m_dataHeader;
  44. AZ::IO::SystemFile m_imageFile;
  45. AZ::IO::SystemFile m_labelFile;
  46. AZStd::size_t m_imageDataStart = 0;
  47. AZStd::size_t m_labelDataStart = 0;
  48. AZStd::size_t m_currentIndex = 0xFFFFFFFF;
  49. AZStd::vector<uint8_t> m_imageBuffer;
  50. AZStd::vector<uint8_t> m_labelBuffer;
  51. AZ::VectorN m_imageVector;
  52. AZ::VectorN m_labelVector;
  53. };
  54. }