Files
composable_kernel/example/54_groupnorm_bwd
John Shumway b5e2f26808 [CK_BUILDER] Put global CK functions in an the CK namespace (#3232)
* Wrap ck host utitlies in CK namespace.

The CK and CK-Tile source code bases are incompatible because CK is not properly using namespaces everywhere. In particular, we need to put hip_check_error in the ck namespace.

Move all functions in include/ck_/host_utility that were in global namespace into the ck namespace.

There may be additional namespace problems like this, and it's possible we'll have namespace clashes. But it is good design to properly guard our to code bases (CK and CKTile) so that they can both coexist. Moreover, estabilishing this compatiblity is essential if we are going to allow the builder to instantiate  kernels from either template library.

* Add using declarations to test code.

After moving some of the untils into the ck namespace, most examples and a few tests had to be updated to recognize the new namespace declarations. We add using declarations to individual compute units for functions that were previously in the global namespace.

* Add using declarations to client examples.

[ROCm/composable_kernel commit: ad57f6ef0b]
2025-11-19 11:23:02 +01:00
..

Group Normalization Backward

This example demonstrates the backward pass of Group Normalization. This operation computes the gradients of the loss with respect to the input, gamma, and beta parameters of a group normalization layer, which is essential for training neural networks that use group normalization, particularly in computer vision applications where batch size independence is important.

Mathematical Formulation

The backward pass of group normalization involves computing gradients for three components: input X, scale parameter gamma, and shift parameter beta.

Given:

  • Input tensor X with shape [N, C, H, W]
  • Number of groups G (where C must be divisible by G)
  • Scale parameter gamma with shape [C]
  • Shift parameter beta with shape [C]
  • Output gradients dL/dY with shape [N, C, H, W]

From the forward pass, for each batch item n and group g:

  • Channels in group: S_g = \{c : c \text{ belongs to group } g\} where |S_g| = C/G
  • Mean: \mu_{ng} = \frac{1}{(C/G) \cdot H \cdot W} \sum_{c \in S_g} \sum_{h,w} X_{nchw}
  • Variance: \sigma_{ng}^2 = \frac{1}{(C/G) \cdot H \cdot W} \sum_{c \in S_g} \sum_{h,w} (X_{nchw} - \mu_{ng})^2
  • Normalized: \hat{X}_{nchw} = \frac{X_{nchw} - \mu_{ng}}{\sqrt{\sigma_{ng}^2 + \epsilon}} for c \in S_g
  • Output: Y_{nchw} = \gamma_c \cdot \hat{X}_{nchw} + \beta_c

Gradient Computations

Gradient w.r.t. beta: \frac{\partial L}{\partial \beta_c} = \sum_{n,h,w} \frac{\partial L}{\partial Y_{nchw}}

Gradient w.r.t. gamma: \frac{\partial L}{\partial \gamma_c} = \sum_{n,h,w} \frac{\partial L}{\partial Y_{nchw}} \cdot \hat{X}_{nchw}

Gradient w.r.t. input (most complex): For channel c in group g: \frac{\partial L}{\partial X_{nchw}} = \frac{\gamma_c}{\sqrt{\sigma_{ng}^2 + \epsilon}} \left[ \frac{\partial L}{\partial Y_{nchw}} - \frac{1}{|S_g| \cdot H \cdot W}\left(\sum_{c' \in S_g} \frac{\partial L}{\partial \beta_{c'}} + \hat{X}_{nchw} \sum_{c' \in S_g} \frac{\partial L}{\partial \gamma_{c'}}\right) \right]

Algorithmic Strategy: Multi-Stage Group-wise Gradient Computation

The backward pass requires coordinated computation across groups with multiple reduction operations.

  1. Pass 1: Compute Gamma and Beta Gradients

    • Grid Scheduling: Parallelize over channels (C dimension).
    • Reduction per Channel: For each channel c, reduce across N, H, W dimensions:
      • grad_beta[c] = sum(grad_output[n, c, h, w]) over all n, h, w
      • grad_gamma[c] = sum(grad_output[n, c, h, w] * x_normalized[n, c, h, w]) over all n, h, w
  2. Pass 2: Compute Group-wise Intermediate Values

    • Grid Scheduling: Parallelize over (N, G) pairs.
    • Group Reduction: For each (n, g) pair:
      • Sum grad_beta values for channels in group g
      • Sum grad_gamma values for channels in group g
      • These values are needed for the input gradient computation
  3. Pass 3: Compute Input Gradients

    • Grid Scheduling: Parallelize over input tensor elements.
    • Per-Element Computation: For each (n, c, h, w):
      • Identify which group g channel c belongs to
      • Read the group-wise intermediate values from Pass 2
      • Apply the complex input gradient formula

Source Code Organization

  • groupnorm_bwd_xdl.cpp: The main example file. It sets up the forward pass results, output gradients, group configuration, and instantiates the DeviceGroupnormBwd operation.
  • ../../include/ck/tensor_operation/gpu/device/device_groupnorm_bwd.hpp: The high-level device interface for group normalization backward operations.
  • The underlying implementation coordinates multiple reduction and computation stages to efficiently handle the group-wise structure of the gradients.

Build and Run

Prerequisites

Ensure the Composable Kernel library is built and installed.

cd /path/to/composable_kernel/build
make -j install

Build the Example

cd /path/to/composable_kernel/example/54_groupnorm_bwd
mkdir build && cd build

cmake \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
  ..

make -j

Run the Example

# Run the example with default settings
./groupnorm_bwd_xdl

# Run with verification, data initialization, and timing
./groupnorm_bwd_xdl 1 2 1

Comparison with Other Normalization Backward Passes

Normalization Type Gradient Scope Complexity Memory Pattern
BatchNorm Across batch for each channel Medium Channel-wise reductions
LayerNorm Across features for each item Medium Per-sample reductions
GroupNorm Across group for each (batch, group) High Group-wise reductions
InstanceNorm Per channel per sample Low Independent computations

Applications in Computer Vision

Group normalization backward is particularly important for:

  • Small Batch Training: When batch sizes are too small for effective batch normalization
  • Transfer Learning: Fine-tuning pre-trained models with different batch sizes
  • Object Detection: Models like YOLO and R-CNN that benefit from batch-size independent normalization
  • Segmentation Networks: Dense prediction tasks where normalization stability is crucial
  • Style Transfer: Applications where group-wise feature normalization helps preserve style information

The group-wise structure provides a balance between the stability of batch normalization and the flexibility of layer normalization, making it valuable for many computer vision applications.