diff --git a/CHANGELOG.md b/CHANGELOG.md index 1223b63be0..7b9ecfcef4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM * Added a compute async pipeline in the CK TILE universal GEMM on gfx950 * Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 32d7d0516f..c3fd7d4f82 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -26,7 +26,7 @@ Multi-D operations extend the standard GEMM operation by supporting additional e - **Implementation**: Available in `grouped_gemm_multi_d.cpp` - **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result) - **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors -- **Data Types**: Supports fp16, fp8 +- **Data Types**: Supports fp16, bf16, fp8 - **Benefits**: Enables complex operations like scaling, activation functions, or other elementwise transformations in a single kernel call - **Build Target**: `make tile_example_grouped_gemm_multi_d -j` @@ -61,7 +61,7 @@ args: -a_layout A tensor data layout - (Default: Row). -b_layout B tensor data layout - (Default: Col). -c_layout C tensor data layout - (Default: Row). - -prec data type. fp16/fp8 - (Default: fp16). + -prec data type. fp16/bf16/fp8 - (Default: fp16). -validate 0. No validation, 1. Validation on CPU. (Default: 1). -warmup Number of iterations before benchmark the kernel. (Default: 10). -repeat Number of iterations to benchmark the kernel. (Default: 100). diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index f5335c3ec0..52391cde62 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -333,6 +333,11 @@ int run_grouped_gemm_example(int argc, char* argv[]) return run_gemm_example_prec_type, ck_tile::half_t>( a_layout, b_layout, argc, argv); } + else if(data_type == "bf16") + { + return run_gemm_example_prec_type, ck_tile::bf16_t>( + a_layout, b_layout, argc, argv); + } else if(data_type == "fp8") { return run_gemm_example_prec_type, ck_tile::fp8_t>( diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 9b14efb561..df8fa8fde0 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -66,6 +66,15 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + struct GemmConfigBase { static constexpr bool kPadM = false; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index 52b84737cc..0c0e895cd5 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -321,6 +321,11 @@ int run_grouped_gemm_example(int argc, char* argv[]) return run_gemm_example_prec_type, ck_tile::half_t>( a_layout, b_layout, argc, argv); } + else if(data_type == "bf16") + { + return run_gemm_example_prec_type, ck_tile::bf16_t>( + a_layout, b_layout, argc, argv); + } else if(data_type == "fp8") { return run_gemm_example_prec_type, ck_tile::fp8_t>( diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp b/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp index 7d71f9f927..8d05c78e14 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm.cpp @@ -10,6 +10,7 @@ using F16 = ck_tile::half_t; using F32 = float; +using BF16 = ck_tile::bf16_t; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using True = ck_tile::bool_constant; @@ -28,7 +29,19 @@ using KernelTypes = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, True>, std::tuple< Row, Row, Row, F16, F16, F32, F16, False>, std::tuple< Col, Row, Row, F16, F16, F32, F16, True>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, False> + std::tuple< Col, Row, Row, F16, F16, F32, F16, False>, + + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>, + + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, True>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, False>, + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, True>, + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, False>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, True>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp index a9b61ac7de..082c67d0f2 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp @@ -10,6 +10,7 @@ using F16 = ck_tile::half_t; using F8 = ck_tile::fp8_t; +using BF16 = ck_tile::bf16_t; using F32 = float; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -57,7 +58,17 @@ using KernelTypes = ::testing::Types< KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>, KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>, - KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2> + KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>, + + KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>, + + KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, True, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, True, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, True, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, True, 128, 128, 128, 2> >; // clang-format on