mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Updates based on PR feedback 3
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Element-wise Normalization
|
||||
# Elementwise Normalization
|
||||
|
||||
This example demonstrates a fused **element-wise operation followed by normalization**. This pattern combines element-wise tensor arithmetic with a normalization operation in a single kernel, which is particularly useful for implementing custom normalization layers or fused activation-normalization blocks.
|
||||
|
||||
@@ -6,7 +6,7 @@ This example demonstrates a fused **element-wise operation followed by normaliza
|
||||
|
||||
The operation performs an element-wise computation followed by a normalization operation.
|
||||
|
||||
1. **Element-wise Stage**: An element-wise operation is applied to one or more input tensors.
|
||||
1. **Elementwise Stage**: An element-wise operation is applied to one or more input tensors.
|
||||
$C_{temp} = f(A, B, \dots)$
|
||||
Where `f` is a user-defined element-wise function that operates on corresponding elements of the input tensors.
|
||||
|
||||
@@ -21,14 +21,14 @@ The operation performs an element-wise computation followed by a normalization o
|
||||
|
||||
The key optimization is that the intermediate tensor `C_temp` is **never written to global memory**. The element-wise computation feeds directly into the normalization calculation.
|
||||
|
||||
## Algorithmic Strategy: Fused Element-wise with Online Normalization
|
||||
## Algorithmic Strategy: Fused Elementwise with Online Normalization
|
||||
|
||||
The implementation combines element-wise computation with an online normalization algorithm.
|
||||
|
||||
1. **Grid Scheduling**: The normalization groups are distributed among thread blocks. Each block handles one or more normalization groups.
|
||||
|
||||
2. **Fused Two-Pass Algorithm**:
|
||||
- **Pass 1 - Compute Element-wise and Moments**:
|
||||
- **Pass 1 - Compute Elementwise and Moments**:
|
||||
- Threads cooperatively load input tensors and apply the element-wise function `f`.
|
||||
- The element-wise results are kept in registers/shared memory.
|
||||
- **Welford's Algorithm**: Threads use Welford's online algorithm to compute the mean and variance of the element-wise results within their normalization group.
|
||||
|
||||
Reference in New Issue
Block a user