feat: add support for bf16 for grouped_gemm & grouped_gemm_preshuffle… (#3225)

* feat: add support for bf16 for grouped_gemm & grouped_gemm_preshuffle kernel(s) along with unit test

* docs: Update CHANGELOG.MD
This commit is contained in:
Aviral Goel
2025-11-18 09:32:27 -05:00
committed by GitHub
parent 3ede8e2a6e
commit ac70206b2c
7 changed files with 48 additions and 4 deletions

View File

@@ -10,6 +10,7 @@
using F16 = ck_tile::half_t;
using F32 = float;
using BF16 = ck_tile::bf16_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
@@ -28,7 +29,19 @@ using KernelTypes = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, True>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, False>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, True>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, True>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, False>,
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, True>,
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, False>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, True>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, False>
>;
// clang-format on