[CK_TILE] Add splitk support to ck tile conv bwd data (#3353)

* add splitk support to ck tile conv bwd data

* add reviewers suggestions

* minor fix

* removed splitkbatchoffset struct
This commit is contained in:
jakpiase
2025-12-23 10:03:42 +01:00
committed by GitHub
parent 8b73633e65
commit ead81d1b0b
2 changed files with 57 additions and 55 deletions

View File

@@ -542,9 +542,6 @@ struct GroupedConvolutionBackwardDataKernel
static constexpr index_t MaxGroupedGemmGroupsNum =
GroupedConvBwdDataKernelArgsSpecialized::MaxGroupedGemmGroupsNum;
// TODO: Enable this
static constexpr bool IsSplitKSupported = false;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
@@ -623,9 +620,8 @@ struct GroupedConvolutionBackwardDataKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
{
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
!IsSplitKSupported)
if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
{
@@ -772,8 +768,8 @@ struct GroupedConvolutionBackwardDataKernel
}();
const auto& c_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(c_ptr,
kargs.c_grid_descs_m_n[group_id]);
return make_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr, kargs.c_grid_descs_m_n[group_id]);
}();
const auto& ds_tensor_view = generate_tuple(
@@ -837,7 +833,7 @@ struct GroupedConvolutionBackwardDataKernel
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const index_t i_m,
const index_t i_n,
const index_t i_k = 0)
const index_t i_k)
{
const auto& a_pad_view = views.at(I0);
const auto& b_pad_view = views.at(I1);
@@ -893,20 +889,24 @@ struct GroupedConvolutionBackwardDataKernel
WeiDataType* c_ptr,
void* smem_ptr_0,
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
const index_t splitted_k,
const index_t block_idx_m,
const index_t block_idx_n,
const index_t block_idx_k,
const index_t group_id)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -914,7 +914,7 @@ struct GroupedConvolutionBackwardDataKernel
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
@@ -945,8 +945,10 @@ struct GroupedConvolutionBackwardDataKernel
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
const index_t splitted_k,
const index_t block_idx_m,
const index_t block_idx_n,
const index_t block_idx_k,
const index_t group_id)
{
// Create Gemm tensor views, pad views and tile windows
@@ -954,18 +956,25 @@ struct GroupedConvolutionBackwardDataKernel
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(
TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
@@ -1031,9 +1040,17 @@ struct GroupedConvolutionBackwardDataKernel
static_cast<long_index_t>(kargs.input_batch_stride);
// SplitK
// TODO: Implement SplitK support
// const index_t split_k_idx =
// __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
const index_t split_k_idx =
__builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
const index_t gemm_k = kargs.a_grid_descs_m_k[group_id].get_length(I1);
constexpr auto K1 = TilePartitioner::KPerBlock;
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((gemm_k + K_t - 1) / K_t * K1);
const index_t i_k = amd_wave_read_first_lane(split_k_idx * KRead);
const index_t splitted_k = amd_wave_read_first_lane(KRead);
// options
// conv_bwd_data = Out * Weight = In
@@ -1060,8 +1077,10 @@ struct GroupedConvolutionBackwardDataKernel
smem_ptr_0,
smem_ptr_1,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
}
}
@@ -1071,7 +1090,17 @@ struct GroupedConvolutionBackwardDataKernel
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
}
}
}

View File

@@ -505,33 +505,6 @@ struct GroupedConvolutionBackwardWeightKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
a_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = amd_wave_read_first_lane(KRead);
}
else
{
splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
@@ -763,12 +736,12 @@ struct GroupedConvolutionBackwardWeightKernel
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<true, true>{});
}();
@@ -776,7 +749,7 @@ struct GroupedConvolutionBackwardWeightKernel
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
}();
@@ -882,7 +855,7 @@ struct GroupedConvolutionBackwardWeightKernel
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
@@ -932,7 +905,7 @@ struct GroupedConvolutionBackwardWeightKernel
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);