2 #include <dlprim/tensor.hpp> 3 #include <dlprim/context.hpp> 91 float batch_var_factor,
float running_var_factor,
134 Tensor *dx,
float fx_factor,
135 Tensor *dgamma,
float dgamma_factor,
136 Tensor *dbeta,
float dbeta_factor,
153 Tensor &dx,
float dx_factor,
157 static std::unique_ptr<BatchNormFwdBwd> create(
Context &ctx,
Shape const &s,
DataType dt=float_data);
virtual void enqueue_backward_affine(bool training_mode, Tensor &x, Tensor &dy, Tensor &mean, Tensor &var, Tensor &gamma, Tensor *dx, float fx_factor, Tensor *dgamma, float dgamma_factor, Tensor *dbeta, float dbeta_factor, float eps, Tensor &ws, ExecutionContext const &e)=0
Perform backpropogation calculations.
Performs batch normalization computations over channel #1 (when #0 is batch)
Definition: bn.hpp:65
Tensor shape.
Definition: shape.hpp:18
virtual void enqueue_calculate_batch_stats(Tensor &x, Tensor &mean, Tensor &var, Tensor &ws, ExecutionContext const &e)=0
Compute batch mean and variance for input x.
virtual size_t workspace()=0
Workspace size needed for intermediate results of computations.
virtual void enqueue_backward_direct(bool training_mode, Tensor &x, Tensor &dy, Tensor &mean, Tensor &var, Tensor &dx, float dx_factor, float eps, Tensor &ws, ExecutionContext const &e)=0
Perform backpropogation calculations for BN without affine addtition Gamma/Beta.
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
virtual void enqueue_update_running_stats(float batch_mean_factor, float running_mean_factor, Tensor &batch_mean, Tensor &running_mean, float batch_var_factor, float running_var_factor, Tensor &batch_var, Tensor &running_var, Tensor &ws, ExecutionContext const &e)=0
Update running sums as.
virtual void enqueue_forward_direct(Tensor &x, Tensor &y, Tensor &mean, Tensor &var, float eps, Tensor &ws, ExecutionContext const &e)=0
Peform forward computation as y = (x-mean) / sqrt(var + eps)
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
virtual void enqueue_forward_affine(Tensor &x, Tensor &y, Tensor &gamma, Tensor &beta, Tensor &mean, Tensor &var, float eps, Tensor &ws, ExecutionContext const &e)=0
Peform forward computation as y = (x-mean) / sqrt(var + eps) * gamma + beta.
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121