[CK TILE] Fix grouped conv kernels splitk and double lds (#3527)

This commit is contained in:
Bartłomiej Kocot
2026-01-08 07:59:38 +01:00
committed by GitHub
parent f449a5faaa
commit bc497beffb
3 changed files with 53 additions and 329 deletions

View File

@@ -1036,84 +1036,16 @@ struct GroupedConvolutionBackwardDataKernel
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs Grouped Convolution Backward Data kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
const InDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
WeiDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
const index_t splitted_k,
const index_t block_idx_m,
const index_t block_idx_n,
const index_t block_idx_k,
const index_t group_id)
{
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k);
const auto& d_block_window =
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
// Run Epilogue Pipeline with k_batch dispatch
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
}
@@ -1195,46 +1127,18 @@ struct GroupedConvolutionBackwardDataKernel
static_cast<InDataType*>(kargs.in_ptr) + group_offset_c + input_batch_offset;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
}
}
else
{
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
}
}
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
}
};

View File

@@ -829,66 +829,14 @@ struct GroupedConvolutionBackwardWeightKernel
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs Grouped Convolution Backward Weight kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
const InDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
WeiDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t num_loop,
const index_t block_idx_m,
const index_t block_idx_n,
const index_t block_idx_k)
{
// Create block windows using helper methods
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline with k_batch dispatching
if(kargs.k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
#if defined(__gfx11__)
return;
#endif
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
}
@@ -949,44 +897,9 @@ struct GroupedConvolutionBackwardWeightKernel
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
num_loop,
i_m,
i_n,
i_k);
}
}
else
{
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
num_loop,
i_m,
i_n,
i_k);
}
}
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
}
}
};

View File

@@ -954,80 +954,16 @@ struct GroupedConvolutionForwardKernel
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, c_desc, block_idx_m, block_idx_n);
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, c_desc, block_idx_m, block_idx_n);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param ds_ptr input D tensors pointer array
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param a_desc Input tensor A descriptor
* @param b_desc Weight tensor B descriptor
* @param c_desc Output tensor C descriptor
* @param gemm_k The GEMM K dimension
* @param k_batch The K batch parameter for split-K
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
template <typename ADescType, typename BDescType, typename CDescType>
CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const ADescType& a_desc,
const BDescType& b_desc,
const CDescType& c_desc,
const index_t gemm_k,
const index_t k_batch,
const index_t block_idx_m,
const index_t block_idx_n,
const CDElementwise& elfunc)
{
// Create block windows using specialized methods
const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m);
const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline with k_batch dispatching
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, c_desc, block_idx_m, block_idx_n);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, c_desc, block_idx_m, block_idx_n);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
}
@@ -1177,50 +1113,21 @@ struct GroupedConvolutionForwardKernel
const auto& c_desc = kargs.c_grid_desc_m_n;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
smem_ptr_1,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
}
}
else
{
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
}
}
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
}
}
};