mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Basic docs for universal gemm & ck-tile gemm. (#2014)
* Basic docs for universal gemm & ck-tile gemm. * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Reviewers suggestions. * Align tparam names in doc with class tparams. * More reviewers fine tuning ;) --------- Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
@@ -12,6 +12,11 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The GEMM problem definition.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure defines the GEMM problem configuration by stating all required information
|
||||
/// like M,N,K sizes and respective strides.
|
||||
struct GemmProblem
|
||||
{
|
||||
CK_TILE_HOST GemmProblem() = default;
|
||||
@@ -29,6 +34,12 @@ struct GemmProblem
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
|
||||
/// object. It contain all necessary information required to build proper kernel argument
|
||||
/// and launch kernel on GPU.
|
||||
struct GemmHostArgs : public GemmProblem
|
||||
{
|
||||
CK_TILE_HOST GemmHostArgs() = default;
|
||||
@@ -56,20 +67,69 @@ struct GemmHostArgs : public GemmProblem
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel device arguments.
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
/// @brief The A input tensor's pointer to device memory.
|
||||
const void* a_ptr;
|
||||
/// @brief The B input tensor's pointer to device memory.
|
||||
const void* b_ptr;
|
||||
/// @brief The C output tensor's pointer to device memory.
|
||||
void* c_ptr;
|
||||
/// @brief GEMM's M dimension size.
|
||||
index_t M;
|
||||
/// @brief GEMM's N dimension size.
|
||||
index_t N;
|
||||
/// @brief GEMM's K dimension size.
|
||||
index_t K;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of A tensor.
|
||||
index_t stride_A;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of B tensor.
|
||||
index_t stride_B;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of C tensor.
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel template.
|
||||
///
|
||||
/// @paragraph Overview Overview
|
||||
/// This class provides the generic matrix multiplication kernel template. By semantic
|
||||
/// division of GEMM algorithm into following parts we achieve flexible, versatile
|
||||
/// and robust kernel implementation.
|
||||
///
|
||||
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
|
||||
/// function call operator" which determines the work scope of each workgroup.
|
||||
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
|
||||
/// This is the place where each workgroup is loading data from global memory and
|
||||
/// carrying out dot products.
|
||||
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
|
||||
/// responsible for storing results to global memory. This is also the place where
|
||||
/// any additional operator fusion may take place.
|
||||
///
|
||||
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
|
||||
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
|
||||
/// internal details of those functional parts. You can think of it like both gemm and
|
||||
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
|
||||
/// the policy is responsible for definition of all necessary data layouts and thread's
|
||||
/// work distribution.
|
||||
///
|
||||
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the
|
||||
/// output data tile to be calculated. It determines the workgroup to
|
||||
/// data relationship (or in other words - which data would be
|
||||
/// processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
/// multiplication. This class should provide implementation of data
|
||||
/// loading from global memory and performing block-wise matrix
|
||||
/// multiplication. You can think of it as a work done by single
|
||||
/// workgroup point of view.
|
||||
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernel
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user