DLPrimitives
operator.hpp
1 #pragma once
2 #include <dlprim/definitions.hpp>
3 #include <dlprim/tensor.hpp>
4 #include <dlprim/context.hpp>
5 
6 namespace dlprim {
7  namespace json { class value; }
8 
9  class SharedResource;
10 
11 
15  class Operator {
16  public:
20  Operator(Context const &ctx) :
21  ctx_(ctx),
22  mode_(CalculationsMode::predict)
23  {
24  }
25 
31  {
32  DLPRIM_CHECK(shared_resource_);
33  return *shared_resource_;
34  }
35 
37  void shared_resource(std::shared_ptr<SharedResource> r)
38  {
39  shared_resource_ = r;
40  }
41 
42  virtual ~Operator()
43  {
44  }
45 
47  virtual char const *operator_type() const = 0;
48 
56  virtual void mode(CalculationsMode mode)
57  {
58  mode_ = mode;
59  }
60 
63  {
64  return mode_;
65  }
66 
72  Operator(Operator const &) = delete;
73  void operator=(Operator const &) = delete;
74  Operator(Operator &&) = delete;
75  void operator=(Operator &&) = delete;
76 
77 
83  virtual bool alias_generator()
84  {
85  return false;
86  }
87 
91  virtual void initialize_params(std::vector<Tensor> &/*parameters*/,ExecutionContext const &/*e*/)
92  {
93  }
94 
104  virtual void setup(std::vector<TensorSpecs> const &in,
105  std::vector<TensorSpecs> &out,
106  std::vector<TensorSpecs> &parameters,
107  size_t &workspace) = 0;
108 
116  virtual void reshape(std::vector<Shape> const &in,
117  std::vector<Shape> &out,
118  size_t &workspace) = 0;
119 
129  virtual void forward(std::vector<Tensor> &input,
130  std::vector<Tensor> &output,
131  std::vector<Tensor> &parameters,
132  Tensor &workspace,
133  ExecutionContext const &ctx) = 0;
134 
158  virtual void backward(std::vector<TensorAndGradient> & /*input*/,
159  std::vector<TensorAndGradient> & /*output*/,
160  std::vector<TensorAndGradient> & /*parameters*/,
161  Tensor &/*workspace*/,
162  ExecutionContext const &/*ctx*/)
163  {
164  throw NotImplementedError("backpropogation is not implemented for " + std::string(operator_type()));
165  }
166 
167  protected:
170  std::shared_ptr<SharedResource> shared_resource_;
171  };
172 
176  std::unique_ptr<Operator> create_by_name(Context &ctx,
177  std::string const &name,
178  json::value const &parameters);
179 
180 
181 } // dlprim
SharedResource & shared_resource()
Getter for object that is shared between operators accross the net, for example random numbers genera...
Definition: operator.hpp:30
CalculationsMode
Operation mode of layers - inference of training.
Definition: definitions.hpp:283
virtual void backward(std::vector< TensorAndGradient > &, std::vector< TensorAndGradient > &, std::vector< TensorAndGradient > &, Tensor &, ExecutionContext const &)
Enqueue backward propogation computations.
Definition: operator.hpp:158
virtual void mode(CalculationsMode mode)
Can be called with both train and predict before setup() is called. afterwards if original mode was t...
Definition: operator.hpp:56
virtual void initialize_params(std::vector< Tensor > &, ExecutionContext const &)
Set default parameters iniitalization.
Definition: operator.hpp:91
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
virtual bool alias_generator()
returns true of the operator is alias - generation - it only changes the shape of tensor but not its ...
Definition: operator.hpp:83
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
This class is central representation of json objects.
Definition: json.hpp:652
Operator(Context const &ctx)
Create operator for specific context (device/platform)
Definition: operator.hpp:20
void shared_resource(std::shared_ptr< SharedResource > r)
Setter of the shared resource.
Definition: operator.hpp:37
Context ctx_
OpenCL/CPU Context to work with.
Definition: operator.hpp:168
virtual CalculationsMode mode()
get current mode
Definition: operator.hpp:62
CalculationsMode mode_
computaions mode
Definition: operator.hpp:169
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
Thrown if some stuff is not implemented yet.
Definition: definitions.hpp:38
std::unique_ptr< Operator > create_by_name(Context &ctx, std::string const &name, json::value const &parameters)
Factory - generate operator by its name (type) with parameters needed.
Resources shared by the entire network.
Definition: shared_resource.hpp:11
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121