DLPrimitives
conv.hpp
1 #pragma once
2 #include <dlprim/tensor.hpp>
3 #include <dlprim/context.hpp>
4 namespace dlprim {
5 namespace core {
6 
11  Conv2DSettings(Conv2DSettings const &) = default;
14  shape(s),
15  dtype(dt)
16  {
17  }
18 
19  Shape shape; // input shape size, note batch is hint rather than requirement
20  DataType dtype=float_data;
21  };
22 
23  class Conv2DBase {
24  public:
25  virtual ~Conv2DBase() {};
26  virtual char const *algo() const = 0;
27  virtual size_t workspace() { return 0; }
28  static Shape get_output_shape(Convolution2DConfigBase const &config,Shape const &in);
29  static Shape get_output_shape_transposed(Convolution2DConfigBase const &config,Shape const &in,int output_pad[2]);
30  };
35  class Conv2DForward : public Conv2DBase {
36  public:
37  virtual ~Conv2DForward() {}
38  virtual void enqueue(Tensor &x,Tensor &w,Tensor *bias,Tensor &y,Tensor &ws,float factor,ExecutionContext const &e) = 0;
46  static std::unique_ptr<Conv2DForward> create(Context &ctx,
47  Conv2DSettings const &config,
48  bool bias,
49  StandardActivations activation = StandardActivations::identity,
50  std::string const &algo = std::string());
51  };
52 
57  public:
58  virtual ~Conv2DBackwardData() {}
59  virtual void enqueue(Tensor &dx,Tensor &w,Tensor &dy,Tensor &ws,float factor,ExecutionContext const &e) = 0;
60  static std::unique_ptr<Conv2DBackwardData> create(Context &ctx,Conv2DSettings const &config,std::string const &algo = std::string());
61  };
62 
67  public:
68  virtual ~Conv2DBackwardFilter() {}
69  virtual void enqueue(Tensor &x,Tensor &dw,Tensor &dy,Tensor &ws,float factor,ExecutionContext const &e) = 0;
70  static std::unique_ptr<Conv2DBackwardFilter> create(Context &ctx,Conv2DSettings const &config,std::string const &algo = std::string());
71  };
72 
73 } // core
74 } // dlprim
Tensor shape.
Definition: shape.hpp:18
Perform InnerProduct/FullyConnected/Dense forward calulations, allow fusing bias and activation into ...
Definition: conv.hpp:35
Configuration of Convoltion.
Definition: conv.hpp:10
Definition: conv.hpp:23
Convolution settings.
Definition: definitions.hpp:301
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
StandardActivations
Parameterless Activations that can be embedded to general kernels like inner product or convolution...
Definition: definitions.hpp:266
Perform Conv2D backward filter calcilations.
Definition: conv.hpp:66
Perform InnerProduct/FullyConnected/Dense backward data calculations.
Definition: conv.hpp:56
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121