[CK_TILE] FMHA BWD Add D96 Instances (#2916)

This commit is contained in:
Yi DING
2025-09-24 17:04:23 +08:00
committed by GitHub
parent 15fff74503
commit fe0a47a011
5 changed files with 35 additions and 28 deletions

View File

@@ -408,8 +408,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
sequence<1, 2>,
sequence<2, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock)
if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
{
return dstr;
}
@@ -457,8 +456,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, // N0 K1
sequence<0, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock)
if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
{
return dstr;
}
@@ -507,8 +505,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
sequence<1, 2>,
sequence<2, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock)
if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
{
return dstr;
}
@@ -558,8 +555,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
sequence<1, 2>,
sequence<2, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kMPerBlock * kKPerBlock)
if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
{
return dstr;
}