Files
composable_kernel/example/20_grouped_conv_bwd_weight
Johannes Graner 3727d5220a [rocm-libraries] ROCm/rocm-libraries#5652 (commit 7dc7d1d)
[CK Conv] Wavelet gemm pipeline for bwd_weight convolution (#5652)

## Motivation

In the current CShuffleV3 backward weight kernel, the in-kernel
conv-to-GEMM transform generates significant INT32 VALU pressure per
MFMA instruction. On VALU-heavy shapes (e.g., G=1, 3×3, C=256), these
index computation ops compete with MFMA for VALU issue slots, creating a
bottleneck that cannot be resolved by pipeline prefetching alone.

This PR adds a wave-specialized ("wavelet") convolution backward weight
kernel that splits workgroup threads into two roles:
- **Load waves**: conv-to-GEMM address computation + global memory loads
+ LDS writes (all VALU/VMEM)
- **Math waves**: LDS reads + MFMA + CShuffle epilogue (no index
computation)

By physically separating the two instruction classes onto different
waves, VALU and MFMA execute on different hardware functional units
without contention.

## Technical Details

**Core kernel (new files):**
- `gridwise_gemm_xdl_waveletmodel_cshuffle_conv_v3.hpp` —
wave-specialized gridwise GEMM for conv bwd weight (2-way split: load +
math)
- `device_grouped_conv_bwd_weight_xdl_waveletmodel_cshuffle_v3.hpp` —
device op following CShuffleV3 patterns; `BlockSize =
TileMathThreadGroupSize` for MFMA wave assignment, `LaunchBlockSize =
TileLoad + TileMath` for kernel launch

**Wave pipeline (modified):**
- `gridwise_gemm_waveletmodel.hpp` — load/math wave pipeline structs
with `sched_group_barrier` scheduling hints to front-load VMEM reads
before address-advance VALU

**Two wave ratios:**
- **(4,4)**: 256 load + 256 math = 512 threads (8 waves). Best on large
shapes.
- **(4,2)**: 256 load + 128 math = 384 threads (6 waves). Best on small
shapes (fewer sync barriers, denser MFMA per math wave).

**Instance coverage (F16 and BF16 symmetric):**

| Ratio | Tiles | Layouts | ConvSpecs |
|-------|-------|---------|-----------|
| (4,4) | M128×N128, M64×N64, M128×N64, M64×N128 | 2D NHWGC, 3D NDHWGC |
Default, Filter1x1Stride1Pad0 |
| (4,2) | M64×N64, M128×N64, M64×N128 | 2D NHWGC | Default,
Filter1x1Stride1Pad0 |

**Existing wavelet model fixes:**
- `BlockSize` corrected from `math::max(TileLoad, TileMath)` to
`TileMathThreadGroupSize` in the flat-GEMM wavelet device op and
gridwise kernel

## Test Plan

- `test_grouped_convnd_bwd_weight` GTest: 34 hardcoded test cases
covering 1D/2D/3D, F16/BF16, G=1/2/16, various spatial sizes
- Performance benchmark: all 37 RetinaNet bwd_weight shapes on gfx950

```bash
ninja -C build test_grouped_convnd_bwd_weight
./build/bin/test_grouped_convnd_bwd_weight
```

## Test Result

**Correctness:** 34/34 GTest cases passed (F16/BF16 × 1D/2D/3D ×
Default/Filter1x1Stride1Pad0 × various G/N/K/C combinations).

**Performance:** Wavelet is the fastest overall instance on 12/37
RetinaNet shapes — all G=1, 3×3 convolutions with C=256 (the VALU-heavy
target shapes):

| Shape | Uplift vs best baseline |
|-------|------------------------|
| K=36, 7×7 | 1.91x |
| K=36, 100×100 | 1.60x |
| K=36, 13×13 | 1.43x |
| K=36, 25×25 | 1.38x |
| K=36, 50×50 | 1.38x |
| K=256, 100×100 | 1.24x |
| K=256, 13×13, s=2 | 1.20x |
| K=256, 25×25, s=2 | 1.20x |
| K=256, 7×7 | 1.17x |
| K=256, 13×13 | 1.13x |
| K=2376, 50×50 | 1.05x |
| K=2376, 100×100 | 1.06x |

Where wavelet does not win (25/37): 1×1 convolutions (explicit kernel
does host-side transform), grouped convolutions with small per-group
channels, and shapes where standard CShuffleV3 already amortizes VALU
overhead.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: jakpiase <jakpia21@gmail.com>
2026-05-18 17:46:01 +02:00
..

Grouped Convolution Backward Pass for Weights

This example demonstrates the backward weight pass for a grouped convolution, often denoted as grouped_conv_bwd_weight. This operation is essential for training neural networks that use grouped or depthwise convolutions, such as ResNeXt, MobileNets, and EfficientNets. Its purpose is to compute the gradient of the loss function with respect to the convolution's filter weights, which is then used by an optimizer (like SGD or Adam) to update the model's parameters.

Mathematical Formulation

The backward weight pass computes the gradient \frac{\partial L}{\partial W}, given the input tensor from the forward pass, In, and the gradient from the subsequent layer, dL/dOut.

For a single group g, the operation is mathematically equivalent to a convolution between the input tensor for that group, In_[g], and the output gradient tensor for that group, dL/dOut_[g].

\frac{\partial L}{\partial W_{[g]}} = \text{In}_{[g]} \star \frac{\partial L}{\partial \text{Out}_{[g]}}

This operation correlates the input activations with the output error signals to determine how each weight should be adjusted to reduce the overall loss. The total gradient dL/dW is the collection of gradients for all G groups.

Algorithmic Strategy: Implicit Grouped GEMM

This operation is a perfect candidate for the Grouped GEMM primitive. The convolution for each of the G groups is independently transformed into a GEMM problem, and all G GEMMs are executed in a single kernel launch.

For each group g:

  1. Input to Columns (im2col): The input tensor In_[g] is logically unrolled into a matrix In'_[g]. This is the same im2col transformation used in the forward pass. This matrix becomes the "A" matrix in the GEMM.

  2. Output Gradient Reshaping: The output gradient tensor dL/dOut_[g] is logically reshaped into a matrix (dL/dOut)'_[g]. This matrix becomes the "B" matrix in the GEMM.

  3. Implicit Grouped GEMM: The weight gradient dL/dW_[g] is computed by a single GEMM: (\text{dL/dW})'_{[g]} = (\text{dL/dOut})'_{[g]} \times (\text{In}'_{[g]})^T

The key to performance is that this is executed as a Grouped GEMM. The DeviceGroupedConvBwdWeight interface takes the G independent problems and maps them to a DeviceGroupedGemm kernel. This kernel schedules the G independent GEMMs across the GPU's compute units. The im2col transformation is performed implicitly; the GEMM kernel reads data directly from the original In and dL/dOut tensors in the correct pattern, avoiding the materialization of large intermediate matrices.

This approach is highly efficient as it leverages the task-parallel nature of the grouped convolution and the computational efficiency of highly optimized GEMM kernels.

Source Code Organization

Build and Run

Prerequisites

Ensure the Composable Kernel library is built and installed.

cd /path/to/composable_kernel/build
make -j install

Build the Example

cd /path/to/composable_kernel/example/20_grouped_conv_bwd_weight
mkdir build && cd build

cmake \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
  ..

make -j

Run the Example

# Run the example with default settings
./grouped_conv_bwd_weight_xdl

# Run with verification, data initialization, and timing
./grouped_conv_bwd_weight_xdl 1 2 1

Importance in Modern CNNs

Grouped and depthwise convolutions are the cornerstone of many efficient, state-of-the-art CNN architectures.

  • Parameter Efficiency: By not connecting every input channel to every output channel, grouped convolutions significantly reduce the number of weights in a layer, leading to smaller and faster models.
  • Depthwise Separable Convolutions: Used in MobileNets, EfficientNets, and Xception, these layers factorize a standard convolution into a depthwise convolution (a grouped convolution with G = C) and a pointwise convolution (1x1 conv). The backward pass for the depthwise part requires an efficient grouped_conv_bwd_weight implementation.
  • ResNeXt: This architecture introduced the "cardinality" dimension, which is simply the number of groups in a grouped convolution, demonstrating that increasing the number of groups can be more effective than increasing layer depth or width.

An optimized grouped_conv_bwd_weight kernel is therefore not an exotic feature but a critical requirement for training a wide range of modern and efficient deep learning models.