DLPrimitives
activation.hpp
1 #pragma once
2 #include <dlprim/operator.hpp>
3 namespace dlprim {
4  namespace json { class value; }
5  struct ActivationConfig {
6  StandardActivations activation = StandardActivations::identity;
7  static ActivationConfig from_json(json::value const &v);
8  };
9 
10 
11 
12  class Activation : public Operator {
13  public:
14 
16  virtual ~Activation();
17 
18  virtual char const *operator_type() const
19  {
20  return "Activation";
21  }
22 
23  virtual void setup(std::vector<TensorSpecs> const &in,
24  std::vector<TensorSpecs> &out,
25  std::vector<TensorSpecs> &parameters,
26  size_t &workspace);
27 
28  virtual void reshape(std::vector<Shape> const &in,
29  std::vector<Shape> &out,
30  size_t &ws);
31 
32  virtual void forward(std::vector<Tensor> &input,
33  std::vector<Tensor> &output,
34  std::vector<Tensor> &parameters,
35  Tensor &workspace,
36  ExecutionContext const &ctx);
37 
38  virtual void backward(std::vector<TensorAndGradient> &input,
39  std::vector<TensorAndGradient> &output,
40  std::vector<TensorAndGradient> &parameters,
41  Tensor &workspace,
42  ExecutionContext const &ctx);
43 
44  static std::unique_ptr<Activation> get_bwd_op(Context &ctx,StandardActivations act,TensorSpecs spec);
45 
46  private:
47  void forward_cpu(Tensor &a,Tensor &output);
48  void backward_cpu(Tensor &y,Tensor &dy,Tensor &dx,float beta);
49  ActivationConfig config_;
50  DataType dtype_;
51  };
52 } // namespace
53 
Definition of Tensor without actual memory/object.
Definition: tensor.hpp:11
Definition: activation.hpp:5
virtual char const * operator_type() const
name of the operator type
Definition: activation.hpp:18
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
This class is central representation of json objects.
Definition: json.hpp:652
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
StandardActivations
Parameterless Activations that can be embedded to general kernels like inner product or convolution...
Definition: definitions.hpp:266
Definition: activation.hpp:12
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121