[CK_TILE] Add FP8xF4 Flatmm (#3401)

* Refactor policy

* fix a bank conflict

* Enable mixed mx flatmm

* Update
This commit is contained in:
Yi DING
2025-12-17 10:01:48 +08:00
committed by GitHub
parent 3dfa794fab
commit 57e1e4a848
9 changed files with 231 additions and 223 deletions

View File

@@ -148,7 +148,7 @@ auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "32", "m dimension")
.insert("n", "128", "n dimension")
.insert("n", "512", "n dimension")
.insert("k", "256", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
@@ -308,6 +308,28 @@ int run_mx_flatmm_example(int argc, char* argv[])
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp8xfp4")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
MXf8f4_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp4xfp8")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXf4f8_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else
{
throw std::runtime_error("Unsupported data_type!");