Kernel Function Declaration
PaddlePaddle has released the kernel declaration through the header file, and the framework is uniform both inside and outside.
Custom kernel editing should be based on a specific kernel function declaration. The header file is under include/paddle/phi/kernels/.
The format of the declaration is as follows:
template <typename T, typename Context>
void KernelNameKernel(const Context& dev_ctx,
InputTensor(s),
Attribute(s),
OutTensor(s));
Agreement:
Template Parameter:It is fixed in format. The data type of the first parameter is
T,and that of the second isContext.Return:
voidis the pattern.Naming:Camel case: kernel name + "Kernel",such as
SoftmaxKernelParameter:Context parameter, InputTensor,Attribute,and OutTensor, all arranged in order:
Context Parameter:It belongs to
const Context&.CustomContextcorresponding with the custom kernel. You can refer to custom_context.h
InputTensor:Number >=0,and the types include:
const DenseTensor&Please refer to dense_tensor.hconst SelectedRows&Please refer to selected_rows.hconst SparseCooTensor&Please refer to sparse_coo_tensor.hconst SparseCsrTensor&Please refer to sparse_csr_tensor.hconst std::vector<DenseTensor*>&const std::vector<SparseCooTensor*>&const std::vector<SparseCsrTensor*>&
Attribute:Number >=0,and the types include:
boolfloatdoubleintint64_tphi::dtype::float16Please refer to float16.hconst Scalar&Please refer to scalar.hDataTypePlease refer to data_type.hDataLayoutPlease refer to layout.hPlacePlease refer to place.hconst std::vector<int64_t>&const ScalarArray&Please refer to scalar_array.hconst std::vector<int>&const std::string&const std::vector<bool>&const std::vector<float>&const std::vector<double>&const std::vector<std::string>&
OutTensor:Number >0,and the types include:
DenseTensor*SelectedRows*SparseCooTensor*SparseCsrTensor*std::vector<DenseTensor*>std::vector<SparseCooTensor*>std::vector<SparseCsrTensor*>
For example,when the kernel function of softmax is in softmax_kernel.h:
// Softmax
// Template Parameter: T - data type
// Context - the device context
// Parameter: dev_ctx - object of the Context
// x - DenseTensor object
// axis - int type
// dtype - DataType type
// out - DenseTensor pointer
// Return: None
template <typename T, typename Context>
void SoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DataType dtype,
DenseTensor* out);
Note:
The kernel function declaration is the basis of the registration and the framework invocation of the custom kernel. It is released by the framework and required to be observed.
The kernel function declaration cannot perfectly match the header file. You can find the declaration you need by searching the name of the function.
