draft for 16*16*128 fp8

This commit is contained in:
solin
2025-05-14 09:56:13 +00:00
parent 41c17d0a95
commit c2d17ec24f
7 changed files with 69 additions and 39 deletions

View File

@@ -27,7 +27,19 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 2;
#if defined(USING_MFMA_16x16x128)
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 128;
#endif
// This part comes from the Codegen
#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16)
constexpr ck_tile::index_t M_Tile = 128;
@@ -151,11 +163,11 @@ int run_flatmm_example(int argc, char* argv[])
{
if(data_type == "fp16")
{
run_flatmm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
//run_flatmm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_flatmm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
//run_flatmm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
@@ -163,7 +175,7 @@ int run_flatmm_example(int argc, char* argv[])
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
//run_flatmm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{