Johannes Graner
3727d5220a
[rocm-libraries] ROCm/rocm-libraries#5652 (commit 7dc7d1d)
[CK Conv] Wavelet gemm pipeline for bwd_weight convolution (#5652)
## Motivation
In the current CShuffleV3 backward weight kernel, the in-kernel
conv-to-GEMM transform generates significant INT32 VALU pressure per
MFMA instruction. On VALU-heavy shapes (e.g., G=1, 3×3, C=256), these
index computation ops compete with MFMA for VALU issue slots, creating a
bottleneck that cannot be resolved by pipeline prefetching alone.
This PR adds a wave-specialized ("wavelet") convolution backward weight
kernel that splits workgroup threads into two roles:
- **Load waves**: conv-to-GEMM address computation + global memory loads
+ LDS writes (all VALU/VMEM)
- **Math waves**: LDS reads + MFMA + CShuffle epilogue (no index
computation)
By physically separating the two instruction classes onto different
waves, VALU and MFMA execute on different hardware functional units
without contention.
## Technical Details
**Core kernel (new files):**
- `gridwise_gemm_xdl_waveletmodel_cshuffle_conv_v3.hpp` —
wave-specialized gridwise GEMM for conv bwd weight (2-way split: load +
math)
- `device_grouped_conv_bwd_weight_xdl_waveletmodel_cshuffle_v3.hpp` —
device op following CShuffleV3 patterns; `BlockSize =
TileMathThreadGroupSize` for MFMA wave assignment, `LaunchBlockSize =
TileLoad + TileMath` for kernel launch
**Wave pipeline (modified):**
- `gridwise_gemm_waveletmodel.hpp` — load/math wave pipeline structs
with `sched_group_barrier` scheduling hints to front-load VMEM reads
before address-advance VALU
**Two wave ratios:**
- **(4,4)**: 256 load + 256 math = 512 threads (8 waves). Best on large
shapes.
- **(4,2)**: 256 load + 128 math = 384 threads (6 waves). Best on small
shapes (fewer sync barriers, denser MFMA per math wave).
**Instance coverage (F16 and BF16 symmetric):**
| Ratio | Tiles | Layouts | ConvSpecs |
|-------|-------|---------|-----------|
| (4,4) | M128×N128, M64×N64, M128×N64, M64×N128 | 2D NHWGC, 3D NDHWGC |
Default, Filter1x1Stride1Pad0 |
| (4,2) | M64×N64, M128×N64, M64×N128 | 2D NHWGC | Default,
Filter1x1Stride1Pad0 |
**Existing wavelet model fixes:**
- `BlockSize` corrected from `math::max(TileLoad, TileMath)` to
`TileMathThreadGroupSize` in the flat-GEMM wavelet device op and
gridwise kernel
## Test Plan
- `test_grouped_convnd_bwd_weight` GTest: 34 hardcoded test cases
covering 1D/2D/3D, F16/BF16, G=1/2/16, various spatial sizes
- Performance benchmark: all 37 RetinaNet bwd_weight shapes on gfx950
```bash
ninja -C build test_grouped_convnd_bwd_weight
./build/bin/test_grouped_convnd_bwd_weight
```
## Test Result
**Correctness:** 34/34 GTest cases passed (F16/BF16 × 1D/2D/3D ×
Default/Filter1x1Stride1Pad0 × various G/N/K/C combinations).
**Performance:** Wavelet is the fastest overall instance on 12/37
RetinaNet shapes — all G=1, 3×3 convolutions with C=256 (the VALU-heavy
target shapes):
| Shape | Uplift vs best baseline |
|-------|------------------------|
| K=36, 7×7 | 1.91x |
| K=36, 100×100 | 1.60x |
| K=36, 13×13 | 1.43x |
| K=36, 25×25 | 1.38x |
| K=36, 50×50 | 1.38x |
| K=256, 100×100 | 1.24x |
| K=256, 13×13, s=2 | 1.20x |
| K=256, 25×25, s=2 | 1.20x |
| K=256, 7×7 | 1.17x |
| K=256, 13×13 | 1.13x |
| K=2376, 50×50 | 1.05x |
| K=2376, 100×100 | 1.06x |
Where wavelet does not win (25/37): 1×1 convolutions (explicit kernel
does host-side transform), grouped convolutions with small per-group
channels, and shapes where standard CShuffleV3 already amortizes VALU
overhead.
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---------
Co-authored-by: jakpiase <jakpia21@gmail.com>
2026-05-18 17:46:01 +02:00
..
2026-05-15 06:46:51 -07:00
2025-11-18 10:17:18 -08:00
2025-11-18 10:17:18 -08:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-01-26 10:20:30 -08:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2025-11-20 10:45:57 -08:00
2026-01-26 10:20:30 -08:00
2026-05-15 06:46:51 -07:00
2026-01-26 12:57:09 -08:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-01-23 12:39:03 -08:00
2026-01-26 12:57:09 -08:00
2026-01-23 12:39:03 -08:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2026-03-31 08:18:11 -07:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2025-11-18 10:17:18 -08:00
2025-11-18 10:17:18 -08:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2025-11-20 10:45:57 -08:00
2026-05-14 12:51:08 -07:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2025-12-11 09:06:20 +01:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2025-12-11 09:06:20 +01:00
2026-01-26 12:57:09 -08:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-18 17:46:01 +02:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-18 06:49:50 -07:00
2026-04-23 11:16:55 +02:00
2026-05-14 12:51:08 -07:00
2026-04-23 22:10:46 +02:00
2026-04-23 22:10:46 +02:00
2026-05-15 06:46:51 -07:00
2026-04-23 22:10:46 +02:00
2026-05-15 06:46:51 -07:00
2026-04-23 22:10:46 +02:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-18 17:46:01 +02:00
2025-12-14 12:49:12 -08:00
2025-11-20 10:45:57 -08:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2025-12-18 13:12:15 -07:00
2026-04-20 12:24:59 +00:00
2026-04-23 07:05:33 -07:00
2026-05-15 06:46:51 -07:00
2025-11-20 10:45:57 -08:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-01-26 12:57:09 -08:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2025-11-18 10:17:18 -08:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-15 06:46:51 -07:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2025-11-18 10:17:18 -08:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2025-11-18 10:17:18 -08:00
2025-11-18 10:17:18 -08:00
2026-05-14 12:51:08 -07:00
2026-05-14 12:51:08 -07:00
2025-11-26 11:00:05 -07:00
2026-05-14 12:51:08 -07:00
2026-05-15 06:46:51 -07:00
2025-11-26 11:00:05 -07:00
2026-01-08 08:02:02 +01:00
2026-04-23 22:10:46 +02:00