Files
composable_kernel/include/ck/tensor_operation/gpu
Johannes Graner 3727d5220a [rocm-libraries] ROCm/rocm-libraries#5652 (commit 7dc7d1d)
[CK Conv] Wavelet gemm pipeline for bwd_weight convolution (#5652)

## Motivation

In the current CShuffleV3 backward weight kernel, the in-kernel
conv-to-GEMM transform generates significant INT32 VALU pressure per
MFMA instruction. On VALU-heavy shapes (e.g., G=1, 3×3, C=256), these
index computation ops compete with MFMA for VALU issue slots, creating a
bottleneck that cannot be resolved by pipeline prefetching alone.

This PR adds a wave-specialized ("wavelet") convolution backward weight
kernel that splits workgroup threads into two roles:
- **Load waves**: conv-to-GEMM address computation + global memory loads
+ LDS writes (all VALU/VMEM)
- **Math waves**: LDS reads + MFMA + CShuffle epilogue (no index
computation)

By physically separating the two instruction classes onto different
waves, VALU and MFMA execute on different hardware functional units
without contention.

## Technical Details

**Core kernel (new files):**
- `gridwise_gemm_xdl_waveletmodel_cshuffle_conv_v3.hpp` —
wave-specialized gridwise GEMM for conv bwd weight (2-way split: load +
math)
- `device_grouped_conv_bwd_weight_xdl_waveletmodel_cshuffle_v3.hpp` —
device op following CShuffleV3 patterns; `BlockSize =
TileMathThreadGroupSize` for MFMA wave assignment, `LaunchBlockSize =
TileLoad + TileMath` for kernel launch

**Wave pipeline (modified):**
- `gridwise_gemm_waveletmodel.hpp` — load/math wave pipeline structs
with `sched_group_barrier` scheduling hints to front-load VMEM reads
before address-advance VALU

**Two wave ratios:**
- **(4,4)**: 256 load + 256 math = 512 threads (8 waves). Best on large
shapes.
- **(4,2)**: 256 load + 128 math = 384 threads (6 waves). Best on small
shapes (fewer sync barriers, denser MFMA per math wave).

**Instance coverage (F16 and BF16 symmetric):**

| Ratio | Tiles | Layouts | ConvSpecs |
|-------|-------|---------|-----------|
| (4,4) | M128×N128, M64×N64, M128×N64, M64×N128 | 2D NHWGC, 3D NDHWGC |
Default, Filter1x1Stride1Pad0 |
| (4,2) | M64×N64, M128×N64, M64×N128 | 2D NHWGC | Default,
Filter1x1Stride1Pad0 |

**Existing wavelet model fixes:**
- `BlockSize` corrected from `math::max(TileLoad, TileMath)` to
`TileMathThreadGroupSize` in the flat-GEMM wavelet device op and
gridwise kernel

## Test Plan

- `test_grouped_convnd_bwd_weight` GTest: 34 hardcoded test cases
covering 1D/2D/3D, F16/BF16, G=1/2/16, various spatial sizes
- Performance benchmark: all 37 RetinaNet bwd_weight shapes on gfx950

```bash
ninja -C build test_grouped_convnd_bwd_weight
./build/bin/test_grouped_convnd_bwd_weight
```

## Test Result

**Correctness:** 34/34 GTest cases passed (F16/BF16 × 1D/2D/3D ×
Default/Filter1x1Stride1Pad0 × various G/N/K/C combinations).

**Performance:** Wavelet is the fastest overall instance on 12/37
RetinaNet shapes — all G=1, 3×3 convolutions with C=256 (the VALU-heavy
target shapes):

| Shape | Uplift vs best baseline |
|-------|------------------------|
| K=36, 7×7 | 1.91x |
| K=36, 100×100 | 1.60x |
| K=36, 13×13 | 1.43x |
| K=36, 25×25 | 1.38x |
| K=36, 50×50 | 1.38x |
| K=256, 100×100 | 1.24x |
| K=256, 13×13, s=2 | 1.20x |
| K=256, 25×25, s=2 | 1.20x |
| K=256, 7×7 | 1.17x |
| K=256, 13×13 | 1.13x |
| K=2376, 50×50 | 1.05x |
| K=2376, 100×100 | 1.06x |

Where wavelet does not win (25/37): 1×1 convolutions (explicit kernel
does host-side transform), grouped convolutions with small per-group
channels, and shapes where standard CShuffleV3 already amortizes VALU
overhead.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: jakpiase <jakpia21@gmail.com>
2026-05-18 17:46:01 +02:00
..