Basic docs for universal gemm & ck-tile gemm. (#2014)

* Basic docs for universal gemm & ck-tile gemm.

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Reviewers suggestions.

* Align tparam names in doc with class tparams.

* More reviewers fine tuning ;)

---------

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
Adam Osewski
2025-04-02 11:03:40 +02:00
committed by GitHub
parent 8c0ab61ece
commit e5ad48a784
3 changed files with 277 additions and 2 deletions

View File

@@ -82,6 +82,109 @@ __global__ void
#endif // end of if (defined(__gfx9__))
}
/// @brief \"Universal\" GEMM kernel with SplitK support.
///
/// @par Overview
/// This GEMM kernel is carrying out following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations that could be applied on each tensor respectively.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam CDataType C tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1Value The vector load size from global memory for A tensor.
/// @tparam BK1Value The vector load size from global memory for B tensor.
/// @tparam MPerXdl M size of matrix-fused-multiply-add instruction.
/// @tparam NPerXdl N size of matrix-fused-multiply-add instruction.
/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate
/// (return back to the window origin) after all thread finish data copy.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate
/// (return back to the window origin) after all thread finish data copy.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,