mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Align the masking logic in HstuCrossAttentionBlockMask with pytorch mask_v2 scripts
This commit is contained in:
@@ -284,9 +284,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else
|
||||
is_cross_attention = true;
|
||||
|
||||
// assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined
|
||||
if(input_max_uih_seqlen_kv <= 0)
|
||||
input_max_uih_seqlen_kv = input_max_uih_seqlen_q;
|
||||
if(!is_cross_attention)
|
||||
{
|
||||
// assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined
|
||||
if(input_max_uih_seqlen_kv <= 0)
|
||||
input_max_uih_seqlen_kv = input_max_uih_seqlen_q;
|
||||
};
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
@@ -333,7 +336,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int phy_seqlen_q = 0;
|
||||
int phy_seqlen_kv = 0;
|
||||
int max_seqlen_q = max_uih_seqlen_q + max_target + contextual_seqlen;
|
||||
int max_seqlen_kv = max_uih_seqlen_kv + max_target + contextual_seqlen;
|
||||
int max_seqlen_kv = is_cross_attention ? max_uih_seqlen_kv + contextual_seqlen
|
||||
: max_uih_seqlen_kv + max_target + contextual_seqlen;
|
||||
|
||||
std::vector<int> seq_offsets_q;
|
||||
std::vector<int> seq_offsets_kv;
|
||||
@@ -356,12 +360,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(int i = 0; i < num_batch; i++)
|
||||
{
|
||||
int batch_seqlen = num_targets.empty()
|
||||
? seq_lengths_kv[i] + contextual_seqlen
|
||||
: seq_lengths_kv[i] + num_targets[i] + contextual_seqlen;
|
||||
if(!is_cross_attention)
|
||||
{
|
||||
int batch_seqlen = num_targets.empty()
|
||||
? seq_lengths_kv[i] + contextual_seqlen
|
||||
: seq_lengths_kv[i] + num_targets[i] + contextual_seqlen;
|
||||
|
||||
phy_seqlen_kv += batch_seqlen;
|
||||
seq_offsets_kv.push_back(phy_seqlen_kv);
|
||||
phy_seqlen_kv += batch_seqlen;
|
||||
seq_offsets_kv.push_back(phy_seqlen_kv);
|
||||
}
|
||||
else // for cross_attention, assume target_in_kv == false
|
||||
{
|
||||
int batch_seqlen = seq_lengths_kv[i] + contextual_seqlen;
|
||||
|
||||
phy_seqlen_kv += batch_seqlen;
|
||||
seq_offsets_kv.push_back(phy_seqlen_kv);
|
||||
}
|
||||
};
|
||||
}
|
||||
else
|
||||
|
||||
@@ -47,7 +47,7 @@ struct HstuCrossAttentionBlockMaskWithLocal
|
||||
min_full_attn_seqlen(min_full_attn_seqlen_)
|
||||
{
|
||||
max_q_uih_len = seqlen_q - num_target_;
|
||||
max_k_uih_len = seqlen_k - num_target_;
|
||||
max_k_uih_len = seqlen_k; // assuming target_in_kv == false
|
||||
|
||||
// in case user provided max_attn_len_ could be bigger than max_uih_len
|
||||
max_attn_len = min(max_k_uih_len, min(max_q_uih_len, max_attn_len));
|
||||
@@ -223,11 +223,12 @@ struct HstuCrossAttentionBlockMaskWithLocal
|
||||
}
|
||||
else
|
||||
{
|
||||
// Non-causal: only apply sliding window constraint, no diagonal inclusion
|
||||
// logic This matches PyTorch reference which just returns boundary mask for non-causal
|
||||
bool in_min_full_scope =
|
||||
(min_full_attn_seqlen > 0) ? (row_id >= max_row_id - min_full_attn_seqlen) : false;
|
||||
|
||||
return (((row_id != col_id) || (row == col)) &&
|
||||
((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope));
|
||||
return ((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -523,7 +524,7 @@ struct HstuCrossAttentionBlockMaskNoLocal
|
||||
: seqlen_q(seqlen_q_), seqlen_k(seqlen_k_), contextual_seqlen(contextual_seqlen_)
|
||||
{
|
||||
max_q_uih_len = seqlen_q - num_target_;
|
||||
max_k_uih_len = seqlen_k - num_target_;
|
||||
max_k_uih_len = seqlen_k; // assuming target_in_kv == false
|
||||
|
||||
if(contextual_seqlen > 0)
|
||||
{
|
||||
@@ -609,7 +610,9 @@ struct HstuCrossAttentionBlockMaskNoLocal
|
||||
}
|
||||
else
|
||||
{
|
||||
return (row_id != col_id) || (row == col);
|
||||
// Non-causal: no masking needed, everything in bounds is allowed
|
||||
// This matches PyTorch reference which just returns boundary mask for non-causal
|
||||
return true;
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -187,6 +187,10 @@ struct reference_hstu_attention
|
||||
{
|
||||
for(int sq = 0; sq < max_seqlen_q; sq++)
|
||||
for(int sk = 0; sk < max_seqlen_kv; sk++)
|
||||
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = 0;
|
||||
|
||||
for(int sq = 0; sq < seqlen_q; sq++)
|
||||
for(int sk = 0; sk < seqlen_kv; sk++)
|
||||
mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) =
|
||||
static_cast<int8_t>(mask.IsTokenPairInsideMask(sq, sk));
|
||||
}
|
||||
|
||||
@@ -51,4 +51,4 @@ for T in "fp16" "bf16"; do
|
||||
done
|
||||
|
||||
## This case is used to verify the masking when seqlen_kv > seqlen_q by comparing the saved mask tensor with the output of test_pytorch_hstu_mask_v2.py
|
||||
$EXE -v=1 -prec=bf16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=49,52,55 -seqlens_kv=63,68,71 -causal=1 -local_len=0 -context_len=3 -minfull_len=0 -targets=4,5,6 -attn_scale=0 -norm_dist=0 -save_mask=1
|
||||
$EXE -v=1 -prec=bf16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=52,55,58 -seqlens_kv=70,76,80 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=4,5,6 -attn_scale=0 -norm_dist=0 -save_mask=1
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_valid_attn_mask_v2(
|
||||
device: torch.device,
|
||||
causal: bool,
|
||||
@@ -10,97 +12,82 @@ def get_valid_attn_mask_v2(
|
||||
seq_lengths_q: torch.Tensor,
|
||||
seq_lengths_kv: torch.Tensor,
|
||||
num_targets: Optional[torch.Tensor] = None,
|
||||
max_attn_len: int = 0,
|
||||
contextual_seq_len: int = 0,
|
||||
min_full_attn_seq_len: int = 0,
|
||||
) -> torch.Tensor:
|
||||
N = max(max_seqlen_q, max_seqlen_kv)
|
||||
ids = torch.arange(0, N, device=device).view(1, N)
|
||||
max_ids_q = seq_lengths_q.view(-1, 1, 1)
|
||||
max_ids_kv = seq_lengths_kv.view(-1, 1, 1)
|
||||
diff_q_kv = max_ids_kv - max_ids_q
|
||||
if contextual_seq_len > 0:
|
||||
ids = ids - contextual_seq_len + 1
|
||||
ids = torch.clamp(ids, min=0)
|
||||
max_ids_q = max_ids_q - contextual_seq_len + 1
|
||||
max_ids_kv = max_ids_kv - contextual_seq_len + 1
|
||||
if num_targets is not None:
|
||||
max_ids_q = max_ids_q - num_targets.view(-1, 1, 1)
|
||||
max_ids_kv = max_ids_kv - num_targets.view(-1, 1, 1)
|
||||
"""
|
||||
Generate attention mask for HSTU attention.
|
||||
|
||||
raw_row_ids = torch.clamp(
|
||||
ids,
|
||||
max=max_ids_q,
|
||||
This implementation matches the Hammer/PyTorch reference (_get_valid_attn_mask_v2)
|
||||
with targets_in_kv=False, meaning:
|
||||
- Q sequence: [UIH_Q][Targets] with length seq_lengths_q
|
||||
- KV sequence: [UIH_KV] only with length seq_lengths_kv (NO targets in KV)
|
||||
|
||||
For causal attention with num_targets:
|
||||
- UIH rows (col_ids < uih_lengths_q): apply causal mask (shifted_col_ids >= row_ids)
|
||||
- Target rows (col_ids >= uih_lengths_q): can attend to everything in KV (full attention)
|
||||
"""
|
||||
# Create position indices - matching Hammer convention:
|
||||
# col_ids indexes Q dimension (rows in attention matrix when viewed as Q x KV)
|
||||
# row_ids indexes KV dimension (columns in attention matrix)
|
||||
col_ids = torch.arange(0, max_seqlen_q, device=device).view(1, max_seqlen_q, 1)
|
||||
row_ids = torch.arange(0, max_seqlen_kv, device=device).view(1, 1, max_seqlen_kv)
|
||||
|
||||
# Boundary mask: positions within valid sequence bounds
|
||||
in_boundary_valid_attn_mask = torch.logical_and(
|
||||
row_ids < seq_lengths_kv.view(-1, 1, 1), col_ids < seq_lengths_q.view(-1, 1, 1)
|
||||
)
|
||||
raw_row_ids = raw_row_ids + diff_q_kv
|
||||
max_ids_q = max_ids_q + diff_q_kv
|
||||
raw_col_ids = torch.clamp(
|
||||
ids,
|
||||
max=max_ids_kv,
|
||||
)
|
||||
row_ids = raw_row_ids.view(-1, N, 1).expand(-1, N, N)
|
||||
col_ids = raw_col_ids.view(-1, 1, N).expand(-1, N, N)
|
||||
|
||||
row_col_dist = row_ids - col_ids
|
||||
|
||||
## ensure mask value in diagonal is always 1
|
||||
##valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N)
|
||||
valid_attn_mask = torch.zeros_like(row_col_dist).to(torch.bool)
|
||||
for idx0 in range(valid_attn_mask.size(0)):
|
||||
for idx1 in torch.arange(max_seqlen_q):
|
||||
valid_attn_mask[idx0, idx1, idx1 + diff_q_kv[idx0]] = 1
|
||||
|
||||
if not causal:
|
||||
row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist)
|
||||
|
||||
## 1) for token pair in [seqlen-num_target, N) x [seqlen-num_target, N), row_col_dist is 0
|
||||
## 2) for token pair in [seqlen-num-target, N) x [0, seqlen-num_target), row_col_dist > 0
|
||||
## 3) for token_pair in [0, seqlen-num_target) x [seqlen-num_target, N). row_col_dist < 0 if causal, else row_col_dist > 0
|
||||
valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0)
|
||||
if max_attn_len > 0:
|
||||
if min_full_attn_seq_len > 0:
|
||||
valid_attn_mask = torch.logical_and(
|
||||
valid_attn_mask,
|
||||
torch.logical_or(
|
||||
row_col_dist <= max_attn_len,
|
||||
row_ids >= max_ids_q - min_full_attn_seq_len,
|
||||
),
|
||||
if causal:
|
||||
if num_targets is None:
|
||||
# Causal without num_targets: simple shifted causal mask
|
||||
delta_col_ids = seq_lengths_kv - seq_lengths_q
|
||||
shifted_col_ids = col_ids + delta_col_ids.view(-1, 1, 1)
|
||||
causal_mask = shifted_col_ids >= row_ids
|
||||
return torch.logical_and(in_boundary_valid_attn_mask, causal_mask).to(
|
||||
torch.int8
|
||||
)
|
||||
else:
|
||||
valid_attn_mask = torch.logical_and(
|
||||
valid_attn_mask, row_col_dist <= max_attn_len
|
||||
# Causal with num_targets and targets_in_kv=False
|
||||
# This exactly mirrors the Hammer logic with targets_in_kv=False
|
||||
uih_lengths_q = seq_lengths_q - num_targets
|
||||
delta_col_ids = seq_lengths_kv - uih_lengths_q
|
||||
# targets_in_kv=False: NO subtraction of num_targets from delta_col_ids
|
||||
shifted_col_ids = col_ids + delta_col_ids.view(-1, 1, 1)
|
||||
|
||||
# UIH rows: apply causal mask
|
||||
causal_mask = torch.logical_and(
|
||||
col_ids < uih_lengths_q.view(-1, 1, 1), shifted_col_ids >= row_ids
|
||||
)
|
||||
if contextual_seq_len > 0:
|
||||
## ensure first contextual_seqlen rows (where row_ids==0) attend to all cols less than max_ids
|
||||
valid_attn_mask = torch.logical_or(
|
||||
valid_attn_mask, torch.logical_and(row_ids == diff_q_kv, col_ids < max_ids_kv)
|
||||
)
|
||||
|
||||
fit_valid_attn_mask = valid_attn_mask[:, :max_seqlen_q, :]
|
||||
# Target rows: full attention to KV (no additional constraint for targets_in_kv=False)
|
||||
target_mask = col_ids >= uih_lengths_q.view(-1, 1, 1)
|
||||
|
||||
return torch.logical_and(
|
||||
in_boundary_valid_attn_mask, torch.logical_or(causal_mask, target_mask)
|
||||
).to(torch.int8)
|
||||
else:
|
||||
# Non-causal: everything in bounds is allowed
|
||||
return in_boundary_valid_attn_mask.to(torch.int8)
|
||||
|
||||
return fit_valid_attn_mask.to(torch.int8)
|
||||
|
||||
def main():
|
||||
max_seqlen_q=64
|
||||
max_seqlen_kv=80
|
||||
contextual_seq_len=3
|
||||
max_attn_len=0
|
||||
causal=True
|
||||
min_full_attn_seq_len=0
|
||||
dev_type=torch.device("cpu")
|
||||
seq_lengths_q=torch.tensor((56,60,64), device=dev_type, dtype=torch.int32)
|
||||
seq_lengths_kv=torch.tensor((70,76,80), device=dev_type, dtype=torch.int32)
|
||||
num_targets=torch.tensor((4,5,6), device=dev_type, dtype=torch.int32)
|
||||
max_seqlen_q = 64
|
||||
max_seqlen_kv = 80
|
||||
causal = True
|
||||
dev_type = torch.device("cpu")
|
||||
seq_lengths_q = torch.tensor((56, 60, 64), device=dev_type, dtype=torch.int32)
|
||||
seq_lengths_kv = torch.tensor((70, 76, 80), device=dev_type, dtype=torch.int32)
|
||||
num_targets = torch.tensor((4, 5, 6), device=dev_type, dtype=torch.int32)
|
||||
|
||||
valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
|
||||
valid_attn_mask = get_valid_attn_mask_v2(
|
||||
dev_type,
|
||||
causal,
|
||||
max_seqlen_q,
|
||||
max_seqlen_kv,
|
||||
seq_lengths_q,
|
||||
seq_lengths_kv,
|
||||
num_targets,
|
||||
)
|
||||
torch.save(valid_attn_mask, "torch_hstu_mask_0.pt")
|
||||
|
||||
max_attn_len=4
|
||||
min_full_attn_seq_len=6
|
||||
valid_attn_mask=get_valid_attn_mask_v2(dev_type, causal, max_seqlen_q, max_seqlen_kv, seq_lengths_q, seq_lengths_kv, num_targets, max_attn_len, contextual_seq_len, min_full_attn_seq_len)
|
||||
torch.save(valid_attn_mask, "torch_hstu_mask_1.pt")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user