Mnist.h 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/IO/SystemFile.h>
  10. #include <ONNX/Model.h>
  11. #include "upng/upng.h"
  12. namespace Mnist
  13. {
  14. //! Holds the digit that was inferenced and the time taken for a single inference run.
  15. //! Only used by MnistExample().
  16. struct MnistReturnValues
  17. {
  18. int64_t m_inference = 0;
  19. float m_runtime = 0.0f;
  20. };
  21. //! Holds the data gathered from RunMnistSuite(), which tests the MNIST ONNX model against images from the MNIST dataset.
  22. struct InferenceData
  23. {
  24. float m_totalRuntimeInMs = 0.0f;
  25. float m_averageRuntimeInMs = 0.0f;
  26. int m_totalNumberOfInferences = 0;
  27. int64_t m_numberOfCorrectInferences = 0;
  28. };
  29. //! Extension of ONNX Model used for Mnist example.
  30. //! Implements additional functionality useful to have for the example, such as keeping hold of the input and output vectors, and result
  31. //! (which the model doesn't do).
  32. struct Mnist
  33. : public ::ONNX::Model
  34. {
  35. public:
  36. //! Loads an image from file into the correct format in m_input.
  37. //! @path is the file location of the image you want to inference (this NEEDS to be an 8-bit color depth png else it won't work).
  38. void LoadImage(const char* path);
  39. //! To be called after Model::Run(), uses softmax to get inference probabilities.
  40. //! Directly mutates m_output and m_result.
  41. void GetResult();
  42. //! The MNIST dataset images are all 28 x 28 px, so you should probably be loading 28 x 28 images into the example.
  43. int m_imageWidth = 28;
  44. int m_imageHeight = 28;
  45. int m_imageSize = 784;
  46. AZStd::vector<AZStd::vector<float>> m_input; //!< This is the input that gets passed into Run(). A binary representation of the pixels in the image.
  47. int64_t m_result{ 0 }; //!< This will be the digit with the highest probability from the inference (what the model thinks the input number was).
  48. private:
  49. // Converts vector of output values into vector of probabilities.
  50. template<typename T>
  51. static void Softmax(T& input);
  52. };
  53. //! This will run a single inference on the passed in MNIST instance.
  54. //! @mnist should be in a ready to run state, ie Load() should have been called.
  55. //! @path is the file location of the image you want to inference (this NEEDS to be an 8-bit color depth png else it won't work).
  56. //! Returns the inference digit and runtime.
  57. MnistReturnValues MnistExample(Mnist& mnist, const char* path);
  58. //! Runs through library of test mnist images in png format, calculating inference accuracy.
  59. //! @testsPerDigit specifies how many runs to do on each digit 0-9. Each run will be done on a unique image of that digit. Limit is
  60. //! ~9,000.
  61. //! @cudaEnable just specifies if the inferences should be run on gpu using CUDA or default cpu.
  62. InferenceData RunMnistSuite(int testsPerDigit, bool cudaEnable);
  63. } // namespace Mnist