Update to test_ck_hstu_mask.sh and test_pytorch_hstu_mask.py to align their testings

This commit is contained in:
Qianfeng Zhang
2025-06-22 15:20:47 +00:00
parent 463a19859a
commit c87a217475
2 changed files with 10 additions and 4 deletions

View File

@@ -3,4 +3,8 @@
BUILD=build
EXE=$BUILD/bin/tile_example_hstu_attention
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=56,60,64 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1
mv ck_hstu_mask.dat ck_hstu_mask_0.dat
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -causal=1 -local_len=4 -context_len=3 -minfull_len=6 -targets=4,5,6 -save_mask=1
mv ck_hstu_mask.dat ck_hstu_mask_1.dat

View File

@@ -73,9 +73,11 @@ def main():
num_targets=torch.tensor((4,5,6), device=dev_type, dtype=torch.int32)
valid_attn_mask=get_valid_attn_mask(dev_type, causal, N, seq_lengths, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
##torch.set_printoptions(profile="full", linewidth=1024)
##print(valid_attn_mask)
torch.save(valid_attn_mask, "torch_hstu_mask.pt")
torch.save(valid_attn_mask, "torch_hstu_mask_0.pt")
min_full_attn_seq_len=6
valid_attn_mask=get_valid_attn_mask(dev_type, causal, N, seq_lengths, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
torch.save(valid_attn_mask, "torch_hstu_mask_1.pt")
if __name__ == "__main__":
main()