mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
feat: add support for bf16 for grouped_gemm & grouped_gemm_preshuffle… (#3225)
* feat: add support for bf16 for grouped_gemm & grouped_gemm_preshuffle kernel(s) along with unit test * docs: Update CHANGELOG.MD
This commit is contained in:
@@ -26,7 +26,7 @@ Multi-D operations extend the standard GEMM operation by supporting additional e
|
||||
- **Implementation**: Available in `grouped_gemm_multi_d.cpp`
|
||||
- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result)
|
||||
- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors
|
||||
- **Data Types**: Supports fp16, fp8
|
||||
- **Data Types**: Supports fp16, bf16, fp8
|
||||
- **Benefits**: Enables complex operations like scaling, activation functions, or other elementwise transformations in a single kernel call
|
||||
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
|
||||
|
||||
@@ -61,7 +61,7 @@ args:
|
||||
-a_layout A tensor data layout - (Default: Row).
|
||||
-b_layout B tensor data layout - (Default: Col).
|
||||
-c_layout C tensor data layout - (Default: Row).
|
||||
-prec data type. fp16/fp8 - (Default: fp16).
|
||||
-prec data type. fp16/bf16/fp8 - (Default: fp16).
|
||||
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10).
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100).
|
||||
|
||||
Reference in New Issue
Block a user