[CK Tile] StreamK support for Bwd Weight grouped convolutions (#5393) ## Motivation Add StreamK work distribution to the CK Tile grouped convolution backward weight kernel. Split-K divides the K-dimension uniformly across a fixed `k_batch`, which causes load imbalance when the number of output tiles doesn't evenly fill the GPU. StreamK distributes total K-iterations evenly across workgroups, improving utilization on these shapes. ## Technical Details StreamK is added as an `if constexpr` branch in the existing kernel, selected by the `TilePartitioner_` template parameter. Two reduction strategies are supported: - **Linear**: tile-starter sequentially accumulates partials from contributing CTAs - **Tree**: pairwise binary tree reduction (O(log n) depth, faster for many contributors) Both persistent and non-persistent data-parallel (DP) sections are supported. Key changes: - `grouped_convolution_backward_weight_kernel.hpp`: StreamK execution path with `RunStreamK`/`RunStreamKLoop`, partial store/load via workspace, flag-based cross-CTA synchronization, `GridSize`/`MakeKernelArgs`/`GetWorkSpaceSize` extensions - `streamk_common.hpp`: Shared `StreamKReductionOps` (reduction helpers) and `StreamKDispatch` (persistent/non-persistent DP dispatch), used by both GEMM and Conv StreamK kernels - `streamk_gemm_kernel.hpp`: Refactored to use shared helpers - Merged split-K and StreamK example invokers via `PartitionerPolicy` template parameter - StreamK example binary with `--streamk_reduction=linear|tree` and `--streamk_persistent=0|1` - CK Builder integration: `SpecifiesStreamK` concept, `TilePartitionerType` factory helper, `InstanceTraits` with StreamK fields - 30 tests: host-side, GPU end-to-end (Linear + Tree + Persistent DP), negative, builder regression ### Performance (MI355X, gfx950) Speedup relative to best split-K (sweep over k_batch={1,2,4,8,16,32}): | Shape | 16x64 tiles | | 128x128 tiles | | |---|---|---|---|---| | | Split-K | StreamK | Split-K | StreamK | | 1x1 128x128 N=32 28x28 | 1.00x | 0.54x | 1.00x | 0.81x | | 3x3 128x128 N=32 14x14 | 1.00x | 0.59x | 1.00x | 0.62x | | 1x1 256x64 N=32 56x56 | 1.00x | 0.83x | 1.00x | 1.83x | | 3x3 512x512 N=2 7x7 | 1.00x | 1.12x | 1.00x | 0.62x | | 1x1 1024x1024 N=4 7x7 | 1.00x | 1.09x | 1.00x | 0.60x | | 3x3 128x128 N=32 28x28 | 1.00x | 0.44x | 1.00x | 0.96x | | 3x3 256x256 N=32 14x14 | 1.00x | 0.67x | 1.00x | 0.93x | | 3x3 512x512 N=32 7x7 | 1.00x | 0.98x | 1.00x | 1.16x | StreamK's value depends on tile config: with larger tiles (fewer output tiles), StreamK delivers up to 1.83x speedup on bottleneck shapes and up to 1.16x on typical large-channel convolutions. Tree reduction consistently outperforms Linear when multiple CTAs contribute to the same tile (up to 2.87x faster), due to O(log n) reduction depth vs O(n) sequential accumulation. The table reports the best of Linear and Tree for each shape. ## Test Plan ```bash ninja -C build test_ck_tile_grouped_conv_bwd_weight_streamk ./build/bin/test_ck_tile_grouped_conv_bwd_weight_streamk # Builder tests (requires CK_EXPERIMENTAL_BUILDER=ON) ninja -C build check-builder ``` 30 tests covering: - Host-side: type traits, kernel args construction, grid size, workspace size - GPU end-to-end (Linear + Tree): small/medium shapes, multi-group, stride>1, pure-DP degeneration, single-tile all-SK, large GemmK, higher occupancy - Persistent DP: Linear + Tree with persistent data-parallel dispatch - Negative: `IsSupportedArgument` rejects unaligned K and C - Builder: Create (instance string validation) + Execution (reference comparison) + instance string regression ## Test Result All 30 conv StreamK tests pass on MI355X (gfx950). 64/64 GEMM StreamK tests pass. Full `check-builder` suite passes. Tolerances computed dynamically using `calculate_rtol_atol` pattern (fp16 ULP-aware). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
Convolution Builder Factory Directory
This directory implements compile-time dispatch from high-level signature and algorithm descriptors to our existing specialized convolution kernel implementations.
See the main builder documentation for an overview.
Design Overview
The factory system operates in two phases:
-
Algorithm Classification: Predicate concepts in
conv_dispatcher.hppinspect the algorithm descriptor to determine which kernel variant it satisfies. The predicates are evaluated in a specific order usingif constexpr:-
Cross-direction (checked first, supports all convolution directions):
ReferenceAlgorithm— simple reference implementation for validationTileAlgorithm— CK Tile backend, dispatches viaConvTileFactory
-
Forward direction (old CK):
FwdXdlV3Algorithm— newer XDL pipeline using block GEMM structureFwdXdlAlgorithm— standard XDL using AMD XDLops instructionsFwdWmmaAlgorithm— WMMA variant for gfx11/gfx12 hardwareFwdDlAlgorithm— vectorized dot-product kernel (non-XDLops)LargeTensorAlgorithm— XDL with extended tensor support
-
Backward weight direction (old CK):
BwdXdlAlgorithm,BwdXdlV3Algorithm,BwdTwoStageXdlAlgorithm,BwdDlAlgorithm,BwdMultiDXdlAlgorithm,BwdWmmaV3Algorithm,BwdTwoStageWmmaV3Algorithm,BwdWmmaAlgorithm,BwdMultiDWmmaV3Algorithm
-
Backward data direction: Currently supports only Reference and Tile algorithms. Optimized old CK kernels are not yet implemented.
-
-
Factory Instantiation: Each factory transforms builder descriptors into backend-specific template parameters and instantiates the corresponding kernel.
Key Files
-
conv_dispatcher.hpp: Entry point withmake_conv_instance()function. Contains dispatch logic and algorithm classification predicates. Start here to understand the overall flow. -
Forward factories (old CK):
conv_fwd_v3_factory.hpp,conv_fwd_xdl_factory.hpp,conv_fwd_wmma_factory.hpp,conv_fwd_dl_factory.hpp,conv_fwd_large_tensor_factory.hpp -
Backward weight factories (old CK):
conv_bwd_weight_xdl_factory.hpp,conv_bwd_weight_xdl_v3_factory.hpp,conv_bwd_weight_two_stage_xdl_factory.hpp,conv_bwd_weight_dl_factory.hpp,conv_bwd_weight_multi_d_xdl_factory.hpp,conv_bwd_weight_wmma_v3_factory.hpp,conv_bwd_weight_two_stage_wmma_v3_factory.hpp,conv_bwd_weight_wmma_factory.hpp,conv_bwd_weight_multi_d_wmma_v3_factory.hpp -
Cross-direction factories:
reference_factory.hpp(reference implementation),conv_tile_factory.hpp(CK Tile backend) -
helpers/: Transformation utilities that map builder types to backend-specific parameters. Organized intohelpers/ck/(old CK mappings) andhelpers/ck_tile/(CK Tile mappings).
Usage
#include "ck_tile/builder/factory/conv_dispatcher.hpp"
// Uses latest version by default (currently "0.1.0")
auto kernel = make_conv_instance<SIGNATURE, ALGORITHM>();
// Or pin to a specific version
auto kernel_v0 = make_conv_instance<SIGNATURE, ALGORITHM, "0.0.0">();
The dispatcher automatically selects the appropriate factory at compile time.
Factory Architecture and the Unification Gap
Each factory is a self-contained facade: it accepts builder descriptors and produces a kernel instance, but it does so with its own algorithm descriptor shape and its own parameter mapping logic. The 16+ factories share no common infrastructure for parameter transformation.
Old CK factories (e.g., ConvFwdXdlV3Factory) flatten all algorithm parameters into a single device operation template instantiation with approximately 49 template arguments. The factory's primary job is mapping builder enum values (layouts, data types, elementwise ops) to CK's internal types. Within old CK, the XDL and WMMA factories duplicate much of this mapping logic despite sharing the same underlying parameter concepts.
The CK Tile factory (ConvTileFactory) composes modern objects — a traits type, a tile partitioner, a GEMM pipeline, and an epilogue pipeline — each with its own configuration. This results in approximately 31 parameters distributed across four composed types rather than one flat template.
Both factory paths produce a kernel Instance type that satisfies the same usage interface (construction, argument setup, invocation). The dispatcher abstracts this difference from the caller. However, the algorithm descriptor accepted by each factory is different — the unification burden currently falls on the caller (MIOpen), not the dispatcher. Collapsing these per-variant descriptors into a single algorithm format that the dispatcher decomposes internally is the key step toward making the builder a true unified facade.