DLPrimitives
concat.hpp
1 #pragma once
2 #include <dlprim/operator.hpp>
3 namespace dlprim {
4  namespace json { class value; }
5  namespace core { class SliceCopy; class Scale; }
6 
7  struct ConcatConfig {
8  int dim = 1;
9  static ConcatConfig from_json(json::value const &v);
10  };
11 
12 
13  struct SliceConfig {
14  int dim = 1;
15  int begin = 0;
16  int end = -1;
17  static SliceConfig from_json(json::value const &v);
18  };
19 
21  protected:
22  template<typename Cp>
23  static void copy_cpu(size_t &offset,int dim,Tensor &in,Tensor &out,Cp cp);
24 
25  };
26 
27  class Concat : public Operator, public ConcatSliceBase {
28  public:
29  Concat(Context &ctx,ConcatConfig const &config);
30  virtual ~Concat();
31  virtual char const *operator_type() const
32  {
33  return "Concat";
34  }
35 
36  virtual void setup(std::vector<TensorSpecs> const &in,
37  std::vector<TensorSpecs> &out,
38  std::vector<TensorSpecs> &parameters,
39  size_t &workspace);
40 
41  virtual void reshape(std::vector<Shape> const &in,
42  std::vector<Shape> &out,
43  size_t &ws);
44 
45  virtual void forward(std::vector<Tensor> &input,
46  std::vector<Tensor> &output,
47  std::vector<Tensor> &parameters,
48  Tensor &workspace,
49  ExecutionContext const &ctx);
50 
51  virtual void backward(std::vector<TensorAndGradient> &input,
52  std::vector<TensorAndGradient> &output,
53  std::vector<TensorAndGradient> &parameters,
54  Tensor &workspace,
55  ExecutionContext const &ctx);
56  private:
57 
58  ConcatConfig cfg_;
59  DataType dtype_;
60  std::unique_ptr<core::SliceCopy> copy_;
61  };
62 
63  class Slice : public Operator, public ConcatSliceBase {
64  public:
65  Slice(Context &ctx,SliceConfig const &config);
66  virtual ~Slice();
67  virtual char const *operator_type() const
68  {
69  return "Slice";
70  }
71 
72  virtual void setup(std::vector<TensorSpecs> const &in,
73  std::vector<TensorSpecs> &out,
74  std::vector<TensorSpecs> &parameters,
75  size_t &workspace);
76 
77  virtual void reshape(std::vector<Shape> const &in,
78  std::vector<Shape> &out,
79  size_t &ws);
80 
81  virtual void forward(std::vector<Tensor> &input,
82  std::vector<Tensor> &output,
83  std::vector<Tensor> &parameters,
84  Tensor &workspace,
85  ExecutionContext const &ctx);
86 
87  virtual void backward(std::vector<TensorAndGradient> &input,
88  std::vector<TensorAndGradient> &output,
89  std::vector<TensorAndGradient> &parameters,
90  Tensor &workspace,
91  ExecutionContext const &ctx);
92  private:
93 
94  SliceConfig cfg_;
95  DataType dtype_;
96  std::unique_ptr<core::SliceCopy> copy_;
97  std::unique_ptr<core::Scale> scale_;
98  };
99 }
virtual char const * operator_type() const
name of the operator type
Definition: concat.hpp:31
Definition: concat.hpp:13
Definition: concat.hpp:7
virtual char const * operator_type() const
name of the operator type
Definition: concat.hpp:67
Definition: concat.hpp:63
Definition: concat.hpp:20
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
This class is central representation of json objects.
Definition: json.hpp:652
Definition: concat.hpp:27
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121