Fix comments in test_pytorch_hstu_mask.py scripts

This commit is contained in:
Qianfeng Zhang
2025-07-22 13:21:01 +00:00
parent 47c4a0c2ec
commit b57939ff64

View File

@@ -40,7 +40,7 @@ def get_valid_attn_mask(
## 1) for token pair in [seqlen-num_target, N) x [seqlen-num_target, N), row_col_dist is 0
## 2) for token pair in [seqlen-num-target, N) x [0, seqlen-num_target), row_col_dist > 0
## 3) for token_pair in [0, seq-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0
## 3) for token_pair in [0, seqlen-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0
valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0)
if max_attn_len > 0:
if min_full_attn_seq_len > 0: