Update to use BottomRight-Diagonal masking when seqlen_kv is bigger than seqlen_q

This commit is contained in:
Qianfeng Zhang
2026-01-25 14:51:53 +00:00
parent 1d4d925ba3
commit 749e83f2fd
5 changed files with 189 additions and 45 deletions

View File

@@ -99,6 +99,7 @@ auto create_args(int argc, char* argv[])
.insert("seqlens", "400", "uih seqlen of single or all batches for query tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("softmax", "0", "use softmax or not")
@@ -253,11 +254,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string str_of_lengths_kv = arg_parser.get_str("seqlens_kv");
std::vector<int> seq_lengths_kv = get_integers_from_string(str_of_lengths_kv);
int input_max_uih_seqlen = arg_parser.get_int("max_seqlen");
int input_max_target = arg_parser.get_int("max_target");
int input_max_uih_seqlen_q = arg_parser.get_int("max_seqlen");
int input_max_uih_seqlen_kv = arg_parser.get_int("max_seqlen_kv");
int input_max_target = arg_parser.get_int("max_target");
int max_uih_seqlen = 0;
int max_target = 0;
int max_uih_seqlen_q = 0;
int max_uih_seqlen_kv = 0;
int max_target = 0;
if(!num_targets.empty())
{
@@ -275,6 +279,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(seq_lengths_kv.empty())
seq_lengths_kv = seq_lengths_q;
// assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined
if(input_max_uih_seqlen_kv <= 0)
input_max_uih_seqlen_kv = input_max_uih_seqlen_q;
if(is_jagged)
{
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
@@ -286,19 +294,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
// only consider num_batch values even if more values are provided by the user
for(int i = 0; i < num_batch; i++)
{
max_uih_seqlen = max(max_uih_seqlen, max(seq_lengths_q[i], seq_lengths_kv[i]));
max_uih_seqlen_q = max(max_uih_seqlen_q, seq_lengths_q[i]);
max_uih_seqlen_kv = max(max_uih_seqlen_kv, seq_lengths_kv[i]);
};
}
else
{
HSTU_CHECK(1 == seq_lengths_q.size() && 1 == seq_lengths_kv.size(),
"sequence lengths for batched mode shoud have single element!");
max_uih_seqlen = max(seq_lengths_q[0], seq_lengths_kv[0]);
max_uih_seqlen_q = seq_lengths_q[0];
max_uih_seqlen_kv = seq_lengths_kv[0];
};
// the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens
// the user input of max_target can either be ignored or be bigger than all targets
HSTU_CHECK(input_max_uih_seqlen <= 0 || input_max_uih_seqlen >= max_uih_seqlen,
HSTU_CHECK(input_max_uih_seqlen_q <= 0 || input_max_uih_seqlen_q >= max_uih_seqlen_q,
"the user input of max_uih_seqlen can either be ignored or be bigger than all "
"uih_seqlens!");
HSTU_CHECK(input_max_uih_seqlen_kv <= 0 || input_max_uih_seqlen_kv >= max_uih_seqlen_kv,
"the user input of max_uih_seqlen can either be ignored or be bigger than all "
"uih_seqlens!");
HSTU_CHECK(input_max_target <= 0 || input_max_target >= max_target,
@@ -306,12 +319,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
HSTU_CHECK(contextual_seqlen >= 0, "contextual_seqlen should be non-negative!");
max_uih_seqlen = (input_max_uih_seqlen > 0) ? input_max_uih_seqlen : max_uih_seqlen;
max_target = (input_max_target > 0) ? input_max_target : max_target;
max_uih_seqlen_q = (input_max_uih_seqlen_q > 0) ? input_max_uih_seqlen_q : max_uih_seqlen_q;
max_uih_seqlen_kv = (input_max_uih_seqlen_kv > 0) ? input_max_uih_seqlen_kv : max_uih_seqlen_kv;
max_target = (input_max_target > 0) ? input_max_target : max_target;
int phy_seqlen_q = 0;
int phy_seqlen_kv = 0;
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;
int max_seqlen_q = max_uih_seqlen_q + max_target + contextual_seqlen;
int max_seqlen_kv = max_uih_seqlen_kv + max_target + contextual_seqlen;
std::vector<int> seq_offsets_q;
std::vector<int> seq_offsets_kv;
@@ -344,8 +359,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else
{
phy_seqlen_q = max_seqlen;
phy_seqlen_kv = max_seqlen;
phy_seqlen_q = max_seqlen_q;
phy_seqlen_kv = max_seqlen_kv;
};
long total_flops = 0;
@@ -384,8 +399,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
ck_tile::HostTensor<int8_t> mask_host(
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
save_mask
? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen_q, max_seqlen_kv}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
if(!initialize_qkv)
{
@@ -440,7 +456,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.num_batch = num_batch;
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
params.max_seqlen = max_seqlen;
params.max_seqlen = max(max_seqlen_q, max_seqlen_kv);
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
@@ -558,7 +574,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_batch,
scale_s,
attn_scale,
max_seqlen,
max_seqlen_q,
max_seqlen_kv,
seq_offsets_q,
seq_offsets_kv,
num_targets,

View File

@@ -29,6 +29,7 @@ struct HstuBlockMaskWithLocal
int max_k_uih_len;
int max_row_id;
int max_col_id;
int diff_q_kv_len;
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_,
int seqlen_q_,
@@ -63,6 +64,9 @@ struct HstuBlockMaskWithLocal
max_row_id = max_q_uih_len;
max_col_id = max_k_uih_len;
}
diff_q_kv_len = max_k_uih_len - max_q_uih_len;
max_row_id += diff_q_kv_len;
};
// to get the loop length along X axis, return index:[start, end), end-start=length
@@ -77,7 +81,7 @@ struct HstuBlockMaskWithLocal
{
if constexpr(kUseCausal)
{
index_t x_end = min(i_y + YTile, seqlen_k);
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
else
@@ -90,7 +94,7 @@ struct HstuBlockMaskWithLocal
}
else // tile is completely inside [max_q_uih_len, seqlen_q)
{
index_t x_end = min(i_y + YTile, seqlen_k);
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
return ck_tile::make_tuple(0, x_end);
};
};
@@ -105,7 +109,7 @@ struct HstuBlockMaskWithLocal
// some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len)
if(i_y < max_q_uih_len)
{
index_t x_start = i_y - max_attn_len;
index_t x_start = i_y + diff_q_kv_len - max_attn_len;
index_t x_start_aligned = x_start - x_start % XTile;
// some rows of the tile in [max_q_uih_len - max_attn_len, max_q_uih_len)
@@ -116,14 +120,14 @@ struct HstuBlockMaskWithLocal
else // whole tile in [contextual_seqlen+max_attn_len, max_q_uih_len
// -max_attn_len)
{
index_t x_end = i_y + YTile + max_attn_len;
index_t x_end = i_y + YTile + diff_q_kv_len + max_attn_len;
return ck_tile::make_tuple(x_start_aligned, x_end);
};
}
else // whole tile in [max_uih_len, seqlen)
{
index_t x_start = max_k_uih_len - max_attn_len;
index_t x_end = min(i_y + YTile, seqlen_k);
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
}
@@ -132,12 +136,13 @@ struct HstuBlockMaskWithLocal
{
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
{
index_t x_end = min(max(i_y + YTile + max_attn_len, max_k_uih_len), seqlen_k);
index_t x_end = min(
max(i_y + YTile + diff_q_kv_len + max_attn_len, max_k_uih_len), seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
else // whole tile in [contextual_seqlen, seqlen)
{
index_t x_end = min(i_y + YTile + max_attn_len, seqlen_k);
index_t x_end = min(i_y + YTile + diff_q_kv_len + max_attn_len, seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
}
@@ -146,15 +151,15 @@ struct HstuBlockMaskWithLocal
{
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
{
index_t x_end = min(i_y + YTile, seqlen_k);
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
// some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len)
if(i_y < max_q_uih_len)
{
index_t x_start = i_y - max_attn_len;
index_t x_start = i_y + diff_q_kv_len - max_attn_len;
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
}
else // whole tile in [max_uih_len, seqlen)
else // whole tile in [max_q_uih_len, seqlen_q)
{
index_t x_start = max_k_uih_len - max_attn_len;
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
@@ -164,12 +169,12 @@ struct HstuBlockMaskWithLocal
{
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
{
index_t x_end = min(max(i_y + YTile, max_k_uih_len), seqlen_k);
index_t x_end = min(max(i_y + YTile + diff_q_kv_len, max_k_uih_len), seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
else // whole tile in [contextual_seqlen, seqlen)
{
index_t x_end = min(i_y + YTile, seqlen_k);
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
return ck_tile::make_tuple(0, x_end);
}
}
@@ -181,17 +186,19 @@ struct HstuBlockMaskWithLocal
int row_id;
int col_id;
row += diff_q_kv_len;
if(contextual_seqlen > 0)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = max(row - contextual_seqlen + 1, 0);
row_id = max(row - contextual_seqlen + 1, diff_q_kv_len);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_row_id);
col_id = min(col_id, max_col_id);
if(row_id == 0 && col_id < max_col_id)
if(row_id == diff_q_kv_len && col_id < max_col_id)
return true;
}
else
@@ -227,17 +234,19 @@ struct HstuBlockMaskWithLocal
int row_id;
int col_id;
row += diff_q_kv_len;
if(contextual_seqlen > 0)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = max(row - contextual_seqlen + 1, 0);
row_id = max(row - contextual_seqlen + 1, diff_q_kv_len);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_row_id);
col_id = min(col_id, max_col_id);
if(row_id == 0 && col_id < max_col_id)
if(row_id == diff_q_kv_len && col_id < max_col_id)
return true;
}
else
@@ -281,7 +290,8 @@ struct HstuBlockMaskWithLocal
{
index_t i_tile_right = i_tile_left + TileWidth;
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_k_uih_len))
if(!is_tile_in_first_split &&
i_tile_right <= min(i_tile_top + diff_q_kv_len + 1, max_k_uih_len))
return true;
}
else
@@ -315,6 +325,7 @@ struct HstuBlockMaskNoLocal
int max_k_uih_len;
int max_row_id;
int max_col_id;
int diff_q_kv_len;
CK_TILE_HOST_DEVICE
HstuBlockMaskNoLocal(int seqlen_q_, int seqlen_k_, int contextual_seqlen_, int num_target_)
@@ -333,6 +344,9 @@ struct HstuBlockMaskNoLocal
max_row_id = max_q_uih_len;
max_col_id = max_k_uih_len;
}
diff_q_kv_len = max_k_uih_len - max_q_uih_len;
max_row_id += diff_q_kv_len;
};
// to get the loop length along X axis, return index:[start, end), end-start=length
@@ -348,11 +362,11 @@ struct HstuBlockMaskNoLocal
}
else
{
index_t x_end = min(i_y + YTile, seqlen_k);
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
if(i_y < contextual_seqlen)
{
if(i_y + YTile > max_k_uih_len)
if(i_y + YTile + diff_q_kv_len > max_k_uih_len)
{
return ck_tile::make_tuple(0, x_end);
}
@@ -373,17 +387,19 @@ struct HstuBlockMaskNoLocal
int row_id;
int col_id;
row += diff_q_kv_len;
if(contextual_seqlen > 0)
{
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
// max_uih_len
row_id = max(row - contextual_seqlen + 1, 0);
row_id = max(row - contextual_seqlen + 1, diff_q_kv_len);
col_id = max(col - contextual_seqlen + 1, 0);
row_id = min(row_id, max_row_id);
col_id = min(col_id, max_col_id);
if(row_id == 0 && col_id < max_col_id)
if(row_id == diff_q_kv_len && col_id < max_col_id)
return true;
}
else
@@ -420,7 +436,7 @@ struct HstuBlockMaskNoLocal
// assume num_target > 0 with high probability, don't check whether num_target is 0;
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top)
if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top + diff_q_kv_len)
return false;
return true;

View File

@@ -41,7 +41,8 @@ struct reference_hstu_attention
int num_batch,
float alpha,
float attn_scale,
int max_seqlen,
int max_seqlen_q,
int max_seqlen_kv,
std::vector<int> seq_q_offsets,
std::vector<int> seq_kv_offsets,
std::vector<int> num_targets, // define masking length at the end of token
@@ -93,8 +94,8 @@ struct reference_hstu_attention
if(static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen)
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen_q &&
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen_kv)
save_mask = true;
// check num_tagets
@@ -114,7 +115,9 @@ struct reference_hstu_attention
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen);
float scale_p = attn_scale
? attn_scale
: 1.0f / static_cast<float>(max(max_seqlen_q, max_seqlen_kv));
BOOL_SWITCH(window_size > 0, kHasLocal, [&] {
using HstuMaskType = typename HstuBlockMasking<kUseCausal, kHasLocal>::Type;
@@ -148,9 +151,8 @@ struct reference_hstu_attention
if(save_mask)
{
// initialize the mask
for(int sq = 0; sq < max_seqlen; sq++)
for(int sk = 0; sk < max_seqlen; sk++)
for(int sq = 0; sq < max_seqlen_q; sq++)
for(int sk = 0; sk < max_seqlen_kv; sk++)
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
}

View File

@@ -49,3 +49,6 @@ for T in "fp16" "bf16"; do
$EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=0 -norm_dist=0
set +x
done
## This case is used to verify the masking when seqlen_kv > seqlen_q by comparing the saved mask tensor with the output of test_pytorch_hstu_mask_v2.py
$EXE -v=1 -prec=bf16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -seqlens_kv=63,68,71 -causal=1 -local_len=0 -context_len=3 -minfull_len=0 -targets=4,5,6 -attn_scale=0 -norm_dist=0 -save_mask=1

View File

@@ -0,0 +1,106 @@
import math
import torch
from typing import Optional
def get_valid_attn_mask_v2(
device: torch.device,
causal: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
seq_lengths_q: torch.Tensor,
seq_lengths_kv: torch.Tensor,
num_targets: Optional[torch.Tensor] = None,
max_attn_len: int = 0,
contextual_seq_len: int = 0,
min_full_attn_seq_len: int = 0,
) -> torch.Tensor:
N = max(max_seqlen_q, max_seqlen_kv)
ids = torch.arange(0, N, device=device).view(1, N)
max_ids_q = seq_lengths_q.view(-1, 1, 1)
max_ids_kv = seq_lengths_kv.view(-1, 1, 1)
diff_q_kv = max_ids_kv - max_ids_q
if contextual_seq_len > 0:
ids = ids - contextual_seq_len + 1
ids = torch.clamp(ids, min=0)
max_ids_q = max_ids_q - contextual_seq_len + 1
max_ids_kv = max_ids_kv - contextual_seq_len + 1
if num_targets is not None:
max_ids_q = max_ids_q - num_targets.view(-1, 1, 1)
max_ids_kv = max_ids_kv - num_targets.view(-1, 1, 1)
raw_row_ids = torch.clamp(
ids,
max=max_ids_q,
)
raw_row_ids = raw_row_ids + diff_q_kv
max_ids_q = max_ids_q + diff_q_kv
raw_col_ids = torch.clamp(
ids,
max=max_ids_kv,
)
row_ids = raw_row_ids.view(-1, N, 1).expand(-1, N, N)
col_ids = raw_col_ids.view(-1, 1, N).expand(-1, N, N)
row_col_dist = row_ids - col_ids
## ensure mask value in diagonal is always 1
##valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N)
valid_attn_mask = torch.zeros_like(row_col_dist).to(torch.bool)
for idx0 in range(valid_attn_mask.size(0)):
for idx1 in torch.arange(max_seqlen_q):
valid_attn_mask[idx0, idx1, idx1 + diff_q_kv[idx0]] = 1
if not causal:
row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist)
## 1) for token pair in [seqlen-num_target, N) x [seqlen-num_target, N), row_col_dist is 0
## 2) for token pair in [seqlen-num-target, N) x [0, seqlen-num_target), row_col_dist > 0
## 3) for token_pair in [0, seqlen-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0
valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0)
if max_attn_len > 0:
if min_full_attn_seq_len > 0:
valid_attn_mask = torch.logical_and(
valid_attn_mask,
torch.logical_or(
row_col_dist <= max_attn_len,
row_ids >= max_ids_q - min_full_attn_seq_len,
),
)
else:
valid_attn_mask = torch.logical_and(
valid_attn_mask, row_col_dist <= max_attn_len
)
if contextual_seq_len > 0:
## ensure first contextual_seqlen rows (where row_ids==0) attend to all cols less than max_ids
valid_attn_mask = torch.logical_or(
valid_attn_mask, torch.logical_and(row_ids == diff_q_kv, col_ids < max_ids_kv)
)
fit_valid_attn_mask = valid_attn_mask[:, :max_seqlen_q, :]
return fit_valid_attn_mask.to(torch.int8)
def main():
max_seqlen_q=64
max_seqlen_kv=80
contextual_seq_len=3
max_attn_len=0
causal=True
min_full_attn_seq_len=0
dev_type=torch.device("cpu")
seq_lengths_q=torch.tensor((56,60,64), device=dev_type, dtype=torch.int32)
seq_lengths_kv=torch.tensor((70,76,80), device=dev_type, dtype=torch.int32)
num_targets=torch.tensor((4,5,6), device=dev_type, dtype=torch.int32)
valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
torch.save(valid_attn_mask, "torch_hstu_mask_0.pt")
max_attn_len=4
min_full_attn_seq_len=6
valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
torch.save(valid_attn_mask, "torch_hstu_mask_1.pt")
if __name__ == "__main__":
main()