mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Allow problem types without define kHasDropout attr
This commit is contained in:
@@ -707,19 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
{
|
||||
if constexpr(AsyncCopyK)
|
||||
{
|
||||
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
|
||||
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>());
|
||||
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>(0));
|
||||
}
|
||||
}
|
||||
|
||||
// this method is only available when Problem::kHasDropout is present
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr std::
|
||||
enable_if_t<std::is_same_v<decltype(Problem::kHasDropout), bool>, ck_tile::index_t>
|
||||
GetSmemSizeDropout()
|
||||
enable_if_t<std::is_convertible_v<decltype(Problem::kHasDropout), bool>, ck_tile::index_t>
|
||||
GetSmemSizeDropout(int)
|
||||
{
|
||||
if constexpr(Problem::kHasDropout)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user