mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Update IsFulleTileInsideMask() for kUseLocal is true situtation
This commit is contained in:
@@ -406,21 +406,27 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
{
|
||||
if constexpr(HstuMask::kUseLocal)
|
||||
{
|
||||
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
if(!mask.IsFullTileInsideMask(q_origin.at(number<0>{}),
|
||||
seqlen_k_curr,
|
||||
number<kK1>{},
|
||||
number<kM0>{}))
|
||||
{
|
||||
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
|
||||
mask.IsTokenPairInsideMask(row, col));
|
||||
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
|
||||
mask.IsTokenPairInsideMask(row, col));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
else // kUseCausal=true, kUseLocal=false
|
||||
{
|
||||
|
||||
@@ -265,10 +265,11 @@ struct HstuBlockMaskWithLocal
|
||||
number<TileWidth>,
|
||||
number<TileHeight>) const
|
||||
{
|
||||
// when local masking used, we assume all tiles need pixel-by-pixel checking
|
||||
std::ignore = i_tile_top;
|
||||
std::ignore = i_tile_left;
|
||||
|
||||
if(min_full_attn_seqlen > 0 && i_tile_top >= max_uih_len - min_full_attn_seqlen)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user