1 #include <dlprim/net.hpp> 3 #include <dlprim/ops/scal.hpp> 4 #include <dlprim/ops/axpby.hpp> 5 #include <dlprim/solvers/solver_base.hpp> 12 float weight_decay = 0.0005;
15 scal_(ctx,dtype),axpby_(ctx,dtype)
21 auto &t = vel_[p.first] =
Tensor(ctx_,p.second.shape(),p.second.dtype());
28 scal_.scale(0,p.second,e);
33 for(
auto &item : vel_) {
34 std::string
const &name = item.first;
38 axpby_.apply(1.0,g,momentum,v,v,e);
39 axpby_.apply((1.0f-weight_decay),p,-lr,v,p,e);
47 std::map<std::string,Tensor> vel_;
Base class for SGD based optimizers.
Definition: solver_base.hpp:9
void zero_grad(Net &n, ExecutionContext const &e)
zero all gradients before accumulating them for next batch
Definition: sgd.hpp:25
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
Tensor & param(std::string const &name)
Get parameter by name, throws ValidationError if does not exist.
Definition: net.hpp:202
void apply(Net &n, ExecutionContext const &e)
apply solver updates
Definition: sgd.hpp:31
Major object used for inference.
Definition: net.hpp:14
Mane namespace.
Definition: context.hpp:9
Tensor & param_diff(std::string const &name)
Get parameter gradient by name, throws ValidationError if does not exist.
Definition: net.hpp:210
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
std::map< std::string, Tensor > & param_diffs()
All operator parameters gradients trainable and not trainable.
Definition: net.hpp:179
void init(Net &n, ExecutionContext const &q)
Prepare solver - takes all parameters that need to be trained and prepares buffers.
Definition: sgd.hpp:18
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121