DLPrimitives
adam.hpp
1 #include <dlprim/net.hpp>
2 #include <dlprim/ops/scal.hpp>
3 #include <dlprim/core/pointwise.hpp>
4 #include <dlprim/ops/initialization.hpp>
5 #include <dlprim/solvers/solver_base.hpp>
6 #include <iostream>
7 #include <cmath>
8 
9 namespace dlprim {
10  namespace solvers {
11  class Adam : public SolverBase {
12  public:
13  float lr = 0.001;
14  float beta1 = 0.9;
15  float beta2 = 0.999;
16  float eps = 1e-8;
17  float weight_decay = 0.0005;
18  Adam(Context &ctx) : ctx_(ctx)
19  {
20  }
21  void init(Net &n,ExecutionContext const &q)
22  {
23  for(auto &p : n.param_diffs()) {
24  auto &v = v_[p.first] = Tensor(ctx_,p.second.shape(),p.second.dtype());
25  auto &m = m_[p.first] = Tensor(ctx_,p.second.shape(),p.second.dtype());
26  set_to_zero(v,q);
27  set_to_zero(m,q);
28  }
29  t_ = 0;
30  }
31  void zero_grad(Net &n,ExecutionContext const &e)
32  {
33  for(auto &p : n.param_diffs()) {
34  set_to_zero(p.second,e);
35  }
36  }
37  void apply(Net &n,ExecutionContext const &e)
38  {
39  t_++;
40  inv_b1_ = 1 / (1 - std::pow(beta1,t_));
41  inv_b2_ = 1 / (1 - std::pow(beta2,t_));
42 
43  for(auto &item : n.param_diffs()) {
44  std::string const &name = item.first;
45  Tensor &v = v_[name];
46  Tensor &p = n.param(name);
47  Tensor &g = item.second;
48  Tensor &m = m_[name];
49  if(ctx_.is_cpu_context()) {
50  apply_cpu(p,g,m,v);
51  }
52  else {
53  apply_gpu(p,g,m,v,e);
54  }
55  }
56  }
57  private:
58  void apply_cpu(Tensor &p_t,Tensor &g_t,Tensor &m_t,Tensor &v_t)
59  {
60  size_t size = p_t.shape().total_size();
61  float *p = p_t.data<float>();
62  float *g = g_t.data<float>();
63  float *m = m_t.data<float>();
64  float *v = v_t.data<float>();
65  for(size_t i=0;i<size;i++) {
66  float grad = g[i] + weight_decay * p[i];
67  float m_next = beta1 * m[i] + (1-beta1) * grad;
68  float v_next = beta2 * v[i] + (1-beta2) * grad * grad;
69  float m_top = m_next * inv_b1_;
70  float v_top = v_next * inv_b2_;
71  float p_next = p[i] - lr * m_top / (std::sqrt(v_top) + eps);
72 
73  m[i] = m_next;
74  v[i] = v_next;
75  p[i] = p_next;
76  }
77  }
78  void apply_gpu(Tensor &p,Tensor &g,Tensor &m,Tensor &v,ExecutionContext const &e)
79  {
80  core::pointwise_operation({p,g,m,v},
81  {p,m,v},
82  {beta1,beta2,inv_b1_,inv_b2_,lr,weight_decay,eps},
83  R"xxx(
84  dtype p=x0, g=x1, m=x2, v=x3;
85  dtype beta1 = w0,beta2 = w1,inv_b1 = w2,inv_b2=w3,lr=w4,weight_decay=w5,eps=w6;
86  dtype grad = g + weight_decay * p;
87  dtype m_next = beta1 * m + (1-beta1) * grad;
88  dtype v_next = beta2 * v + (1-beta2) * grad * grad;
89  dtype m_top = m_next * inv_b1;
90  dtype v_top = v_next * inv_b2;
91  dtype p_next = p - lr * m_top / (sqrt(v_top) + eps);
92  y0 = p_next;
93  y1 = m_next;
94  y2 = v_next;
95  )xxx",
96  e);
97  }
98 
99  Context ctx_;
100  std::map<std::string,Tensor> m_;
101  std::map<std::string,Tensor> v_;
102  int t_;
103  float inv_b1_,inv_b2_;
104  };
105  } // solvers
106 }
Base class for SGD based optimizers.
Definition: solver_base.hpp:9
size_t total_size() const
Total number of elements - product of all items.
Definition: shape.hpp:72
void zero_grad(Net &n, ExecutionContext const &e)
zero all gradients before accumulating them for next batch
Definition: adam.hpp:31
void init(Net &n, ExecutionContext const &q)
Prepare solver - takes all parameters that need to be trained and prepares buffers.
Definition: adam.hpp:21
void apply(Net &n, ExecutionContext const &e)
apply solver updates
Definition: adam.hpp:37
T * data()
get pointer to the host pointer and cast to relevant type
Definition: tensor.hpp:246
void pointwise_operation(std::vector< Tensor > xs, std::vector< Tensor > ys, std::vector< double > ws, std::string const &code, ExecutionContext const &ec)
per form operations function(xs,ws)->yw such that each tensor in xs and ys has same shape...
void set_to_zero(Tensor &t, ExecutionContext const &e)
Set value of t to zero.
Shape const & shape() const
get tensor shape
Definition: tensor.hpp:134
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
Tensor & param(std::string const &name)
Get parameter by name, throws ValidationError if does not exist.
Definition: net.hpp:202
Major object used for inference.
Definition: net.hpp:14
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
bool is_cpu_context() const
Returns true if the context was created as CPU context.
Definition: context.hpp:348
std::map< std::string, Tensor > & param_diffs()
All operator parameters gradients trainable and not trainable.
Definition: net.hpp:179
Definition: adam.hpp:11
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121