Fix in using KV LdsBuffers to avoid un-expected over-writting that causes un-deterministic results

This commit is contained in:
Qianfeng Zhang
2025-06-21 13:48:14 +00:00
parent a5f24d7470
commit 4fa6474254
2 changed files with 27 additions and 23 deletions

View File

@@ -487,18 +487,18 @@ struct HstuAttentionFwdPipelineQRKSVS
shuffle_tile(v_shuffle_tmp, v_tile);
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
// i+1, No overlap occurs between V and K in the same unroll, and V in current
// unroll and K in next unroll or first unrool in next iteration
// i+2, No overlap occurs between V and K in the same unroll, and V in current
// unroll and K in next unroll or first unroll in next iteration
store_tile(
v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
// i+1, No overlap occurs between V and K in the same unroll, and V in current
// unroll and K in next unroll or first unrool in next iteration
store_tile(v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
// i+2, No overlap occurs between V and K in the same unroll, and V in current
// unroll and K in next unroll or first unroll in next iteration
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tile)); // store the prefetch
};
@@ -520,14 +520,29 @@ struct HstuAttentionFwdPipelineQRKSVS
block_sync_lds();
gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}]);
gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
seqlen_k_curr += kK1;
});
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 3
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
__builtin_amdgcn_s_barrier();
if constexpr(i_k1 < k1_loops - 1)
{
// check whether current V-LdsBufer overlap with next K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((i_k1 + 2) % NumKVLdsBuffers == (i_k1 + 1) % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
}
else
{
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((i_k1 + 2) % NumKVLdsBuffers == 0)
{
__builtin_amdgcn_s_barrier();
};
}
});
} while(i_loop++ < num_loops);
tile_elementwise_inout(

View File

@@ -20,7 +20,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers()
{
return 3;
return 4;
}
template <typename Problem>
@@ -787,17 +787,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer()
{
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers<Problem>();
return (k1_loops - 1 + 1) % num_kv_lds_buffers == 0;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{