DLPrimitives
mse_loss.hpp
1 #pragma once
2 #include <dlprim/operator.hpp>
3 
4 namespace dlprim {
5  namespace json { class value; }
6  namespace core { class PointwiseOperationBroadcastReduce; }
7 
8  struct MSELossConfig {
9  enum Reduction {
10  reduce_none,
11  reduce_sum,
12  reduce_mean
13  };
14  Reduction reduce = reduce_mean;
15  static MSELossConfig from_json(json::value const &v);
16  };
17 
18  class MSELoss : public Operator {
19  public:
20  MSELoss(Context &ctx,MSELossConfig const &cfg=MSELossConfig());
21  virtual ~MSELoss();
22  virtual char const *operator_type() const
23  {
24  return "MSELoss";
25  }
26 
27  virtual void setup(std::vector<TensorSpecs> const &in,
28  std::vector<TensorSpecs> &out,
29  std::vector<TensorSpecs> &parameters,
30  size_t &workspace);
31 
32  virtual void reshape(std::vector<Shape> const &in,
33  std::vector<Shape> &out,
34  size_t &ws);
35 
36  virtual void forward(std::vector<Tensor> &input,
37  std::vector<Tensor> &output,
38  std::vector<Tensor> &parameters,
39  Tensor &workspace,
40  ExecutionContext const &ctx);
41 
42  virtual void backward( std::vector<TensorAndGradient> &input,
43  std::vector<TensorAndGradient> &output,
44  std::vector<TensorAndGradient> &,
45  Tensor &,
46  ExecutionContext const &e);
47 
48  private:
49  void setup_gpu(std::vector<TensorSpecs> in,std::vector<TensorSpecs> out,size_t &workspace);
50  void forward_cpu(Tensor &a,Tensor &b,Tensor &y);
51  void backward_cpu(Tensor &dy,Tensor &a,Tensor &b,Tensor &dx,float scale,float accum);
52 
53  MSELossConfig cfg_;
54  DataType dtype_;
55  std::unique_ptr<core::PointwiseOperationBroadcastReduce> fwd_;
56  };
57 
58 }// dlprim
59 
Definition: mse_loss.hpp:18
Definition: mse_loss.hpp:8
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
This class is central representation of json objects.
Definition: json.hpp:652
virtual char const * operator_type() const
name of the operator type
Definition: mse_loss.hpp:22
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121