Extend Grouped GEMM with MultiD (Single & Double Shared Memory) feature to use persistent kernel option (#2933)

* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature

* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel

* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments

* fix: segfault fix by passing correct parameters for d tensors

* style: clang format

* WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults

* feat(grouped_gemm_multi_d): add functionality to run persistant kernel

* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature

* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel

* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments

* fix: segfault fix by passing correct parameters for d tensors

* style: clang format

* fix: incorrect validation method and Dtensor layout in test suite

* docs: improved README text based on review comments

* fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint

[ROCm/composable_kernel commit: bebf0e9d15]
This commit is contained in:
Aviral Goel
2025-09-29 18:03:56 -04:00
committed by GitHub
parent 7a9bff148f
commit 7775768c88
5 changed files with 163 additions and 19 deletions

View File

@@ -10,16 +10,15 @@ The grouped GEMM examples include two advanced optimization features:
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
- **Benefits**: Improved memory efficiency and reduced data movement
#### Persistence Mode
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
#### Multi-D Operations
Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
@@ -31,7 +30,8 @@ Multi-D operations extend the standard GEMM operation by supporting additional e
- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
Multi-D operations supports both persistence and non-persistence modes.
Weight preshuffle supports only on non-persistence mode.
## Build
```
@@ -48,7 +48,7 @@ make tile_example_grouped_gemm_multi_d -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
## example

View File

@@ -166,6 +166,112 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
return ave_time;
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise>
float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(!splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
return ave_time;
}
#include "run_grouped_gemm_multi_d_example.inc"
int main(int argc, char* argv[])

View File

@@ -95,6 +95,7 @@ struct GemmConfigV3 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
@@ -170,7 +171,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<2>;
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<DsDataType::size()>;
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
{
@@ -201,7 +202,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<2>);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>);
}
template <typename GemmConfig,

View File

@@ -86,9 +86,43 @@ float invoke_gemm(int n_warmup,
}
else
{
(void)group_count;
// not supported yet
throw std::runtime_error("Persistent grouped gemm multiple-d is not supported yet");
std::vector<ck_tile::GemmTransKernelArg<DsDataType::size()>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = args[0].k_batch > 1;
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr},
{arg.b_ptr},
arg.ds_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
{arg.stride_A},
{arg.stride_B},
arg.stride_Ds,
arg.stride_E,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(
kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time =
grouped_gemm_multi_d_tileloop<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDEElementWise>(stream, group_count, kargs_ptr, splitk);
}
return ave_time;
}
@@ -322,12 +356,6 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
b_k_n_tensors[i],
{d0_m_n_tensors[i], d1_m_n_tensors[i]},
e_m_n_host_refs[i]);
std::cout << "e_m_n_host_refs[i]: " << std::endl;
e_m_n_host_refs[i].print_first_n(std::cout, 10);
std::cout << std::endl;
std::cout << "e_m_n_tensors[i]: " << std::endl;
e_m_n_tensors[i].print_first_n(std::cout, 10);
std::cout << std::endl;
const float max_accumulated_value =
*std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end());

View File

@@ -324,10 +324,18 @@ struct GroupedGemmKernel
}
else // SingleSmemBuffer
{
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else // Non-persistent kernel
{
@@ -365,6 +373,7 @@ struct GroupedGemmKernel
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection(const ADataType* a_ptr,
const BDataType* b_ptr,
const std::array<const void*, NumDTensor_>& ds_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
@@ -375,7 +384,7 @@ struct GroupedGemmKernel
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =