mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Update to test_ck_hstu_mask.sh and test_pytorch_hstu_mask.py to align their testings
This commit is contained in:
@@ -3,4 +3,8 @@
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=56,60,64 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1
|
||||
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -causal=1 -local_len=4 -context_len=3 -minfull_len=0 -targets=4,5,6 -save_mask=1
|
||||
mv ck_hstu_mask.dat ck_hstu_mask_0.dat
|
||||
|
||||
$EXE -v=1 -prec=fp16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -causal=1 -local_len=4 -context_len=3 -minfull_len=6 -targets=4,5,6 -save_mask=1
|
||||
mv ck_hstu_mask.dat ck_hstu_mask_1.dat
|
||||
|
||||
@@ -73,9 +73,11 @@ def main():
|
||||
num_targets=torch.tensor((4,5,6), device=dev_type, dtype=torch.int32)
|
||||
|
||||
valid_attn_mask=get_valid_attn_mask(dev_type, causal, N, seq_lengths, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
|
||||
##torch.set_printoptions(profile="full", linewidth=1024)
|
||||
##print(valid_attn_mask)
|
||||
torch.save(valid_attn_mask, "torch_hstu_mask.pt")
|
||||
torch.save(valid_attn_mask, "torch_hstu_mask_0.pt")
|
||||
|
||||
min_full_attn_seq_len=6
|
||||
valid_attn_mask=get_valid_attn_mask(dev_type, causal, N, seq_lengths, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
|
||||
torch.save(valid_attn_mask, "torch_hstu_mask_1.pt")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user