DLPrimitives
onnx.hpp
1 #pragma once
2 #include <dlprim/model.hpp>
3 namespace onnx {
4  class TensorProto;
5  class NodeProto;
6 }
7 namespace dlprim {
8 
19  class DLPRIM_API ONNXModel : public ModelBase {
20  public:
21  ONNXModel();
22  virtual ~ONNXModel();
26  void load(std::string const &file_name);
30  virtual json::value const &network();
34  virtual Tensor get_parameter(std::string const &name);
35  private:
36  void load_proto(std::string const &file_name);
37  void prepare_network();
38  void prepare_inputs_outputs();
39  void parse_operators();
40  void validate_outputs();
41  void add_conv(onnx::NodeProto const &node);
42  void add_ip(onnx::NodeProto const &node);
43  void add_matmul(onnx::NodeProto const &node);
44  void add_bn(onnx::NodeProto const &node);
45  void add_operator(onnx::NodeProto const &node,json::value &v,bool add_outputs = true);
46  void add_standard_activation(onnx::NodeProto const &node,std::string const &name,bool validate_inputs = true);
47  void add_concat(onnx::NodeProto const &node);
48  void add_softmax(onnx::NodeProto const &node);
49  void add_elementwise(onnx::NodeProto const &node,std::string const &operation);
50  void add_global_pooling(onnx::NodeProto const &node,std::string const &operation);
51  void add_pool2d(onnx::NodeProto const &node,std::string const &operation);
52  void add_flatten(onnx::NodeProto const &node);
53  void add_clip(onnx::NodeProto const &node);
54  void add_pad(onnx::NodeProto const &node);
55  void add_bias(onnx::NodeProto const &node);
56  void add_squeeze(onnx::NodeProto const &node);
57  void add_reshape(onnx::NodeProto const &node);
58  void handle_constant(onnx::NodeProto const &node);
59  std::pair<std::string,Tensor> transpose_parameter(std::string const &name);
60 
61 
62  template<typename T>
63  T get_scalar_constant(std::string const &name);
64 
65  void check_outputs(onnx::NodeProto const &node,int minv,int maxv=-1);
66  void check_inputs(onnx::NodeProto const &node,int inputs_min,int inputs_max=-1,int params_min=0,int params_max=-1);
67 
68  std::vector<int> tensor_to_intvec(Tensor t);
69  void check_pad_op_2d(onnx::NodeProto const &node,std::vector<int> &pads);
70  std::vector<int> get_pads_2d(onnx::NodeProto const &node);
71  bool has_attr(onnx::NodeProto const &node,std::string const &name);
72  template<typename T>
73  T get_attr(onnx::NodeProto const &node,std::string const &name,T default_value);
74  template<typename T>
75  T get_attr(onnx::NodeProto const &node,std::string const &name);
76  struct Data;
77  std::unique_ptr<Data> d;
78  };
79 
80 }
This class is central representation of json objects.
Definition: json.hpp:652
Definition: onnx.hpp:3
External model for loading ONNX models for inference with dlprim.
Definition: onnx.hpp:19
Base class used for loading non-native model formats to dlprimitives.
Definition: model.hpp:11
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99