diff --git a/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask.py b/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask.py index 35e5f427f2..6ba1955c27 100644 --- a/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask.py +++ b/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask.py @@ -40,7 +40,7 @@ def get_valid_attn_mask( ## 1) for token pair in [seqlen-num_target, N) x [seqlen-num_target, N), row_col_dist is 0 ## 2) for token pair in [seqlen-num-target, N) x [0, seqlen-num_target), row_col_dist > 0 - ## 3) for token_pair in [0, seq-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0 + ## 3) for token_pair in [0, seqlen-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0 valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0) if max_attn_len > 0: if min_full_attn_seq_len > 0: