DLPrimitives
gemm.hpp
1 #pragma once
2 #include <dlprim/context.hpp>
3 #include <dlprim/definitions.hpp>
4 namespace dlprim {
6 namespace gpu {
7 
8  class GEMM {
9  public:
10 
11  virtual ~GEMM()
12  {
13  }
14 
15  static constexpr int no_bias = 0;
16  static constexpr int bias_M = 1;
17  static constexpr int bias_N = 2;
18 
19  virtual void gemm(int M,int N,int K,
20  cl::Buffer &a,
21  cl_ulong offset_a,
22  int lda,
23  cl::Buffer &b,
24  cl_ulong offset_b,
25  int ldb,
26  cl::Buffer &c,
27  cl_ulong offset_c,
28  int ldc,
29  cl::Buffer *bias,
30  cl_ulong bias_offset,
31  float beta,
32  int size_of_c,
33  ExecutionContext const &e) = 0;
34 
35  static std::unique_ptr<GEMM> get_optimal_gemm(
36  Context &ctx,DataType dtype,
37  bool trans_a,bool trans_b,
38  int M,int N,int K,
39  int bias = 0,
40  StandardActivations act = StandardActivations::identity,
41  int im2col_chan = 0);
42 
43  static std::unique_ptr<GEMM> get_optimal_conv_gemm(
44  Context &ctx,DataType dtype,
45  GemmOpMode op_mode,
46  bool trans_a,bool trans_b,
47  int M,int N,int K,
48  int kernel[2],int dilate[2],int padding[2],int stride[2],int groups,
49  int src_channels,int src_rows,int src_cols,
50  int tgt_rows,int tgt_cols,
51  int bias = 0,
52  StandardActivations act = StandardActivations::identity,
53  int im2col_chan = 0);
54 
55  };
56 
57 }
58 }
Definition: gemm.hpp:8
This is main object that represent the pair of OpenCL platform and device all other objects use it...
Definition: context.hpp:302
DataType
type definition
Definition: definitions.hpp:70
Mane namespace.
Definition: context.hpp:9
GemmOpMode
internal GEMM mode
Definition: definitions.hpp:292
StandardActivations
Parameterless Activations that can be embedded to general kernels like inner product or convolution...
Definition: definitions.hpp:266
This class is used to pass cl::Events that the kernel should wait for and/or signal event completion...
Definition: context.hpp:121