mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Basic docs for universal gemm & ck-tile gemm. (#2014)
* Basic docs for universal gemm & ck-tile gemm. * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Reviewers suggestions. * Align tparam names in doc with class tparams. * More reviewers fine tuning ;) --------- Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -21,6 +21,105 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/// @brief \"Universal\" GEMM operation with SplitK support.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM operation implements the following mathematical equation:
|
||||
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
|
||||
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
|
||||
/// elementwise operations applied to the A, B, and C tensors, respectively.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through it's design
|
||||
/// and versatilty.
|
||||
///
|
||||
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
|
||||
/// to split the dot product accumulated over the K dimension into multiple working groups.
|
||||
/// The partial products of different workgroups are then reduced using the AtomicAdd
|
||||
/// operation.
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam CLayout C tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam CDataType C tensor data type.
|
||||
/// @tparam GemmAccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
|
||||
/// (after GEMM).
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
/// @tparam NPerBlock The input/output data tile size in the N dimension.
|
||||
/// @tparam KPerBlock The input data tile size in the K dimension.
|
||||
/// @tparam AK1 The vector load size from global memory for A tensor.
|
||||
/// @tparam BK1 The vector load size from global memory for B tensor.
|
||||
/// @tparam MPerXDL M size of matrix-fused-multiply-add instruction.
|
||||
/// @tparam NPerXDL N size of matrix-fused-multiply-add instruction.
|
||||
/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront.
|
||||
/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront.
|
||||
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question, "How many threads can be
|
||||
/// arranged on each input data axis?"
|
||||
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question: "How many threads to
|
||||
/// arrange on each input data axis?"
|
||||
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in M dimension.
|
||||
/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
|
||||
/// Used when storing data to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled).
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -130,9 +229,22 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
// Invoker
|
||||
/// @brief Helper structure responsible for kernel invocation.
|
||||
///
|
||||
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
|
||||
/// kernel function. It usually determines the launched grid size prepares kernel
|
||||
/// arguments as well as perform specific kernel configuration selection based on
|
||||
/// runtime arguments.
|
||||
///
|
||||
/// @note If appropriately configured it may measure kernel execution time.
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
|
||||
Reference in New Issue
Block a user