diff --git a/example/ck_tile/18_hstu_attention/scripts/test_ck_hstu_mask.sh b/example/ck_tile/18_hstu_attention/scripts/test_ck_hstu_mask.sh index 6ed1cac3a5..5eef292343 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_ck_hstu_mask.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_ck_hstu_mask.sh @@ -3,4 +3,8 @@ BUILD=build EXE=$BUILD/bin/tile_example_hstu_attention -$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=56,60,64 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1 +$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1 +mv ck_hstu_mask.dat ck_hstu_mask_0.dat + +$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -causal=1 -local_len=4 -context_len=3 -minfull_len=6 -targets=4,5,6 -save_mask=1 +mv ck_hstu_mask.dat ck_hstu_mask_1.dat diff --git a/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask.py b/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask.py index 04530a6257..35e5f427f2 100644 --- a/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask.py +++ b/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask.py @@ -73,9 +73,11 @@ def main(): num_targets=torch.tensor((4,5,6), device=dev_type, dtype=torch.int32) valid_attn_mask=get_valid_attn_mask(dev_type, causal, N, seq_lengths, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len) - ##torch.set_printoptions(profile="full", linewidth=1024) - ##print(valid_attn_mask) - torch.save(valid_attn_mask, "torch_hstu_mask.pt") + torch.save(valid_attn_mask, "torch_hstu_mask_0.pt") + + min_full_attn_seq_len=6 + valid_attn_mask=get_valid_attn_mask(dev_type, causal, N, seq_lengths, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len) + torch.save(valid_attn_mask, "torch_hstu_mask_1.pt") if __name__ == "__main__": main()