Added Support for tile_grouped_gemm_preshuffle example (#2993)

* Added Support for tile_grouped_gemm_preshuffle example

* Resolved PR comments + Added unit tests for preshuffle with persistent

* Fixed CMake Build config error

* Fix clang error that caused CI to fail

* Fix clang formatting

* Fix clang issue

* Fix errors causing test cases to fail

* Fix grouped_gemm_preshuffle unit test failure

* Resolve PR comments

* Cleaned code + removed unnecassary changes

* Update test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp

Co-authored-by: Aviral Goel <aviral.goel@amd.com>

* Fix clang formatting

* Made changes to improve code readability

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Aviral Goel <aviral.goel@amd.com>

[ROCm/composable_kernel commit: a46b725992]
This commit is contained in:
mkumar16-amd
2025-10-28 00:01:19 +05:30
committed by GitHub
parent e1e96b89fa
commit 5fa81082e4
5 changed files with 278 additions and 33 deletions

View File

@@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
};

View File

@@ -167,6 +167,113 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
return ave_time;
}
#include "run_grouped_gemm_example.inc"
template <typename GemmConfig, typename PrecType>

View File

@@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup,
}
else
{
if(GemmConfig::Preshuffle)
{
// not supported yet
throw std::runtime_error(
"Persistent grouped gemm with preshuffle is not supported yet");
}
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
// commentCode has comments. Press enter to view. the gemm problems known on the host.
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm
// problems known on the host. Instead, we can just pass the pointer to the kernel and let
// the workgroups figure out which tiles to work on. This is useful when the gemm problems
// are generated dynamically. In this example however, we generate the `kargs` using the
// known gemm_descs, and copy the gemm descriptions to the device memory. The contents of
// the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from
// earlier stage.
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();