19#include <onnxruntime_cxx_api.h>
76 tl::expected<Tensor, NeuralError>
infer(
const Tensor& input)
override;
94 int gpuId()
const override;
98 Ort::SessionOptions m_sessionOptions;
99 Ort::Session m_session{
nullptr };
100 Ort::AllocatorWithDefaultOptions m_allocator;
102 std::vector<int64_t> m_inputShape;
103 std::vector<int64_t> m_outputShape;
104 std::string m_inputName;
105 std::string m_outputName;
108 tl::expected<Ort::Value, NeuralError> tensorToOrtValue(
const Tensor& tensor, std::optional<std::vector<int64_t>> shapeOverride = std::nullopt)
const;
109 tl::expected<Tensor, NeuralError> ortValueToTensor(
const Ort::Value& value)
const;
Contains the NeuralNet interface for neural network inference.
Abstract base class for neural networks.
Definition NeuralNet.h:31
A class that implements the NeuralNet interface using ONNX Runtime.
Definition OrtNeuralNet.h:58
int gpuId() const override
Get the GPU ID used by the neural network if it is configured to use GPU.
Definition OrtNeuralNet.cpp:113
tl::expected< Tensor, NeuralError > infer(const Tensor &input) override
Perform inference on the given input Tensor using the ONNX model.
Definition OrtNeuralNet.cpp:58
Shape outputShape() const override
Get the output shape of the neural network.
Definition OrtNeuralNet.cpp:109
Shape inputShape() const override
Get the input shape of the neural network.
Definition OrtNeuralNet.cpp:105
A Tensor represents a N-dimensional array containing elements of the same type. Can be used as input ...
Definition Tensor.h:201
std::vector< int64_t > Shape
Shape of a N-dimensional Tensor represented as the size in each dimension. Can be -1 in case of dynam...
Definition Tensor.h:75
NN Neural Network Library containing Tensor and NeuralNet classes for inference.
Definition NeuralError.h:13