Extend Grouped GEMM with MultiD (Single & Double Shared Memory) feature to use persistent kernel option (#2933)

* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature

* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel

* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments

* fix: segfault fix by passing correct parameters for d tensors

* style: clang format

* WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults

* feat(grouped_gemm_multi_d): add functionality to run persistant kernel

* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature

* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel

* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments

* fix: segfault fix by passing correct parameters for d tensors

* style: clang format

* fix: incorrect validation method and Dtensor layout in test suite

* docs: improved README text based on review comments

* fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint
This commit is contained in:
Aviral Goel
2025-09-29 18:03:56 -04:00
committed by GitHub
parent 243118c275
commit bebf0e9d15
5 changed files with 163 additions and 19 deletions

View File

@@ -10,16 +10,15 @@ The grouped GEMM examples include two advanced optimization features:
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
- **Benefits**: Improved memory efficiency and reduced data movement
#### Persistence Mode
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
#### Multi-D Operations
Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
@@ -31,7 +30,8 @@ Multi-D operations extend the standard GEMM operation by supporting additional e
- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
Multi-D operations supports both persistence and non-persistence modes.
Weight preshuffle supports only on non-persistence mode.
## Build
```
@@ -48,7 +48,7 @@ make tile_example_grouped_gemm_multi_d -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
## example