iVS3D v2.0.9
Loading...
Searching...
No Matches
OrtNeuralNet.h
Go to the documentation of this file.
1#pragma once
2
9#include "NeuralNet.h"
10
11#include <string>
12#include <memory>
13#include <functional>
14#include <vector>
15#include <any>
16#include <iostream>
17#include <optional>
18
19#include <onnxruntime_cxx_api.h>
20
21namespace NN
22{
57 class OrtNeuralNet : public NeuralNet
58 {
59 public:
69 OrtNeuralNet(const std::string& modelPath, bool useCuda = false, int gpuId = 0);
70
76 tl::expected<std::vector<Tensor>, NeuralError> infer(const Tensor& input) override;
77
82 size_t inputCount() const override;
83
88 size_t outputCount() const override;
89
95 Shape inputShape(size_t idx = 0) const override;
96
102 Shape outputShape(size_t idx = 0) const override;
103
109 std::string inputName(size_t idx = 0) const override;
110
116 std::string outputName(size_t idx = 0) const override;
121 int gpuId() const override;
122
123 private:
124 Ort::Env m_env;
125 Ort::SessionOptions m_sessionOptions;
126 Ort::Session m_session{ nullptr };
127 Ort::AllocatorWithDefaultOptions m_allocator;
128
129 std::vector<int64_t> m_inputShape; // for now only one input is supported!
130 std::vector<std::vector<int64_t>> m_outputShapes;
131 std::string m_inputName;
132 std::vector<std::string> m_outputNames;
133 int m_gpuId;
134
135 tl::expected<Ort::Value, NeuralError> tensorToOrtValue(const Tensor& tensor, std::optional<std::vector<int64_t>> shapeOverride = std::nullopt) const;
136 tl::expected<Tensor, NeuralError> ortValueToTensor(const Ort::Value& value) const;
137 };
138}
139
Contains the NeuralNet interface for neural network inference.
Represents an error that occurred in the neural network module and contains the error type and messag...
Definition NeuralError.h:48
Abstract base class for neural networks.
Definition NeuralNet.h:31
A class that implements the NeuralNet interface using ONNX Runtime.
Definition OrtNeuralNet.h:58
int gpuId() const override
Get the GPU ID used by the neural network if it is configured to use GPU.
Definition OrtNeuralNet.cpp:163
std::string outputName(size_t idx=0) const override
Get the name of the output tensor.
Definition OrtNeuralNet.cpp:156
Shape outputShape(size_t idx=0) const override
Get the output shape of the neural network.
Definition OrtNeuralNet.cpp:145
Shape inputShape(size_t idx=0) const override
Get the input shape of the neural network.
Definition OrtNeuralNet.cpp:143
tl::expected< std::vector< Tensor >, NeuralError > infer(const Tensor &input) override
Perform inference on the given input Tensor using the ONNX model.
Definition OrtNeuralNet.cpp:71
std::string inputName(size_t idx=0) const override
Get the name of the input tensor.
Definition OrtNeuralNet.cpp:152
size_t outputCount() const override
Get the number of outputs of the neural network.
Definition OrtNeuralNet.cpp:141
size_t inputCount() const override
Get the number of inputs of the neural network.
Definition OrtNeuralNet.cpp:139
A Tensor represents a N-dimensional array containing elements of the same type. Can be used as input ...
Definition Tensor.h:201
std::vector< int64_t > Shape
Shape of a N-dimensional Tensor represented as the size in each dimension. Can be -1 in case of dynam...
Definition Tensor.h:75
NN Neural Network Library containing Tensor and NeuralNet classes for inference.
Definition NeuralError.h:13