mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Fix in using KV LdsBuffers to avoid un-expected over-writting that causes un-deterministic results
This commit is contained in:
@@ -487,18 +487,18 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
shuffle_tile(v_shuffle_tmp, v_tile);
|
||||
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
// i+1, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unrool in next iteration
|
||||
// i+2, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unroll in next iteration
|
||||
store_tile(
|
||||
v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
// i+1, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unrool in next iteration
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}],
|
||||
// i+2, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unroll in next iteration
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tile)); // store the prefetch
|
||||
};
|
||||
|
||||
@@ -520,14 +520,29 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}]);
|
||||
gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
|
||||
seqlen_k_curr += kK1;
|
||||
});
|
||||
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 3
|
||||
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
|
||||
__builtin_amdgcn_s_barrier();
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
// check whether current V-LdsBufer overlap with next K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((i_k1 + 2) % NumKVLdsBuffers == (i_k1 + 1) % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((i_k1 + 2) % NumKVLdsBuffers == 0)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
}
|
||||
});
|
||||
} while(i_loop++ < num_loops);
|
||||
|
||||
tile_elementwise_inout(
|
||||
|
||||
@@ -20,7 +20,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers()
|
||||
{
|
||||
return 3;
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -787,17 +787,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer()
|
||||
{
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
|
||||
constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
return (k1_loops - 1 + 1) % num_kv_lds_buffers == 0;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user