mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
feat(grouped_gemm): add preshuffle v2 support to grouped gemm example (#2721)
* docs(README): update readme with new build instructions * feat(grouped_gemm): add support back for non persistent kernel * refactor(grouped_gemm): simplify tensor creation * refactor(grouped_gemm): Persistance is now GemmConfig value for easier management * chore(grouped_gemm): add print statements to ease debugging * WIP(grouped_gemm): add grouped_gemm_preshuffle example and update CMake configuration * fix(tile_gemm_traits): change default value of Preshuffle_ from 0 to false for clarity * WIP(grouped_gemm): add dummy variables to compile the preshuffle pipelines * chore(grouped_gemm): add print statements and variables to debug numerical error with preshuffle * style: clang format work so far * BUG!(grouped_gemm_kernel.hpp): figured out a potential bug in for numerical errors in preshuffle pipeline * fix(grouped_gemm_kernel): add function in the kernel code to dynamically calculate tail_number resolving numerical errors * refactor(gemm_presuffle): make preshuffle pipeline v2 compatible with operator () calls from grouped gemm * chore(grouped_gemm): add/remove debug comments and debug print statements * feat(grouped_gemm): integrate preshuffle pipeline v2 into grouped gemm for all supported shapes * chore(gemm_profile): add new argument combinations * fix: branch cleanup, formatting, refactoring * fix: branch cleanup, formatting, refactoring * chore(changelog): update changelog to reflect new featuer * address review comments & nit
This commit is contained in:
@@ -8,11 +8,11 @@ The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operati
|
||||
|
||||
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
|
||||
|
||||
### Parsing Arguments
|
||||
The example takes three arguments: `group_count`, `repeat`, and `warmup`:
|
||||
- `group_count`: the number of GEMM operations in the group,
|
||||
### Key Arguments
|
||||
The example takes several arguments including `group_count`, `repeat`, and `warmup`:
|
||||
- `group_count`: the number of GEMM operations in the group
|
||||
- `repeat`: the number of times to repeat the kernel for benchmarking
|
||||
- `warmup`: the number of iterations before the actual kernel run time measure.
|
||||
- `warmup`: the number of iterations before the actual kernel run time measure
|
||||
|
||||
```cpp
|
||||
// Example
|
||||
@@ -133,6 +133,28 @@ float invoke_gemm(int n_warmup,
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(GetWorkspaceSize(args));
|
||||
```
|
||||
|
||||
### Advanced Features: Preshuffle and Persistence
|
||||
|
||||
The grouped GEMM examples include two advanced optimization features:
|
||||
|
||||
#### Weight Preshuffle
|
||||
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
|
||||
|
||||
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
|
||||
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
|
||||
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
|
||||
- **Benefits**: Improved memory efficiency and reduced data movement
|
||||
|
||||
#### Persistence Mode
|
||||
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
|
||||
|
||||
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
|
||||
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
|
||||
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
|
||||
|
||||
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
|
||||
|
||||
Finally the arguments are passed to group_gemm and the kernel is launched.
|
||||
```cpp
|
||||
// API
|
||||
@@ -151,26 +173,42 @@ mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_grouped_gemm -j
|
||||
# The preshuffle example
|
||||
make tile_example_grouped_gemm_preshuffle -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_grouped_gemm`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-Ms M dimensions - empty by default. (default:)
|
||||
-Ns N dimensions - empty by default. (default:)
|
||||
-Ks K dimensions - empty by default. (default:)
|
||||
-stride_As Tensor A strides - it is empty by default. (default:)
|
||||
-stride_Bs Tensor B strides - it is empty by default. (default:)
|
||||
-stride_Cs Tensor C strides - it is empty by default. (default:)
|
||||
-a_layout A tensor data layout - Row by default. (default:R)
|
||||
-b_layout B tensor data layout - Row by default. (default:C)
|
||||
-c_layout C tensor data layout - Row by default. (default:R)
|
||||
-validate 0. No validation, 1. Validation on CPU. (default:1)
|
||||
-warmup number of iterations before benchmark the kernel. (default:10)
|
||||
-repeat number of iterations to benchmark the kernel. (default:100)
|
||||
-group_count group count. (default:8)
|
||||
-kbatch kbatch for SplitK (default:1)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:grouped_gemm.json)
|
||||
-Ms M dimensions - (Default: empty).
|
||||
-Ns N dimensions - (Default: empty).
|
||||
-Ks K dimensions - (Default: empty).
|
||||
-stride_As Tensor A strides - (Default: empty).
|
||||
-stride_Bs Tensor B strides - (Default: empty).
|
||||
-stride_Cs Tensor C strides - (Default: empty).
|
||||
-a_layout A tensor data layout - (Default: Row).
|
||||
-b_layout B tensor data layout - (Default: Col).
|
||||
-c_layout C tensor data layout - (Default: Row).
|
||||
-prec data type. fp16/fp8 - (Default: fp16).
|
||||
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10).
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100).
|
||||
-group_count Group count. (Default: 16).
|
||||
-kbatch kbatch for SplitK (Default: 1).
|
||||
-json 0: No Json, 1: Dump Results in Json format (Default: 0).
|
||||
-jsonfile json file name to dump results (Default: grouped_gemm.json).
|
||||
```
|
||||
|
||||
If any of `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs`, or `stride_Cs` are missing or their sizes
|
||||
don't match `group_count`, the example generates defaults per group index `i` (0-based):
|
||||
|
||||
```text
|
||||
M[i] = 256 + 256 * i
|
||||
N[i] = 256 + 512 * i
|
||||
K[i] = 512 + 384 * i
|
||||
|
||||
stride_A[i] = K[i]
|
||||
stride_B[i] = K[i]
|
||||
stride_C[i] = N[i]
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user