Updates based on PR feedback 3

This commit is contained in:
Vidyasagar
2025-10-02 11:28:02 -07:00
parent 79d37b4d0b
commit 384dddddfe
11 changed files with 31 additions and 34 deletions

View File

@@ -1,4 +1,4 @@
# Element-wise Normalization
# Elementwise Normalization
This example demonstrates a fused **element-wise operation followed by normalization**. This pattern combines element-wise tensor arithmetic with a normalization operation in a single kernel, which is particularly useful for implementing custom normalization layers or fused activation-normalization blocks.
@@ -6,7 +6,7 @@ This example demonstrates a fused **element-wise operation followed by normaliza
The operation performs an element-wise computation followed by a normalization operation.
1. **Element-wise Stage**: An element-wise operation is applied to one or more input tensors.
1. **Elementwise Stage**: An element-wise operation is applied to one or more input tensors.
$C_{temp} = f(A, B, \dots)$
Where `f` is a user-defined element-wise function that operates on corresponding elements of the input tensors.
@@ -21,14 +21,14 @@ The operation performs an element-wise computation followed by a normalization o
The key optimization is that the intermediate tensor `C_temp` is **never written to global memory**. The element-wise computation feeds directly into the normalization calculation.
## Algorithmic Strategy: Fused Element-wise with Online Normalization
## Algorithmic Strategy: Fused Elementwise with Online Normalization
The implementation combines element-wise computation with an online normalization algorithm.
1. **Grid Scheduling**: The normalization groups are distributed among thread blocks. Each block handles one or more normalization groups.
2. **Fused Two-Pass Algorithm**:
- **Pass 1 - Compute Element-wise and Moments**:
- **Pass 1 - Compute Elementwise and Moments**:
- Threads cooperatively load input tensors and apply the element-wise function `f`.
- The element-wise results are kept in registers/shared memory.
- **Welford's Algorithm**: Threads use Welford's online algorithm to compute the mean and variance of the element-wise results within their normalization group.