iVS3D v2.0.0
Loading...
Searching...
No Matches
NeuralNet.h
Go to the documentation of this file.
1#pragma once
2
10#include "Tensor.h"
11#include "NeuralError.h"
12
13#include <tl/expected.hpp>
14#include <string>
15#include <vector>
16#include <memory>
17
18namespace NN
19{
31 class NeuralNet {
32 public:
33 virtual ~NeuralNet() = default;
34
46 virtual tl::expected<Tensor, NeuralError> infer(const Tensor& input) = 0;
47
56 tl::expected<Tensor, NeuralError> operator()(const Tensor& input) {
57 return infer(input);
58 }
59
65 virtual Shape inputShape() const = 0;
66
72 virtual Shape outputShape() const = 0;
73
83 virtual int gpuId() const = 0;
84 };
85
95 using NeuralNetPtr = std::shared_ptr<NeuralNet>;
96}
Defines error handling classes for the neural network module.
Contains the Tensor class for representing N-dimensional arrays with various data types.
Abstract base class for neural networks.
Definition NeuralNet.h:31
virtual int gpuId() const =0
Get the GPU ID used by the neural network if it is configured to use GPU.
virtual tl::expected< Tensor, NeuralError > infer(const Tensor &input)=0
Perform inference on the given input tensor.
virtual Shape outputShape() const =0
Get the output shape of the neural network. This might contain dynamic dimensions (e....
virtual Shape inputShape() const =0
Get the input shape of the neural network. This might contain dynamic dimensions (e....
tl::expected< Tensor, NeuralError > operator()(const Tensor &input)
Call the infer method with the given input tensor.
Definition NeuralNet.h:56
A Tensor represents a N-dimensional array containing elements of the same type. Can be used as input ...
Definition Tensor.h:201
std::vector< int64_t > Shape
Shape of a N-dimensional Tensor represented as the size in each dimension. Can be -1 in case of dynam...
Definition Tensor.h:75
std::shared_ptr< NeuralNet > NeuralNetPtr
Smart pointer type for managing NeuralNet instances.
Definition NeuralNet.h:95
NN Neural Network Library containing Tensor and NeuralNet classes for inference.
Definition NeuralError.h:13