1.fix loop_num=odd bug 2.optimize mi300 performance of big MNK(tilesize 128x128x128) 3.optimize decode perf on mi300

This commit is contained in:
root
2025-06-25 20:41:25 -05:00
parent 40f1d5829e
commit 56f84349ca
3 changed files with 271 additions and 68 deletions

View File

@@ -145,8 +145,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
}
};
if(has_hot_loop)
{
// if(has_hot_loop)
// {
if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
@@ -165,28 +165,28 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
else
{
if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
// }
// else
// {
// if(tail_num == ck_tile::TailNumber::Odd)
// {
// RunSplitk(ck_tile::bool_constant<false>{},
// ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
// }
// else if(tail_num == ck_tile::TailNumber::Even)
// {
// RunSplitk(ck_tile::bool_constant<false>{},
// ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
// }
// else
// {
// std::ostringstream err;
// err << "Num K loop must be larger than number of prefetech stages."
// << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
// << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
// throw std::runtime_error(err.str());
// }
// }
return ave_time;
}