Allow problem types without define kHasDropout attr

This commit is contained in:
PoYen, Chen
2024-08-08 10:53:42 +00:00
parent a0d2163045
commit 2f42e4460f

View File

@@ -707,19 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
if constexpr(AsyncCopyK)
{
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0);
}
else
{
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>());
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>(0));
}
}
// this method is only available when Problem::kHasDropout is present
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr std::
enable_if_t<std::is_same_v<decltype(Problem::kHasDropout), bool>, ck_tile::index_t>
GetSmemSizeDropout()
enable_if_t<std::is_convertible_v<decltype(Problem::kHasDropout), bool>, ck_tile::index_t>
GetSmemSizeDropout(int)
{
if constexpr(Problem::kHasDropout)
{