[CK_TILE] Multiple-ABD GEMM example (#2788)

* Multi ABD - initial commit

* Clang-foramt fix

* block gemm, unify the name of CDataType

* Apply chnages to mem-pipeline

* Rollback prefix for DType and Layout

* Gemm Kernel Basic, rename

* WMMA config

* Grouped GEMM

* Clang-format

* Dropout, name

* Review v2

* Move element_wise fn to unnary, remov old ones fn

* clang-format

* Fix issue review

* WP operator adjust to universal gemm

* v2 prepare

* Remove unused comment

* Remove vectorsize

* Rollback

* Adjust pipeline for abd

* Shuffle argument

* CI-fail fix quant

* Fix ag_br pipeline

* Failing tests

* Typo

* Single argument support
This commit is contained in:
Mateusz Ozga
2025-09-19 01:14:11 +02:00
committed by GitHub
parent 14bbc545ea
commit 30ab1d6a71
41 changed files with 3603 additions and 552 deletions

View File

@@ -90,10 +90,10 @@ struct BatchedGemmKernel
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
"C/CLayout and C/EDataType must be scalars.");
struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{

View File

@@ -89,7 +89,7 @@ struct GemmKernel
/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, E and D
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
@@ -106,10 +106,10 @@ struct GemmKernel
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ELayout>::value &&
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
"C/CLayout and C/EDataType must be scalars.");
static constexpr index_t NumATensor = 1;
static constexpr index_t NumBTensor = 1;

View File

@@ -0,0 +1,193 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/// @brief The MultiABD GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernelMultiABD "GemmKernelMultiABD" when creating
/// kernel arguments object. It contain all necessary information required to build proper
/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem
/// configuration by stating all required information like M,N,K sizes and respective strides.
/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required).
/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required).
/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not
/// required).
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct GemmMultiABDHostArgs
{
CK_TILE_HOST GemmMultiABDHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: as_ptr(as_ptr_),
bs_ptr(bs_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_As(stride_As_),
stride_Bs(stride_Bs_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const std::array<const void*, NumATensor> as_ptr;
const std::array<const void*, NumBTensor> bs_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
const std::array<index_t, NumATensor> stride_As;
const std::array<index_t, NumBTensor> stride_Bs;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernelMultiABD
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
/// @brief Specify the layout configurations for A, B, E and D
using AsLayout = remove_cvref_t<typename GemmPipeline::AsLayout>;
using BsLayout = remove_cvref_t<typename GemmPipeline::BsLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
/// @brief Specify the data type configurations for A, B, E and D
using AsDataType = remove_cvref_t<typename GemmPipeline::AsDataType>;
using BsDataType = remove_cvref_t<typename GemmPipeline::BsDataType>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
/// @brief ALayout and ADataType are expected to be a tuple, not a scalar.
static_assert(is_detected<is_tuple, AsLayout>::value &&
is_detected<is_tuple, AsDataType>::value,
"ALayout and ADataType must be a tuple.");
/// @brief BLayout and BDataType are expected to be a tuple, not a scalar.
static_assert(is_detected<is_tuple, BsLayout>::value &&
is_detected<is_tuple, BsDataType>::value,
"BLayout and BDataType must be a tuple.");
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"CLayout and EDataType must be a scalar.");
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
static_assert(is_detected<is_tuple, DsLayout>::value &&
is_detected<is_tuple, DsDataType>::value &&
DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
"DsLayout and DsDataType must be tuples and must have the same size.");
/// @brief The sizes of NumATensor, NumBTensor and NumDTensor is set by the user."
static constexpr index_t NumATensor = AsDataType::size();
static constexpr index_t NumBTensor = BsDataType::size();
static constexpr index_t NumDTensor = DsDataType::size();
CK_TILE_HOST static auto GetName() -> const std::string
{
return UniversalGemmKernel::GetName();
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
{
return UniversalGemmKernel::GridSize(M, N, KBatch);
}
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
return UniversalGemmKernel::MaxOccupancyGridSize(s);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return UniversalGemmKernel::BlockSize();
}
CK_TILE_HOST static constexpr auto
MakeKernelArgs(const GemmMultiABDHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs) ->
typename UniversalGemmKernel::KernelArgs
{
/// @brief Universal GEMM requires array objects and corresponding stride information for
/// matrices A, B, and D.
return UniversalGemmKernel::MakeKernelArgs(
UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>(hostArgs.as_ptr,
hostArgs.bs_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.k_batch,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_As,
hostArgs.stride_Bs,
hostArgs.stride_Ds,
hostArgs.stride_E));
}
CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
{
// Currently MultiABD kernel doesn't support k_batch > 1
if(kargs.k_batch > 1)
{
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
{
UniversalGemmKernel{}.template operator()(kargs);
}
};
} // namespace ck_tile

View File

@@ -95,7 +95,7 @@ struct GemmKernelMultiD
/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
/// @brief Specify the data type configurations for A, B, E and D
@@ -114,10 +114,10 @@ struct GemmKernelMultiD
!is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars.");
/// @brief ELayout and EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ELayout>::value &&
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"ELayout and EDataType must be scalars.");
"CLayout and EDataType must be scalars.");
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
static_assert(is_detected<is_tuple, DsLayout>::value &&

View File

@@ -120,10 +120,10 @@ struct GroupedGemmKernel
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
"C/CLayout and C/EDataType must be scalars.");
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
@@ -364,12 +364,8 @@ struct GroupedGemmKernel
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template

View File

@@ -157,23 +157,23 @@ struct UniversalGemmKernel
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr bool ADataTypeIsTuple =
is_detected<is_tuple, typename GemmPipeline::ADataType>::value;
is_detected<is_tuple, typename GemmPipeline::AsDataType>::value;
static constexpr bool BDataTypeIsTuple =
is_detected<is_tuple, typename GemmPipeline::BDataType>::value;
is_detected<is_tuple, typename GemmPipeline::BsDataType>::value;
static constexpr bool DDataTypeIsTuple =
is_detected<is_tuple, typename EpiloguePipeline::DsDataType>::value;
static constexpr bool ALayoutIsTuple =
is_detected<is_tuple, typename GemmPipeline::ALayout>::value;
is_detected<is_tuple, typename GemmPipeline::AsLayout>::value;
static constexpr bool BLayoutIsTuple =
is_detected<is_tuple, typename GemmPipeline::BLayout>::value;
is_detected<is_tuple, typename GemmPipeline::BsLayout>::value;
static constexpr bool DLayoutIsTuple =
is_detected<is_tuple, typename EpiloguePipeline::DsLayout>::value;
using AsLayout = std::conditional_t<ALayoutIsTuple,
remove_cvref_t<typename GemmPipeline::ALayout>,
remove_cvref_t<typename GemmPipeline::AsLayout>,
remove_cvref_t<tuple<typename GemmPipeline::ALayout>>>;
using BsLayout = std::conditional_t<BLayoutIsTuple,
remove_cvref_t<typename GemmPipeline::BLayout>,
remove_cvref_t<typename GemmPipeline::BsLayout>,
remove_cvref_t<tuple<typename GemmPipeline::BLayout>>>;
using DsLayout = std::conditional_t<DLayoutIsTuple,
@@ -181,11 +181,11 @@ struct UniversalGemmKernel
remove_cvref_t<tuple<typename EpiloguePipeline::DsLayout>>>;
using AsDataType = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<typename GemmPipeline::ADataType>,
remove_cvref_t<typename GemmPipeline::AsDataType>,
remove_cvref_t<tuple<typename GemmPipeline::ADataType>>>;
using BsDataType = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<typename GemmPipeline::BDataType>,
remove_cvref_t<typename GemmPipeline::BsDataType>,
remove_cvref_t<tuple<typename GemmPipeline::BDataType>>>;
using DsDataType =
@@ -193,9 +193,12 @@ struct UniversalGemmKernel
remove_cvref_t<typename EpiloguePipeline::DsDataType>,
remove_cvref_t<tuple<typename EpiloguePipeline::DsDataType>>>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
@@ -483,7 +486,7 @@ struct UniversalGemmKernel
bool DTesnorIsValid = {true};
static_for<0, NumDTensor, 1>{}([&](auto index) {
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if(std::is_same_v<DiLayout, ELayout> == false)
if(std::is_same_v<DiLayout, CLayout> == false)
{
DTesnorIsValid = false;
}
@@ -529,7 +532,7 @@ struct UniversalGemmKernel
}
});
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
@@ -724,7 +727,7 @@ struct UniversalGemmKernel
// TODO: enable vector write for C in ColMajor
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
@@ -818,7 +821,7 @@ struct UniversalGemmKernel
// TODO vector write in for C in ColMajor
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
@@ -975,8 +978,8 @@ struct UniversalGemmKernel
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile =
GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
if(UseDefaultScheduler || (get_warp_id() == 0))
{
@@ -1031,8 +1034,13 @@ struct UniversalGemmKernel
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}(
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
AElementWise{},
bs_block_window,
BElementWise{},
num_loop,
smem_ptr_0,
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);