diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index def95129ab..638bc468b2 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -99,6 +99,7 @@ auto create_args(int argc, char* argv[]) .insert("seqlens", "400", "uih seqlen of single or all batches for query tensor, actually allocated seqlen will include the target of each batch and context_len") .insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len") .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") + .insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") .insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention") .insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets") .insert("softmax", "0", "use softmax or not") @@ -253,11 +254,14 @@ bool run(const ck_tile::ArgParser& arg_parser) std::string str_of_lengths_kv = arg_parser.get_str("seqlens_kv"); std::vector seq_lengths_kv = get_integers_from_string(str_of_lengths_kv); - int input_max_uih_seqlen = arg_parser.get_int("max_seqlen"); - int input_max_target = arg_parser.get_int("max_target"); + int input_max_uih_seqlen_q = arg_parser.get_int("max_seqlen"); + int input_max_uih_seqlen_kv = arg_parser.get_int("max_seqlen_kv"); + int input_max_target = arg_parser.get_int("max_target"); - int max_uih_seqlen = 0; - int max_target = 0; + int max_uih_seqlen_q = 0; + int max_uih_seqlen_kv = 0; + + int max_target = 0; if(!num_targets.empty()) { @@ -275,6 +279,10 @@ bool run(const ck_tile::ArgParser& arg_parser) if(seq_lengths_kv.empty()) seq_lengths_kv = seq_lengths_q; + // assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined + if(input_max_uih_seqlen_kv <= 0) + input_max_uih_seqlen_kv = input_max_uih_seqlen_q; + if(is_jagged) { // supplement seq_lengths_q using the last input value if user-provided lengths not enough @@ -286,19 +294,24 @@ bool run(const ck_tile::ArgParser& arg_parser) // only consider num_batch values even if more values are provided by the user for(int i = 0; i < num_batch; i++) { - max_uih_seqlen = max(max_uih_seqlen, max(seq_lengths_q[i], seq_lengths_kv[i])); + max_uih_seqlen_q = max(max_uih_seqlen_q, seq_lengths_q[i]); + max_uih_seqlen_kv = max(max_uih_seqlen_kv, seq_lengths_kv[i]); }; } else { HSTU_CHECK(1 == seq_lengths_q.size() && 1 == seq_lengths_kv.size(), "sequence lengths for batched mode shoud have single element!"); - max_uih_seqlen = max(seq_lengths_q[0], seq_lengths_kv[0]); + max_uih_seqlen_q = seq_lengths_q[0]; + max_uih_seqlen_kv = seq_lengths_kv[0]; }; // the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens // the user input of max_target can either be ignored or be bigger than all targets - HSTU_CHECK(input_max_uih_seqlen <= 0 || input_max_uih_seqlen >= max_uih_seqlen, + HSTU_CHECK(input_max_uih_seqlen_q <= 0 || input_max_uih_seqlen_q >= max_uih_seqlen_q, + "the user input of max_uih_seqlen can either be ignored or be bigger than all " + "uih_seqlens!"); + HSTU_CHECK(input_max_uih_seqlen_kv <= 0 || input_max_uih_seqlen_kv >= max_uih_seqlen_kv, "the user input of max_uih_seqlen can either be ignored or be bigger than all " "uih_seqlens!"); HSTU_CHECK(input_max_target <= 0 || input_max_target >= max_target, @@ -306,12 +319,14 @@ bool run(const ck_tile::ArgParser& arg_parser) HSTU_CHECK(contextual_seqlen >= 0, "contextual_seqlen should be non-negative!"); - max_uih_seqlen = (input_max_uih_seqlen > 0) ? input_max_uih_seqlen : max_uih_seqlen; - max_target = (input_max_target > 0) ? input_max_target : max_target; + max_uih_seqlen_q = (input_max_uih_seqlen_q > 0) ? input_max_uih_seqlen_q : max_uih_seqlen_q; + max_uih_seqlen_kv = (input_max_uih_seqlen_kv > 0) ? input_max_uih_seqlen_kv : max_uih_seqlen_kv; + max_target = (input_max_target > 0) ? input_max_target : max_target; int phy_seqlen_q = 0; int phy_seqlen_kv = 0; - int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen; + int max_seqlen_q = max_uih_seqlen_q + max_target + contextual_seqlen; + int max_seqlen_kv = max_uih_seqlen_kv + max_target + contextual_seqlen; std::vector seq_offsets_q; std::vector seq_offsets_kv; @@ -344,8 +359,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } else { - phy_seqlen_q = max_seqlen; - phy_seqlen_kv = max_seqlen; + phy_seqlen_q = max_seqlen_q; + phy_seqlen_kv = max_seqlen_kv; }; long total_flops = 0; @@ -384,8 +399,9 @@ bool run(const ck_tile::ArgParser& arg_parser) std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_v}); ck_tile::HostTensor mask_host( - save_mask ? std::array{num_batch, num_head, max_seqlen, max_seqlen} - : std::array{1, 1, 1, 1}); + save_mask + ? std::array{num_batch, num_head, max_seqlen_q, max_seqlen_kv} + : std::array{1, 1, 1, 1}); if(!initialize_qkv) { @@ -440,7 +456,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.num_batch = num_batch; params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer(); params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer(); - params.max_seqlen = max_seqlen; + params.max_seqlen = max(max_seqlen_q, max_seqlen_kv); params.q_ptr = q_dev.GetDeviceBuffer(); params.k_ptr = k_dev.GetDeviceBuffer(); params.v_ptr = v_dev.GetDeviceBuffer(); @@ -558,7 +574,8 @@ bool run(const ck_tile::ArgParser& arg_parser) num_batch, scale_s, attn_scale, - max_seqlen, + max_seqlen_q, + max_seqlen_kv, seq_offsets_q, seq_offsets_kv, num_targets, diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index f8e53ebd05..84122577a5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -29,6 +29,7 @@ struct HstuBlockMaskWithLocal int max_k_uih_len; int max_row_id; int max_col_id; + int diff_q_kv_len; CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_, int seqlen_q_, @@ -63,6 +64,9 @@ struct HstuBlockMaskWithLocal max_row_id = max_q_uih_len; max_col_id = max_k_uih_len; } + + diff_q_kv_len = max_k_uih_len - max_q_uih_len; + max_row_id += diff_q_kv_len; }; // to get the loop length along X axis, return index:[start, end), end-start=length @@ -77,7 +81,7 @@ struct HstuBlockMaskWithLocal { if constexpr(kUseCausal) { - index_t x_end = min(i_y + YTile, seqlen_k); + index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k); return ck_tile::make_tuple(0, x_end); } else @@ -90,7 +94,7 @@ struct HstuBlockMaskWithLocal } else // tile is completely inside [max_q_uih_len, seqlen_q) { - index_t x_end = min(i_y + YTile, seqlen_k); + index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k); return ck_tile::make_tuple(0, x_end); }; }; @@ -105,7 +109,7 @@ struct HstuBlockMaskWithLocal // some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len) if(i_y < max_q_uih_len) { - index_t x_start = i_y - max_attn_len; + index_t x_start = i_y + diff_q_kv_len - max_attn_len; index_t x_start_aligned = x_start - x_start % XTile; // some rows of the tile in [max_q_uih_len - max_attn_len, max_q_uih_len) @@ -116,14 +120,14 @@ struct HstuBlockMaskWithLocal else // whole tile in [contextual_seqlen+max_attn_len, max_q_uih_len // -max_attn_len) { - index_t x_end = i_y + YTile + max_attn_len; + index_t x_end = i_y + YTile + diff_q_kv_len + max_attn_len; return ck_tile::make_tuple(x_start_aligned, x_end); }; } else // whole tile in [max_uih_len, seqlen) { index_t x_start = max_k_uih_len - max_attn_len; - index_t x_end = min(i_y + YTile, seqlen_k); + index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k); return ck_tile::make_tuple(x_start - x_start % XTile, x_end); } @@ -132,12 +136,13 @@ struct HstuBlockMaskWithLocal { if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen) { - index_t x_end = min(max(i_y + YTile + max_attn_len, max_k_uih_len), seqlen_k); + index_t x_end = min( + max(i_y + YTile + diff_q_kv_len + max_attn_len, max_k_uih_len), seqlen_k); return ck_tile::make_tuple(0, x_end); } else // whole tile in [contextual_seqlen, seqlen) { - index_t x_end = min(i_y + YTile + max_attn_len, seqlen_k); + index_t x_end = min(i_y + YTile + diff_q_kv_len + max_attn_len, seqlen_k); return ck_tile::make_tuple(0, x_end); } } @@ -146,15 +151,15 @@ struct HstuBlockMaskWithLocal { if(i_y >= min(contextual_seqlen, 1) + max_attn_len) { - index_t x_end = min(i_y + YTile, seqlen_k); + index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k); // some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len) if(i_y < max_q_uih_len) { - index_t x_start = i_y - max_attn_len; + index_t x_start = i_y + diff_q_kv_len - max_attn_len; return ck_tile::make_tuple(x_start - x_start % XTile, x_end); } - else // whole tile in [max_uih_len, seqlen) + else // whole tile in [max_q_uih_len, seqlen_q) { index_t x_start = max_k_uih_len - max_attn_len; return ck_tile::make_tuple(x_start - x_start % XTile, x_end); @@ -164,12 +169,12 @@ struct HstuBlockMaskWithLocal { if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen) { - index_t x_end = min(max(i_y + YTile, max_k_uih_len), seqlen_k); + index_t x_end = min(max(i_y + YTile + diff_q_kv_len, max_k_uih_len), seqlen_k); return ck_tile::make_tuple(0, x_end); } else // whole tile in [contextual_seqlen, seqlen) { - index_t x_end = min(i_y + YTile, seqlen_k); + index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k); return ck_tile::make_tuple(0, x_end); } } @@ -181,17 +186,19 @@ struct HstuBlockMaskWithLocal int row_id; int col_id; + row += diff_q_kv_len; + if(contextual_seqlen > 0) { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - row_id = max(row - contextual_seqlen + 1, 0); + row_id = max(row - contextual_seqlen + 1, diff_q_kv_len); col_id = max(col - contextual_seqlen + 1, 0); row_id = min(row_id, max_row_id); col_id = min(col_id, max_col_id); - if(row_id == 0 && col_id < max_col_id) + if(row_id == diff_q_kv_len && col_id < max_col_id) return true; } else @@ -227,17 +234,19 @@ struct HstuBlockMaskWithLocal int row_id; int col_id; + row += diff_q_kv_len; + if(contextual_seqlen > 0) { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - row_id = max(row - contextual_seqlen + 1, 0); + row_id = max(row - contextual_seqlen + 1, diff_q_kv_len); col_id = max(col - contextual_seqlen + 1, 0); row_id = min(row_id, max_row_id); col_id = min(col_id, max_col_id); - if(row_id == 0 && col_id < max_col_id) + if(row_id == diff_q_kv_len && col_id < max_col_id) return true; } else @@ -281,7 +290,8 @@ struct HstuBlockMaskWithLocal { index_t i_tile_right = i_tile_left + TileWidth; - if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_k_uih_len)) + if(!is_tile_in_first_split && + i_tile_right <= min(i_tile_top + diff_q_kv_len + 1, max_k_uih_len)) return true; } else @@ -315,6 +325,7 @@ struct HstuBlockMaskNoLocal int max_k_uih_len; int max_row_id; int max_col_id; + int diff_q_kv_len; CK_TILE_HOST_DEVICE HstuBlockMaskNoLocal(int seqlen_q_, int seqlen_k_, int contextual_seqlen_, int num_target_) @@ -333,6 +344,9 @@ struct HstuBlockMaskNoLocal max_row_id = max_q_uih_len; max_col_id = max_k_uih_len; } + + diff_q_kv_len = max_k_uih_len - max_q_uih_len; + max_row_id += diff_q_kv_len; }; // to get the loop length along X axis, return index:[start, end), end-start=length @@ -348,11 +362,11 @@ struct HstuBlockMaskNoLocal } else { - index_t x_end = min(i_y + YTile, seqlen_k); + index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k); if(i_y < contextual_seqlen) { - if(i_y + YTile > max_k_uih_len) + if(i_y + YTile + diff_q_kv_len > max_k_uih_len) { return ck_tile::make_tuple(0, x_end); } @@ -373,17 +387,19 @@ struct HstuBlockMaskNoLocal int row_id; int col_id; + row += diff_q_kv_len; + if(contextual_seqlen > 0) { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - row_id = max(row - contextual_seqlen + 1, 0); + row_id = max(row - contextual_seqlen + 1, diff_q_kv_len); col_id = max(col - contextual_seqlen + 1, 0); row_id = min(row_id, max_row_id); col_id = min(col_id, max_col_id); - if(row_id == 0 && col_id < max_col_id) + if(row_id == diff_q_kv_len && col_id < max_col_id) return true; } else @@ -420,7 +436,7 @@ struct HstuBlockMaskNoLocal // assume num_target > 0 with high probability, don't check whether num_target is 0; // so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile - if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top) + if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top + diff_q_kv_len) return false; return true; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 68ec7e514e..ee844f2bcb 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -41,7 +41,8 @@ struct reference_hstu_attention int num_batch, float alpha, float attn_scale, - int max_seqlen, + int max_seqlen_q, + int max_seqlen_kv, std::vector seq_q_offsets, std::vector seq_kv_offsets, std::vector num_targets, // define masking length at the end of token @@ -93,8 +94,8 @@ struct reference_hstu_attention if(static_cast(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch && static_cast(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head && - static_cast(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen && - static_cast(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen) + static_cast(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen_q && + static_cast(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen_kv) save_mask = true; // check num_tagets @@ -114,7 +115,9 @@ struct reference_hstu_attention int num_target = num_targets.empty() ? 0 : num_targets[i_batch]; - float scale_p = attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen); + float scale_p = attn_scale + ? attn_scale + : 1.0f / static_cast(max(max_seqlen_q, max_seqlen_kv)); BOOL_SWITCH(window_size > 0, kHasLocal, [&] { using HstuMaskType = typename HstuBlockMasking::Type; @@ -148,9 +151,8 @@ struct reference_hstu_attention if(save_mask) { - // initialize the mask - for(int sq = 0; sq < max_seqlen; sq++) - for(int sk = 0; sk < max_seqlen; sk++) + for(int sq = 0; sq < max_seqlen_q; sq++) + for(int sk = 0; sk < max_seqlen_kv; sk++) mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = static_cast(mask.IsTokenPairInsideMask(sq, sk)); } diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_seqlen_kv.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_seqlen_kv.sh index f6c06401f3..f7b5abbe45 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_seqlen_kv.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_seqlen_kv.sh @@ -49,3 +49,6 @@ for T in "fp16" "bf16"; do $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=0 -norm_dist=0 set +x done + +## This case is used to verify the masking when seqlen_kv > seqlen_q by comparing the saved mask tensor with the output of test_pytorch_hstu_mask_v2.py +$EXE -v=1 -prec=bf16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -seqlens_kv=63,68,71 -causal=1 -local_len=0 -context_len=3 -minfull_len=0 -targets=4,5,6 -attn_scale=0 -norm_dist=0 -save_mask=1 diff --git a/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask_v2.py b/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask_v2.py new file mode 100644 index 0000000000..21bc843f42 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask_v2.py @@ -0,0 +1,106 @@ +import math +import torch +from typing import Optional + +def get_valid_attn_mask_v2( + device: torch.device, + causal: bool, + max_seqlen_q: int, + max_seqlen_kv: int, + seq_lengths_q: torch.Tensor, + seq_lengths_kv: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, +) -> torch.Tensor: + N = max(max_seqlen_q, max_seqlen_kv) + ids = torch.arange(0, N, device=device).view(1, N) + max_ids_q = seq_lengths_q.view(-1, 1, 1) + max_ids_kv = seq_lengths_kv.view(-1, 1, 1) + diff_q_kv = max_ids_kv - max_ids_q + if contextual_seq_len > 0: + ids = ids - contextual_seq_len + 1 + ids = torch.clamp(ids, min=0) + max_ids_q = max_ids_q - contextual_seq_len + 1 + max_ids_kv = max_ids_kv - contextual_seq_len + 1 + if num_targets is not None: + max_ids_q = max_ids_q - num_targets.view(-1, 1, 1) + max_ids_kv = max_ids_kv - num_targets.view(-1, 1, 1) + + raw_row_ids = torch.clamp( + ids, + max=max_ids_q, + ) + raw_row_ids = raw_row_ids + diff_q_kv + max_ids_q = max_ids_q + diff_q_kv + raw_col_ids = torch.clamp( + ids, + max=max_ids_kv, + ) + row_ids = raw_row_ids.view(-1, N, 1).expand(-1, N, N) + col_ids = raw_col_ids.view(-1, 1, N).expand(-1, N, N) + + row_col_dist = row_ids - col_ids + + ## ensure mask value in diagonal is always 1 + ##valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N) + valid_attn_mask = torch.zeros_like(row_col_dist).to(torch.bool) + for idx0 in range(valid_attn_mask.size(0)): + for idx1 in torch.arange(max_seqlen_q): + valid_attn_mask[idx0, idx1, idx1 + diff_q_kv[idx0]] = 1 + + if not causal: + row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist) + + ## 1) for token pair in [seqlen-num_target, N) x [seqlen-num_target, N), row_col_dist is 0 + ## 2) for token pair in [seqlen-num-target, N) x [0, seqlen-num_target), row_col_dist > 0 + ## 3) for token_pair in [0, seqlen-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0 + valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0) + if max_attn_len > 0: + if min_full_attn_seq_len > 0: + valid_attn_mask = torch.logical_and( + valid_attn_mask, + torch.logical_or( + row_col_dist <= max_attn_len, + row_ids >= max_ids_q - min_full_attn_seq_len, + ), + ) + else: + valid_attn_mask = torch.logical_and( + valid_attn_mask, row_col_dist <= max_attn_len + ) + if contextual_seq_len > 0: + ## ensure first contextual_seqlen rows (where row_ids==0) attend to all cols less than max_ids + valid_attn_mask = torch.logical_or( + valid_attn_mask, torch.logical_and(row_ids == diff_q_kv, col_ids < max_ids_kv) + ) + + fit_valid_attn_mask = valid_attn_mask[:, :max_seqlen_q, :] + + return fit_valid_attn_mask.to(torch.int8) + +def main(): + max_seqlen_q=64 + max_seqlen_kv=80 + contextual_seq_len=3 + max_attn_len=0 + causal=True + min_full_attn_seq_len=0 + dev_type=torch.device("cpu") + seq_lengths_q=torch.tensor((56,60,64), device=dev_type, dtype=torch.int32) + seq_lengths_kv=torch.tensor((70,76,80), device=dev_type, dtype=torch.int32) + num_targets=torch.tensor((4,5,6), device=dev_type, dtype=torch.int32) + + valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len) + torch.save(valid_attn_mask, "torch_hstu_mask_0.pt") + + max_attn_len=4 + min_full_attn_seq_len=6 + valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len) + torch.save(valid_attn_mask, "torch_hstu_mask_1.pt") + +if __name__ == "__main__": + main() + +