mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK Tile] StreamK support for Bwd Weight grouped convolutions (#5393) ## Motivation Add StreamK work distribution to the CK Tile grouped convolution backward weight kernel. Split-K divides the K-dimension uniformly across a fixed `k_batch`, which causes load imbalance when the number of output tiles doesn't evenly fill the GPU. StreamK distributes total K-iterations evenly across workgroups, improving utilization on these shapes. ## Technical Details StreamK is added as an `if constexpr` branch in the existing kernel, selected by the `TilePartitioner_` template parameter. Two reduction strategies are supported: - **Linear**: tile-starter sequentially accumulates partials from contributing CTAs - **Tree**: pairwise binary tree reduction (O(log n) depth, faster for many contributors) Both persistent and non-persistent data-parallel (DP) sections are supported. Key changes: - `grouped_convolution_backward_weight_kernel.hpp`: StreamK execution path with `RunStreamK`/`RunStreamKLoop`, partial store/load via workspace, flag-based cross-CTA synchronization, `GridSize`/`MakeKernelArgs`/`GetWorkSpaceSize` extensions - `streamk_common.hpp`: Shared `StreamKReductionOps` (reduction helpers) and `StreamKDispatch` (persistent/non-persistent DP dispatch), used by both GEMM and Conv StreamK kernels - `streamk_gemm_kernel.hpp`: Refactored to use shared helpers - Merged split-K and StreamK example invokers via `PartitionerPolicy` template parameter - StreamK example binary with `--streamk_reduction=linear|tree` and `--streamk_persistent=0|1` - CK Builder integration: `SpecifiesStreamK` concept, `TilePartitionerType` factory helper, `InstanceTraits` with StreamK fields - 30 tests: host-side, GPU end-to-end (Linear + Tree + Persistent DP), negative, builder regression ### Performance (MI355X, gfx950) Speedup relative to best split-K (sweep over k_batch={1,2,4,8,16,32}): | Shape | 16x64 tiles | | 128x128 tiles | | |---|---|---|---|---| | | Split-K | StreamK | Split-K | StreamK | | 1x1 128x128 N=32 28x28 | 1.00x | 0.54x | 1.00x | 0.81x | | 3x3 128x128 N=32 14x14 | 1.00x | 0.59x | 1.00x | 0.62x | | 1x1 256x64 N=32 56x56 | 1.00x | 0.83x | 1.00x | 1.83x | | 3x3 512x512 N=2 7x7 | 1.00x | 1.12x | 1.00x | 0.62x | | 1x1 1024x1024 N=4 7x7 | 1.00x | 1.09x | 1.00x | 0.60x | | 3x3 128x128 N=32 28x28 | 1.00x | 0.44x | 1.00x | 0.96x | | 3x3 256x256 N=32 14x14 | 1.00x | 0.67x | 1.00x | 0.93x | | 3x3 512x512 N=32 7x7 | 1.00x | 0.98x | 1.00x | 1.16x | StreamK's value depends on tile config: with larger tiles (fewer output tiles), StreamK delivers up to 1.83x speedup on bottleneck shapes and up to 1.16x on typical large-channel convolutions. Tree reduction consistently outperforms Linear when multiple CTAs contribute to the same tile (up to 2.87x faster), due to O(log n) reduction depth vs O(n) sequential accumulation. The table reports the best of Linear and Tree for each shape. ## Test Plan ```bash ninja -C build test_ck_tile_grouped_conv_bwd_weight_streamk ./build/bin/test_ck_tile_grouped_conv_bwd_weight_streamk # Builder tests (requires CK_EXPERIMENTAL_BUILDER=ON) ninja -C build check-builder ``` 30 tests covering: - Host-side: type traits, kernel args construction, grid size, workspace size - GPU end-to-end (Linear + Tree): small/medium shapes, multi-group, stride>1, pure-DP degeneration, single-tile all-SK, large GemmK, higher occupancy - Persistent DP: Linear + Tree with persistent data-parallel dispatch - Negative: `IsSupportedArgument` rejects unaligned K and C - Builder: Create (instance string validation) + Execution (reference comparison) + instance string regression ## Test Result All 30 conv StreamK tests pass on MI355X (gfx950). 64/64 GEMM StreamK tests pass. Full `check-builder` suite passes. Tolerances computed dynamically using `calculate_rtol_atol` pattern (fp16 ULP-aware). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.