Files
composable_kernel/example/54_groupnorm_bwd/README.md
2025-10-16 10:13:27 +00:00

116 lines
5.7 KiB
Markdown

# Group Normalization Backward
This example demonstrates the **backward pass of Group Normalization**. This operation computes the gradients of the loss with respect to the input, gamma, and beta parameters of a group normalization layer, which is essential for training neural networks that use group normalization, particularly in computer vision applications where batch size independence is important.
## Mathematical Formulation
The backward pass of group normalization involves computing gradients for three components: input `X`, scale parameter `gamma`, and shift parameter `beta`.
Given:
- Input tensor `X` with shape `[N, C, H, W]`
- Number of groups `G` (where `C` must be divisible by `G`)
- Scale parameter `gamma` with shape `[C]`
- Shift parameter `beta` with shape `[C]`
- Output gradients `dL/dY` with shape `[N, C, H, W]`
From the forward pass, for each batch item `n` and group `g`:
- Channels in group: $S_g = \{c : c \text{ belongs to group } g\}$ where $|S_g| = C/G$
- Mean: $\mu_{ng} = \frac{1}{(C/G) \cdot H \cdot W} \sum_{c \in S_g} \sum_{h,w} X_{nchw}$
- Variance: $\sigma_{ng}^2 = \frac{1}{(C/G) \cdot H \cdot W} \sum_{c \in S_g} \sum_{h,w} (X_{nchw} - \mu_{ng})^2$
- Normalized: $\hat{X}_{nchw} = \frac{X_{nchw} - \mu_{ng}}{\sqrt{\sigma_{ng}^2 + \epsilon}}$ for $c \in S_g$
- Output: $Y_{nchw} = \gamma_c \cdot \hat{X}_{nchw} + \beta_c$
### Gradient Computations
**Gradient w.r.t. beta**:
$\frac{\partial L}{\partial \beta_c} = \sum_{n,h,w} \frac{\partial L}{\partial Y_{nchw}}$
**Gradient w.r.t. gamma**:
$\frac{\partial L}{\partial \gamma_c} = \sum_{n,h,w} \frac{\partial L}{\partial Y_{nchw}} \cdot \hat{X}_{nchw}$
**Gradient w.r.t. input** (most complex):
For channel `c` in group `g`:
$\frac{\partial L}{\partial X_{nchw}} = \frac{\gamma_c}{\sqrt{\sigma_{ng}^2 + \epsilon}} \left[ \frac{\partial L}{\partial Y_{nchw}} - \frac{1}{|S_g| \cdot H \cdot W}\left(\sum_{c' \in S_g} \frac{\partial L}{\partial \beta_{c'}} + \hat{X}_{nchw} \sum_{c' \in S_g} \frac{\partial L}{\partial \gamma_{c'}}\right) \right]$
## Algorithmic Strategy: Multi-Stage Group-wise Gradient Computation
The backward pass requires coordinated computation across groups with multiple reduction operations.
1. **Pass 1: Compute Gamma and Beta Gradients**
- **Grid Scheduling**: Parallelize over channels (`C` dimension).
- **Reduction per Channel**: For each channel `c`, reduce across `N`, `H`, `W` dimensions:
- `grad_beta[c] = sum(grad_output[n, c, h, w])` over all `n, h, w`
- `grad_gamma[c] = sum(grad_output[n, c, h, w] * x_normalized[n, c, h, w])` over all `n, h, w`
2. **Pass 2: Compute Group-wise Intermediate Values**
- **Grid Scheduling**: Parallelize over `(N, G)` pairs.
- **Group Reduction**: For each `(n, g)` pair:
- Sum `grad_beta` values for channels in group `g`
- Sum `grad_gamma` values for channels in group `g`
- These values are needed for the input gradient computation
3. **Pass 3: Compute Input Gradients**
- **Grid Scheduling**: Parallelize over input tensor elements.
- **Per-Element Computation**: For each `(n, c, h, w)`:
- Identify which group `g` channel `c` belongs to
- Read the group-wise intermediate values from Pass 2
- Apply the complex input gradient formula
## Source Code Organization
- [`groupnorm_bwd_xdl.cpp`](./groupnorm_bwd_xdl.cpp): The main example file. It sets up the forward pass results, output gradients, group configuration, and instantiates the `DeviceGroupnormBwd` operation.
- [`../../include/ck/tensor_operation/gpu/device/device_groupnorm_bwd.hpp`](../../include/ck/tensor_operation/gpu/device/device_groupnorm_bwd.hpp): The high-level device interface for group normalization backward operations.
- The underlying implementation coordinates multiple reduction and computation stages to efficiently handle the group-wise structure of the gradients.
## Build and Run
### Prerequisites
Ensure the Composable Kernel library is built and installed.
```bash
cd /path/to/composable_kernel/build
make -j install
```
### Build the Example
```bash
cd /path/to/composable_kernel/example/54_groupnorm_bwd
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
```
### Run the Example
```bash
# Run the example with default settings
./groupnorm_bwd_xdl
# Run with verification, data initialization, and timing
./groupnorm_bwd_xdl 1 2 1
```
## Comparison with Other Normalization Backward Passes
| Normalization Type | Gradient Scope | Complexity | Memory Pattern |
|-------------------|----------------|------------|----------------|
| **BatchNorm** | Across batch for each channel | Medium | Channel-wise reductions |
| **LayerNorm** | Across features for each item | Medium | Per-sample reductions |
| **GroupNorm** | Across group for each (batch, group) | High | Group-wise reductions |
| **InstanceNorm** | Per channel per sample | Low | Independent computations |
## Applications in Computer Vision
Group normalization backward is particularly important for:
- **Small Batch Training**: When batch sizes are too small for effective batch normalization
- **Transfer Learning**: Fine-tuning pre-trained models with different batch sizes
- **Object Detection**: Models like YOLO and R-CNN that benefit from batch-size independent normalization
- **Segmentation Networks**: Dense prediction tasks where normalization stability is crucial
- **Style Transfer**: Applications where group-wise feature normalization helps preserve style information
The group-wise structure provides a balance between the stability of batch normalization and the flexibility of layer normalization, making it valuable for many computer vision applications.