mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[rocm-libraries] ROCm/rocm-libraries#8220 (commit 4c04a3a)
[CK Tile] WAVELET pipeline for backward-data grouped convolution (#8220) ## Motivation On the RetinaNet shapes (gfx950, fp16) CK Tile backward-data conv was ~18% behind classic CK, with the gap concentrated in the K=2376 3x3 detection-head family where bwd_data spends most of its time. The WAVELET GEMM pipeline already gives uplift for forward and backward-weight conv; this ports it to backward-data and consolidates the now-shared machinery across all three directions. ## Technical Details - Backward-data wavelet support in the tile kernel: launch extra load waves when the pipeline exposes `LaunchBlockSize`, and split the epilogue into math waves (run the CShuffle epilogue) and load waves (`RunBarrierStub`). - Register 7 WAVELET instances (fp16 and bf16), tuned for backward-data's tall-skinny GEMM rather than the forward tile shapes: a big-M `256/128/64` workhorse, a `VecA=4` variant for the `K % 8 != 0` shapes, and a `NumGroupsToMerge=32` variant for grouped (depthwise-style) shapes. - Implement the native backward-data instance parser in `generate_instances.py`. - Deduplicate the wavelet machinery shared by forward, backward-data, and backward-weight: `GroupedConvLaunchBlockSize`, `is_wavelet_pipeline`, and `RunWaveletAwareEpilogue` in `grouped_convolution_utils.hpp`; the three native instance parsers collapse to one parameterized parser. The three kernels now call the shared helpers. ## Test Plan - Rebuild the full profiler instance pools for all three directions (fp16/bf16/fp32, nhwgc/ndhwgc) to exercise the shared helpers across every instantiation. - Tile GTests on gfx950: `test_grouped_convnd_fwd_tile`, `test_grouped_convnd_bwd_data_tile`, `test_grouped_convnd_bwd_weight_tile`. - Per-shape sweep of the 35 RetinaNet backward-data shapes vs classic CK and the non-wavelet tile pool (`profile_wavelet_bwd_data.py`); correctness spot-checked with GPU-reference verification on the new big-M and NumGroupsToMerge instances. ## Test Result - GTests pass: forward 9/9, backward-data 6/6, backward-weight 6/6. - Backward-data perf (3x3 g=1 region, geomean classic/tile): 0.88 -> 1.11, i.e. the tile path goes from ~12% slower than classic to ~8% faster. The largest single backward-data shape (256x100x100->2376) moves from 11% slower than classic to 12.5% faster. - The dedup refactor preserves behavior (net -174 lines across the kernels/generator), confirmed by the full rebuild and the GTests above. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
329e589840
commit
01cca38c8e
@@ -529,7 +529,9 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
// Wavelet pipelines launch extra load waves (LaunchBlockSize > BlockSize); others use
|
||||
// BlockSize. See GroupedConvLaunchBlockSize in grouped_convolution_utils.hpp.
|
||||
static constexpr index_t kBlockSize = GroupedConvLaunchBlockSize<GemmPipeline>;
|
||||
|
||||
using OutDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
@@ -934,29 +936,31 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<InDataType, fp16_t, bf16_t>::value))
|
||||
// Run the epilogue with split-K dispatch, wrapped for wavelet load/math waves.
|
||||
RunWaveletAwareEpilogue<GemmPipeline, EpiloguePipeline>([&]() {
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<InDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
|
||||
@@ -456,21 +456,9 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
// For wavelet, LaunchBlockSize > BlockSize. Use LaunchBlockSize for kernel launch.
|
||||
template <typename T, typename = void>
|
||||
struct has_launch_block_size : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct has_launch_block_size<T, std::void_t<decltype(T::LaunchBlockSize)>> : std::true_type
|
||||
{
|
||||
};
|
||||
static constexpr index_t kBlockSize = []() {
|
||||
if constexpr(has_launch_block_size<GemmPipeline>::value)
|
||||
return GemmPipeline::LaunchBlockSize;
|
||||
else
|
||||
return GemmPipeline::BlockSize;
|
||||
}();
|
||||
// Wavelet pipelines launch extra load waves (LaunchBlockSize > BlockSize); others use
|
||||
// BlockSize. See GroupedConvLaunchBlockSize in grouped_convolution_utils.hpp.
|
||||
static constexpr index_t kBlockSize = GroupedConvLaunchBlockSize<GemmPipeline>;
|
||||
|
||||
using OutDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
@@ -1061,22 +1049,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{block_idx_k, block_idx_m});
|
||||
}
|
||||
|
||||
// SFINAE helper: detect GemmPipeline::IsWavelet
|
||||
template <typename T, typename = void>
|
||||
struct has_is_wavelet : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct has_is_wavelet<T, std::void_t<decltype(T::IsWavelet)>> : std::true_type
|
||||
{
|
||||
};
|
||||
static constexpr bool kIsWavelet = []() {
|
||||
if constexpr(has_is_wavelet<GemmPipeline>::value)
|
||||
return GemmPipeline::IsWavelet;
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
@@ -1109,38 +1081,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
if constexpr(kIsWavelet)
|
||||
{
|
||||
// Wavelet: math waves run the epilogue, load waves run matching barriers
|
||||
if(GemmPipeline::IsMathWave())
|
||||
{
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Load waves: match epilogue barrier count to avoid deadlock
|
||||
EpiloguePipeline::RunBarrierStub();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Standard (non-wavelet) path
|
||||
// Run the epilogue with split-K dispatch, wrapped for wavelet load/math waves.
|
||||
RunWaveletAwareEpilogue<GemmPipeline, EpiloguePipeline>([&]() {
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
@@ -1159,7 +1101,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
|
||||
@@ -572,38 +572,9 @@ struct GroupedConvolutionForwardKernel
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline_::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
// For wavelet, LaunchBlockSize > BlockSize (extra load-only waves). Use
|
||||
// LaunchBlockSize for the kernel launch; non-wavelet pipelines fall back to BlockSize.
|
||||
template <typename T, typename = void>
|
||||
struct has_launch_block_size : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct has_launch_block_size<T, std::void_t<decltype(T::LaunchBlockSize)>> : std::true_type
|
||||
{
|
||||
};
|
||||
static constexpr index_t kBlockSize = []() {
|
||||
if constexpr(has_launch_block_size<Pipeline>::value)
|
||||
return Pipeline::LaunchBlockSize;
|
||||
else
|
||||
return Pipeline::BlockSize;
|
||||
}();
|
||||
|
||||
// SFINAE helper: detect Pipeline::IsWavelet (load/math wave specialization).
|
||||
template <typename T, typename = void>
|
||||
struct has_is_wavelet : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct has_is_wavelet<T, std::void_t<decltype(T::IsWavelet)>> : std::true_type
|
||||
{
|
||||
};
|
||||
static constexpr bool kIsWavelet = []() {
|
||||
if constexpr(has_is_wavelet<Pipeline>::value)
|
||||
return Pipeline::IsWavelet;
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
// Wavelet pipelines launch extra load waves (LaunchBlockSize > BlockSize); others use
|
||||
// BlockSize. See GroupedConvLaunchBlockSize in grouped_convolution_utils.hpp.
|
||||
static constexpr index_t kBlockSize = GroupedConvLaunchBlockSize<Pipeline>;
|
||||
|
||||
using InDataType = remove_cvref_t<typename Pipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename Pipeline::BDataType>;
|
||||
@@ -1375,14 +1346,11 @@ struct GroupedConvolutionForwardKernel
|
||||
const auto& c_block_tile =
|
||||
Pipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if constexpr(kIsWavelet)
|
||||
{
|
||||
// Wavelet: only math waves hold accumulators and run the epilogue.
|
||||
// Load waves run a matching barrier sequence to avoid LDS-sync deadlock.
|
||||
// Forward has no split-K (IsSplitKSupported == false), so only the
|
||||
// memory_operation_enum::set path is reachable.
|
||||
if(Pipeline::IsMathWave())
|
||||
// Run the epilogue with k_batch dispatch, wrapped for wavelet load/math waves.
|
||||
// Forward has no split-K (IsSplitKSupported == false), so the atomic_add branch
|
||||
// compiles out and only the set path is reachable.
|
||||
RunWaveletAwareEpilogue<Pipeline, EpiloguePipeline>([&]() {
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
@@ -1393,32 +1361,19 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
EpiloguePipeline::RunBarrierStub();
|
||||
}
|
||||
}
|
||||
else if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) &&
|
||||
IsSplitKSupported)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) &&
|
||||
IsSplitKSupported)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const
|
||||
|
||||
@@ -21,6 +21,71 @@ enum class GroupedConvDirection
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Wavelet pipeline support shared by all three grouped-conv directions. The wavelet GEMM
|
||||
// pipeline launches extra load-only waves (LaunchBlockSize > BlockSize) and splits the
|
||||
// workgroup into math waves (hold accumulators, run the epilogue) and load waves (run a
|
||||
// matching barrier sequence). Non-wavelet pipelines expose neither member; these helpers
|
||||
// detect that via SFINAE so each kernel dispatches without duplicating the machinery.
|
||||
namespace impl {
|
||||
template <typename T, typename = void>
|
||||
struct has_launch_block_size : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct has_launch_block_size<T, std::void_t<decltype(T::LaunchBlockSize)>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct has_is_wavelet : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct has_is_wavelet<T, std::void_t<decltype(T::IsWavelet)>> : std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
// Block size to launch with: wavelet pipelines need LaunchBlockSize (load + math waves);
|
||||
// all others fall back to BlockSize.
|
||||
template <typename Pipeline>
|
||||
inline constexpr index_t GroupedConvLaunchBlockSize = []() {
|
||||
if constexpr(impl::has_launch_block_size<Pipeline>::value)
|
||||
return Pipeline::LaunchBlockSize;
|
||||
else
|
||||
return Pipeline::BlockSize;
|
||||
}();
|
||||
|
||||
// True when the pipeline uses wavelet load/math wave specialization.
|
||||
template <typename Pipeline>
|
||||
inline constexpr bool is_wavelet_pipeline = []() {
|
||||
if constexpr(impl::has_is_wavelet<Pipeline>::value)
|
||||
return Pipeline::IsWavelet;
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
|
||||
// Run the CShuffle epilogue with wavelet load/math wave dispatch. For wavelet pipelines only
|
||||
// the math waves run @p epilogue_body (which writes the C tile); load waves run a matching
|
||||
// barrier sequence (RunBarrierStub) to avoid an LDS-sync deadlock. Non-wavelet pipelines run
|
||||
// @p epilogue_body directly. The body is direction-specific (split-K dispatch, window
|
||||
// construction), so it is passed in rather than shared.
|
||||
template <typename GemmPipeline, typename EpiloguePipeline, typename EpilogueBody>
|
||||
CK_TILE_DEVICE void RunWaveletAwareEpilogue(EpilogueBody&& epilogue_body)
|
||||
{
|
||||
if constexpr(is_wavelet_pipeline<GemmPipeline>)
|
||||
{
|
||||
if(GemmPipeline::IsMathWave())
|
||||
epilogue_body();
|
||||
else
|
||||
EpiloguePipeline::RunBarrierStub();
|
||||
}
|
||||
else
|
||||
{
|
||||
epilogue_body();
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The Grouped Conv kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
|
||||
Reference in New Issue
Block a user