mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Add FP8xF4 Flatmm (#3401)
* Refactor policy * fix a bank conflict * Enable mixed mx flatmm * Update
This commit is contained in:
@@ -148,7 +148,7 @@ auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "32", "m dimension")
|
||||
.insert("n", "128", "n dimension")
|
||||
.insert("n", "512", "n dimension")
|
||||
.insert("k", "256", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
@@ -308,6 +308,28 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else if(mx_prec == "fp8xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
MXf8f4_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else if(mx_prec == "fp4xfp8")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp16_t,
|
||||
MXf4f8_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
|
||||
Reference in New Issue
Block a user