Add IsFullTileInsideMask() to avoid pixel-by-pixel checking when kUseCausl=true but kUseLocal=false

This commit is contained in:
Qianfeng Zhang
2025-04-27 09:31:38 +00:00
parent 054c397e05
commit 1af27022ef
2 changed files with 72 additions and 11 deletions

View File

@@ -358,20 +358,48 @@ struct HstuAttentionFwdPipelineQRKSVS
if constexpr(HstuMask::IsMasking)
{
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tiles[i_k1].get_tile_distribution(), make_tuple(idx0, idx1));
if constexpr(HstuMask::kUseLocal)
{
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tiles[i_k1].get_tile_distribution(),
make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sacc_tiles[i_k1](i_j_idx) *=
static_cast<GemmAccDataType>(mask.IsTokenPairInsideMask(row, col));
sacc_tiles[i_k1](i_j_idx) *= static_cast<GemmAccDataType>(
mask.IsTokenPairInsideMask(row, col));
});
});
});
}
else // kUseCausal=true, kUseLocal=false
{
if(!mask.IsFullTileInsideMask(
q_origin.at(number<0>{}), seqlen_k_curr, number<kK1>{}))
{
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tiles[i_k1].get_tile_distribution(),
make_tuple(idx0, idx1));
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sacc_tiles[i_k1](i_j_idx) *= static_cast<GemmAccDataType>(
mask.IsTokenPairInsideMask(row, col));
});
});
}
};
}
else if constexpr(kPadSeqLenK)
{

View File

@@ -115,6 +115,18 @@ struct HstuBlockMaskWithLocal
return static_cast<int>(result);
}
};
// if the whole tile inside the masking area, no need for pixel-by-pixel checking
template <index_t TileWidth>
CK_TILE_DEVICE constexpr bool
IsFullTileInsideMask(index_t i_tile_top, index_t i_tile_left, number<TileWidth>) const
{
// when local masking used, we assume all tiles need pixel-by-pixel checking
std::ignore = i_tile_top;
std::ignore = i_tile_left;
return false;
}
};
template <bool kUseCausal>
@@ -185,6 +197,27 @@ struct HstuBlockMaskNoLocal
return 1;
};
// if the whole tile inside the masking area, no need for pixel-by-pixel checking
template <index_t TileWidth>
CK_TILE_DEVICE constexpr bool
IsFullTileInsideMask(index_t i_tile_top, index_t i_tile_left, number<TileWidth>) const
{
if constexpr(kUseCausal)
{
index_t i_tile_right = i_tile_left + TileWidth;
if(i_tile_right > i_tile_top)
return false;
return true;
}
else
{
// need further check kPadSeqLenK in the masking context
return true;
};
}
};
template <bool kUseCausal, bool kUseLocal>