DLPrimitives
shape.hpp
1 #pragma once
2 #include <array>
3 #include <vector>
4 #include <dlprim/definitions.hpp>
5 #include <ostream>
6 #include <sstream>
7 
8 namespace dlprim {
9 
10  class Shape;
11 
12  std::ostream &operator<<(std::ostream &o,Shape const &s);
13 
14 
18  class Shape {
19  public:
20  Shape() : shape_{},size_(0) {}
21  Shape(size_t b): shape_({b}),size_(1) {}
22  Shape(size_t b,size_t c): shape_({b,c}),size_(2) {}
23  Shape(size_t b,size_t c,size_t h): shape_({b,c,h}),size_(3) {}
24  Shape(size_t b,size_t c,size_t h,size_t w): shape_({b,c,h,w}),size_(4) {}
25  Shape(size_t b,size_t c,size_t d,size_t h,size_t w): shape_({b,c,d,h,w}),size_(5) {}
26 
30  template<typename It>
31  static Shape from_range(It begin, It end)
32  {
33  Shape s;
34  while(begin!=end) {
35  if(s.size_ >= max_tensor_dim)
36  throw ValidationError("Unsupported tensor size");
37  s.shape_[s.size_++] = *begin++;
38  }
39  return s;
40  }
41 
42  bool operator==(Shape const &other) const
43  {
44  if(size_ != other.size_)
45  return false;
46  for(int i=0;i<size_;i++)
47  if(shape_[i] != other.shape_[i])
48  return false;
49  return true;
50  }
51  bool operator!=(Shape const &other) const
52  {
53  return !(*this == other);
54  }
55 
59  size_t size_no_batch() const
60  {
61  if(size_ <= 0)
62  return 0;
63  size_t r=1;
64  for(int i=1;i<size_;i++) {
65  r*=shape_[i];
66  }
67  return r;
68  }
72  size_t total_size() const
73  {
74  if(size_ == 0)
75  return 0;
76  size_t r=1;
77  for(int i=0;i<size_;i++) {
78  r*=size_t(shape_[i]);
79  }
80  return r;
81  }
85  int size() const
86  {
87  return size_;
88  }
89  size_t &operator[](int i)
90  {
91  return shape_[i];
92  }
96  size_t operator[](int i) const
97  {
98  return shape_[i];
99  }
107  {
108  size_t d0 = 1,d1 = 1,d2=1;
109  for(int i=0;i<size_;i++) {
110  if(i < axis)
111  d0*=shape_[i];
112  else if(i == axis)
113  d1*=shape_[i];
114  else
115  d2*=shape_[i];
116  }
117  return Shape(d0,d1,d2);
118  }
119 
123  Shape unsqueeze(int axis) const;
124 
133  Shape squeeze(std::vector<int> dims) const;
134 
138  Shape squeeze() const;
139 
143  Shape broadcast_strides(Shape const &target) const;
147  Shape reshape(std::vector<int> const &dims) const;
148 
149  size_t const *begin() const
150  {
151  return &shape_[0];
152  }
153  size_t const *end() const
154  {
155  return begin() + size_;
156  }
157 
158  private:
159  std::array<size_t,max_tensor_dim> shape_;
160  int size_;
161  };
162 
164  Shape broadcast(Shape const &ain,Shape const &bin);
165 
177  void shrink_broadcast_ranges(std::vector<Shape> &shapes);
178 
179 
180 };
Shape split_and_merge_over_axis(int axis) const
Split the shape accordint to axis - before axis and after for example:
Definition: shape.hpp:106
int size() const
dimetions count of the shape
Definition: shape.hpp:85
Shape broadcast_strides(Shape const &target) const
Compute strides needed for broadcasting this shape to target shape.
size_t total_size() const
Total number of elements - product of all items.
Definition: shape.hpp:72
std::ostream & operator<<(std::ostream &out, string_key const &s)
Write the string to the stream.
Definition: json.hpp:363
Shape broadcast(Shape const &ain, Shape const &bin)
calculate numpy style broadcast shape
Tensor shape.
Definition: shape.hpp:18
Shape reshape(std::vector< int > const &dims) const
Reshape, to dims, if dim[i] == 0 the dim is preserverd, if dim[i] == -1 it is calculated from the res...
Shape unsqueeze(int axis) const
Add dimention=1 at axis location, for example for Shape(2,3).unsqueeze(0) == Shape(1,2,3)
void shrink_broadcast_ranges(std::vector< Shape > &shapes)
Broadcast shapes numpy style and remove planes that can be merged.
static constexpr int max_tensor_dim
Maximal number of dimensions in tensor.
Definition: definitions.hpp:254
size_t size_no_batch() const
Total number of elements in shape without the first one - batch.
Definition: shape.hpp:59
Shape squeeze() const
Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze() = Shape(4...
size_t operator[](int i) const
specific dimension
Definition: shape.hpp:96
Thrown in case of invalid parameters.
Definition: definitions.hpp:46
static Shape from_range(It begin, It end)
Initialize from pair of iterators.
Definition: shape.hpp:31
Mane namespace.
Definition: context.hpp:9