Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA

This commit is contained in:
aska-0096
2025-08-04 10:27:42 +00:00
parent 4f31847de1
commit 0d12fc944f
5 changed files with 101 additions and 52 deletions

View File

@@ -95,7 +95,15 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
template <index_t lgkmcnt = 0>
CK_TILE_DEVICE void block_sync_lds()
{
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_LGKMCNT(lgkmcnt));
if constexpr(lgkmcnt > 15)
{
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_LGKMCNT(15));
}
else
{
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_LGKMCNT(lgkmcnt));
}
__builtin_amdgcn_s_barrier();
}

View File

@@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template <typename T>
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
{
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
thread_buffer<T, 2> v;
v(0) = bit_cast<T>(x[0]);
v(1) = bit_cast<T>(x[1]);
return v;
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{

View File

@@ -227,14 +227,20 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
auto q_lds = make_tensor_view<address_space_enum::lds>(
auto q_lds_write_view = make_tensor_view<address_space_enum::lds>(
static_cast<QDataType*>(smem_ptr), Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_store_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto q_lds_read_view = make_tensor_view<address_space_enum::lds>(
static_cast<QDataType*>(smem_ptr),
Policy::template MakeQLdsBlockDescriptor<Problem, true>());
auto q_lds_store_window =
make_tile_window(q_lds_write_view,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0});
auto q_lds_read_window =
make_tile_window(q_lds,
make_tile_window(q_lds_read_view,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem>());
@@ -452,7 +458,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
// Set CrossWarp to false will trigger better strategy on gfx950, but will cause
// performance regression because of un-coexecutable packed math, silent it for now
block_tile_reduce_sync(
m_local, f_max, bool_constant<false>{} /*, bool_constant<false>{}*/);
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
@@ -505,7 +514,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
block_tile_reduce_sync(
rowsum_p, f_sum, bool_constant<false>{} /*, bool_constant<false>{}*/);
auto p_tile = make_static_distributed_tensor<PDataType>(
Policy::template MakePRegTileDistribution<Problem>());
@@ -964,7 +974,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
block_tile_reduce_sync(
m_local, f_max, bool_constant<false>{} /*, bool_constant<false>{}*/);
static_for<0, 12, 1>{}([&](auto i) {
ignore = i;
@@ -1029,7 +1040,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
block_tile_reduce_sync(
rowsum_p, f_sum, bool_constant<false>{} /*, bool_constant<false>{}*/);
auto p_tile = make_static_distributed_tensor<PDataType>(
Policy::template MakePRegTileDistribution<Problem>());

View File

@@ -14,10 +14,14 @@ namespace ck_tile {
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
template <typename AccDistributedTensor_,
typename ReduceFunc,
bool WithBroadcast = true,
bool CrossWarp = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func,
bool_constant<WithBroadcast> = {})
bool_constant<WithBroadcast> = {},
bool_constant<CrossWarp> = {})
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
@@ -56,14 +60,24 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
constexpr index_t lid_delta =
lid_over_rid_derivative * (1 << (nstage - istage - 1));
if constexpr(CrossWarp)
{
constexpr index_t lid_delta =
lid_over_rid_derivative * (1 << (nstage - istage - 1));
// pull data from remote lane
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
// pull data from remote lane
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
// reduce
v_local = reduce_func(v_local, v_remote);
// reduce
v_local = reduce_func(v_local, v_remote);
}
else
{
// pull data from remote lane
const auto v_swapped_regs = warp_shuffle_down_pair(v_local);
// reduce
v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1));
}
});
}
});