mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] FMHA BWD Add D96 Instances (#2916)
This commit is contained in:
@@ -408,8 +408,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock)
|
||||
if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
@@ -457,8 +456,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>, // N0 K1
|
||||
sequence<0, 1>>{});
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock)
|
||||
if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
@@ -507,8 +505,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock)
|
||||
if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
@@ -558,8 +555,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock)
|
||||
if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user