mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
[CK_TILE] Grouped GEMM tile loop (#2146)
* Add trait to use a persistent kernel and split the entrypoints in grouped gemm
* Some helper functions for persistent kernel case
* Get max occupancy grid using device properties
* Implement tile loop in main entry point to grouped gemm
* Enable GridSize() on device
* Handle offset tile index using real current block index
* Add persistent kernel choice to grouped gemm example
* Use a for-loop for iterating over the group
* Reduce VGPR spills by early-exit
* Enable persistent kernel choice in grouped_gemm example
* Add persistent kernel option to grouped_gemm test
* Fix formatting with remod.py
* Remove GridUpdateBlocks as blocks are now iteratively computed
* Add comment about VGPR spilling
* Fix formatting
* Use CK_TILE_HOST instead of __host__
* Enable all Row/Col combinations in grouped gemm unit test
* Add some KBatch=2 cases to grouped gemm tests
* Fix SplitK for grouped gemm
* Enable pipeline hotloop/tailnumber selection in-kernel for grouped gemm
* Add type traits
* Split examples to regular and tileloop
* Formatting
* Use hipExtStreamGetCUMask to get current active CUs for the given stream
* Align test and example kernel config, and disable validation for splitk repeats
* Remove debug options from CMakeLists.txt
* Separate the code paths for persistent/non-persistent in test
* Fix formatting
* Address review comments
---------
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: d1e6f0982d]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
@@ -8,19 +8,27 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
|
||||
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
|
||||
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
|
||||
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, False>,
|
||||
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 1;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
@@ -14,12 +15,37 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(256 + 64 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, group_count);
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemm, SplitK)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 2;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
// Get the persistent value from ck_tile::bool_constant
|
||||
using PersistentType = std::tuple_element_t<7, Tuple>;
|
||||
static constexpr bool Persistent = PersistentType::value;
|
||||
|
||||
struct GroupedGemKernelParam
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
@@ -31,9 +35,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
static const bool kPadK = false;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 32;
|
||||
static const ck_tile::index_t M_Tile = 256;
|
||||
static const ck_tile::index_t N_Tile = 256;
|
||||
static const ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 2;
|
||||
static const ck_tile::index_t N_Warp = 2;
|
||||
@@ -41,7 +45,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static const ck_tile::index_t K_Warp_Tile = 8;
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
@@ -53,7 +57,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr bool TransposeC = false;
|
||||
@@ -138,11 +142,12 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
|
||||
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
@@ -163,7 +168,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
};
|
||||
@@ -171,6 +176,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
std::cout << "Run without SplitK" << std::endl;
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -178,6 +184,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Run using SplitK" << std::endl;
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -213,6 +220,135 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
if(splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
@@ -220,6 +356,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
std::vector<int>& stride_As,
|
||||
std::vector<int>& stride_Bs,
|
||||
std::vector<int>& stride_Cs,
|
||||
const int kbatch = 1,
|
||||
const int group_count = 16)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
@@ -294,10 +431,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
std::cout << "gemm[" << i << "]"
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << " KBatch: " << kbatch << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
@@ -315,18 +452,51 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
// TODO add support for kbatch > 1
|
||||
static constexpr ck_tile::index_t k_batch = 1;
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
// Generate kernel arguments
|
||||
std::vector<ck_tile::GemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = gemm_descs[0].k_batch > 1;
|
||||
for(const auto& arg : gemm_descs)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.c_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_C,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, false, 1};
|
||||
ck_tile::hip_check_error(
|
||||
hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr, splitk);
|
||||
}
|
||||
else
|
||||
{
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// Copy results back to host for validation
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
@@ -340,7 +510,14 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user