feat: add support for bf16 for grouped_gemm & grouped_gemm_preshuffle… (#3225)

* feat: add support for bf16 for grouped_gemm & grouped_gemm_preshuffle kernel(s) along with unit test

* docs: Update CHANGELOG.MD

[ROCm/composable_kernel commit: ac70206b2c]
This commit is contained in:
Aviral Goel
2025-11-18 09:32:27 -05:00
committed by GitHub
parent acb3b43bc0
commit a07cd6bc71
7 changed files with 48 additions and 4 deletions

View File

@@ -26,7 +26,7 @@ Multi-D operations extend the standard GEMM operation by supporting additional e
- **Implementation**: Available in `grouped_gemm_multi_d.cpp`
- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result)
- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors
- **Data Types**: Supports fp16, fp8
- **Data Types**: Supports fp16, bf16, fp8
- **Benefits**: Enables complex operations like scaling, activation functions, or other elementwise transformations in a single kernel call
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
@@ -61,7 +61,7 @@ args:
-a_layout A tensor data layout - (Default: Row).
-b_layout B tensor data layout - (Default: Col).
-c_layout C tensor data layout - (Default: Row).
-prec data type. fp16/fp8 - (Default: fp16).
-prec data type. fp16/bf16/fp8 - (Default: fp16).
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
-warmup Number of iterations before benchmark the kernel. (Default: 10).
-repeat Number of iterations to benchmark the kernel. (Default: 100).

View File

@@ -333,6 +333,11 @@ int run_grouped_gemm_example(int argc, char* argv[])
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(

View File

@@ -66,6 +66,15 @@ struct GemmTypeConfig<ck_tile::fp8_t>
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
struct GemmConfigBase
{
static constexpr bool kPadM = false;

View File

@@ -321,6 +321,11 @@ int run_grouped_gemm_example(int argc, char* argv[])
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(