mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
116 lines
5.7 KiB
Markdown
116 lines
5.7 KiB
Markdown
# Group Normalization Backward
|
|
|
|
This example demonstrates the **backward pass of Group Normalization**. This operation computes the gradients of the loss with respect to the input, gamma, and beta parameters of a group normalization layer, which is essential for training neural networks that use group normalization, particularly in computer vision applications where batch size independence is important.
|
|
|
|
## Mathematical Formulation
|
|
|
|
The backward pass of group normalization involves computing gradients for three components: input `X`, scale parameter `gamma`, and shift parameter `beta`.
|
|
|
|
Given:
|
|
- Input tensor `X` with shape `[N, C, H, W]`
|
|
- Number of groups `G` (where `C` must be divisible by `G`)
|
|
- Scale parameter `gamma` with shape `[C]`
|
|
- Shift parameter `beta` with shape `[C]`
|
|
- Output gradients `dL/dY` with shape `[N, C, H, W]`
|
|
|
|
From the forward pass, for each batch item `n` and group `g`:
|
|
- Channels in group: $S_g = \{c : c \text{ belongs to group } g\}$ where $|S_g| = C/G$
|
|
- Mean: $\mu_{ng} = \frac{1}{(C/G) \cdot H \cdot W} \sum_{c \in S_g} \sum_{h,w} X_{nchw}$
|
|
- Variance: $\sigma_{ng}^2 = \frac{1}{(C/G) \cdot H \cdot W} \sum_{c \in S_g} \sum_{h,w} (X_{nchw} - \mu_{ng})^2$
|
|
- Normalized: $\hat{X}_{nchw} = \frac{X_{nchw} - \mu_{ng}}{\sqrt{\sigma_{ng}^2 + \epsilon}}$ for $c \in S_g$
|
|
- Output: $Y_{nchw} = \gamma_c \cdot \hat{X}_{nchw} + \beta_c$
|
|
|
|
### Gradient Computations
|
|
|
|
**Gradient w.r.t. beta**:
|
|
$\frac{\partial L}{\partial \beta_c} = \sum_{n,h,w} \frac{\partial L}{\partial Y_{nchw}}$
|
|
|
|
**Gradient w.r.t. gamma**:
|
|
$\frac{\partial L}{\partial \gamma_c} = \sum_{n,h,w} \frac{\partial L}{\partial Y_{nchw}} \cdot \hat{X}_{nchw}$
|
|
|
|
**Gradient w.r.t. input** (most complex):
|
|
For channel `c` in group `g`:
|
|
$\frac{\partial L}{\partial X_{nchw}} = \frac{\gamma_c}{\sqrt{\sigma_{ng}^2 + \epsilon}} \left[ \frac{\partial L}{\partial Y_{nchw}} - \frac{1}{|S_g| \cdot H \cdot W}\left(\sum_{c' \in S_g} \frac{\partial L}{\partial \beta_{c'}} + \hat{X}_{nchw} \sum_{c' \in S_g} \frac{\partial L}{\partial \gamma_{c'}}\right) \right]$
|
|
|
|
## Algorithmic Strategy: Multi-Stage Group-wise Gradient Computation
|
|
|
|
The backward pass requires coordinated computation across groups with multiple reduction operations.
|
|
|
|
1. **Pass 1: Compute Gamma and Beta Gradients**
|
|
- **Grid Scheduling**: Parallelize over channels (`C` dimension).
|
|
- **Reduction per Channel**: For each channel `c`, reduce across `N`, `H`, `W` dimensions:
|
|
- `grad_beta[c] = sum(grad_output[n, c, h, w])` over all `n, h, w`
|
|
- `grad_gamma[c] = sum(grad_output[n, c, h, w] * x_normalized[n, c, h, w])` over all `n, h, w`
|
|
|
|
2. **Pass 2: Compute Group-wise Intermediate Values**
|
|
- **Grid Scheduling**: Parallelize over `(N, G)` pairs.
|
|
- **Group Reduction**: For each `(n, g)` pair:
|
|
- Sum `grad_beta` values for channels in group `g`
|
|
- Sum `grad_gamma` values for channels in group `g`
|
|
- These values are needed for the input gradient computation
|
|
|
|
3. **Pass 3: Compute Input Gradients**
|
|
- **Grid Scheduling**: Parallelize over input tensor elements.
|
|
- **Per-Element Computation**: For each `(n, c, h, w)`:
|
|
- Identify which group `g` channel `c` belongs to
|
|
- Read the group-wise intermediate values from Pass 2
|
|
- Apply the complex input gradient formula
|
|
|
|
## Source Code Organization
|
|
|
|
- [`groupnorm_bwd_xdl.cpp`](./groupnorm_bwd_xdl.cpp): The main example file. It sets up the forward pass results, output gradients, group configuration, and instantiates the `DeviceGroupnormBwd` operation.
|
|
- [`../../include/ck/tensor_operation/gpu/device/device_groupnorm_bwd.hpp`](../../include/ck/tensor_operation/gpu/device/device_groupnorm_bwd.hpp): The high-level device interface for group normalization backward operations.
|
|
- The underlying implementation coordinates multiple reduction and computation stages to efficiently handle the group-wise structure of the gradients.
|
|
|
|
## Build and Run
|
|
|
|
### Prerequisites
|
|
Ensure the Composable Kernel library is built and installed.
|
|
```bash
|
|
cd /path/to/composable_kernel/build
|
|
make -j install
|
|
```
|
|
|
|
### Build the Example
|
|
```bash
|
|
cd /path/to/composable_kernel/example/54_groupnorm_bwd
|
|
mkdir build && cd build
|
|
|
|
cmake \
|
|
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
|
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
|
|
..
|
|
|
|
make -j
|
|
```
|
|
|
|
### Run the Example
|
|
```bash
|
|
# Run the example with default settings
|
|
./groupnorm_bwd_xdl
|
|
|
|
# Run with verification, data initialization, and timing
|
|
./groupnorm_bwd_xdl 1 2 1
|
|
```
|
|
|
|
## Comparison with Other Normalization Backward Passes
|
|
|
|
| Normalization Type | Gradient Scope | Complexity | Memory Pattern |
|
|
|-------------------|----------------|------------|----------------|
|
|
| **BatchNorm** | Across batch for each channel | Medium | Channel-wise reductions |
|
|
| **LayerNorm** | Across features for each item | Medium | Per-sample reductions |
|
|
| **GroupNorm** | Across group for each (batch, group) | High | Group-wise reductions |
|
|
| **InstanceNorm** | Per channel per sample | Low | Independent computations |
|
|
|
|
## Applications in Computer Vision
|
|
|
|
Group normalization backward is particularly important for:
|
|
|
|
- **Small Batch Training**: When batch sizes are too small for effective batch normalization
|
|
- **Transfer Learning**: Fine-tuning pre-trained models with different batch sizes
|
|
- **Object Detection**: Models like YOLO and R-CNN that benefit from batch-size independent normalization
|
|
- **Segmentation Networks**: Dense prediction tasks where normalization stability is crucial
|
|
- **Style Transfer**: Applications where group-wise feature normalization helps preserve style information
|
|
|
|
The group-wise structure provides a balance between the stability of batch normalization and the flexibility of layer normalization, making it valuable for many computer vision applications.
|