Align the masking logic in HstuCrossAttentionBlockMask with pytorch mask_v2 scripts

This commit is contained in:
Qianfeng Zhang
2026-02-09 15:55:13 +00:00
parent 6f8b9548b5
commit f2a555dac7
5 changed files with 102 additions and 94 deletions

View File

@@ -284,9 +284,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
is_cross_attention = true;
// assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined
if(input_max_uih_seqlen_kv <= 0)
input_max_uih_seqlen_kv = input_max_uih_seqlen_q;
if(!is_cross_attention)
{
// assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined
if(input_max_uih_seqlen_kv <= 0)
input_max_uih_seqlen_kv = input_max_uih_seqlen_q;
};
if(is_jagged)
{
@@ -333,7 +336,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
int phy_seqlen_q = 0;
int phy_seqlen_kv = 0;
int max_seqlen_q = max_uih_seqlen_q + max_target + contextual_seqlen;
int max_seqlen_kv = max_uih_seqlen_kv + max_target + contextual_seqlen;
int max_seqlen_kv = is_cross_attention ? max_uih_seqlen_kv + contextual_seqlen
: max_uih_seqlen_kv + max_target + contextual_seqlen;
std::vector<int> seq_offsets_q;
std::vector<int> seq_offsets_kv;
@@ -356,12 +360,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int i = 0; i < num_batch; i++)
{
int batch_seqlen = num_targets.empty()
? seq_lengths_kv[i] + contextual_seqlen
: seq_lengths_kv[i] + num_targets[i] + contextual_seqlen;
if(!is_cross_attention)
{
int batch_seqlen = num_targets.empty()
? seq_lengths_kv[i] + contextual_seqlen
: seq_lengths_kv[i] + num_targets[i] + contextual_seqlen;
phy_seqlen_kv += batch_seqlen;
seq_offsets_kv.push_back(phy_seqlen_kv);
phy_seqlen_kv += batch_seqlen;
seq_offsets_kv.push_back(phy_seqlen_kv);
}
else // for cross_attention, assume target_in_kv == false
{
int batch_seqlen = seq_lengths_kv[i] + contextual_seqlen;
phy_seqlen_kv += batch_seqlen;
seq_offsets_kv.push_back(phy_seqlen_kv);
}
};
}
else

View File

@@ -47,7 +47,7 @@ struct HstuCrossAttentionBlockMaskWithLocal
min_full_attn_seqlen(min_full_attn_seqlen_)
{
max_q_uih_len = seqlen_q - num_target_;
max_k_uih_len = seqlen_k - num_target_;
max_k_uih_len = seqlen_k; // assuming target_in_kv == false
// in case user provided max_attn_len_ could be bigger than max_uih_len
max_attn_len = min(max_k_uih_len, min(max_q_uih_len, max_attn_len));
@@ -223,11 +223,12 @@ struct HstuCrossAttentionBlockMaskWithLocal
}
else
{
// Non-causal: only apply sliding window constraint, no diagonal inclusion
// logic This matches PyTorch reference which just returns boundary mask for non-causal
bool in_min_full_scope =
(min_full_attn_seqlen > 0) ? (row_id >= max_row_id - min_full_attn_seqlen) : false;
return (((row_id != col_id) || (row == col)) &&
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
return ((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope);
}
};
@@ -523,7 +524,7 @@ struct HstuCrossAttentionBlockMaskNoLocal
: seqlen_q(seqlen_q_), seqlen_k(seqlen_k_), contextual_seqlen(contextual_seqlen_)
{
max_q_uih_len = seqlen_q - num_target_;
max_k_uih_len = seqlen_k - num_target_;
max_k_uih_len = seqlen_k; // assuming target_in_kv == false
if(contextual_seqlen > 0)
{
@@ -609,7 +610,9 @@ struct HstuCrossAttentionBlockMaskNoLocal
}
else
{
return (row_id != col_id) || (row == col);
// Non-causal: no masking needed, everything in bounds is allowed
// This matches PyTorch reference which just returns boundary mask for non-causal
return true;
};
};

View File

@@ -187,6 +187,10 @@ struct reference_hstu_attention
{
for(int sq = 0; sq < max_seqlen_q; sq++)
for(int sk = 0; sk < max_seqlen_kv; sk++)
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = 0;
for(int sq = 0; sq < seqlen_q; sq++)
for(int sk = 0; sk < seqlen_kv; sk++)
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
}

View File

@@ -51,4 +51,4 @@ for T in "fp16" "bf16"; do
done
## This case is used to verify the masking when seqlen_kv > seqlen_q by comparing the saved mask tensor with the output of test_pytorch_hstu_mask_v2.py
$EXE -v=1 -prec=bf16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -seqlens_kv=63,68,71 -causal=1 -local_len=0 -context_len=3 -minfull_len=0 -targets=4,5,6 -attn_scale=0 -norm_dist=0 -save_mask=1
$EXE -v=1 -prec=bf16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=52,55,58 -seqlens_kv=70,76,80 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=4,5,6 -attn_scale=0 -norm_dist=0 -save_mask=1

View File

@@ -1,7 +1,9 @@
import math
import torch
from typing import Optional
import torch
def get_valid_attn_mask_v2(
device: torch.device,
causal: bool,
@@ -10,97 +12,82 @@ def get_valid_attn_mask_v2(
seq_lengths_q: torch.Tensor,
seq_lengths_kv: torch.Tensor,
num_targets: Optional[torch.Tensor] = None,
max_attn_len: int = 0,
contextual_seq_len: int = 0,
min_full_attn_seq_len: int = 0,
) -> torch.Tensor:
N = max(max_seqlen_q, max_seqlen_kv)
ids = torch.arange(0, N, device=device).view(1, N)
max_ids_q = seq_lengths_q.view(-1, 1, 1)
max_ids_kv = seq_lengths_kv.view(-1, 1, 1)
diff_q_kv = max_ids_kv - max_ids_q
if contextual_seq_len > 0:
ids = ids - contextual_seq_len + 1
ids = torch.clamp(ids, min=0)
max_ids_q = max_ids_q - contextual_seq_len + 1
max_ids_kv = max_ids_kv - contextual_seq_len + 1
if num_targets is not None:
max_ids_q = max_ids_q - num_targets.view(-1, 1, 1)
max_ids_kv = max_ids_kv - num_targets.view(-1, 1, 1)
"""
Generate attention mask for HSTU attention.
raw_row_ids = torch.clamp(
ids,
max=max_ids_q,
This implementation matches the Hammer/PyTorch reference (_get_valid_attn_mask_v2)
with targets_in_kv=False, meaning:
- Q sequence: [UIH_Q][Targets] with length seq_lengths_q
- KV sequence: [UIH_KV] only with length seq_lengths_kv (NO targets in KV)
For causal attention with num_targets:
- UIH rows (col_ids < uih_lengths_q): apply causal mask (shifted_col_ids >= row_ids)
- Target rows (col_ids >= uih_lengths_q): can attend to everything in KV (full attention)
"""
# Create position indices - matching Hammer convention:
# col_ids indexes Q dimension (rows in attention matrix when viewed as Q x KV)
# row_ids indexes KV dimension (columns in attention matrix)
col_ids = torch.arange(0, max_seqlen_q, device=device).view(1, max_seqlen_q, 1)
row_ids = torch.arange(0, max_seqlen_kv, device=device).view(1, 1, max_seqlen_kv)
# Boundary mask: positions within valid sequence bounds
in_boundary_valid_attn_mask = torch.logical_and(
row_ids < seq_lengths_kv.view(-1, 1, 1), col_ids < seq_lengths_q.view(-1, 1, 1)
)
raw_row_ids = raw_row_ids + diff_q_kv
max_ids_q = max_ids_q + diff_q_kv
raw_col_ids = torch.clamp(
ids,
max=max_ids_kv,
)
row_ids = raw_row_ids.view(-1, N, 1).expand(-1, N, N)
col_ids = raw_col_ids.view(-1, 1, N).expand(-1, N, N)
row_col_dist = row_ids - col_ids
## ensure mask value in diagonal is always 1
##valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N)
valid_attn_mask = torch.zeros_like(row_col_dist).to(torch.bool)
for idx0 in range(valid_attn_mask.size(0)):
for idx1 in torch.arange(max_seqlen_q):
valid_attn_mask[idx0, idx1, idx1 + diff_q_kv[idx0]] = 1
if not causal:
row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist)
## 1) for token pair in [seqlen-num_target, N) x [seqlen-num_target, N), row_col_dist is 0
## 2) for token pair in [seqlen-num-target, N) x [0, seqlen-num_target), row_col_dist > 0
## 3) for token_pair in [0, seqlen-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0
valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0)
if max_attn_len > 0:
if min_full_attn_seq_len > 0:
valid_attn_mask = torch.logical_and(
valid_attn_mask,
torch.logical_or(
row_col_dist <= max_attn_len,
row_ids >= max_ids_q - min_full_attn_seq_len,
),
if causal:
if num_targets is None:
# Causal without num_targets: simple shifted causal mask
delta_col_ids = seq_lengths_kv - seq_lengths_q
shifted_col_ids = col_ids + delta_col_ids.view(-1, 1, 1)
causal_mask = shifted_col_ids >= row_ids
return torch.logical_and(in_boundary_valid_attn_mask, causal_mask).to(
torch.int8
)
else:
valid_attn_mask = torch.logical_and(
valid_attn_mask, row_col_dist <= max_attn_len
# Causal with num_targets and targets_in_kv=False
# This exactly mirrors the Hammer logic with targets_in_kv=False
uih_lengths_q = seq_lengths_q - num_targets
delta_col_ids = seq_lengths_kv - uih_lengths_q
# targets_in_kv=False: NO subtraction of num_targets from delta_col_ids
shifted_col_ids = col_ids + delta_col_ids.view(-1, 1, 1)
# UIH rows: apply causal mask
causal_mask = torch.logical_and(
col_ids < uih_lengths_q.view(-1, 1, 1), shifted_col_ids >= row_ids
)
if contextual_seq_len > 0:
## ensure first contextual_seqlen rows (where row_ids==0) attend to all cols less than max_ids
valid_attn_mask = torch.logical_or(
valid_attn_mask, torch.logical_and(row_ids == diff_q_kv, col_ids < max_ids_kv)
)
fit_valid_attn_mask = valid_attn_mask[:, :max_seqlen_q, :]
# Target rows: full attention to KV (no additional constraint for targets_in_kv=False)
target_mask = col_ids >= uih_lengths_q.view(-1, 1, 1)
return torch.logical_and(
in_boundary_valid_attn_mask, torch.logical_or(causal_mask, target_mask)
).to(torch.int8)
else:
# Non-causal: everything in bounds is allowed
return in_boundary_valid_attn_mask.to(torch.int8)
return fit_valid_attn_mask.to(torch.int8)
def main():
max_seqlen_q=64
max_seqlen_kv=80
contextual_seq_len=3
max_attn_len=0
causal=True
min_full_attn_seq_len=0
dev_type=torch.device("cpu")
seq_lengths_q=torch.tensor((56,60,64), device=dev_type, dtype=torch.int32)
seq_lengths_kv=torch.tensor((70,76,80), device=dev_type, dtype=torch.int32)
num_targets=torch.tensor((4,5,6), device=dev_type, dtype=torch.int32)
max_seqlen_q = 64
max_seqlen_kv = 80
causal = True
dev_type = torch.device("cpu")
seq_lengths_q = torch.tensor((56, 60, 64), device=dev_type, dtype=torch.int32)
seq_lengths_kv = torch.tensor((70, 76, 80), device=dev_type, dtype=torch.int32)
num_targets = torch.tensor((4, 5, 6), device=dev_type, dtype=torch.int32)
valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
valid_attn_mask = get_valid_attn_mask_v2(
dev_type,
causal,
max_seqlen_q,
max_seqlen_kv,
seq_lengths_q,
seq_lengths_kv,
num_targets,
)
torch.save(valid_attn_mask, "torch_hstu_mask_0.pt")
max_attn_len=4
min_full_attn_seq_len=6
valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
torch.save(valid_attn_mask, "torch_hstu_mask_1.pt")
if __name__ == "__main__":
main()