mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Updates based on PR feedback 4
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
# Elementwise Normalization
|
||||
|
||||
This example demonstrates a fused **element-wise operation followed by normalization**. This pattern combines element-wise tensor arithmetic with a normalization operation in a single kernel, which is particularly useful for implementing custom normalization layers or fused activation-normalization blocks.
|
||||
This example demonstrates a fused **elementwise operation followed by normalization**. This pattern combines elementwise tensor arithmetic with a normalization operation in a single kernel, which is particularly useful for implementing custom normalization layers or fused activation-normalization blocks.
|
||||
|
||||
## Mathematical Formulation
|
||||
|
||||
The operation performs an element-wise computation followed by a normalization operation.
|
||||
The operation performs an elementwise computation followed by a normalization operation.
|
||||
|
||||
1. **Elementwise Stage**: An element-wise operation is applied to one or more input tensors.
|
||||
1. **Elementwise Stage**: An elementwise operation is applied to one or more input tensors.
|
||||
$C_{temp} = f(A, B, \dots)$
|
||||
Where `f` is a user-defined element-wise function that operates on corresponding elements of the input tensors.
|
||||
Where `f` is a user-defined elementwise function that operates on corresponding elements of the input tensors.
|
||||
|
||||
2. **Normalization Stage**: The result is then normalized. The normalization can be performed along specified dimensions.
|
||||
- **Compute Statistics**: For each normalization group, compute the mean and variance.
|
||||
@@ -19,31 +19,31 @@ The operation performs an element-wise computation followed by a normalization o
|
||||
- **Scale and Shift**: Apply learnable parameters.
|
||||
$D = \gamma \cdot \hat{C} + \beta$
|
||||
|
||||
The key optimization is that the intermediate tensor `C_temp` is **never written to global memory**. The element-wise computation feeds directly into the normalization calculation.
|
||||
The key optimization is that the intermediate tensor `C_temp` is **never written to global memory**. The elementwise computation feeds directly into the normalization calculation.
|
||||
|
||||
## Algorithmic Strategy: Fused Elementwise with Online Normalization
|
||||
|
||||
The implementation combines element-wise computation with an online normalization algorithm.
|
||||
The implementation combines elementwise computation with an online normalization algorithm.
|
||||
|
||||
1. **Grid Scheduling**: The normalization groups are distributed among thread blocks. Each block handles one or more normalization groups.
|
||||
|
||||
2. **Fused Two-Pass Algorithm**:
|
||||
- **Pass 1 - Compute Elementwise and Moments**:
|
||||
- Threads cooperatively load input tensors and apply the element-wise function `f`.
|
||||
- The element-wise results are kept in registers/shared memory.
|
||||
- **Welford's Algorithm**: Threads use Welford's online algorithm to compute the mean and variance of the element-wise results within their normalization group.
|
||||
- Threads cooperatively load input tensors and apply the elementwise function `f`.
|
||||
- The elementwise results are kept in registers/shared memory.
|
||||
- **Welford's Algorithm**: Threads use Welford's online algorithm to compute the mean and variance of the elementwise results within their normalization group.
|
||||
- **Intra-Block Reduction**: A parallel reduction in shared memory computes the final statistics for the group.
|
||||
- **Pass 2 - Normalize and Store**:
|
||||
- Using the computed statistics, threads apply the normalization formula to their element-wise results.
|
||||
- Using the computed statistics, threads apply the normalization formula to their elementwise results.
|
||||
- The final normalized result is written to the output tensor `D`.
|
||||
|
||||
This approach ensures that the element-wise computation is performed only once, and the results are immediately consumed by the normalization process without requiring additional memory bandwidth.
|
||||
This approach ensures that the elementwise computation is performed only once, and the results are immediately consumed by the normalization process without requiring additional memory bandwidth.
|
||||
|
||||
## Source Code Organization
|
||||
|
||||
- [`elementwise_normalization_xdl.cpp`](./elementwise_normalization_xdl.cpp): The main example file. It sets up the input tensors, defines the element-wise operation and normalization parameters, and instantiates the `DeviceElementwiseNormalization` operation.
|
||||
- [`../../include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp`](../../include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp): The high-level device interface for the fused element-wise normalization operation.
|
||||
- The underlying grid-wise kernel implements the complex fusion of element-wise operations with the two-pass normalization algorithm.
|
||||
- [`elementwise_normalization_xdl.cpp`](./elementwise_normalization_xdl.cpp): The main example file. It sets up the input tensors, defines the elementwise operation and normalization parameters, and instantiates the `DeviceElementwiseNormalization` operation.
|
||||
- [`../../include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp`](../../include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp): The high-level device interface for the fused elementwise normalization operation.
|
||||
- The underlying grid-wise kernel implements the complex fusion of elementwise operations with the two-pass normalization algorithm.
|
||||
|
||||
## Build and Run
|
||||
|
||||
@@ -81,6 +81,6 @@ make -j
|
||||
This fused operation is valuable for implementing custom normalization layers and optimizing activation-normalization sequences.
|
||||
|
||||
- **Custom Activation-Normalization Blocks**: Some architectures use non-standard activation functions followed by normalization. For example, a Swish activation followed by layer normalization can be fused into a single kernel using this pattern.
|
||||
- **Residual Connection with Normalization**: In some variants of residual networks, the residual addition is immediately followed by normalization. This can be expressed as an element-wise addition (residual) followed by normalization.
|
||||
- **Preprocessing Pipelines**: In data preprocessing, tensors might need element-wise transformations (e.g., color space conversion) followed by normalization (e.g., standardization). This kernel can fuse these operations.
|
||||
- **Research Architectures**: Novel normalization techniques often involve custom element-wise operations before the normalization step. This kernel provides a flexible foundation for implementing such research ideas efficiently.
|
||||
- **Residual Connection with Normalization**: In some variants of residual networks, the residual addition is immediately followed by normalization. This can be expressed as an elementwise addition (residual) followed by normalization.
|
||||
- **Preprocessing Pipelines**: In data preprocessing, tensors might need elementwise transformations (e.g., color space conversion) followed by normalization (e.g., standardization). This kernel can fuse these operations.
|
||||
- **Research Architectures**: Novel normalization techniques often involve custom elementwise operations before the normalization step. This kernel provides a flexible foundation for implementing such research ideas efficiently.
|
||||
|
||||
Reference in New Issue
Block a user