iVS3D v2.0.0
Loading...
Searching...
No Matches
OrtNeuralNet.h
Go to the documentation of this file.
1#pragma once
2
9#include "NeuralNet.h"
10
11#include <string>
12#include <memory>
13#include <functional>
14#include <vector>
15#include <any>
16#include <iostream>
17#include <optional>
18
19#include <onnxruntime_cxx_api.h>
20
21namespace NN
22{
57 class OrtNeuralNet : public NeuralNet
58 {
59 public:
69 OrtNeuralNet(const std::string& modelPath, bool useCuda = false, int gpuId = 0);
70
76 tl::expected<Tensor, NeuralError> infer(const Tensor& input) override;
77
82 Shape inputShape() const override;
83
88 Shape outputShape() const override;
89
94 int gpuId() const override;
95
96 private:
97 Ort::Env m_env;
98 Ort::SessionOptions m_sessionOptions;
99 Ort::Session m_session{ nullptr };
100 Ort::AllocatorWithDefaultOptions m_allocator;
101
102 std::vector<int64_t> m_inputShape;
103 std::vector<int64_t> m_outputShape;
104 std::string m_inputName;
105 std::string m_outputName;
106 int m_gpuId;
107
108 tl::expected<Ort::Value, NeuralError> tensorToOrtValue(const Tensor& tensor, std::optional<std::vector<int64_t>> shapeOverride = std::nullopt) const;
109 tl::expected<Tensor, NeuralError> ortValueToTensor(const Ort::Value& value) const;
110 };
111}
112
Contains the NeuralNet interface for neural network inference.
Abstract base class for neural networks.
Definition NeuralNet.h:31
A class that implements the NeuralNet interface using ONNX Runtime.
Definition OrtNeuralNet.h:58
int gpuId() const override
Get the GPU ID used by the neural network if it is configured to use GPU.
Definition OrtNeuralNet.cpp:113
tl::expected< Tensor, NeuralError > infer(const Tensor &input) override
Perform inference on the given input Tensor using the ONNX model.
Definition OrtNeuralNet.cpp:58
Shape outputShape() const override
Get the output shape of the neural network.
Definition OrtNeuralNet.cpp:109
Shape inputShape() const override
Get the input shape of the neural network.
Definition OrtNeuralNet.cpp:105
A Tensor represents a N-dimensional array containing elements of the same type. Can be used as input ...
Definition Tensor.h:201
std::vector< int64_t > Shape
Shape of a N-dimensional Tensor represented as the size in each dimension. Can be -1 in case of dynam...
Definition Tensor.h:75
NN Neural Network Library containing Tensor and NeuralNet classes for inference.
Definition NeuralError.h:13