Added Support for tile_grouped_gemm_preshuffle example (#2993)

* Added Support for tile_grouped_gemm_preshuffle example

* Resolved PR comments + Added unit tests for preshuffle with persistent

* Fixed CMake Build config error

* Fix clang error that caused CI to fail

* Fix clang formatting

* Fix clang issue

* Fix errors causing test cases to fail

* Fix grouped_gemm_preshuffle unit test failure

* Resolve PR comments

* Cleaned code + removed unnecassary changes

* Update test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp

Co-authored-by: Aviral Goel <aviral.goel@amd.com>

* Fix clang formatting

* Made changes to improve code readability

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Aviral Goel <aviral.goel@amd.com>
This commit is contained in:
mkumar16-amd
2025-10-28 00:01:19 +05:30
committed by GitHub
parent 6c2ca1211a
commit a46b725992
5 changed files with 278 additions and 33 deletions

View File

@@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
};

View File

@@ -167,6 +167,113 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
return ave_time;
}
#include "run_grouped_gemm_example.inc"
template <typename GemmConfig, typename PrecType>

View File

@@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup,
}
else
{
if(GemmConfig::Preshuffle)
{
// not supported yet
throw std::runtime_error(
"Persistent grouped gemm with preshuffle is not supported yet");
}
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
// commentCode has comments. Press enter to view. the gemm problems known on the host.
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm
// problems known on the host. Instead, we can just pass the pointer to the kernel and let
// the workgroups figure out which tiles to work on. This is useful when the gemm problems
// are generated dynamically. In this example however, we generate the `kargs` using the
// known gemm_descs, and copy the gemm descriptions to the device memory. The contents of
// the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from
// earlier stage.
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();

View File

@@ -8,11 +8,13 @@
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_preshuffle_util.hpp"
using F16 = ck_tile::half_t;
using F8 = ck_tile::fp8_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using F16 = ck_tile::half_t;
using F8 = ck_tile::fp8_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using False = std::false_type;
using True = std::true_type;
// Custom tuple-like structure for kernel configuration
template <typename ALayout_,
@@ -22,6 +24,7 @@ template <typename ALayout_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename Persistent_,
int M_Tile_val_,
int N_Tile_val_,
int K_Tile_val_,
@@ -35,6 +38,7 @@ struct KernelConfig
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CDataType = CDataType_;
using Persistent = Persistent_;
static constexpr int M_Tile_ = M_Tile_val_;
static constexpr int N_Tile_ = N_Tile_val_;
@@ -44,11 +48,16 @@ struct KernelConfig
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>
>;
// clang-format on

View File

@@ -39,9 +39,13 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
using BDataType = typename Tuple::BDataType;
using AccDataType = typename Tuple::AccDataType;
using CDataType = typename Tuple::CDataType;
using PrecType = BDataType;
using DsLayout = ck_tile::tuple<>; // not used
using DsDataType = ck_tile::tuple<>; // not used
using DsLayout = ck_tile::tuple<>; // not used
using DsDataType = ck_tile::tuple<>; // not used
// Get the persistent value from ck_tile::bool_constant
using PersistentType = typename Tuple::Persistent;
static constexpr bool Persistent = PersistentType::value;
static const bool kPadM = false;
static const bool kPadN = false;
@@ -231,6 +235,129 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
private:
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm_persistent(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
// Enable persistent mode for preshuffle
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits</*kPadM*/ true,
/*kPadN*/ true,
/*kPadK*/ true,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC,
/*UseStructuredSparsity*/ false,
/*Persistent*/ true, // Enable persistent mode
/*NumWaveGroups*/ 1,
/*Preshuffle*/ true>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Default,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
@@ -350,9 +477,20 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
invoke_grouped_gemm<ALayout, BLayout, CLayout>(gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
if constexpr(Persistent)
{
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
}
else
{
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
}
// Copy results back to host for validation
for(int i = 0; i < group_count; i++)