mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Update to use BottomRight-Diagonal masking when seqlen_kv is bigger than seqlen_q
This commit is contained in:
@@ -99,6 +99,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("seqlens", "400", "uih seqlen of single or all batches for query tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("seqlens_kv", "", "uih seqlen of single or all batches for key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("max_seqlen_kv", "0", "max uih_seqlen_kv, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("softmax", "0", "use softmax or not")
|
||||
@@ -253,11 +254,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::string str_of_lengths_kv = arg_parser.get_str("seqlens_kv");
|
||||
std::vector<int> seq_lengths_kv = get_integers_from_string(str_of_lengths_kv);
|
||||
|
||||
int input_max_uih_seqlen = arg_parser.get_int("max_seqlen");
|
||||
int input_max_target = arg_parser.get_int("max_target");
|
||||
int input_max_uih_seqlen_q = arg_parser.get_int("max_seqlen");
|
||||
int input_max_uih_seqlen_kv = arg_parser.get_int("max_seqlen_kv");
|
||||
int input_max_target = arg_parser.get_int("max_target");
|
||||
|
||||
int max_uih_seqlen = 0;
|
||||
int max_target = 0;
|
||||
int max_uih_seqlen_q = 0;
|
||||
int max_uih_seqlen_kv = 0;
|
||||
|
||||
int max_target = 0;
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
@@ -275,6 +279,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(seq_lengths_kv.empty())
|
||||
seq_lengths_kv = seq_lengths_q;
|
||||
|
||||
// assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined
|
||||
if(input_max_uih_seqlen_kv <= 0)
|
||||
input_max_uih_seqlen_kv = input_max_uih_seqlen_q;
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
|
||||
@@ -286,19 +294,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// only consider num_batch values even if more values are provided by the user
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
max_uih_seqlen = max(max_uih_seqlen, max(seq_lengths_q[i], seq_lengths_kv[i]));
|
||||
max_uih_seqlen_q = max(max_uih_seqlen_q, seq_lengths_q[i]);
|
||||
max_uih_seqlen_kv = max(max_uih_seqlen_kv, seq_lengths_kv[i]);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
HSTU_CHECK(1 == seq_lengths_q.size() && 1 == seq_lengths_kv.size(),
|
||||
"sequence lengths for batched mode shoud have single element!");
|
||||
max_uih_seqlen = max(seq_lengths_q[0], seq_lengths_kv[0]);
|
||||
max_uih_seqlen_q = seq_lengths_q[0];
|
||||
max_uih_seqlen_kv = seq_lengths_kv[0];
|
||||
};
|
||||
|
||||
// the user input of max_uih_seqlen can either be ignored or be bigger than all uih_seqlens
|
||||
// the user input of max_target can either be ignored or be bigger than all targets
|
||||
HSTU_CHECK(input_max_uih_seqlen <= 0 || input_max_uih_seqlen >= max_uih_seqlen,
|
||||
HSTU_CHECK(input_max_uih_seqlen_q <= 0 || input_max_uih_seqlen_q >= max_uih_seqlen_q,
|
||||
"the user input of max_uih_seqlen can either be ignored or be bigger than all "
|
||||
"uih_seqlens!");
|
||||
HSTU_CHECK(input_max_uih_seqlen_kv <= 0 || input_max_uih_seqlen_kv >= max_uih_seqlen_kv,
|
||||
"the user input of max_uih_seqlen can either be ignored or be bigger than all "
|
||||
"uih_seqlens!");
|
||||
HSTU_CHECK(input_max_target <= 0 || input_max_target >= max_target,
|
||||
@@ -306,12 +319,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
HSTU_CHECK(contextual_seqlen >= 0, "contextual_seqlen should be non-negative!");
|
||||
|
||||
max_uih_seqlen = (input_max_uih_seqlen > 0) ? input_max_uih_seqlen : max_uih_seqlen;
|
||||
max_target = (input_max_target > 0) ? input_max_target : max_target;
|
||||
max_uih_seqlen_q = (input_max_uih_seqlen_q > 0) ? input_max_uih_seqlen_q : max_uih_seqlen_q;
|
||||
max_uih_seqlen_kv = (input_max_uih_seqlen_kv > 0) ? input_max_uih_seqlen_kv : max_uih_seqlen_kv;
|
||||
max_target = (input_max_target > 0) ? input_max_target : max_target;
|
||||
|
||||
int phy_seqlen_q = 0;
|
||||
int phy_seqlen_kv = 0;
|
||||
int max_seqlen = max_uih_seqlen + max_target + contextual_seqlen;
|
||||
int max_seqlen_q = max_uih_seqlen_q + max_target + contextual_seqlen;
|
||||
int max_seqlen_kv = max_uih_seqlen_kv + max_target + contextual_seqlen;
|
||||
|
||||
std::vector<int> seq_offsets_q;
|
||||
std::vector<int> seq_offsets_kv;
|
||||
@@ -344,8 +359,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
else
|
||||
{
|
||||
phy_seqlen_q = max_seqlen;
|
||||
phy_seqlen_kv = max_seqlen;
|
||||
phy_seqlen_q = max_seqlen_q;
|
||||
phy_seqlen_kv = max_seqlen_kv;
|
||||
};
|
||||
|
||||
long total_flops = 0;
|
||||
@@ -384,8 +399,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<int8_t> mask_host(
|
||||
save_mask ? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen, max_seqlen}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
save_mask
|
||||
? std::array<ck_tile::index_t, 4>{num_batch, num_head, max_seqlen_q, max_seqlen_kv}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
if(!initialize_qkv)
|
||||
{
|
||||
@@ -440,7 +456,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.num_batch = num_batch;
|
||||
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
|
||||
params.seq_kv_offsets_ptr = seq_offsets_kv_dev.GetDeviceBuffer();
|
||||
params.max_seqlen = max_seqlen;
|
||||
params.max_seqlen = max(max_seqlen_q, max_seqlen_kv);
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
@@ -558,7 +574,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_batch,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
max_seqlen,
|
||||
max_seqlen_q,
|
||||
max_seqlen_kv,
|
||||
seq_offsets_q,
|
||||
seq_offsets_kv,
|
||||
num_targets,
|
||||
|
||||
@@ -29,6 +29,7 @@ struct HstuBlockMaskWithLocal
|
||||
int max_k_uih_len;
|
||||
int max_row_id;
|
||||
int max_col_id;
|
||||
int diff_q_kv_len;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(bool is_tile_in_first_split_,
|
||||
int seqlen_q_,
|
||||
@@ -63,6 +64,9 @@ struct HstuBlockMaskWithLocal
|
||||
max_row_id = max_q_uih_len;
|
||||
max_col_id = max_k_uih_len;
|
||||
}
|
||||
|
||||
diff_q_kv_len = max_k_uih_len - max_q_uih_len;
|
||||
max_row_id += diff_q_kv_len;
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
@@ -77,7 +81,7 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else
|
||||
@@ -90,7 +94,7 @@ struct HstuBlockMaskWithLocal
|
||||
}
|
||||
else // tile is completely inside [max_q_uih_len, seqlen_q)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
@@ -105,7 +109,7 @@ struct HstuBlockMaskWithLocal
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len)
|
||||
if(i_y < max_q_uih_len)
|
||||
{
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
index_t x_start = i_y + diff_q_kv_len - max_attn_len;
|
||||
index_t x_start_aligned = x_start - x_start % XTile;
|
||||
|
||||
// some rows of the tile in [max_q_uih_len - max_attn_len, max_q_uih_len)
|
||||
@@ -116,14 +120,14 @@ struct HstuBlockMaskWithLocal
|
||||
else // whole tile in [contextual_seqlen+max_attn_len, max_q_uih_len
|
||||
// -max_attn_len)
|
||||
{
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
index_t x_end = i_y + YTile + diff_q_kv_len + max_attn_len;
|
||||
return ck_tile::make_tuple(x_start_aligned, x_end);
|
||||
};
|
||||
}
|
||||
else // whole tile in [max_uih_len, seqlen)
|
||||
{
|
||||
index_t x_start = max_k_uih_len - max_attn_len;
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
|
||||
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
@@ -132,12 +136,13 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
|
||||
{
|
||||
index_t x_end = min(max(i_y + YTile + max_attn_len, max_k_uih_len), seqlen_k);
|
||||
index_t x_end = min(
|
||||
max(i_y + YTile + diff_q_kv_len + max_attn_len, max_k_uih_len), seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else // whole tile in [contextual_seqlen, seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile + max_attn_len, seqlen_k);
|
||||
index_t x_end = min(i_y + YTile + diff_q_kv_len + max_attn_len, seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
}
|
||||
@@ -146,15 +151,15 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if(i_y >= min(contextual_seqlen, 1) + max_attn_len)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
|
||||
|
||||
// some row of the tile in [contextual_seqlen+max_attn_len, max_q_uih_len)
|
||||
if(i_y < max_q_uih_len)
|
||||
{
|
||||
index_t x_start = i_y - max_attn_len;
|
||||
index_t x_start = i_y + diff_q_kv_len - max_attn_len;
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
}
|
||||
else // whole tile in [max_uih_len, seqlen)
|
||||
else // whole tile in [max_q_uih_len, seqlen_q)
|
||||
{
|
||||
index_t x_start = max_k_uih_len - max_attn_len;
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
@@ -164,12 +169,12 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
if(i_y < contextual_seqlen) // some row of the tile in [0, contextual_seqlen)
|
||||
{
|
||||
index_t x_end = min(max(i_y + YTile, max_k_uih_len), seqlen_k);
|
||||
index_t x_end = min(max(i_y + YTile + diff_q_kv_len, max_k_uih_len), seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else // whole tile in [contextual_seqlen, seqlen)
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
}
|
||||
@@ -181,17 +186,19 @@ struct HstuBlockMaskWithLocal
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
row += diff_q_kv_len;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
row_id = max(row - contextual_seqlen + 1, diff_q_kv_len);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_col_id)
|
||||
if(row_id == diff_q_kv_len && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
@@ -227,17 +234,19 @@ struct HstuBlockMaskWithLocal
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
row += diff_q_kv_len;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
row_id = max(row - contextual_seqlen + 1, diff_q_kv_len);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_col_id)
|
||||
if(row_id == diff_q_kv_len && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
@@ -281,7 +290,8 @@ struct HstuBlockMaskWithLocal
|
||||
{
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
|
||||
if(!is_tile_in_first_split && i_tile_right <= min(i_tile_top + 1, max_k_uih_len))
|
||||
if(!is_tile_in_first_split &&
|
||||
i_tile_right <= min(i_tile_top + diff_q_kv_len + 1, max_k_uih_len))
|
||||
return true;
|
||||
}
|
||||
else
|
||||
@@ -315,6 +325,7 @@ struct HstuBlockMaskNoLocal
|
||||
int max_k_uih_len;
|
||||
int max_row_id;
|
||||
int max_col_id;
|
||||
int diff_q_kv_len;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
HstuBlockMaskNoLocal(int seqlen_q_, int seqlen_k_, int contextual_seqlen_, int num_target_)
|
||||
@@ -333,6 +344,9 @@ struct HstuBlockMaskNoLocal
|
||||
max_row_id = max_q_uih_len;
|
||||
max_col_id = max_k_uih_len;
|
||||
}
|
||||
|
||||
diff_q_kv_len = max_k_uih_len - max_q_uih_len;
|
||||
max_row_id += diff_q_kv_len;
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
@@ -348,11 +362,11 @@ struct HstuBlockMaskNoLocal
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_end = min(i_y + YTile, seqlen_k);
|
||||
index_t x_end = min(i_y + YTile + diff_q_kv_len, seqlen_k);
|
||||
|
||||
if(i_y < contextual_seqlen)
|
||||
{
|
||||
if(i_y + YTile > max_k_uih_len)
|
||||
if(i_y + YTile + diff_q_kv_len > max_k_uih_len)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
@@ -373,17 +387,19 @@ struct HstuBlockMaskNoLocal
|
||||
int row_id;
|
||||
int col_id;
|
||||
|
||||
row += diff_q_kv_len;
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
// row_id/col_id is clamped from physical row/col according to contextual_seqlen and
|
||||
// max_uih_len
|
||||
row_id = max(row - contextual_seqlen + 1, 0);
|
||||
row_id = max(row - contextual_seqlen + 1, diff_q_kv_len);
|
||||
col_id = max(col - contextual_seqlen + 1, 0);
|
||||
|
||||
row_id = min(row_id, max_row_id);
|
||||
col_id = min(col_id, max_col_id);
|
||||
|
||||
if(row_id == 0 && col_id < max_col_id)
|
||||
if(row_id == diff_q_kv_len && col_id < max_col_id)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
@@ -420,7 +436,7 @@ struct HstuBlockMaskNoLocal
|
||||
|
||||
// assume num_target > 0 with high probability, don't check whether num_target is 0;
|
||||
// so if num_target is 0, IsTokenPairInsideMask() will be called for the bottom tile
|
||||
if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top)
|
||||
if(i_tile_bottom >= max_q_uih_len || i_tile_right > i_tile_top + diff_q_kv_len)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
|
||||
@@ -41,7 +41,8 @@ struct reference_hstu_attention
|
||||
int num_batch,
|
||||
float alpha,
|
||||
float attn_scale,
|
||||
int max_seqlen,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_kv,
|
||||
std::vector<int> seq_q_offsets,
|
||||
std::vector<int> seq_kv_offsets,
|
||||
std::vector<int> num_targets, // define masking length at the end of token
|
||||
@@ -93,8 +94,8 @@ struct reference_hstu_attention
|
||||
|
||||
if(static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[0]) == num_batch &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[1]) == num_head &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen)
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[2]) == max_seqlen_q &&
|
||||
static_cast<int>(mask_batch_nhead_seq_seq.get_lengths()[3]) == max_seqlen_kv)
|
||||
save_mask = true;
|
||||
|
||||
// check num_tagets
|
||||
@@ -114,7 +115,9 @@ struct reference_hstu_attention
|
||||
|
||||
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
|
||||
|
||||
float scale_p = attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen);
|
||||
float scale_p = attn_scale
|
||||
? attn_scale
|
||||
: 1.0f / static_cast<float>(max(max_seqlen_q, max_seqlen_kv));
|
||||
|
||||
BOOL_SWITCH(window_size > 0, kHasLocal, [&] {
|
||||
using HstuMaskType = typename HstuBlockMasking<kUseCausal, kHasLocal>::Type;
|
||||
@@ -148,9 +151,8 @@ struct reference_hstu_attention
|
||||
|
||||
if(save_mask)
|
||||
{
|
||||
// initialize the mask
|
||||
for(int sq = 0; sq < max_seqlen; sq++)
|
||||
for(int sk = 0; sk < max_seqlen; sk++)
|
||||
for(int sq = 0; sq < max_seqlen_q; sq++)
|
||||
for(int sk = 0; sk < max_seqlen_kv; sk++)
|
||||
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
|
||||
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
|
||||
}
|
||||
|
||||
@@ -49,3 +49,6 @@ for T in "fp16" "bf16"; do
|
||||
$EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=0 -norm_dist=0
|
||||
set +x
|
||||
done
|
||||
|
||||
## This case is used to verify the masking when seqlen_kv > seqlen_q by comparing the saved mask tensor with the output of test_pytorch_hstu_mask_v2.py
|
||||
$EXE -v=1 -prec=bf16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -seqlens_kv=63,68,71 -causal=1 -local_len=0 -context_len=3 -minfull_len=0 -targets=4,5,6 -attn_scale=0 -norm_dist=0 -save_mask=1
|
||||
|
||||
106
example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask_v2.py
Normal file
106
example/ck_tile/18_hstu_attention/test_pytorch_hstu_mask_v2.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
def get_valid_attn_mask_v2(
|
||||
device: torch.device,
|
||||
causal: bool,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_kv: int,
|
||||
seq_lengths_q: torch.Tensor,
|
||||
seq_lengths_kv: torch.Tensor,
|
||||
num_targets: Optional[torch.Tensor] = None,
|
||||
max_attn_len: int = 0,
|
||||
contextual_seq_len: int = 0,
|
||||
min_full_attn_seq_len: int = 0,
|
||||
) -> torch.Tensor:
|
||||
N = max(max_seqlen_q, max_seqlen_kv)
|
||||
ids = torch.arange(0, N, device=device).view(1, N)
|
||||
max_ids_q = seq_lengths_q.view(-1, 1, 1)
|
||||
max_ids_kv = seq_lengths_kv.view(-1, 1, 1)
|
||||
diff_q_kv = max_ids_kv - max_ids_q
|
||||
if contextual_seq_len > 0:
|
||||
ids = ids - contextual_seq_len + 1
|
||||
ids = torch.clamp(ids, min=0)
|
||||
max_ids_q = max_ids_q - contextual_seq_len + 1
|
||||
max_ids_kv = max_ids_kv - contextual_seq_len + 1
|
||||
if num_targets is not None:
|
||||
max_ids_q = max_ids_q - num_targets.view(-1, 1, 1)
|
||||
max_ids_kv = max_ids_kv - num_targets.view(-1, 1, 1)
|
||||
|
||||
raw_row_ids = torch.clamp(
|
||||
ids,
|
||||
max=max_ids_q,
|
||||
)
|
||||
raw_row_ids = raw_row_ids + diff_q_kv
|
||||
max_ids_q = max_ids_q + diff_q_kv
|
||||
raw_col_ids = torch.clamp(
|
||||
ids,
|
||||
max=max_ids_kv,
|
||||
)
|
||||
row_ids = raw_row_ids.view(-1, N, 1).expand(-1, N, N)
|
||||
col_ids = raw_col_ids.view(-1, 1, N).expand(-1, N, N)
|
||||
|
||||
row_col_dist = row_ids - col_ids
|
||||
|
||||
## ensure mask value in diagonal is always 1
|
||||
##valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N)
|
||||
valid_attn_mask = torch.zeros_like(row_col_dist).to(torch.bool)
|
||||
for idx0 in range(valid_attn_mask.size(0)):
|
||||
for idx1 in torch.arange(max_seqlen_q):
|
||||
valid_attn_mask[idx0, idx1, idx1 + diff_q_kv[idx0]] = 1
|
||||
|
||||
if not causal:
|
||||
row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist)
|
||||
|
||||
## 1) for token pair in [seqlen-num_target, N) x [seqlen-num_target, N), row_col_dist is 0
|
||||
## 2) for token pair in [seqlen-num-target, N) x [0, seqlen-num_target), row_col_dist > 0
|
||||
## 3) for token_pair in [0, seqlen-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0
|
||||
valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0)
|
||||
if max_attn_len > 0:
|
||||
if min_full_attn_seq_len > 0:
|
||||
valid_attn_mask = torch.logical_and(
|
||||
valid_attn_mask,
|
||||
torch.logical_or(
|
||||
row_col_dist <= max_attn_len,
|
||||
row_ids >= max_ids_q - min_full_attn_seq_len,
|
||||
),
|
||||
)
|
||||
else:
|
||||
valid_attn_mask = torch.logical_and(
|
||||
valid_attn_mask, row_col_dist <= max_attn_len
|
||||
)
|
||||
if contextual_seq_len > 0:
|
||||
## ensure first contextual_seqlen rows (where row_ids==0) attend to all cols less than max_ids
|
||||
valid_attn_mask = torch.logical_or(
|
||||
valid_attn_mask, torch.logical_and(row_ids == diff_q_kv, col_ids < max_ids_kv)
|
||||
)
|
||||
|
||||
fit_valid_attn_mask = valid_attn_mask[:, :max_seqlen_q, :]
|
||||
|
||||
return fit_valid_attn_mask.to(torch.int8)
|
||||
|
||||
def main():
|
||||
max_seqlen_q=64
|
||||
max_seqlen_kv=80
|
||||
contextual_seq_len=3
|
||||
max_attn_len=0
|
||||
causal=True
|
||||
min_full_attn_seq_len=0
|
||||
dev_type=torch.device("cpu")
|
||||
seq_lengths_q=torch.tensor((56,60,64), device=dev_type, dtype=torch.int32)
|
||||
seq_lengths_kv=torch.tensor((70,76,80), device=dev_type, dtype=torch.int32)
|
||||
num_targets=torch.tensor((4,5,6), device=dev_type, dtype=torch.int32)
|
||||
|
||||
valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
|
||||
torch.save(valid_attn_mask, "torch_hstu_mask_0.pt")
|
||||
|
||||
max_attn_len=4
|
||||
min_full_attn_seq_len=6
|
||||
valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
|
||||
torch.save(valid_attn_mask, "torch_hstu_mask_1.pt")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user