2 #include <dlprim/operator.hpp> 4 namespace json {
class value; }
15 virtual void setup(std::vector<TensorSpecs>
const &in,
16 std::vector<TensorSpecs> &out,
17 std::vector<TensorSpecs> ¶meters,
20 virtual void reshape(std::vector<Shape>
const &in,
21 std::vector<Shape> &out,
25 virtual void forward(std::vector<Tensor> &,
26 std::vector<Tensor> &,
27 std::vector<Tensor> &,
34 virtual void backward(std::vector<TensorAndGradient> &,
35 std::vector<TensorAndGradient> &t,
36 std::vector<TensorAndGradient> &,
68 std::vector<int> dims;
94 std::vector<int> dims;
Definition: reshape.hpp:67
Shape squeeze(std::vector< int > dims) const
Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze({2,3}) = Shape(4,5)
Definition: reshape.hpp:47
Tensor shape.
Definition: shape.hpp:18
virtual void backward(std::vector< TensorAndGradient > &, std::vector< TensorAndGradient > &t, std::vector< TensorAndGradient > &, Tensor &, ExecutionContext const &)
Enqueue backward propogation computations.
Definition: reshape.hpp:34
Shape reshape(std::vector< int > const &dims) const
Reshape, to dims, if dim[i] == 0 the dim is preserverd, if dim[i] == -1 it is calculated from the res...
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
Definition: reshape.hpp:73
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
virtual void forward(std::vector< Tensor > &, std::vector< Tensor > &, std::vector< Tensor > &, Tensor &, ExecutionContext const &)
Enqueue forward propogation computations.
Definition: reshape.hpp:25
virtual char const * operator_type() const
name of the operator type
Definition: reshape.hpp:57
This class is central representation of json objects.
Definition: json.hpp:652
size_t size_no_batch() const
Total number of elements in shape without the first one - batch.
Definition: shape.hpp:59
Definition: reshape.hpp:6
Definition: reshape.hpp:51
Definition: reshape.hpp:98
virtual char const * operator_type() const
name of the operator type
Definition: reshape.hpp:79
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
Definition: reshape.hpp:93
virtual bool alias_generator()
returns true of the operator is alias - generation - it only changes the shape of tensor but not its ...
Definition: reshape.hpp:10
virtual char const * operator_type() const
name of the operator type
Definition: reshape.hpp:104
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121