DLPrimitives
nll_loss.hpp
1 #pragma once
2 #include <dlprim/operator.hpp>
3 
4 namespace dlprim {
5  namespace json { class value; }
6 
7  struct NLLLossConfig {
8  enum Reduction {
9  reduce_none,
10  reduce_sum,
11  reduce_mean
12  };
13  Reduction reduce = reduce_mean;
14  static NLLLossConfig from_json(json::value const &v);
15  };
16 
17  class NLLLoss : public Operator {
18  public:
19  NLLLoss(Context &ctx,NLLLossConfig const &cfg=NLLLossConfig());
20  virtual ~NLLLoss();
21  virtual char const *operator_type() const
22  {
23  return "NLLLoss";
24  }
25 
26  virtual void setup(std::vector<TensorSpecs> const &in,
27  std::vector<TensorSpecs> &out,
28  std::vector<TensorSpecs> &parameters,
29  size_t &workspace);
30 
31  virtual void reshape(std::vector<Shape> const &in,
32  std::vector<Shape> &out,
33  size_t &ws);
34 
35  virtual void forward(std::vector<Tensor> &input,
36  std::vector<Tensor> &output,
37  std::vector<Tensor> &parameters,
38  Tensor &workspace,
39  ExecutionContext const &ctx);
40 
41  virtual void backward( std::vector<TensorAndGradient> &input,
42  std::vector<TensorAndGradient> &output,
43  std::vector<TensorAndGradient> &,
44  Tensor &,
45  ExecutionContext const &e);
46 
47  private:
48  template<typename Index>
49  void forwad_cpu(Tensor &x,Tensor &lbl,Tensor &y);
50  template<typename Index>
51  void backward_cpu(Tensor &dx,Tensor &lbl,Tensor &dy,float accum);
52 
53  NLLLossConfig cfg_;
54  };
55 
56 }// dlprim
57 
Definition: nll_loss.hpp:7
virtual char const * operator_type() const
name of the operator type
Definition: nll_loss.hpp:21
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
This class is central representation of json objects.
Definition: json.hpp:652
Definition: nll_loss.hpp:17
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121