mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
reduce instance
This commit is contained in:
@@ -393,7 +393,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("scale", "1", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
@@ -421,12 +421,7 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
@@ -435,19 +430,20 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
-1,
|
||||
-1,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
throw std::runtime_error("scale_opt=0 no enabled to accelerate compiling");
|
||||
// if(persistent_opt == 0)
|
||||
// {
|
||||
// run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
// FlatmmConfig<ck_tile::fp8_t>,
|
||||
// -1,
|
||||
// -1,
|
||||
// true>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -468,19 +464,6 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
|
||||
Reference in New Issue
Block a user