TrainingDataView.cpp 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <Assets/TrainingDataView.h>
  10. #include <numeric>
  11. #include <random>
  12. namespace MachineLearning
  13. {
  14. TrainingDataView::TrainingDataView(ILabeledTrainingDataPtr sourceData)
  15. : m_sourceData(sourceData)
  16. {
  17. FillIndicies();
  18. }
  19. bool TrainingDataView::IsValid() const
  20. {
  21. return m_sourceData != nullptr;
  22. }
  23. void TrainingDataView::SetSourceData(ILabeledTrainingDataPtr sourceData)
  24. {
  25. m_sourceData = sourceData;
  26. FillIndicies();
  27. }
  28. void TrainingDataView::SetRange(AZStd::size_t first, AZStd::size_t last)
  29. {
  30. m_first = first;
  31. m_last = last;
  32. FillIndicies();
  33. }
  34. AZStd::size_t TrainingDataView::GetOriginalSize() const
  35. {
  36. if (m_sourceData)
  37. {
  38. return m_sourceData->GetSampleCount();
  39. }
  40. return 0;
  41. }
  42. void TrainingDataView::ShuffleSamples()
  43. {
  44. std::shuffle(m_indices.begin(), m_indices.end(), std::mt19937(std::random_device{}()));
  45. }
  46. bool TrainingDataView::LoadArchive(const AZ::IO::Path& imageFilename, const AZ::IO::Path& labelFilename)
  47. {
  48. AZ_Assert(m_sourceData, "No datasource assigned to view");
  49. bool result = m_sourceData->LoadArchive(imageFilename, labelFilename);
  50. FillIndicies();
  51. return result;
  52. }
  53. AZStd::size_t TrainingDataView::GetSampleCount() const
  54. {
  55. return m_last - m_first;
  56. }
  57. const AZ::VectorN& TrainingDataView::GetLabelByIndex(AZStd::size_t index)
  58. {
  59. AZ_Assert(m_sourceData, "No datasource assigned to view");
  60. AZ_Assert(index + m_first < m_last, "Out of range index requested");
  61. if (m_firstCache != m_first || m_lastCache != m_last)
  62. {
  63. FillIndicies();
  64. }
  65. return m_sourceData->GetLabelByIndex(m_indices[index]);
  66. }
  67. const AZ::VectorN& TrainingDataView::GetDataByIndex(AZStd::size_t index)
  68. {
  69. AZ_Assert(m_sourceData, "No datasource assigned to view");
  70. AZ_Assert(index + m_first < m_last, "Out of range index requested");
  71. if (m_firstCache != m_first || m_lastCache != m_last)
  72. {
  73. FillIndicies();
  74. }
  75. return m_sourceData->GetDataByIndex(m_indices[index]);
  76. }
  77. void TrainingDataView::FillIndicies()
  78. {
  79. // Generate a set of training indices that we can later optionally shuffle
  80. m_indices.resize(m_last);
  81. std::iota(m_indices.begin(), m_indices.end(), m_first);
  82. m_firstCache = m_first;
  83. m_lastCache = m_last;
  84. }
  85. }