2 #include <dlprim/operator.hpp> 5 namespace json {
class value; }
17 bool setup_kernel_params(
int sm_range);
20 int items_per_wi_ = 0;
35 virtual void setup(std::vector<TensorSpecs>
const &in,
36 std::vector<TensorSpecs> &out,
37 std::vector<TensorSpecs> ¶meters,
40 virtual void reshape(std::vector<Shape>
const &in,
41 std::vector<Shape> &out,
44 virtual void forward(std::vector<Tensor> &input,
45 std::vector<Tensor> &output,
46 std::vector<Tensor> ¶meters,
50 virtual void backward( std::vector<TensorAndGradient> &input,
51 std::vector<TensorAndGradient> &output,
52 std::vector<TensorAndGradient> &,
71 return "SoftmaxWithLoss";
74 virtual void setup(std::vector<TensorSpecs>
const &in,
75 std::vector<TensorSpecs> &out,
76 std::vector<TensorSpecs> ¶meters,
79 virtual void reshape(std::vector<Shape>
const &in,
80 std::vector<Shape> &out,
83 virtual void forward(std::vector<Tensor> &input,
84 std::vector<Tensor> &output,
85 std::vector<Tensor> ¶meters,
89 virtual void backward( std::vector<TensorAndGradient> &input,
90 std::vector<TensorAndGradient> &output,
91 std::vector<TensorAndGradient> &,
96 template<
typename IndexType>
98 template<
typename IndexType>
102 void setup_kernel(
int sm_range);
105 cl::Kernel kernel_,kernel_bwd_;
106 std::unique_ptr<Scal> scal_;
Definition: softmax.hpp:64
virtual char const * operator_type() const
name of the operator type
Definition: softmax.hpp:69
Definition: softmax.hpp:25
Base class for backward/forward propogation calculations for internal network.
Definition: operator.hpp:15
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
Definition: softmax.hpp:8
DataType
type definition
Definition: definitions.hpp:70
This class is central representation of json objects.
Definition: json.hpp:652
Mane namespace.
Definition: context.hpp:9
Central Data Contrainer - Tensor.
Definition: tensor.hpp:99
virtual char const * operator_type() const
name of the operator type
Definition: softmax.hpp:30
Definition: softmax.hpp:13
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121