2 #include <dlprim/operator.hpp> 5 namespace json {
class value; }
6 namespace core {
class PointwiseOperationBroadcastReduce; }
14 Reduction reduce = reduce_mean;
27 virtual void setup(std::vector<TensorSpecs>
const &in,
28 std::vector<TensorSpecs> &out,
29 std::vector<TensorSpecs> ¶meters,
32 virtual void reshape(std::vector<Shape>
const &in,
33 std::vector<Shape> &out,
36 virtual void forward(std::vector<Tensor> &input,
37 std::vector<Tensor> &output,
38 std::vector<Tensor> ¶meters,
42 virtual void backward( std::vector<TensorAndGradient> &input,
43 std::vector<TensorAndGradient> &output,
44 std::vector<TensorAndGradient> &,
49 void setup_gpu(std::vector<TensorSpecs> in,std::vector<TensorSpecs> out,
size_t &workspace);
55 std::unique_ptr<core::PointwiseOperationBroadcastReduce> fwd_;
Definition: mse_loss.hpp:18
Definition: mse_loss.hpp:8
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
This class is central representation of json objects.
Definition: json.hpp:652
virtual char const * operator_type() const
name of the operator type
Definition: mse_loss.hpp:22
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121