Let causal == 0 cases to do IsFullTileInsideMask() checking before calling IsTokenPairInsideMask()

This commit is contained in:
Qianfeng Zhang
2025-06-26 10:23:03 +00:00
parent 3c300d3069
commit 5451912526

View File

@@ -397,77 +397,23 @@ struct HstuAttentionFwdPipelineQRKSVS
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, sacc_tile);
}
if constexpr(HstuMask::IsMasking)
if(!mask.IsFullTileInsideMask(
q_origin.at(number<0>{}), seqlen_k_curr, number<kK1>{}, number<kM0>{}))
{
if constexpr(HstuMask::kUseLocal)
{
if(!mask.IsFullTileInsideMask(q_origin.at(number<0>{}),
seqlen_k_curr,
number<kK1>{},
number<kM0>{}))
{
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
mask.IsTokenPairInsideMask(row, col));
});
});
}
}
else // kUseCausal=true, kUseLocal=false
{
if(!mask.IsFullTileInsideMask(q_origin.at(number<0>{}),
seqlen_k_curr,
number<kK1>{},
number<kM0>{}))
{
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
mask.IsTokenPairInsideMask(row, col));
});
});
}
};
}
else if constexpr(kPadSeqLenK)
{
if(i_loop >= num_loops - 1)
{
constexpr auto s_spans = SaccBlockTileType::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sacc_tile(i_j_idx) *= static_cast<GemmAccDataType>(
mask.IsTokenPairInsideMask(row, col));
});
sacc_tile(i_j_idx) *=
static_cast<GemmAccDataType>(mask.IsTokenPairInsideMask(row, col));
});
}
});
}
pcomp_tile = cast_tile<CompDataType>(sacc_tile);