Fix and improve the gemm quant pipeline infrastructure (#3245)

This commit is contained in:
Thomas Ning
2025-11-26 18:04:27 -08:00
committed by GitHub
parent 79aae7c7f7
commit a38aeceb21
11 changed files with 96 additions and 272 deletions

View File

@@ -30,7 +30,7 @@ struct BaseGemmPipelineAgBgCrCompV3
{
if(BlockHasHotloop(num_loop))
{
return TailNumber::Full;
return TailNumber::Odd;
}
else
{
@@ -52,23 +52,27 @@ struct BaseGemmPipelineAgBgCrCompV3
// Handle all the valid cases.
if(has_hot_loop)
{
if(tail_number == TailNumber::Full)
if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Full>{});
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
}
else
{
if(tail_number == TailNumber::Odd)
if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Odd>{});
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == TailNumber::Even)
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Even>{});
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
}
#if defined(__HIP_DEVICE_COMPILE__)
@@ -76,16 +80,8 @@ struct BaseGemmPipelineAgBgCrCompV3
__builtin_unreachable();
#else
// If execution reaches here, it's an invalid combination of arguments.
if(has_hot_loop)
{
throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
"be TailNumber::Full.");
}
else
{
throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
"be TailNumber::Odd or TailNumber::Even.");
}
throw std::logic_error("Invalid TailNumber value: must be "
"TailNumber::Odd or TailNumber::Even");
#endif
}
};
@@ -588,7 +584,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
} while(i < (num_loop - 1));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
if constexpr(TailNum == TailNumber::Odd)
{
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency