feat(grouped_gemm): add preshuffle v2 support to grouped gemm example (#2721)

* docs(README): update readme with new build instructions

* feat(grouped_gemm): add support back for non persistent kernel

* refactor(grouped_gemm): simplify tensor creation

* refactor(grouped_gemm): Persistance is now GemmConfig value for easier management

* chore(grouped_gemm): add print statements to ease debugging

* WIP(grouped_gemm): add grouped_gemm_preshuffle example and update CMake configuration

* fix(tile_gemm_traits): change default value of Preshuffle_ from 0 to false for clarity

* WIP(grouped_gemm): add dummy variables to compile the preshuffle pipelines

* chore(grouped_gemm): add print statements and variables to debug numerical error with preshuffle

* style: clang format work so far

* BUG!(grouped_gemm_kernel.hpp): figured out a potential bug in for numerical errors in preshuffle pipeline

* fix(grouped_gemm_kernel): add function in the kernel code to dynamically calculate tail_number resolving numerical errors

* refactor(gemm_presuffle): make preshuffle pipeline v2 compatible with operator () calls from grouped gemm

* chore(grouped_gemm): add/remove debug comments and debug print statements

* feat(grouped_gemm): integrate preshuffle pipeline v2 into grouped gemm for all supported shapes

* chore(gemm_profile): add new argument combinations

* fix: branch cleanup, formatting, refactoring

* fix: branch cleanup, formatting, refactoring

* chore(changelog):  update changelog to reflect new featuer

* address review comments & nit
This commit is contained in:
Aviral Goel
2025-09-07 17:18:35 -04:00
committed by GitHub
parent 5224d2ead3
commit e279e9420e
13 changed files with 808 additions and 178 deletions

View File

@@ -8,11 +8,11 @@ The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operati
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
### Parsing Arguments
The example takes three arguments: `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group,
### Key Arguments
The example takes several arguments including `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group
- `repeat`: the number of times to repeat the kernel for benchmarking
- `warmup`: the number of iterations before the actual kernel run time measure.
- `warmup`: the number of iterations before the actual kernel run time measure
```cpp
// Example
@@ -133,6 +133,28 @@ float invoke_gemm(int n_warmup,
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
```
### Advanced Features: Preshuffle and Persistence
The grouped GEMM examples include two advanced optimization features:
#### Weight Preshuffle
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
- **Benefits**: Improved memory efficiency and reduced data movement
#### Persistence Mode
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
Finally the arguments are passed to group_gemm and the kernel is launched.
```cpp
// API
@@ -151,26 +173,42 @@ mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_grouped_gemm -j
# The preshuffle example
make tile_example_grouped_gemm_preshuffle -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`
## example
```
args:
-Ms M dimensions - empty by default. (default:)
-Ns N dimensions - empty by default. (default:)
-Ks K dimensions - empty by default. (default:)
-stride_As Tensor A strides - it is empty by default. (default:)
-stride_Bs Tensor B strides - it is empty by default. (default:)
-stride_Cs Tensor C strides - it is empty by default. (default:)
-a_layout A tensor data layout - Row by default. (default:R)
-b_layout B tensor data layout - Row by default. (default:C)
-c_layout C tensor data layout - Row by default. (default:R)
-validate 0. No validation, 1. Validation on CPU. (default:1)
-warmup number of iterations before benchmark the kernel. (default:10)
-repeat number of iterations to benchmark the kernel. (default:100)
-group_count group count. (default:8)
-kbatch kbatch for SplitK (default:1)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:grouped_gemm.json)
-Ms M dimensions - (Default: empty).
-Ns N dimensions - (Default: empty).
-Ks K dimensions - (Default: empty).
-stride_As Tensor A strides - (Default: empty).
-stride_Bs Tensor B strides - (Default: empty).
-stride_Cs Tensor C strides - (Default: empty).
-a_layout A tensor data layout - (Default: Row).
-b_layout B tensor data layout - (Default: Col).
-c_layout C tensor data layout - (Default: Row).
-prec data type. fp16/fp8 - (Default: fp16).
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
-warmup Number of iterations before benchmark the kernel. (Default: 10).
-repeat Number of iterations to benchmark the kernel. (Default: 100).
-group_count Group count. (Default: 16).
-kbatch kbatch for SplitK (Default: 1).
-json 0: No Json, 1: Dump Results in Json format (Default: 0).
-jsonfile json file name to dump results (Default: grouped_gemm.json).
```
If any of `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs`, or `stride_Cs` are missing or their sizes
don't match `group_count`, the example generates defaults per group index `i` (0-based):
```text
M[i] = 256 + 256 * i
N[i] = 256 + 512 * i
K[i] = 512 + 384 * i
stride_A[i] = K[i]
stride_B[i] = K[i]
stride_C[i] = N[i]
```