Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA

This commit is contained in:
aska-0096
2025-08-04 10:27:42 +00:00
parent 4f31847de1
commit 0d12fc944f
5 changed files with 101 additions and 52 deletions

View File

@@ -14,10 +14,14 @@ namespace ck_tile {
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
template <typename AccDistributedTensor_,
typename ReduceFunc,
bool WithBroadcast = true,
bool CrossWarp = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func,
bool_constant<WithBroadcast> = {})
bool_constant<WithBroadcast> = {},
bool_constant<CrossWarp> = {})
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
@@ -56,14 +60,24 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
constexpr index_t lid_delta =
lid_over_rid_derivative * (1 << (nstage - istage - 1));
if constexpr(CrossWarp)
{
constexpr index_t lid_delta =
lid_over_rid_derivative * (1 << (nstage - istage - 1));
// pull data from remote lane
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
// pull data from remote lane
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
// reduce
v_local = reduce_func(v_local, v_remote);
// reduce
v_local = reduce_func(v_local, v_remote);
}
else
{
// pull data from remote lane
const auto v_swapped_regs = warp_shuffle_down_pair(v_local);
// reduce
v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1));
}
});
}
});