mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Fix and improve the gemm quant pipeline infrastructure (#3245)
This commit is contained in:
@@ -30,7 +30,7 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
{
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
return TailNumber::Full;
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -52,23 +52,27 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == TailNumber::Full)
|
||||
if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Full>{});
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
|
||||
if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
@@ -76,16 +80,8 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
// If execution reaches here, it's an invalid combination of arguments.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
|
||||
"be TailNumber::Full.");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
|
||||
"be TailNumber::Odd or TailNumber::Even.");
|
||||
}
|
||||
throw std::logic_error("Invalid TailNumber value: must be "
|
||||
"TailNumber::Odd or TailNumber::Even");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -588,7 +584,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
|
||||
Reference in New Issue
Block a user