DLPrimitives
ip.hpp
1 #pragma once
2 #include <dlprim/tensor.hpp>
3 #include <dlprim/context.hpp>
4 namespace dlprim {
8 namespace core {
12  struct IPSettings {
13  int inputs = -1;
14  int outputs = -1;
15  int optimal_batch_size = -1;
16  DataType dtype=float_data;
17  };
18 
23  class IPForward {
24  public:
25  virtual ~IPForward() {}
26  virtual void enqueue(Tensor &x,Tensor &w,Tensor *bias,Tensor &y,ExecutionContext const &e) = 0;
34  static std::unique_ptr<IPForward> create(Context &ctx,
35  IPSettings const &config,
36  bool bias,
37  StandardActivations activation = StandardActivations::identity);
38  };
39 
44  public:
45  virtual ~IPBackwardData() {}
46  virtual void enqueue(Tensor &dx,Tensor &w,Tensor &dy,float factor,ExecutionContext const &e) = 0;
51  static std::unique_ptr<IPBackwardData> create(Context &ctx,IPSettings const &config);
52  };
53 
58  public:
59  virtual ~IPBackwardFilter() {}
60  virtual void enqueue(Tensor &x,Tensor &dw,Tensor &dy,float factor,ExecutionContext const &e) = 0;
65  static std::unique_ptr<IPBackwardFilter> create(Context &ctx,IPSettings const &config);
66  };
67 
68 } // core
69 } // dlprim
int outputs
number of input features
Definition: ip.hpp:14
Configuration of InnerProduct layer.
Definition: ip.hpp:12
int optimal_batch_size
output features
Definition: ip.hpp:15
DataType dtype
Expected batch size the network is used with.
Definition: ip.hpp:16
Perform InnerProduct/FullyConnected/Dense forward calulations, allow fusing bias and activation into ...
Definition: ip.hpp:23
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
Perform InnerProduct/FullyConnected/Dense backward filter calcilations.
Definition: ip.hpp:57
Perform InnerProduct/FullyConnected/Dense backward data calculations.
Definition: ip.hpp:43
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
StandardActivations
Parameterless Activations that can be embedded to general kernels like inner product or convolution...
Definition: definitions.hpp:266
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121