Update in using masking for the case where kMasking is false and kPadSeqLenK is true

This commit is contained in:
Qianfeng Zhang
2025-04-23 10:47:27 +00:00
parent 8dcde8d10f
commit 2d2e1941a8

View File

@@ -362,15 +362,16 @@ struct HstuAttentionFwdPipelineQRKSVS
}
else if constexpr(kPadSeqLenK)
{
set_tile_if(
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len &&
i_loop < num_loops - 1)
return false;
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
if(i_loop >= num_loops - 1)
{
set_tile_if(
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
}
}
pcomp_tiles[i_k1] = cast_tile<CompDataType>(sacc_tiles[i_k1]);