DLPrimitives
bn.hpp
1 #pragma once
2 #include <dlprim/tensor.hpp>
3 #include <dlprim/context.hpp>
4 namespace dlprim {
5 namespace core {
63 
64 
66  public:
67  virtual ~BatchNormFwdBwd() {}
68 
72  virtual size_t workspace() = 0;
73 
74 
80  virtual void enqueue_calculate_batch_stats(Tensor &x,Tensor &mean,Tensor &var,Tensor &ws,ExecutionContext const &e) = 0;
81 
89  virtual void enqueue_update_running_stats(float batch_mean_factor,float running_mean_factor,
90  Tensor &batch_mean,Tensor &running_mean,
91  float batch_var_factor,float running_var_factor,
92  Tensor &batch_var,Tensor &running_var,
93  Tensor &ws,ExecutionContext const &e) = 0;
94 
100  virtual void enqueue_forward_direct(Tensor &x,Tensor &y,
101  Tensor &mean,Tensor &var,float eps,
102  Tensor &ws,ExecutionContext const &e) = 0;
110  virtual void enqueue_forward_affine(Tensor &x,Tensor &y,
111  Tensor &gamma,Tensor &beta,
112  Tensor &mean,Tensor &var,
113  float eps,
114  Tensor &ws,ExecutionContext const &e) = 0;
115 
130  virtual void enqueue_backward_affine(bool training_mode,
131  Tensor &x,Tensor &dy,
132  Tensor &mean,Tensor &var,
133  Tensor &gamma,
134  Tensor *dx,float fx_factor,
135  Tensor *dgamma,float dgamma_factor,
136  Tensor *dbeta,float dbeta_factor,
137  float eps,
138  Tensor &ws,ExecutionContext const &e) = 0;
139 
150  virtual void enqueue_backward_direct(bool training_mode,
151  Tensor &x,Tensor &dy,
152  Tensor &mean,Tensor &var,
153  Tensor &dx,float dx_factor,
154  float eps,
155  Tensor &ws,ExecutionContext const &e) = 0;
156 
157  static std::unique_ptr<BatchNormFwdBwd> create(Context &ctx,Shape const &s,DataType dt=float_data);
158 
159  };
160 
161 } // core
162 } // dlprim
virtual void enqueue_backward_affine(bool training_mode, Tensor &x, Tensor &dy, Tensor &mean, Tensor &var, Tensor &gamma, Tensor *dx, float fx_factor, Tensor *dgamma, float dgamma_factor, Tensor *dbeta, float dbeta_factor, float eps, Tensor &ws, ExecutionContext const &e)=0
Perform backpropogation calculations.
Performs batch normalization computations over channel #1 (when #0 is batch)
Definition: bn.hpp:65
Tensor shape.
Definition: shape.hpp:18
virtual void enqueue_calculate_batch_stats(Tensor &x, Tensor &mean, Tensor &var, Tensor &ws, ExecutionContext const &e)=0
Compute batch mean and variance for input x.
virtual size_t workspace()=0
Workspace size needed for intermediate results of computations.
virtual void enqueue_backward_direct(bool training_mode, Tensor &x, Tensor &dy, Tensor &mean, Tensor &var, Tensor &dx, float dx_factor, float eps, Tensor &ws, ExecutionContext const &e)=0
Perform backpropogation calculations for BN without affine addtition Gamma/Beta.
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
virtual void enqueue_update_running_stats(float batch_mean_factor, float running_mean_factor, Tensor &batch_mean, Tensor &running_mean, float batch_var_factor, float running_var_factor, Tensor &batch_var, Tensor &running_var, Tensor &ws, ExecutionContext const &e)=0
Update running sums as.
virtual void enqueue_forward_direct(Tensor &x, Tensor &y, Tensor &mean, Tensor &var, float eps, Tensor &ws, ExecutionContext const &e)=0
Peform forward computation as y = (x-mean) / sqrt(var + eps)
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
virtual void enqueue_forward_affine(Tensor &x, Tensor &y, Tensor &gamma, Tensor &beta, Tensor &mean, Tensor &var, float eps, Tensor &ws, ExecutionContext const &e)=0
Peform forward computation as y = (x-mean) / sqrt(var + eps) * gamma + beta.
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121