mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Update in using masking for the case where kMasking is false and kPadSeqLenK is true
This commit is contained in:
@@ -362,15 +362,16 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
}
|
||||
else if constexpr(kPadSeqLenK)
|
||||
{
|
||||
set_tile_if(
|
||||
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len &&
|
||||
i_loop < num_loops - 1)
|
||||
return false;
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
if(i_loop >= num_loops - 1)
|
||||
{
|
||||
set_tile_if(
|
||||
sacc_tiles[i_k1], type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pcomp_tiles[i_k1] = cast_tile<CompDataType>(sacc_tiles[i_k1]);
|
||||
|
||||
Reference in New Issue
Block a user