DLPrimitives
sgd.hpp
1 #include <dlprim/net.hpp>
2 #include <iostream>
3 #include <dlprim/ops/scal.hpp>
4 #include <dlprim/ops/axpby.hpp>
5 #include <dlprim/solvers/solver_base.hpp>
6 namespace dlprim {
7  namespace solvers {
8  class SGD : public SolverBase {
9  public:
10  float lr = 0.1;
11  float momentum = 0.9;
12  float weight_decay = 0.0005;
13  SGD(Context &ctx,DataType dtype = float_data) :
14  ctx_(ctx),
15  scal_(ctx,dtype),axpby_(ctx,dtype)
16  {
17  }
18  void init(Net &n,ExecutionContext const &q)
19  {
20  for(auto &p : n.param_diffs()) {
21  auto &t = vel_[p.first] = Tensor(ctx_,p.second.shape(),p.second.dtype());
22  scal_.scale(0,t,q);
23  }
24  }
25  void zero_grad(Net &n,ExecutionContext const &e)
26  {
27  for(auto &p : n.param_diffs()) {
28  scal_.scale(0,p.second,e);
29  }
30  }
31  void apply(Net &n,ExecutionContext const &e)
32  {
33  for(auto &item : vel_) {
34  std::string const &name = item.first;
35  Tensor &v = item.second;
36  Tensor &p = n.param(name);
37  Tensor &g = n.param_diff(name);
38  axpby_.apply(1.0,g,momentum,v,v,e); // v = momentum * v - lr * gr
39  axpby_.apply((1.0f-weight_decay),p,-lr,v,p,e);
40  }
41  }
42  private:
43  Context ctx_;
44  Scal scal_;
45  AXPBY axpby_;
46 
47  std::map<std::string,Tensor> vel_;
48  };
49  } // solvers
50 }
Base class for SGD based optimizers.
Definition: solver_base.hpp:9
Definition: axpby.hpp:5
void zero_grad(Net &n, ExecutionContext const &e)
zero all gradients before accumulating them for next batch
Definition: sgd.hpp:25
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
Tensor & param(std::string const &name)
Get parameter by name, throws ValidationError if does not exist.
Definition: net.hpp:202
Definition: sgd.hpp:8
void apply(Net &n, ExecutionContext const &e)
apply solver updates
Definition: sgd.hpp:31
Major object used for inference.
Definition: net.hpp:14
Mane namespace.
Definition: context.hpp:9
Tensor & param_diff(std::string const &name)
Get parameter gradient by name, throws ValidationError if does not exist.
Definition: net.hpp:210
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
std::map< std::string, Tensor > & param_diffs()
All operator parameters gradients trainable and not trainable.
Definition: net.hpp:179
Definition: scal.hpp:6
void init(Net &n, ExecutionContext const &q)
Prepare solver - takes all parameters that need to be trained and prepares buffers.
Definition: sgd.hpp:18
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121