DLPrimitives
inner_product.hpp
1 #pragma once
2 #include <dlprim/operator.hpp>
3 namespace dlprim {
4  namespace gpu { class GEMM; }
5  namespace json { class value; }
6  namespace core { class IPForward; class IPBackwardData; class IPBackwardFilter; }
7  class BWBias;
8 
10  int inputs = -1;
11  int outputs = -1;
12  bool bias = true;
13  StandardActivations activation = StandardActivations::identity;
14  static InnerProductConfig from_json(json::value const &v);
15  };
16 
17 
18  class InnerProduct : public Operator {
19  public:
20 
21  InnerProduct(Context &ctx,InnerProductConfig const &cfg);
22  virtual ~InnerProduct();
23 
24  virtual char const *operator_type() const
25  {
26  return "InnerProduct";
27  }
28  void initialize_params(std::vector<Tensor> &parameters,ExecutionContext const &e);
29 
30  virtual void setup(std::vector<TensorSpecs> const &in,
31  std::vector<TensorSpecs> &out,
32  std::vector<TensorSpecs> &parameters,
33  size_t &workspace);
34 
35  virtual void reshape(std::vector<Shape> const &in,
36  std::vector<Shape> &out,
37  size_t &ws);
38 
39  virtual void forward(std::vector<Tensor> &input,
40  std::vector<Tensor> &output,
41  std::vector<Tensor> &parameters,
42  Tensor &workspace,
43  ExecutionContext const &ctx);
44 
45  virtual void backward(std::vector<TensorAndGradient> &input,
46  std::vector<TensorAndGradient> &output,
47  std::vector<TensorAndGradient> &parameters,
48  Tensor &workspace,
49  ExecutionContext const &ctx);
50 
51 
52  protected:
53  void forward_cpu(Tensor &in,Tensor &out,Tensor &M,Tensor *bias);
54  void backward_filter_cpu(Tensor &dy,Tensor &x,Tensor &dM,float factor);
55  void backward_data_cpu(Tensor &dy,Tensor &dx,Tensor &M,float factor);
56 
57 
58  InnerProductConfig config_;
59  DataType dtype_;
60  std::unique_ptr<core::IPForward> ip_;
61  std::unique_ptr<core::IPBackwardData> bwd_ip_;
62  std::unique_ptr<core::IPBackwardFilter> bwd_weights_ip_;
63  std::unique_ptr<Operator> activation_;
64  std::unique_ptr<BWBias> bwd_bias_;
65  };
66 
67 }
Definition: inner_product.hpp:9
Definition: inner_product.hpp:18
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
This class is central representation of json objects.
Definition: json.hpp:652
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
StandardActivations
Parameterless Activations that can be embedded to general kernels like inner product or convolution...
Definition: definitions.hpp:266
virtual char const * operator_type() const
name of the operator type
Definition: inner_product.hpp:24
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121