DLPrimitives
reshape.hpp
1 #pragma once
2 #include <dlprim/operator.hpp>
3 namespace dlprim {
4  namespace json { class value; }
5 
6  class ReshapeBase : public Operator {
7  public:
8  ReshapeBase(Context const &);
9  virtual ~ReshapeBase();
10  virtual bool alias_generator()
11  {
12  return true;
13  }
14 
15  virtual void setup(std::vector<TensorSpecs> const &in,
16  std::vector<TensorSpecs> &out,
17  std::vector<TensorSpecs> &parameters,
18  size_t &workspace);
19 
20  virtual void reshape(std::vector<Shape> const &in,
21  std::vector<Shape> &out,
22  size_t &ws);
23 
24 
25  virtual void forward(std::vector<Tensor> &,
26  std::vector<Tensor> &,
27  std::vector<Tensor> &,
28  Tensor &,
29  ExecutionContext const &)
30  {
31  return;
32  }
33 
34  virtual void backward(std::vector<TensorAndGradient> &,
35  std::vector<TensorAndGradient> &t,
36  std::vector<TensorAndGradient> &,
37  Tensor &,
38  ExecutionContext const &)
39  {
40  return;
41  }
42 
43  virtual Shape new_shape(Shape const &in) = 0;
44 
45  };
46 
47  struct FlattenConfig {
48  static FlattenConfig from_json(json::value const &) { return FlattenConfig(); }
49  };
50 
51  class Flatten : public ReshapeBase {
52  public:
53 
54  Flatten(Context &ctx,FlattenConfig const &/*config*/ = FlattenConfig()) : ReshapeBase(ctx) {}
55  virtual ~Flatten() {}
56 
57  virtual char const *operator_type() const
58  {
59  return "Flatten";
60  }
61  virtual Shape new_shape(Shape const &in)
62  {
63  Shape r(in[0],in.size_no_batch());
64  return r;
65  }
66  };
67  struct SqueezeConfig {
68  std::vector<int> dims;
69  bool all=false;
70  static SqueezeConfig from_json(json::value const &);
71  };
72 
73  class Squeeze : public ReshapeBase {
74  public:
75 
76  Squeeze(Context &ctx,SqueezeConfig const &config = SqueezeConfig()) : ReshapeBase(ctx), cfg_(config) {}
77  virtual ~Squeeze() {}
78 
79  virtual char const *operator_type() const
80  {
81  return "Squeeze";
82  }
83  virtual Shape new_shape(Shape const &in)
84  {
85  if(cfg_.all)
86  return in.squeeze();
87  else
88  return in.squeeze(cfg_.dims);
89  }
90  SqueezeConfig cfg_;
91  };
92 
93  struct ReshapeConfig {
94  std::vector<int> dims;
95  static ReshapeConfig from_json(json::value const &);
96  };
97 
98  class Reshape : public ReshapeBase {
99  public:
100 
101  Reshape(Context &ctx,ReshapeConfig const &config = ReshapeConfig()) : ReshapeBase(ctx), cfg_(config) {}
102  virtual ~Reshape() {}
103 
104  virtual char const *operator_type() const
105  {
106  return "Reshape";
107  }
108  virtual Shape new_shape(Shape const &in)
109  {
110  return in.reshape(cfg_.dims);
111  }
112  ReshapeConfig cfg_;
113  };
114 
115 } // namespace
116 
Definition: reshape.hpp:67
Shape squeeze(std::vector< int > dims) const
Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze({2,3}) = Shape(4,5)
Definition: reshape.hpp:47
Tensor shape.
Definition: shape.hpp:18
virtual void backward(std::vector< TensorAndGradient > &, std::vector< TensorAndGradient > &t, std::vector< TensorAndGradient > &, Tensor &, ExecutionContext const &)
Enqueue backward propogation computations.
Definition: reshape.hpp:34
Shape reshape(std::vector< int > const &dims) const
Reshape, to dims, if dim[i] == 0 the dim is preserverd, if dim[i] == -1 it is calculated from the res...
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
Definition: reshape.hpp:73
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
virtual void forward(std::vector< Tensor > &, std::vector< Tensor > &, std::vector< Tensor > &, Tensor &, ExecutionContext const &)
Enqueue forward propogation computations.
Definition: reshape.hpp:25
virtual char const * operator_type() const
name of the operator type
Definition: reshape.hpp:57
This class is central representation of json objects.
Definition: json.hpp:652
size_t size_no_batch() const
Total number of elements in shape without the first one - batch.
Definition: shape.hpp:59
Definition: reshape.hpp:6
Definition: reshape.hpp:51
Definition: reshape.hpp:98
virtual char const * operator_type() const
name of the operator type
Definition: reshape.hpp:79
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
Definition: reshape.hpp:93
virtual bool alias_generator()
returns true of the operator is alias - generation - it only changes the shape of tensor but not its ...
Definition: reshape.hpp:10
virtual char const * operator_type() const
name of the operator type
Definition: reshape.hpp:104
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121