![]() |
iVS3D v2.0.0
|
A class that implements the NeuralNet interface using ONNX Runtime. More...
#include <OrtNeuralNet.h>
Public Member Functions | |
| OrtNeuralNet (const std::string &modelPath, bool useCuda=false, int gpuId=0) | |
| Construct a new OrtNeuralNet object. | |
| tl::expected< Tensor, NeuralError > | infer (const Tensor &input) override |
| Perform inference on the given input Tensor using the ONNX model. | |
| Shape | inputShape () const override |
| Get the input shape of the neural network. | |
| Shape | outputShape () const override |
| Get the output shape of the neural network. | |
| int | gpuId () const override |
| Get the GPU ID used by the neural network if it is configured to use GPU. | |
Public Member Functions inherited from NN::NeuralNet | |
| tl::expected< Tensor, NeuralError > | operator() (const Tensor &input) |
| Call the infer method with the given input tensor. | |
A class that implements the NeuralNet interface using ONNX Runtime.
This class is responsible for loading an ONNX model, performing inference, and converting between Tensor and ONNX Runtime's Ort::Value. It supports both CPU and GPU execution, depending on the model and the environment setup.
Usage:
| NN::OrtNeuralNet::OrtNeuralNet | ( | const std::string & | modelPath, |
| bool | useCuda = false, |
||
| int | gpuId = 0 |
||
| ) |
Construct a new OrtNeuralNet object.
| modelPath | The path to the ONNX model file. |
| useCuda | Whether to use CUDA for GPU execution. Default is false (CPU). |
| gpuId | The ID of the GPU to use if CUDA is enabled. Default is 0. |
|
overridevirtual |
Get the GPU ID used by the neural network if it is configured to use GPU.
Implements NN::NeuralNet.
|
overridevirtual |
Perform inference on the given input Tensor using the ONNX model.
| input | The input Tensor to the neural network. This tensor must have the correct shape and data type expected by the model. |
Implements NN::NeuralNet.
|
overridevirtual |
Get the input shape of the neural network.
Implements NN::NeuralNet.
|
overridevirtual |
Get the output shape of the neural network.
Implements NN::NeuralNet.