mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
feat: add support for bf16 for grouped_gemm & grouped_gemm_preshuffle… (#3225)
* feat: add support for bf16 for grouped_gemm & grouped_gemm_preshuffle kernel(s) along with unit test * docs: Update CHANGELOG.MD
This commit is contained in:
@@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
## Composable Kernel 1.2.0 for ROCm 7.2.0
|
||||
|
||||
### Added
|
||||
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
|
||||
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM
|
||||
* Added a compute async pipeline in the CK TILE universal GEMM on gfx950
|
||||
* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM.
|
||||
|
||||
@@ -26,7 +26,7 @@ Multi-D operations extend the standard GEMM operation by supporting additional e
|
||||
- **Implementation**: Available in `grouped_gemm_multi_d.cpp`
|
||||
- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result)
|
||||
- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors
|
||||
- **Data Types**: Supports fp16, fp8
|
||||
- **Data Types**: Supports fp16, bf16, fp8
|
||||
- **Benefits**: Enables complex operations like scaling, activation functions, or other elementwise transformations in a single kernel call
|
||||
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
|
||||
|
||||
@@ -61,7 +61,7 @@ args:
|
||||
-a_layout A tensor data layout - (Default: Row).
|
||||
-b_layout B tensor data layout - (Default: Col).
|
||||
-c_layout C tensor data layout - (Default: Row).
|
||||
-prec data type. fp16/fp8 - (Default: fp16).
|
||||
-prec data type. fp16/bf16/fp8 - (Default: fp16).
|
||||
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10).
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100).
|
||||
|
||||
@@ -333,6 +333,11 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
|
||||
@@ -66,6 +66,15 @@ struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
|
||||
@@ -321,6 +321,11 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
@@ -28,7 +29,19 @@ using KernelTypes = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>,
|
||||
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, False>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, False>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -57,7 +58,17 @@ using KernelTypes = ::testing::Types<
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>,
|
||||
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>,
|
||||
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, True, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, True, 128, 128, 128, 2>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user