Model.cpp 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <ONNX/Model.h>
  9. namespace ONNX
  10. {
  11. void Model::Load(const InitSettings& initSettings)
  12. {
  13. // Get the FileIOBase to resolve the path to the ONNX gem
  14. AZ::IO::FixedMaxPath onnxModelPath;
  15. // If no filepath provided for onnx model, set default to a model.onnx file in the Assets folder.
  16. if (initSettings.m_modelFile.empty())
  17. {
  18. AZ::IO::FileIOBase* fileIo = AZ::IO::FileIOBase::GetInstance();
  19. fileIo->ResolvePath(onnxModelPath, "@gemroot:ONNX@/Assets/model.onnx");
  20. }
  21. else
  22. {
  23. onnxModelPath = initSettings.m_modelFile;
  24. }
  25. // If no model name is provided, will default to the name of the onnx model file.
  26. if (initSettings.m_modelName.empty())
  27. {
  28. m_modelName = onnxModelPath.Filename().Stem().FixedMaxPathString();
  29. }
  30. else
  31. {
  32. m_modelName = initSettings.m_modelName;
  33. }
  34. m_modelColor = initSettings.m_modelColor;
  35. // Grabs environment created on init of system component.
  36. Ort::Env* env = nullptr;
  37. ONNXRequestBus::BroadcastResult(env, &ONNXRequestBus::Events::GetEnv);
  38. #ifdef ENABLE_CUDA
  39. // OrtCudaProviderOptions must be added to the session options to specify execution on CUDA.
  40. // Can specify a number of parameters about the CUDA execution here - currently all left at default.
  41. Ort::SessionOptions sessionOptions;
  42. if (initSettings.m_cudaEnable)
  43. {
  44. OrtCUDAProviderOptions cuda_options;
  45. sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
  46. }
  47. m_cudaEnable = initSettings.m_cudaEnable;
  48. #endif
  49. // The model_path provided to Ort::Session needs to be const wchar_t*, even though the docs state const char* - doesn't work otherwise.
  50. AZStd::string onnxModelPathString = onnxModelPath.String();
  51. m_session = Ort::Session(*env, AZStd::wstring(onnxModelPathString.cbegin(), onnxModelPathString.cend()).c_str(), sessionOptions);
  52. m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  53. // Grabs memory allocator created on init of system component.
  54. Ort::AllocatorWithDefaultOptions* m_allocator;
  55. ONNXRequestBus::BroadcastResult(m_allocator, &ONNXRequestBus::Events::GetAllocator);
  56. // Extract input names from model file and put into const char* vectors.
  57. // Extract input shapes from model file and put into AZStd::vector<int64_t>.
  58. m_inputCount = m_session.GetInputCount();
  59. for (size_t i = 0; i < m_inputCount; i++)
  60. {
  61. const char* inName = m_session.GetInputName(i, *m_allocator);
  62. m_inputNames.push_back(inName);
  63. std::vector<int64_t> inputShape = m_session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
  64. AZStd::vector<int64_t> azInputShape(inputShape.begin(), inputShape.end());
  65. for (int index = 0; index < azInputShape.size(); index++)
  66. {
  67. if (azInputShape[index] == -1)
  68. {
  69. azInputShape[index] = 1;
  70. }
  71. }
  72. m_inputShapes.push_back(azInputShape);
  73. }
  74. // Extract output names from model file and put into const char* vectors.
  75. // Extract output shapes from model file and put into AZStd::vector<int64_t>.
  76. // Initialize m_outputs vector using output shape and count.
  77. m_outputCount = m_session.GetOutputCount();
  78. AZStd::vector<AZStd::vector<float>> outputs(m_outputCount);
  79. for (size_t i = 0; i < m_outputCount; i++)
  80. {
  81. const char* outName = m_session.GetOutputName(i, *m_allocator);
  82. m_outputNames.push_back(outName);
  83. std::vector<int64_t> outputShape = m_session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
  84. AZStd::vector<int64_t> azOutputShape(outputShape.begin(), outputShape.end());
  85. for (int index = 0; index < azOutputShape.size(); index++)
  86. {
  87. if (azOutputShape[index] == -1)
  88. {
  89. azOutputShape[index] = 1;
  90. }
  91. }
  92. m_outputShapes.push_back(azOutputShape);
  93. int64_t outputSize = 1;
  94. for (int j = 0; j < m_outputShapes[i].size(); j++)
  95. {
  96. // The size of each output is simply all the magnitudes of the shape dimensions multiplied together.
  97. if (m_outputShapes[i][j] > 0)
  98. {
  99. outputSize *= m_outputShapes[i][j];
  100. }
  101. }
  102. AZStd::vector<float> output(outputSize);
  103. outputs[i] = output;
  104. }
  105. m_outputs = outputs;
  106. }
  107. void Model::Run(AZStd::vector<AZStd::vector<float>>& inputs)
  108. {
  109. m_timer.Stamp(); // Start timing of inference.
  110. // Tensor creation is lightweight, and a tensor is just a wrapper around the memory owned by the vector passed in as data during creation.
  111. // As such, creating input and output tensors in each run call does not adversely affect performance.
  112. AZStd::vector<Ort::Value> inputTensors;
  113. for (int i = 0; i < m_inputCount; i++)
  114. {
  115. Ort::Value inputTensor =
  116. Ort::Value::CreateTensor<float>(m_memoryInfo, inputs[i].data(), inputs[i].size(), m_inputShapes[i].data(), m_inputShapes[i].size());
  117. inputTensors.push_back(AZStd::move(inputTensor));
  118. }
  119. AZStd::vector<Ort::Value> outputTensors;
  120. for (int i = 0; i < m_outputCount; i++)
  121. {
  122. Ort::Value outputTensor =
  123. Ort::Value::CreateTensor<float>(m_memoryInfo, m_outputs[i].data(), m_outputs[i].size(), m_outputShapes[i].data(), m_outputShapes[i].size());
  124. outputTensors.push_back(AZStd::move(outputTensor));
  125. }
  126. Ort::RunOptions runOptions;
  127. runOptions.SetRunLogVerbosityLevel(ORT_LOGGING_LEVEL_VERBOSE); // Gives more useful logging info if m_session.Run() fails.
  128. m_session.Run(runOptions, m_inputNames.data(), inputTensors.data(), m_inputCount, m_outputNames.data(), outputTensors.data(), m_outputCount);
  129. float delta = 1000.f * m_timer.GetDeltaTimeInSeconds(); // Finish timing of inference and get time in milliseconds.
  130. m_delta = delta;
  131. ONNXRequestBus::Broadcast(&::ONNX::ONNXRequestBus::Events::AddTimingSample, m_modelName.c_str(), m_delta, m_modelColor);
  132. }
  133. } // namespace ONNX