DLPrimitives
Public Member Functions | Static Public Member Functions | List of all members
dlprim::Shape Class Reference

Tensor shape. More...

#include <include/dlprim/shape.hpp>

Public Member Functions

 Shape (size_t b)
 
 Shape (size_t b, size_t c)
 
 Shape (size_t b, size_t c, size_t h)
 
 Shape (size_t b, size_t c, size_t h, size_t w)
 
 Shape (size_t b, size_t c, size_t d, size_t h, size_t w)
 
bool operator== (Shape const &other) const
 
bool operator!= (Shape const &other) const
 
size_t size_no_batch () const
 Total number of elements in shape without the first one - batch.
 
size_t total_size () const
 Total number of elements - product of all items.
 
int size () const
 dimetions count of the shape
 
size_t & operator[] (int i)
 
size_t operator[] (int i) const
 specific dimension
 
Shape split_and_merge_over_axis (int axis) const
 Split the shape accordint to axis - before axis and after for example: More...
 
Shape unsqueeze (int axis) const
 Add dimention=1 at axis location, for example for Shape(2,3).unsqueeze(0) == Shape(1,2,3)
 
Shape squeeze (std::vector< int > dims) const
 Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze({2,3}) = Shape(4,5) More...
 
Shape squeeze () const
 Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze() = Shape(4,5)
 
Shape broadcast_strides (Shape const &target) const
 Compute strides needed for broadcasting this shape to target shape.
 
Shape reshape (std::vector< int > const &dims) const
 Reshape, to dims, if dim[i] == 0 the dim is preserverd, if dim[i] == -1 it is calculated from the rest of dimensions.
 
size_t const * begin () const
 
size_t const * end () const
 

Static Public Member Functions

template<typename It >
static Shape from_range (It begin, It end)
 Initialize from pair of iterators.
 

Detailed Description

Tensor shape.

Member Function Documentation

Shape dlprim::Shape::split_and_merge_over_axis ( int  axis) const
inline

Split the shape accordint to axis - before axis and after for example:

  • [2,3,4,5] split axis==2 -> [6,4,5]
  • [2,3,4,5] split axis == 0 -> [1,2,60]
  • [2,3] split axis == 2 -> [6,1,1]

References dlprim::broadcast(), broadcast_strides(), reshape(), dlprim::shrink_broadcast_ranges(), squeeze(), and unsqueeze().

Shape dlprim::Shape::squeeze ( std::vector< int >  dims) const

Remove dimesnions containing 1 that appear at dims, for example Shape(4,5,1,1).squeeze({2,3}) = Shape(4,5)

if dim values is negative it is counted from end

Note for all i in [0:dims.size) it is required shape[dim[i]] == 1

Referenced by dlprim::Squeeze::operator_type().


The documentation for this class was generated from the following file: