Persistent grouped gemm CompV4 Enablement & Polish (#2605)

* enable the persistent kernel for CompV4

* polish the example and clang format

* fix the non-persistent kernel error

---------

Co-authored-by: ThomasNing <thomasning@amd.com>
This commit is contained in:
Thomas Ning
2025-08-04 23:43:01 -07:00
committed by GitHub
parent 2a78da4708
commit cbfecf8d7a
6 changed files with 148 additions and 289 deletions

View File

@@ -252,13 +252,6 @@ struct GroupedGemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
Run(kargs.group_karg, block_idx_2d, block_idx_z);
}
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
@@ -277,24 +270,56 @@ struct GroupedGemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(UsePersistentKernel)
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
Base::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
else
{
Base::RunGemm({a_ptr},
{b_ptr},
{/*ds_ptr*/},
c_ptr,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
i_n);
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
Base::RunGemm({a_ptr},
{b_ptr},
{/*ds_ptr*/},
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
}
@@ -358,6 +383,69 @@ struct GroupedGemmKernel
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
* and the tail-number. This is needed for the persistent tile-loop when
* we didn't have access to the K dimension on the host.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr_1 The second start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const UniversalGemmKernelArgs<>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
const auto& d_block_window = gemm_tile_windows.at(Base::I2);
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr,
index_t block_id,
index_t group_count) const
@@ -401,7 +489,7 @@ struct GroupedGemmKernel
kargs.group_karg.M,
kargs.group_karg.N,
(block_id - kargs.block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
}
// For persistent kernels