[CK_TILE] Grouped GEMM tile loop (#2146)

* Add trait to use a persistent kernel and split the entrypoints in grouped gemm

* Some helper functions for persistent kernel case

* Get max occupancy grid using device properties

* Implement tile loop in main entry point to grouped gemm

* Enable GridSize() on device

* Handle offset tile index using real current block index

* Add persistent kernel choice to grouped gemm example

* Use a for-loop for iterating over the group

* Reduce VGPR spills by early-exit

* Enable persistent kernel choice in grouped_gemm example

* Add persistent kernel option to grouped_gemm test

* Fix formatting with remod.py

* Remove GridUpdateBlocks as blocks are now iteratively computed

* Add comment about VGPR spilling

* Fix formatting

* Use CK_TILE_HOST instead of __host__

* Enable all Row/Col combinations in grouped gemm unit test

* Add some KBatch=2 cases to grouped gemm tests

* Fix SplitK for grouped gemm

* Enable pipeline hotloop/tailnumber selection in-kernel for grouped gemm

* Add type traits

* Split examples to regular and tileloop

* Formatting

* Use hipExtStreamGetCUMask to get current active CUs for the given stream

* Align test and example kernel config, and disable validation for splitk repeats

* Remove debug options from CMakeLists.txt

* Separate the code paths for persistent/non-persistent in test

* Fix formatting

* Address review comments

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

[ROCm/composable_kernel commit: d1e6f0982d]
This commit is contained in:
Sami Remes
2025-05-20 17:18:57 +03:00
committed by GitHub
parent bb8bb3d5c1
commit 038417bd3d
15 changed files with 908 additions and 146 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
@@ -8,19 +8,27 @@
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent
std::tuple< Row, Col, Row, F16, F16, F32, F16, True>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, False>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, True>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, False>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, True>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, False>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, True>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, False>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, True>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>
>;
// clang-format on

View File

@@ -3,6 +3,7 @@
TYPED_TEST(TestCkTileGroupedGemm, Basic)
{
const int group_count = 8;
const int kbatch = 1;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
@@ -14,12 +15,37 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(256 + 64 * i);
Ks.push_back(512 + 128 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, group_count);
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
}
TYPED_TEST(TestCkTileGroupedGemm, SplitK)
{
const int group_count = 8;
const int kbatch = 2;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
}

View File

@@ -24,6 +24,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
// Get the persistent value from ck_tile::bool_constant
using PersistentType = std::tuple_element_t<7, Tuple>;
static constexpr bool Persistent = PersistentType::value;
struct GroupedGemKernelParam
{
static const bool kPadM = false;
@@ -31,9 +35,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const bool kPadK = false;
static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 32;
static const ck_tile::index_t M_Tile = 256;
static const ck_tile::index_t N_Tile = 256;
static const ck_tile::index_t K_Tile = 64;
static const ck_tile::index_t M_Warp = 2;
static const ck_tile::index_t N_Warp = 2;
@@ -41,7 +45,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const ck_tile::index_t M_Warp_Tile = 32;
static const ck_tile::index_t N_Warp_Tile = 32;
static const ck_tile::index_t K_Warp_Tile = 8;
static const ck_tile::index_t K_Warp_Tile = 16;
};
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
@@ -53,7 +57,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_)
void* kargs_ptr)
{
constexpr bool DoubleSmemBuffer = false;
constexpr bool TransposeC = false;
@@ -138,11 +142,12 @@ class TestCkTileGroupedGemm : public ::testing::Test
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
constexpr dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
@@ -163,7 +168,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
@@ -171,6 +176,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
std::cout << "Run without SplitK" << std::endl;
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
@@ -178,6 +184,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
}
else
{
std::cout << "Run using SplitK" << std::endl;
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
@@ -213,6 +220,135 @@ class TestCkTileGroupedGemm : public ::testing::Test
}
}
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
{
constexpr bool TransposeC = false;
constexpr bool DoubleSmemBuffer = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
GroupedGemKernelParam::N_Tile,
GroupedGemKernelParam::K_Tile>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::K_Warp>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = memory_operation_.value;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
if(splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
@@ -220,6 +356,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
std::vector<int>& stride_As,
std::vector<int>& stride_Bs,
std::vector<int>& stride_Cs,
const int kbatch = 1,
const int group_count = 16)
{
using namespace ck_tile::literals;
@@ -294,10 +431,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
<< " c_m_n: " << c_m_n_tensors[i].mDesc << " KBatch: " << kbatch << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
@@ -315,18 +452,51 @@ class TestCkTileGroupedGemm : public ::testing::Test
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
// TODO add support for kbatch > 1
static constexpr ck_tile::index_t k_batch = 1;
gemm_descs.push_back(
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());
if constexpr(Persistent)
{
// Generate kernel arguments
std::vector<ck_tile::GemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = gemm_descs[0].k_batch > 1;
for(const auto& arg : gemm_descs)
{
kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.c_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
arg.stride_C,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, false, 1};
ck_tile::hip_check_error(
hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr, splitk);
}
else
{
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
}
// Copy results back to host for validation
for(int i = 0; i < group_count; i++)
{
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
@@ -340,7 +510,14 @@ class TestCkTileGroupedGemm : public ::testing::Test
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
}
EXPECT_TRUE(pass);
}