mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Add splitk support to ck tile conv bwd data (#3353)
* add splitk support to ck tile conv bwd data * add reviewers suggestions * minor fix * removed splitkbatchoffset struct
This commit is contained in:
@@ -542,9 +542,6 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
static constexpr index_t MaxGroupedGemmGroupsNum =
|
||||
GroupedConvBwdDataKernelArgsSpecialized::MaxGroupedGemmGroupsNum;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = false;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
@@ -623,9 +620,8 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
@@ -772,8 +768,8 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
}();
|
||||
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr,
|
||||
kargs.c_grid_descs_m_n[group_id]);
|
||||
return make_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr, kargs.c_grid_descs_m_n[group_id]);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
@@ -837,7 +833,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const index_t i_m,
|
||||
const index_t i_n,
|
||||
const index_t i_k = 0)
|
||||
const index_t i_k)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
@@ -893,20 +889,24 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
WeiDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t splitted_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k,
|
||||
const index_t group_id)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
|
||||
gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -914,7 +914,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
@@ -945,8 +945,10 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t splitted_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k,
|
||||
const index_t group_id)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
@@ -954,18 +956,25 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(
|
||||
TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
@@ -1031,9 +1040,17 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
static_cast<long_index_t>(kargs.input_batch_stride);
|
||||
|
||||
// SplitK
|
||||
// TODO: Implement SplitK support
|
||||
// const index_t split_k_idx =
|
||||
// __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
|
||||
const index_t split_k_idx =
|
||||
__builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
|
||||
|
||||
const index_t gemm_k = kargs.a_grid_descs_m_k[group_id].get_length(I1);
|
||||
|
||||
constexpr auto K1 = TilePartitioner::KPerBlock;
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((gemm_k + K_t - 1) / K_t * K1);
|
||||
|
||||
const index_t i_k = amd_wave_read_first_lane(split_k_idx * KRead);
|
||||
const index_t splitted_k = amd_wave_read_first_lane(KRead);
|
||||
|
||||
// options
|
||||
// conv_bwd_data = Out * Weight = In
|
||||
@@ -1060,8 +1077,10 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitted_k,
|
||||
i_m,
|
||||
i_n,
|
||||
i_k,
|
||||
group_id);
|
||||
}
|
||||
}
|
||||
@@ -1071,7 +1090,17 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitted_k,
|
||||
i_m,
|
||||
i_n,
|
||||
i_k,
|
||||
group_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -505,33 +505,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
|
||||
|
||||
a_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
|
||||
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = amd_wave_read_first_lane(KRead);
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
@@ -763,12 +736,12 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
@@ -776,7 +749,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
@@ -882,7 +855,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
@@ -932,7 +905,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user