2 #include <dlprim/operator.hpp> 5 namespace json {
class value; }
7 class Pooling2DBackwardBase;
8 class Pooling2DForward;
22 int kernel[2] = {1,1};
26 bool count_include_pad =
false;
39 virtual void setup(std::vector<TensorSpecs>
const &in,
40 std::vector<TensorSpecs> &out,
41 std::vector<TensorSpecs> ¶meters,
44 virtual void reshape(std::vector<Shape>
const &in,
45 std::vector<Shape> &out,
48 virtual void forward(std::vector<Tensor> &input,
49 std::vector<Tensor> &output,
50 std::vector<Tensor> ¶meters,
54 virtual void backward(std::vector<TensorAndGradient> &input,
55 std::vector<TensorAndGradient> &output,
56 std::vector<TensorAndGradient> ¶meters,
62 int calc_output_size(
int in_size,
int dim);
66 template<
typename Dtype,
typename ReduceOpts>
67 void forward_cpu(
Tensor &in,
Tensor &output,ReduceOpts rop);
70 template<
typename ReduceOpts>
71 void backward_cpu_ave(
Tensor &dx,
Tensor &dy,
float factor,ReduceOpts rop);
78 struct AveReduceValid;
91 std::unique_ptr<core::Pooling2DForward> fwd_;
92 std::unique_ptr<core::Pooling2DBackwardBase> bwd_;
99 static_cast<PoolingBase &
>(cfg) = PoolingBase::from_json(v);
111 return "GlobalPooling";
113 virtual void setup(std::vector<TensorSpecs>
const &in,
114 std::vector<TensorSpecs> &out,
115 std::vector<TensorSpecs> ¶meters,
118 virtual void reshape(std::vector<Shape>
const &in,
119 std::vector<Shape> &out,
122 virtual void forward(std::vector<Tensor> &input,
123 std::vector<Tensor> &output,
124 std::vector<Tensor> ¶meters,
128 virtual void backward(std::vector<TensorAndGradient> &input,
129 std::vector<TensorAndGradient> &output,
130 std::vector<TensorAndGradient> ¶meters,
139 size_t setup_kernel(
Shape const &sp);
143 std::unique_ptr<core::Pooling2DForward> fwd_;
144 std::unique_ptr<core::Pooling2DBackwardBase> bwd_;
Definition: pooling.hpp:21
Tensor shape.
Definition: shape.hpp:18
Definition: pooling.hpp:104
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
virtual char const * operator_type() const
name of the operator type
Definition: pooling.hpp:109
Definition: pooling.hpp:30
Definition: pooling.hpp:95
This class is central representation of json objects.
Definition: json.hpp:652
virtual char const * operator_type() const
name of the operator type
Definition: pooling.hpp:35
Mane namespace.
Definition: context.hpp:9
Definition: pooling.hpp:10
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121