DLPrimitives
solver_base.hpp
1 #pragma once
2 #include <dlprim/net.hpp>
3 namespace dlprim {
7  namespace solvers {
9  class SolverBase {
10  public:
12  virtual void init(Net &n,ExecutionContext const &q) = 0;
14  virtual void zero_grad(Net &n,ExecutionContext const &e) = 0;
16  virtual void apply(Net &n,ExecutionContext const &e) = 0;
17  virtual ~SolverBase() {}
21  void step(Net &n,ExecutionContext const &e)
22  {
23  zero_grad(n,e);
24  n.forward(e);
25  n.backward(e);
26  apply(n,e);
27  }
28  };
29  } // solvers
30 } // dlprim
Base class for SGD based optimizers.
Definition: solver_base.hpp:9
virtual void init(Net &n, ExecutionContext const &q)=0
Prepare solver - takes all parameters that need to be trained and prepares buffers.
virtual void zero_grad(Net &n, ExecutionContext const &e)=0
zero all gradients before accumulating them for next batch
virtual void apply(Net &n, ExecutionContext const &e)=0
apply solver updates
Major object used for inference.
Definition: net.hpp:14
Mane namespace.
Definition: context.hpp:9
void step(Net &n, ExecutionContext const &e)
shortcut for single training step zero_grad, forward, backward, apply
Definition: solver_base.hpp:21
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121