2 #include <dlprim/operator.hpp> 4 namespace json {
class value; }
12 bool use_global_stats =
false;
28 virtual void initialize_params(std::vector<Tensor> ¶meters,
ExecutionContext const &e);
32 virtual void setup(std::vector<TensorSpecs>
const &in,
33 std::vector<TensorSpecs> &out,
34 std::vector<TensorSpecs> ¶meters,
37 virtual void reshape(std::vector<Shape>
const &in,
38 std::vector<Shape> &out,
41 virtual void forward(std::vector<Tensor> &input,
42 std::vector<Tensor> &output,
43 std::vector<Tensor> ¶meters,
47 virtual void backward(std::vector<TensorAndGradient> &input,
48 std::vector<TensorAndGradient> &output,
49 std::vector<TensorAndGradient> ¶meters,
55 void backward_cpu(std::vector<TensorAndGradient> &input,
56 std::vector<TensorAndGradient> &output,
57 std::vector<TensorAndGradient> ¶meters,
59 void forward_cpu(std::vector<Tensor> &input,
60 std::vector<Tensor> &output,
61 std::vector<Tensor> ¶meters,
63 void cpu_backward_data(
Tensor &x,
Tensor &dx,
Tensor &dy,
float *mean,
float *var,
float *dy_sum,
float *dyx_sum,
float *gamma_in);
70 static int plane_size(
Shape const &s);
72 Tensor current_mean_,current_var_;
73 Tensor combined_scale_,combined_bias_;
78 std::unique_ptr<core::BatchNormFwdBwd> bn_gpu_;
Performs batch normalization computations over channel #1 (when #0 is batch)
Definition: bn.hpp:65
Tensor shape.
Definition: shape.hpp:18
CalculationsMode
Operation mode of layers - inference of training.
Definition: definitions.hpp:283
virtual CalculationsMode mode()
get current mode
Definition: batch_normalization.hpp:30
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
This class is central representation of json objects.
Definition: json.hpp:652
Definition: batch_normalization.hpp:7
virtual CalculationsMode mode()
get current mode
Definition: operator.hpp:62
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
Definition: batch_normalization.hpp:17
virtual char const * operator_type() const
name of the operator type
Definition: batch_normalization.hpp:23
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121