MnistDataLoader.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <Assets/MnistDataLoader.h>
  10. #include <Algorithms/Activations.h>
  11. #include <AzCore/IO/FileReader.h>
  12. #include <AzCore/IO/Path/Path.h>
  13. #include <AzCore/Console/ILogger.h>
  14. #include <AzNetworking/Utilities/Endian.h>
  15. #include <AzCore/RTTI/RTTI.h>
  16. #include <AzCore/RTTI/BehaviorContext.h>
  17. #include <AzCore/Serialization/EditContext.h>
  18. #include <AzCore/Serialization/SerializeContext.h>
  19. namespace MachineLearning
  20. {
  21. void MnistDataLoader::Reflect(AZ::ReflectContext* context)
  22. {
  23. if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  24. {
  25. serializeContext->Class<MnistDataLoader>()
  26. ->Version(1)
  27. ;
  28. if (AZ::EditContext* editContext = serializeContext->GetEditContext())
  29. {
  30. editContext->Class<MnistDataLoader>("Parameters defining a single training data instance", "")
  31. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  32. ;
  33. }
  34. }
  35. auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
  36. if (behaviorContext)
  37. {
  38. behaviorContext->Class<MnistDataLoader>()->
  39. Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
  40. Attribute(AZ::Script::Attributes::Module, "machineLearning")->
  41. Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
  42. Constructor<>()->
  43. Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)
  44. ;
  45. }
  46. }
  47. bool MnistDataLoader::LoadArchive(const AZ::IO::Path& imageFilename, const AZ::IO::Path& labelFilename)
  48. {
  49. return LoadImageFile(imageFilename) && LoadLabelFile(labelFilename);
  50. }
  51. AZStd::size_t MnistDataLoader::GetSampleCount() const
  52. {
  53. return m_dataHeader.m_imageCount;
  54. }
  55. const AZ::VectorN& MnistDataLoader::GetLabelByIndex(AZStd::size_t index)
  56. {
  57. OneHotEncode(m_labelBuffer[index], 10, m_labelVector);
  58. return m_labelVector;
  59. }
  60. const AZ::VectorN& MnistDataLoader::GetDataByIndex(AZStd::size_t index)
  61. {
  62. const AZStd::size_t imageDataStride = m_dataHeader.m_height * m_dataHeader.m_width;
  63. m_imageVector.Resize(imageDataStride);
  64. for (AZStd::size_t iter = 0; iter < imageDataStride; ++iter)
  65. {
  66. m_imageVector.SetElement(iter, static_cast<float>(m_imageBuffer[index * imageDataStride + iter]) / 255.0f);
  67. }
  68. return m_imageVector;
  69. }
  70. bool MnistDataLoader::LoadImageFile(const AZ::IO::Path& imageFilename)
  71. {
  72. AZ::IO::FixedMaxPath filePathFixed = imageFilename.c_str();
  73. if (AZ::IO::FileIOBase* fileIOBase = AZ::IO::FileIOBase::GetInstance())
  74. {
  75. fileIOBase->ResolvePath(filePathFixed, imageFilename.c_str());
  76. }
  77. if (!m_imageFile.Open(filePathFixed.c_str(), AZ::IO::SystemFile::SF_OPEN_READ_ONLY))
  78. {
  79. AZLOG_ERROR("Failed to load '%s'. File could not be opened.", filePathFixed.c_str());
  80. return false;
  81. }
  82. const AZ::IO::SizeType length = m_imageFile.Length();
  83. if (length == 0)
  84. {
  85. AZLOG_ERROR("Failed to load '%s'. File is empty.", filePathFixed.c_str());
  86. return false;
  87. }
  88. m_imageFile.Seek(0, AZ::IO::SystemFile::SF_SEEK_BEGIN);
  89. AZ::IO::SizeType bytesRead = m_imageFile.Read(sizeof(MnistDataHeader), &m_dataHeader);
  90. if (bytesRead != sizeof(MnistDataHeader))
  91. {
  92. // Failed to read the whole header
  93. AZLOG_ERROR("Failed to load '%s', failed to read archive header.", filePathFixed.c_str());
  94. m_imageFile.Close();
  95. return false;
  96. }
  97. m_dataHeader.m_imageHeader = ntohl(m_dataHeader.m_imageHeader);
  98. m_dataHeader.m_imageCount = ntohl(m_dataHeader.m_imageCount);
  99. m_dataHeader.m_height = ntohl(m_dataHeader.m_height);
  100. m_dataHeader.m_width = ntohl(m_dataHeader.m_width);
  101. constexpr uint32_t MnistImageHeaderValue = 2051;
  102. if (m_dataHeader.m_imageHeader != MnistImageHeaderValue)
  103. {
  104. // Invalid format
  105. AZLOG_ERROR("Failed to load '%s', file is not an MNIST archive (expected %u, encountered %u).", filePathFixed.c_str(), MnistImageHeaderValue, m_dataHeader.m_imageHeader);
  106. m_imageFile.Close();
  107. return false;
  108. }
  109. const AZStd::size_t imageDataStride = m_dataHeader.m_height * m_dataHeader.m_width;
  110. m_imageBuffer.resize(m_dataHeader.m_imageCount * imageDataStride);
  111. m_imageFile.Read(m_dataHeader.m_imageCount * imageDataStride, m_imageBuffer.data());
  112. return true;
  113. }
  114. bool MnistDataLoader::LoadLabelFile(const AZ::IO::Path& labelFilename)
  115. {
  116. AZ::IO::FixedMaxPath filePathFixed = labelFilename.c_str();
  117. if (AZ::IO::FileIOBase* fileIOBase = AZ::IO::FileIOBase::GetInstance())
  118. {
  119. fileIOBase->ResolvePath(filePathFixed, labelFilename.c_str());
  120. }
  121. if (!m_labelFile.Open(filePathFixed.c_str(), AZ::IO::SystemFile::SF_OPEN_READ_ONLY))
  122. {
  123. AZLOG_ERROR("Failed to load '%s'. File could not be opened.", filePathFixed.c_str());
  124. return false;
  125. }
  126. const AZ::IO::SizeType length = m_labelFile.Length();
  127. if (length == 0)
  128. {
  129. AZLOG_ERROR("Failed to load '%s'. File is empty.", filePathFixed.c_str());
  130. return false;
  131. }
  132. m_labelFile.Seek(0, AZ::IO::SystemFile::SF_SEEK_BEGIN);
  133. struct MnistLabelHeader
  134. {
  135. uint32_t m_labelHeader = 0;
  136. uint32_t m_labelCount = 0;
  137. };
  138. MnistLabelHeader labelHeader;
  139. AZ::IO::SizeType bytesRead = m_labelFile.Read(sizeof(MnistLabelHeader), &labelHeader);
  140. if (bytesRead != sizeof(MnistLabelHeader))
  141. {
  142. // Failed to read the whole header
  143. AZLOG_ERROR("Failed to load '%s', failed to read label header.", filePathFixed.c_str());
  144. m_labelFile.Close();
  145. return false;
  146. }
  147. labelHeader.m_labelHeader = ntohl(labelHeader.m_labelHeader);
  148. labelHeader.m_labelCount = ntohl(labelHeader.m_labelCount);
  149. constexpr uint32_t MnistLabelHeaderValue = 2049;
  150. if (labelHeader.m_labelHeader != MnistLabelHeaderValue)
  151. {
  152. // Invalid format
  153. AZLOG_ERROR("Failed to load '%s', file is not an MNIST archive (expected %u, encountered %u).", filePathFixed.c_str(), MnistLabelHeaderValue, labelHeader.m_labelHeader);
  154. m_labelFile.Close();
  155. return false;
  156. }
  157. if (m_dataHeader.m_imageCount != labelHeader.m_labelCount)
  158. {
  159. AZLOG_ERROR("Failed to load '%s', mismatch between image count (%u) and label count (%u).", filePathFixed.c_str(), m_dataHeader.m_imageCount, labelHeader.m_labelCount);
  160. m_labelFile.Close();
  161. return false;
  162. }
  163. m_labelBuffer.resize(labelHeader.m_labelCount);
  164. m_labelFile.Read(labelHeader.m_labelCount, m_labelBuffer.data());
  165. AZLOG_INFO("Loaded MNIST archive %s containing %u samples", filePathFixed.c_str(), m_dataHeader.m_imageCount);
  166. return true;
  167. }
  168. }