mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Extend Grouped GEMM with MultiD (Single & Double Shared Memory) feature to use persistent kernel option (#2933)
* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * style: clang format * WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults * feat(grouped_gemm_multi_d): add functionality to run persistant kernel * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * style: clang format * fix: incorrect validation method and Dtensor layout in test suite * docs: improved README text based on review comments * fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint
This commit is contained in:
@@ -10,16 +10,15 @@ The grouped GEMM examples include two advanced optimization features:
|
||||
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
|
||||
|
||||
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
|
||||
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
|
||||
- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration
|
||||
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
|
||||
- **Benefits**: Improved memory efficiency and reduced data movement
|
||||
|
||||
|
||||
#### Persistence Mode
|
||||
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
|
||||
|
||||
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
|
||||
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
|
||||
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
|
||||
|
||||
#### Multi-D Operations
|
||||
Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
|
||||
@@ -31,7 +30,8 @@ Multi-D operations extend the standard GEMM operation by supporting additional e
|
||||
- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call
|
||||
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
|
||||
|
||||
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
|
||||
Multi-D operations supports both persistence and non-persistence modes.
|
||||
Weight preshuffle supports only on non-persistence mode.
|
||||
|
||||
## Build
|
||||
```
|
||||
@@ -48,7 +48,7 @@ make tile_example_grouped_gemm_multi_d -j
|
||||
# The quant grouped gemm fp8 example
|
||||
make tile_example_quant_grouped_gemm -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
|
||||
Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
|
||||
|
||||
|
||||
## example
|
||||
|
||||
Reference in New Issue
Block a user