DLPrimitives
Public Member Functions | Static Public Member Functions | List of all members
dlprim::core::BatchNormFwdBwd Class Referenceabstract

Performs batch normalization computations over channel #1 (when #0 is batch) More...

#include <include/dlprim/core/bn.hpp>

Public Member Functions

virtual size_t workspace ()=0
 Workspace size needed for intermediate results of computations.
 
virtual void enqueue_calculate_batch_stats (Tensor &x, Tensor &mean, Tensor &var, Tensor &ws, ExecutionContext const &e)=0
 Compute batch mean and variance for input x. More...
 
virtual void enqueue_update_running_stats (float batch_mean_factor, float running_mean_factor, Tensor &batch_mean, Tensor &running_mean, float batch_var_factor, float running_var_factor, Tensor &batch_var, Tensor &running_var, Tensor &ws, ExecutionContext const &e)=0
 Update running sums as. More...
 
virtual void enqueue_forward_direct (Tensor &x, Tensor &y, Tensor &mean, Tensor &var, float eps, Tensor &ws, ExecutionContext const &e)=0
 Peform forward computation as y = (x-mean) / sqrt(var + eps) More...
 
virtual void enqueue_forward_affine (Tensor &x, Tensor &y, Tensor &gamma, Tensor &beta, Tensor &mean, Tensor &var, float eps, Tensor &ws, ExecutionContext const &e)=0
 Peform forward computation as y = (x-mean) / sqrt(var + eps) * gamma + beta. More...
 
virtual void enqueue_backward_affine (bool training_mode, Tensor &x, Tensor &dy, Tensor &mean, Tensor &var, Tensor &gamma, Tensor *dx, float fx_factor, Tensor *dgamma, float dgamma_factor, Tensor *dbeta, float dbeta_factor, float eps, Tensor &ws, ExecutionContext const &e)=0
 Perform backpropogation calculations. More...
 
virtual void enqueue_backward_direct (bool training_mode, Tensor &x, Tensor &dy, Tensor &mean, Tensor &var, Tensor &dx, float dx_factor, float eps, Tensor &ws, ExecutionContext const &e)=0
 Perform backpropogation calculations for BN without affine addtition Gamma/Beta. More...
 

Static Public Member Functions

static std::unique_ptr< BatchNormFwdBwdcreate (Context &ctx, Shape const &s, DataType dt=float_data)
 

Detailed Description

Performs batch normalization computations over channel #1 (when #0 is batch)

Pseudo code parameters:

// Layer Data
Tensor running_mean,running_var,gamma,beta;
// Temorary Data kept between FW and BW
Tensor mean,var;
// Workspace
Tensor ws;

Actual pseudo code calcultions Affine, Train

// Forward Pass
enqueue_update_running_stats(0.1,0.9,mean,running_mean,
0.1 * m/(m-1),0.9,var,running_var,ws);
enqueue_forward_affine(x,y, gamma,beta, mean, var,ws);
// Backward pass
enqueue_backward_affine(true,x,dy,mean,var,gamma,&dx,&dgamma,&dbeta,ws);

Affine, Test (fixed batch)

// Forward Pass
enqueue_forward_affine(x,y, gamma,beta, running_mean, running_var,ws);
// Backward pass
enqueue_backward_affine(false,x,dy,running_mean,runnig_var,gamma,&dx,&dgamma,&dbeta,ws);

Without affine, Train

// Forward Pass
enqueue_update_running_stats(0.1,0.9,mean,running_mean,
0.1 * m/(m-1),0.9,var,running_var,ws);
enqueue_forward_direct(x,y, mean, var,ws);
// Backward pass
enqueue_backward_direct(true,x,dy,mean,var,dx,ws);

without affine, Test (fixed batch)

// Forward Pass
enqueue_forward_direct(x,y, running_mean, running_var,ws);
// Backward pass
enqueue_backward_direct(false,x,dy,running_mean,runnig_var,dx,ws);

/

Member Function Documentation

virtual void dlprim::core::BatchNormFwdBwd::enqueue_backward_affine ( bool  training_mode,
Tensor x,
Tensor dy,
Tensor mean,
Tensor var,
Tensor gamma,
Tensor dx,
float  fx_factor,
Tensor dgamma,
float  dgamma_factor,
Tensor dbeta,
float  dbeta_factor,
float  eps,
Tensor ws,
ExecutionContext const &  e 
)
pure virtual

Perform backpropogation calculations.

training_mode - assumes that mean/var were calculated on batches of X - they need to be kept from forward stage otherwise mean/var considered constant values

gamma/beta affine transofrmation after BN

dy - top gradient for backpropogation dx - calculate backpropogation on X dgamma - calculate backpropogation gradient for gamma scale dbeta - calculate backpropogation gradient for beta scale ws - worksspace

virtual void dlprim::core::BatchNormFwdBwd::enqueue_backward_direct ( bool  training_mode,
Tensor x,
Tensor dy,
Tensor mean,
Tensor var,
Tensor dx,
float  dx_factor,
float  eps,
Tensor ws,
ExecutionContext const &  e 
)
pure virtual

Perform backpropogation calculations for BN without affine addtition Gamma/Beta.

training_mode - assumes that mean/var were calculated on batches of X - they need to be kept from forward stage otherwise mean/var considered constant values

dy - top gradient for backpropogation dx - calculate backpropogation on X ws - worksspace

virtual void dlprim::core::BatchNormFwdBwd::enqueue_calculate_batch_stats ( Tensor x,
Tensor mean,
Tensor var,
Tensor ws,
ExecutionContext const &  e 
)
pure virtual

Compute batch mean and variance for input x.

Note mean and var shoudl have Shape(features) where features is x.shape()[1]

virtual void dlprim::core::BatchNormFwdBwd::enqueue_forward_affine ( Tensor x,
Tensor y,
Tensor gamma,
Tensor beta,
Tensor mean,
Tensor var,
float  eps,
Tensor ws,
ExecutionContext const &  e 
)
pure virtual

Peform forward computation as y = (x-mean) / sqrt(var + eps) * gamma + beta.

Notes:

  • mean/var can be taken from batch or from global running stats as per user request
  • mean/var and gamma/beta are converted to single y=ax+b and than computation is done in a single step
virtual void dlprim::core::BatchNormFwdBwd::enqueue_forward_direct ( Tensor x,
Tensor y,
Tensor mean,
Tensor var,
float  eps,
Tensor ws,
ExecutionContext const &  e 
)
pure virtual

Peform forward computation as y = (x-mean) / sqrt(var + eps)

Note mean/var can be taken from batch or from global running stats as per user request

virtual void dlprim::core::BatchNormFwdBwd::enqueue_update_running_stats ( float  batch_mean_factor,
float  running_mean_factor,
Tensor batch_mean,
Tensor running_mean,
float  batch_var_factor,
float  running_var_factor,
Tensor batch_var,
Tensor running_var,
Tensor ws,
ExecutionContext const &  e 
)
pure virtual

Update running sums as.

running_mean = running_mean_factor * running_mean + batch_mean_factor * batch_mean;
running_var = running_var_factor * running_var + batch_var_factor * batch_var;

The documentation for this class was generated from the following file: