Mnist.cpp 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "Mnist.h"
  9. namespace Mnist
  10. {
  11. template<typename T>
  12. static void Mnist::Softmax(T& input)
  13. {
  14. const float rowmax = *AZStd::ranges::max_element(input.begin(), input.end());
  15. AZStd::vector<float> y(input.size());
  16. float sum = 0.0f;
  17. for (size_t i = 0; i != input.size(); ++i)
  18. {
  19. sum += y[i] = AZStd::exp2(input[i] - rowmax);
  20. }
  21. for (size_t i = 0; i != input.size(); ++i)
  22. {
  23. input[i] = y[i] / sum;
  24. }
  25. }
  26. void Mnist::GetResult()
  27. {
  28. Softmax(m_outputs[0]);
  29. m_result = AZStd::distance(m_outputs[0].begin(), AZStd::ranges::max_element(m_outputs[0].begin(), m_outputs[0].end()));
  30. }
  31. void Mnist::LoadImage(const char* path)
  32. {
  33. // Gets the png image from file and decodes using upng library.
  34. upng_t* upng = upng_new_from_file(path);
  35. upng_decode(upng);
  36. const unsigned char* buffer = upng_get_buffer(upng);
  37. // Converts image from buffer into binary greyscale representation.
  38. // i.e. a pure black pixel is a 0, anything else is a 1.
  39. // Bear in mind that the images in the dataset are flipped compared to how we'd usually think,
  40. // so the background is black and the actual digit is white.
  41. for (int y = 0; y < m_imageHeight; y++)
  42. {
  43. for (int x = 0; x < m_imageWidth; x++)
  44. {
  45. int content = static_cast<int>(buffer[(y)*m_imageWidth + x]);
  46. if (content == 0)
  47. {
  48. m_input[0][m_imageWidth * y + x] = 0.0f;
  49. }
  50. else
  51. {
  52. m_input[0][m_imageHeight * y + x] = 1.0f;
  53. }
  54. }
  55. }
  56. }
  57. MnistReturnValues MnistExample(Mnist& mnist, const char* path)
  58. {
  59. mnist.LoadImage(path);
  60. mnist.Run(mnist.m_input);
  61. mnist.GetResult();
  62. MnistReturnValues returnValues;
  63. returnValues.m_inference = mnist.m_result;
  64. returnValues.m_runtime = mnist.m_delta;
  65. return (returnValues);
  66. }
  67. InferenceData RunMnistSuite(int testsPerDigit, bool cudaEnable)
  68. {
  69. // Initialises and loads the mnist model.
  70. // The same instance of the model is used for all runs.
  71. Mnist mnist;
  72. AZStd::vector<AZStd::vector<float>> input(0);
  73. AZStd::vector<float> image(mnist.m_imageSize);
  74. input.push_back(image);
  75. mnist.m_input = input;
  76. Mnist::InitSettings modelInitSettings;
  77. if (cudaEnable)
  78. {
  79. modelInitSettings.m_modelName = "MNIST CUDA (Precomputed)";
  80. modelInitSettings.m_modelColor = AZ::Color::CreateFromRgba(56, 229, 59, 255);
  81. modelInitSettings.m_cudaEnable = true;
  82. }
  83. else
  84. {
  85. modelInitSettings.m_modelName = "MNIST (Precomputed)";
  86. }
  87. mnist.Load(modelInitSettings);
  88. int numOfEach = testsPerDigit;
  89. int totalFiles = 0;
  90. int64_t numOfCorrectInferences = 0;
  91. float totalRuntimeInMilliseconds = 0;
  92. AZ::IO::FixedMaxPath mnistTestImageRoot;
  93. // This bit cycles through the folder with all the mnist test images, calling MnistExample() for the specified number of each digit.
  94. // The structure of the folder is as such: /testing/{digit}/{random_integer}.png e.g /testing/3/10.png
  95. auto TestImage = [&mnist, &numOfCorrectInferences, &totalFiles, &totalRuntimeInMilliseconds, &mnistTestImageRoot, numOfEach](AZ::IO::Path digitFilePath, bool isFile) -> bool
  96. {
  97. if (!isFile)
  98. {
  99. AZ::IO::FixedMaxPath directoryName = digitFilePath.Filename();
  100. char* directoryEnd;
  101. AZ::s64 digit = strtoll(directoryName.c_str(), &directoryEnd, 10);
  102. if (directoryName.c_str() != directoryEnd)
  103. {
  104. // How many files of that digit have been tested
  105. int version = 0;
  106. // Search for any png files
  107. auto RunMnistExample = [&mnist, &numOfCorrectInferences, &totalFiles, &totalRuntimeInMilliseconds, &mnistTestImageRoot, &directoryName, &version, &digit, &numOfEach](AZ::IO::Path pngFilePath, bool) -> bool
  108. {
  109. // Stop running examples once version limit for that digit has been reached
  110. if ((version < numOfEach))
  111. {
  112. MnistReturnValues returnedValues = MnistExample(mnist, (mnistTestImageRoot / directoryName / pngFilePath).c_str());
  113. if (returnedValues.m_inference == digit)
  114. {
  115. numOfCorrectInferences += 1;
  116. }
  117. totalRuntimeInMilliseconds += returnedValues.m_runtime;
  118. totalFiles++;
  119. version++;
  120. }
  121. return true;
  122. };
  123. AZ::IO::SystemFile::FindFiles((mnistTestImageRoot / directoryName / "*.png").c_str(), RunMnistExample);
  124. }
  125. }
  126. return true;
  127. };
  128. // Get the FileIOBase to resolve the path to the MNIST testing image folder in the onnx gem
  129. AZ::IO::FileIOBase* fileIo = AZ::IO::FileIOBase::GetInstance();
  130. if (fileIo->ResolvePath(mnistTestImageRoot, "@gemroot:ONNX@/Assets/mnist_png/testing"))
  131. {
  132. // mnistTestImageRoot is set to the root folder of the MNIST testing images
  133. AZ::IO::SystemFile::FindFiles((mnistTestImageRoot / "*").c_str(), TestImage);
  134. }
  135. float accuracy = ((float)numOfCorrectInferences / (float)totalFiles) * 100.0f;
  136. float avgRuntimeInMilliseconds = totalRuntimeInMilliseconds / (totalFiles);
  137. AZ_Printf("ONNX", " Run Type: %s\n", cudaEnable ? "CUDA" : "CPU");
  138. AZ_Printf("ONNX", " Evaluated: %d Correct: %d Accuracy: %f%%\n", totalFiles, numOfCorrectInferences, accuracy);
  139. AZ_Printf("ONNX", " Total Runtime: %fms Avg Runtime: %fms\n", totalRuntimeInMilliseconds, avgRuntimeInMilliseconds);
  140. InferenceData result;
  141. result.m_averageRuntimeInMs = avgRuntimeInMilliseconds;
  142. result.m_totalRuntimeInMs = totalRuntimeInMilliseconds;
  143. result.m_totalNumberOfInferences = totalFiles;
  144. result.m_numberOfCorrectInferences = numOfCorrectInferences;
  145. return result;
  146. }
  147. } // namespace Mnist