DLPrimitives
common.hpp
1 #pragma once
2 #include <dlprim/tensor.hpp>
3 #include <dlprim/context.hpp>
4 namespace dlprim {
5 namespace core {
6 
11  class Scale {
12  public:
13  Scale(Context &ctx,DataType dtype=float_data);
14  void enqueue(float s,Tensor &t,ExecutionContext const &ec);
15  private:
16  cl::Kernel k_;
17  };
18 
19  void add_tensors(Tensor &a,Tensor &b,Tensor &sum,ExecutionContext const &ec);
24  void scale_tensor(float s,Tensor &t,ExecutionContext const &ec);
25 
29  void fill_tensor(Tensor &t,double value,ExecutionContext const &e);
30 
35  rnd_uniform = 0,
36  rnd_normal = 1,
37  rnd_bernoulli = 2
38  };
51  void fill_random(Tensor &t,cl_ulong philox_seed,cl_ulong philox_seq,RandomDistribution dist,float p1,float p2,ExecutionContext const &e);
52 
53 
54 
58  class SliceCopy {
59  public:
60  SliceCopy(Context &ctx,DataType dtype=float_data);
61  ~SliceCopy();
62 
72  void tensor_slice_copy(int dim,size_t slice,
73  Tensor &target,size_t target_offset,
74  Tensor &source,size_t source_offset,
75  float target_scale,ExecutionContext const &e);
76  private:
77  cl::Kernel kernel_;
78  DataType dtype_;
79  };
80 
81 
82 
83 } // core
84 } // dlprim
void fill_tensor(Tensor &t, double value, ExecutionContext const &e)
Set to zero tensor - OpenCL only.
Class for copying a slice of an tensor.
Definition: common.hpp:58
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
RandomDistribution
Type of random distribution.
Definition: common.hpp:34
DataType
type definition
Definition: definitions.hpp:70
Scale tensor by factor inplace, if s==0 fills with zero so nan is not propagated of s==0...
Definition: common.hpp:11
void fill_random(Tensor &t, cl_ulong philox_seed, cl_ulong philox_seq, RandomDistribution dist, float p1, float p2, ExecutionContext const &e)
Fill tensor with random numbers using provided distribution.
void scale_tensor(float s, Tensor &t, ExecutionContext const &ec)
Scale tensor by factor inplace, if s==0 fills with zero so nan is not propagated of s==0...
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121