4 #include <dlprim/definitions.hpp> 12 std::ostream &
operator<<(std::ostream &o,Shape
const &s);
20 Shape() : shape_{},size_(0) {}
21 Shape(
size_t b): shape_({b}),size_(1) {}
22 Shape(
size_t b,
size_t c): shape_({b,c}),size_(2) {}
23 Shape(
size_t b,
size_t c,
size_t h): shape_({b,c,h}),size_(3) {}
24 Shape(
size_t b,
size_t c,
size_t h,
size_t w): shape_({b,c,h,w}),size_(4) {}
25 Shape(
size_t b,
size_t c,
size_t d,
size_t h,
size_t w): shape_({b,c,d,h,w}),size_(5) {}
37 s.shape_[s.size_++] = *begin++;
42 bool operator==(
Shape const &other)
const 44 if(size_ != other.size_)
46 for(
int i=0;i<size_;i++)
47 if(shape_[i] != other.shape_[i])
51 bool operator!=(
Shape const &other)
const 53 return !(*
this == other);
64 for(
int i=1;i<size_;i++) {
77 for(
int i=0;i<size_;i++) {
89 size_t &operator[](
int i)
108 size_t d0 = 1,d1 = 1,d2=1;
109 for(
int i=0;i<size_;i++) {
117 return Shape(d0,d1,d2);
149 size_t const *begin()
const 153 size_t const *end()
const 155 return begin() + size_;
159 std::array<size_t,max_tensor_dim> shape_;
Shape split_and_merge_over_axis(int axis) const
Split the shape accordint to axis - before axis and after for example:
Definition: shape.hpp:106
int size() const
dimetions count of the shape
Definition: shape.hpp:85
Shape broadcast_strides(Shape const &target) const
Compute strides needed for broadcasting this shape to target shape.
size_t total_size() const
Total number of elements - product of all items.
Definition: shape.hpp:72
std::ostream & operator<<(std::ostream &out, string_key const &s)
Write the string to the stream.
Definition: json.hpp:363
Shape broadcast(Shape const &ain, Shape const &bin)
calculate numpy style broadcast shape
Tensor shape.
Definition: shape.hpp:18
Shape reshape(std::vector< int > const &dims) const
Reshape, to dims, if dim[i] == 0 the dim is preserverd, if dim[i] == -1 it is calculated from the res...
Shape unsqueeze(int axis) const
Add dimention=1 at axis location, for example for Shape(2,3).unsqueeze(0) == Shape(1,2,3)
void shrink_broadcast_ranges(std::vector< Shape > &shapes)
Broadcast shapes numpy style and remove planes that can be merged.
static constexpr int max_tensor_dim
Maximal number of dimensions in tensor.
Definition: definitions.hpp:254
size_t size_no_batch() const
Total number of elements in shape without the first one - batch.
Definition: shape.hpp:59
Shape squeeze() const
Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze() = Shape(4...
size_t operator[](int i) const
specific dimension
Definition: shape.hpp:96
Thrown in case of invalid parameters.
Definition: definitions.hpp:46
static Shape from_range(It begin, It end)
Initialize from pair of iterators.
Definition: shape.hpp:31
Mane namespace.
Definition: context.hpp:9