DLPrimitives
|
#include <include/dlprim/shape.hpp>
Public Member Functions | |
Shape (size_t b) | |
Shape (size_t b, size_t c) | |
Shape (size_t b, size_t c, size_t h) | |
Shape (size_t b, size_t c, size_t h, size_t w) | |
Shape (size_t b, size_t c, size_t d, size_t h, size_t w) | |
bool | operator== (Shape const &other) const |
bool | operator!= (Shape const &other) const |
size_t | size_no_batch () const |
Total number of elements in shape without the first one - batch. | |
size_t | total_size () const |
Total number of elements - product of all items. | |
int | size () const |
dimetions count of the shape | |
size_t & | operator[] (int i) |
size_t | operator[] (int i) const |
specific dimension | |
Shape | split_and_merge_over_axis (int axis) const |
Split the shape accordint to axis - before axis and after for example: More... | |
Shape | unsqueeze (int axis) const |
Add dimention=1 at axis location, for example for Shape(2,3).unsqueeze(0) == Shape(1,2,3) | |
Shape | squeeze (std::vector< int > dims) const |
Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze({2,3}) = Shape(4,5) More... | |
Shape | squeeze () const |
Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze() = Shape(4,5) | |
Shape | broadcast_strides (Shape const &target) const |
Compute strides needed for broadcasting this shape to target shape. | |
Shape | reshape (std::vector< int > const &dims) const |
Reshape, to dims, if dim[i] == 0 the dim is preserverd, if dim[i] == -1 it is calculated from the rest of dimensions. | |
size_t const * | begin () const |
size_t const * | end () const |
Static Public Member Functions | |
template<typename It > | |
static Shape | from_range (It begin, It end) |
Initialize from pair of iterators. | |
Tensor shape.
|
inline |
Split the shape accordint to axis - before axis and after for example:
References dlprim::broadcast(), broadcast_strides(), reshape(), dlprim::shrink_broadcast_ranges(), squeeze(), and unsqueeze().
Shape dlprim::Shape::squeeze | ( | std::vector< int > | dims | ) | const |
Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze({2,3}) = Shape(4,5)
if dim values is negative it is counted from end
Note for all i in [0:dims.size)
it is required shape[dim[i]] == 1
Referenced by dlprim::Squeeze::operator_type().