Tiny fix in hstu attention IsFullTileInsideMask()

This commit is contained in:
root
2025-06-18 15:12:05 +00:00
committed by Qianfeng Zhang
parent 08886e99d5
commit 9e62359b59

View File

@@ -267,7 +267,11 @@ struct HstuBlockMaskWithLocal
{
std::ignore = i_tile_left;
if(min_full_attn_seqlen > 0 && i_tile_top >= max_uih_len - min_full_attn_seqlen)
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
// assume num_target > 0 with high probability, don't check whether num_target is 0;
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
if(i_tile_top >= max_uih_len - min_full_attn_seqlen && i_tile_bottom < max_uih_len)
return true;
return false;
@@ -440,26 +444,29 @@ struct HstuBlockMaskNoLocal
{
if constexpr(kUseCausal)
{
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = i_tile_top + TileHeight;
index_t i_tile_right = i_tile_left + (TileWidth - 1);
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
if(i_tile_right > i_tile_top ||
(i_tile_bottom > max_uih_len && i_tile_right > max_uih_len))
// assume num_target > 0 with high probability, don't check whether num_target is 0;
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
if(i_tile_bottom >= max_uih_len || i_tile_right > i_tile_top)
return false;
return true;
}
else
{
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = i_tile_top + TileHeight;
index_t i_tile_right = i_tile_left + (TileWidth - 1);
index_t i_tile_bottom = i_tile_top + (TileHeight - 1);
if(i_tile_bottom > max_uih_len && i_tile_right > max_uih_len)
// assume num_target > 0 with high probability, don't check whether num_target is 0;
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
if(i_tile_bottom >= max_uih_len || i_tile_right >= max_uih_len)
return false;
return true;
};
}
}
};
};
template <bool kUseCausal, bool kUseLocal>