mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
[CK Tile] WAVELET pipeline for backward-data grouped convolution (#8220) ## Motivation On the RetinaNet shapes (gfx950, fp16) CK Tile backward-data conv was ~18% behind classic CK, with the gap concentrated in the K=2376 3x3 detection-head family where bwd_data spends most of its time. The WAVELET GEMM pipeline already gives uplift for forward and backward-weight conv; this ports it to backward-data and consolidates the now-shared machinery across all three directions. ## Technical Details - Backward-data wavelet support in the tile kernel: launch extra load waves when the pipeline exposes `LaunchBlockSize`, and split the epilogue into math waves (run the CShuffle epilogue) and load waves (`RunBarrierStub`). - Register 7 WAVELET instances (fp16 and bf16), tuned for backward-data's tall-skinny GEMM rather than the forward tile shapes: a big-M `256/128/64` workhorse, a `VecA=4` variant for the `K % 8 != 0` shapes, and a `NumGroupsToMerge=32` variant for grouped (depthwise-style) shapes. - Implement the native backward-data instance parser in `generate_instances.py`. - Deduplicate the wavelet machinery shared by forward, backward-data, and backward-weight: `GroupedConvLaunchBlockSize`, `is_wavelet_pipeline`, and `RunWaveletAwareEpilogue` in `grouped_convolution_utils.hpp`; the three native instance parsers collapse to one parameterized parser. The three kernels now call the shared helpers. ## Test Plan - Rebuild the full profiler instance pools for all three directions (fp16/bf16/fp32, nhwgc/ndhwgc) to exercise the shared helpers across every instantiation. - Tile GTests on gfx950: `test_grouped_convnd_fwd_tile`, `test_grouped_convnd_bwd_data_tile`, `test_grouped_convnd_bwd_weight_tile`. - Per-shape sweep of the 35 RetinaNet backward-data shapes vs classic CK and the non-wavelet tile pool (`profile_wavelet_bwd_data.py`); correctness spot-checked with GPU-reference verification on the new big-M and NumGroupsToMerge instances. ## Test Result - GTests pass: forward 9/9, backward-data 6/6, backward-weight 6/6. - Backward-data perf (3x3 g=1 region, geomean classic/tile): 0.88 -> 1.11, i.e. the tile path goes from ~12% slower than classic to ~8% faster. The largest single backward-data shape (256x100x100->2376) moves from 11% slower than classic to 12.5% faster. - The dedup refactor preserves behavior (net -174 lines across the kernels/generator), confirmed by the full rebuild and the GTests above. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.