Updates based on PR feedback 3

This commit is contained in:
Vidyasagar
2025-10-02 11:28:02 -07:00
parent 79d37b4d0b
commit 384dddddfe
11 changed files with 31 additions and 34 deletions

View File

@@ -2,7 +2,7 @@
## Theory
This client example demonstrates **softmax computation over 4D tensors**. Softmax is a key operation in deep learning, especially in attention mechanisms and classification, converting logits into normalized probabilities.
This client example demonstrates **Softmax computation over 4D tensors**. Softmax is a key operation in deep learning, especially in attention mechanisms and classification, converting logits into normalized probabilities.
**Mathematical Formulation:**
Given input $X$ and axis $a$:
@@ -15,8 +15,8 @@ $$
1. Subtract the maximum value for numerical stability.
2. Exponentiate and sum.
3. Normalize by the sum.
- Efficient parallel softmax requires careful reduction and memory access patterns.
- This example demonstrates softmax over a 4D tensor, as used in attention and vision models.
- Efficient parallel Softmax requires careful reduction and memory access patterns.
- This example demonstrates Softmax over a 4D tensor, as used in attention and vision models.
## How to Run
@@ -47,8 +47,8 @@ client_example/06_softmax/
### Key Functions
- **main()** (in `softmax4d.cpp`):
Sets up input tensors, configures softmax parameters, launches the softmax kernel, and verifies the result.
Sets up input tensors, configures Softmax parameters, launches the Softmax kernel, and verifies the result.
- **Softmax kernel invocation**:
Uses the Composable Kernel device API to launch the softmax operation.
Uses the Composable Kernel device API to launch the Softmax operation.
This client example provides a demonstration of efficient, numerically stable softmax for 4D tensors in deep learning models.
This client example provides a demonstration of efficient, numerically stable Softmax for 4D tensors in deep learning models.

View File

@@ -1,19 +1,19 @@
# Binary Element-wise Operations with Broadcasting
# Binary Elementwise Operations with Broadcasting
This example demonstrates a generic binary element-wise operation, a fundamental building block in numerical computing. It covers two important cases:
1. **Simple Element-wise**: Applying a binary function to two input tensors of the *same* shape.
2. **Element-wise with Broadcasting**: Applying a binary function to two input tensors of *different but compatible* shapes.
1. **Simple Elementwise**: Applying a binary function to two input tensors of the *same* shape.
2. **Elementwise with Broadcasting**: Applying a binary function to two input tensors of *different but compatible* shapes.
Broadcasting defines a set of rules for applying element-wise operations on tensors of different sizes, and it is a cornerstone of libraries like NumPy and TensorFlow.
## Mathematical Formulation
### Simple Element-wise
### Simple Elementwise
Given two input tensors, A and B, of the same rank and dimensions, and a binary operator $\odot$, the operation computes an output tensor C where each element is:
$C_{i,j,k,\dots} = A_{i,j,k,\dots} \odot B_{i,j,k,\dots}$
### Element-wise with Broadcasting
### Elementwise with Broadcasting
Broadcasting allows element-wise operations on tensors with different shapes, provided they are compatible. Two dimensions are compatible if they are equal, or if one of them is 1. The operation implicitly "stretches" or "duplicates" the tensor with the dimension of size 1 to match the other tensor's shape.
For example, adding a bias vector `B` of shape `(1, N)` to a matrix `A` of shape `(M, N)`:
@@ -65,11 +65,8 @@ This example contains multiple files to demonstrate different scenarios:
## Build and Run
### Prerequisites
Ensure the Composable Kernel library is built and installed.
```bash
cd /path/to/composable_kernel/build
make -j install
```
Please follow the instructions in the main [Build Guide](../../README.md#building-ck) section as a prerequisite to building and running this example.
### Build the Example
```bash

View File

@@ -1,4 +1,4 @@
# Grouped GEMM with Bias, Element-wise Operation, and Permutation
# Grouped GEMM with Bias, Elementwise Operation, and Permutation
This example demonstrates a highly complex and specialized fusion: a **Grouped GEMM** where each individual GEMM operation is fused with a bias addition, a second element-wise operation, and a final permutation of the output. This kernel is designed to accelerate layers that have a group-parallel structure, such as depthwise separable convolutions or multi-head attention, when they are part of a larger fused computational graph.
@@ -12,7 +12,7 @@ This operation performs `G` independent fused GEMM operations in parallel, where
2. **Bias Addition Stage**: A bias vector `D_[g]` is broadcast and added.
$C_{temp2[g]} = C_{temp1[g]} + D_{[g]}$
3. **Element-wise Stage**: A second element-wise operation is performed with tensor `E_[g]`.
3. **Elementwise Stage**: A second element-wise operation is performed with tensor `E_[g]`.
$C_{temp3[g]} = C_{temp2[g]} \odot E_{[g]}$
4. **Permutation Stage**: The final result for the group is permuted.

View File

@@ -1,4 +1,4 @@
# Batched GEMM with Bias, Element-wise Operation, and Permutation
# Batched GEMM with Bias, Elementwise Operation, and Permutation
This example demonstrates a **Batched GEMM** where each individual GEMM operation is fused with a bias addition, a second element-wise operation, and a final permutation of the output. This kernel is designed to accelerate layers that have a batch-parallel structure, such as the dense layers in a Transformer's feed-forward network, when they are part of a larger fused computational graph.
@@ -12,7 +12,7 @@ This operation performs `B` independent fused GEMM operations in parallel, where
2. **Bias Addition Stage**: A bias vector `D_[b]` is broadcast and added.
$C_{temp2[b]} = C_{temp1[b]} + D_{[b]}$
3. **Element-wise Stage**: A second element-wise operation is performed with tensor `E_[b]`.
3. **Elementwise Stage**: A second element-wise operation is performed with tensor `E_[b]`.
$C_{temp3[b]} = C_{temp2[b]} \odot E_{[b]}$
4. **Permutation Stage**: The final result for the batch item is permuted.

View File

@@ -1,4 +1,4 @@
# Grouped Convolution Forward with Multiple Element-wise Inputs
# Grouped Convolution Forward with Multiple Elementwise Inputs
This example demonstrates a **Grouped Convolution Forward Pass** fused with an element-wise operation that takes multiple auxiliary input tensors (`D` tensors). This is a powerful fusion that combines the parallel structure of grouped convolutions with the ability to merge subsequent element-wise layers, such as custom activations or residual connections, into a single kernel.
@@ -9,7 +9,7 @@ This operation performs `G` independent fused convolution operations in parallel
1. **Convolution Stage**: A standard N-dimensional forward convolution is performed for the group.
$C_{out[g]} = \text{Conv}(\text{In}_{[g]}, \text{W}_{[g]})$
2. **Element-wise Stage**: The result of the convolution is combined with one or more auxiliary tensors ($D_{0[g]}, D_{1[g]}, \dots$) using a user-defined element-wise function `f`.
2. **Elementwise Stage**: The result of the convolution is combined with one or more auxiliary tensors ($D_{0[g]}, D_{1[g]}, \dots$) using a user-defined element-wise function `f`.
$E_{[g]} = f(C_{out[g]}, D_{0[g]}, D_{1[g]}, \dots)$
The key optimization is that the intermediate convolution result, $C_{out[g]}$, is never written to global memory. It is computed and held in registers, then immediately consumed by the element-wise part of the kernel's epilogue before the final result `E` is stored.

View File

@@ -1,4 +1,4 @@
# Grouped Convolution Backward Data with Multiple Element-wise Inputs
# Grouped Convolution Backward Data with Multiple Elementwise Inputs
This example demonstrates a **Grouped Convolution Backward Data Pass** fused with an element-wise operation that takes multiple auxiliary input tensors (`D` tensors). The backward data pass (also known as a transposed convolution or deconvolution) computes the gradient of the loss with respect to the convolution's *input* tensor. Fusing it with other operations is a powerful way to optimize the backward pass of a neural network.
@@ -10,7 +10,7 @@ The operation computes the gradient with respect to the input (`GradIn`) of a gr
$GradIn_{temp[g]} = \text{ConvBwdData}(\text{GradOut}_{[g]}, \text{W}_{[g]})$
Where `GradOut` is the gradient from the subsequent layer and `W` is the weight tensor from the forward pass.
2. **Element-wise Stage**: The result of the backward convolution is combined with one or more auxiliary tensors ($D_{0[g]}, D_{1[g]}, \dots$) using a user-defined element-wise function `f`.
2. **Elementwise Stage**: The result of the backward convolution is combined with one or more auxiliary tensors ($D_{0[g]}, D_{1[g]}, \dots$) using a user-defined element-wise function `f`.
$GradIn_{[g]} = f(GradIn_{temp[g]}, D_{0[g]}, D_{1[g]}, \dots)$
This fusion is particularly useful for operations like adding the gradient from a residual "skip" connection, which is a common pattern in modern network architectures. By fusing the addition, we avoid a separate kernel launch and a full read/write pass of the `GradIn` tensor.

View File

@@ -1,4 +1,4 @@
# Split-K GEMM with Bias, Element-wise Operation, and Permutation
# Split-K GEMM with Bias, Elementwise Operation, and Permutation
This example demonstrates a highly complex fusion: a **Split-K GEMM** where the final result is fused with a bias addition, a second element-wise operation, and a final permutation. This kernel combines the parallelism-enhancing Split-K strategy with a multi-stage epilogue, making it suitable for accelerating very large or "skinny" GEMMs that are part of a more complex computational graph.
@@ -12,7 +12,7 @@ The operation first computes a GEMM using the Split-K algorithm and then applies
2. **Bias Addition Stage**: A bias vector `D` is broadcast and added.
$C_{temp2} = C_{temp1} + D$
3. **Element-wise Stage**: A second element-wise operation is performed with tensor `E`.
3. **Elementwise Stage**: A second element-wise operation is performed with tensor `E`.
$C_{temp3} = C_{temp2} \odot E$
4. **Permutation Stage**: The final result is permuted.

View File

@@ -1,4 +1,4 @@
# Element-wise Normalization
# Elementwise Normalization
This example demonstrates a fused **element-wise operation followed by normalization**. This pattern combines element-wise tensor arithmetic with a normalization operation in a single kernel, which is particularly useful for implementing custom normalization layers or fused activation-normalization blocks.
@@ -6,7 +6,7 @@ This example demonstrates a fused **element-wise operation followed by normaliza
The operation performs an element-wise computation followed by a normalization operation.
1. **Element-wise Stage**: An element-wise operation is applied to one or more input tensors.
1. **Elementwise Stage**: An element-wise operation is applied to one or more input tensors.
$C_{temp} = f(A, B, \dots)$
Where `f` is a user-defined element-wise function that operates on corresponding elements of the input tensors.
@@ -21,14 +21,14 @@ The operation performs an element-wise computation followed by a normalization o
The key optimization is that the intermediate tensor `C_temp` is **never written to global memory**. The element-wise computation feeds directly into the normalization calculation.
## Algorithmic Strategy: Fused Element-wise with Online Normalization
## Algorithmic Strategy: Fused Elementwise with Online Normalization
The implementation combines element-wise computation with an online normalization algorithm.
1. **Grid Scheduling**: The normalization groups are distributed among thread blocks. Each block handles one or more normalization groups.
2. **Fused Two-Pass Algorithm**:
- **Pass 1 - Compute Element-wise and Moments**:
- **Pass 1 - Compute Elementwise and Moments**:
- Threads cooperatively load input tensors and apply the element-wise function `f`.
- The element-wise results are kept in registers/shared memory.
- **Welford's Algorithm**: Threads use Welford's online algorithm to compute the mean and variance of the element-wise results within their normalization group.

View File

@@ -31,7 +31,7 @@ The implementation uses the implicit GEMM algorithm for convolution with the act
- **Output Accumulation**: Results are accumulated in registers as standard GEMM tiles.
2. **Fused Activation Epilogue**: Before storing results to global memory:
- **Element-wise Activation**: Apply the activation function to each element in the accumulated tile.
- **Elementwise Activation**: Apply the activation function to each element in the accumulated tile.
- **Vectorized Operations**: Use vectorized instructions where possible for activation computation.
- **Store Activated Result**: Write the final activated output directly to global memory.

View File

@@ -37,7 +37,7 @@ The implementation treats this as a parallel reduction problem with spatial aggr
- **Intra-Block Reduction**: Threads perform parallel reduction using shared memory to compute the final statistics for each batch item.
3. **Normalization and Scale/Shift**:
- **Element-wise Processing**: Each thread processes one or more elements of the batch item.
- **Elementwise Processing**: Each thread processes one or more elements of the batch item.
- **Apply Normalization**: Use the computed mean and variance to normalize each element.
- **Apply Scale/Shift**: Apply the appropriate `gamma` and `beta` values based on the parameterization choice.
- **Store Result**: Write the final normalized result to the output tensor.

View File

@@ -9,10 +9,10 @@ The operation performs a matrix multiplication followed by two sequential elemen
1. **GEMM Stage**: A standard matrix multiplication.
$C_{temp1} = A \times B$
2. **First Multiplication**: Element-wise multiplication with tensor `D`.
2. **First Multiplication**: Elementwise multiplication with tensor `D`.
$C_{temp2} = C_{temp1} \odot D$
3. **Second Multiplication**: Element-wise multiplication with tensor `E`.
3. **Second Multiplication**: Elementwise multiplication with tensor `E`.
$F = C_{temp2} \odot E$
The key optimization is that the intermediate tensors `C_temp1` and `C_temp2` are **never written to global memory**. All operations are fused into the GEMM's epilogue, operating on data held in registers.