[flatmm] add bf8 datatype

This commit is contained in:
so
2025-03-25 08:01:59 +00:00
parent 390dd7001d
commit 41d466d93b
5 changed files with 35 additions and 3 deletions

View File

@@ -39,7 +39,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = std::is_same_v<ADataType, ck_tile::fp8_t> ? 32 : 16;
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
using CodegenFlatmmShape =
ck_tile::TileFlatmmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -132,6 +132,10 @@ int run_flatmm_example(int argc, char* argv[])
{
run_flatmm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");