fix the jagged mode tensor access in reference_hstu_attention

This commit is contained in:
Qianfeng Zhang
2025-03-29 12:55:04 +00:00
parent 4a0fc292d0
commit 83f29243df
3 changed files with 77 additions and 44 deletions

View File

@@ -29,7 +29,8 @@
## test/verify
``` bash
#> build/bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -nidx=9 -nhead=4 -hsizeq=64 -hsizev=64 -seqq=13 -seqk=512 -init=u -seed=123 -perf=0 -maskmax=0
#> build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6
-causal=1 -local_len=5 -context_len=6 -minfull_len=6
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
```
@@ -38,16 +39,18 @@
``` C++
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
.insert("b", "12", "batch size")
.insert("nidx", "9", "number of indices for accessing the batches")
.insert("nhead", "4", "number of heads")
.insert("hsizeq", "64", "headdim size of Q/K")
.insert("hsizev", "64", "headdim size of V/O")
.insert("seqq", "13", "length of the sequence dimension of query tensor")
.insert("seqv", "1024", "length of the sequence dimension of key tensor")
.insert("init", "u", "init method for input tensor values, u, uniform random float values, n, normalized random float values")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("perf", "0", "weather measure execution time or not")
.insert("maskmax", "0", "used to set mask values to random [0, maskmax), maskmax should in [0, 128], 0 means set all values to 1");
.insert("perf", "0", "weather measure execution time or not");
```

View File

@@ -216,22 +216,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
using GemmAccDataType = typename HSTUAttentionTypeConfig<InOutDataType>::GemmAccDataType;
using SMComputeDataType = typename HSTUAttentionTypeConfig<InOutDataType>::SMComputeDataType;
BOOL_SWITCH_2(use_causal, USE_CAUSAL_, use_local, USE_LOCAL_, [&] {
BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] {
ck_tile::reference_hstu_attention<InOutDataType,
GemmAccDataType,
SMComputeDataType,
USE_CAUSAL_,
USE_LOCAL_>::Run(q_host,
k_host,
v_host,
o_host_ref,
num_batch,
1.0f,
seq_offsets,
num_targets,
max_attn_len,
contextual_seq_len,
min_full_seq_len);
kIsJagged,
kUseCausal,
kUseLocal>::Run(q_host,
k_host,
v_host,
o_host_ref,
num_batch,
1.0f,
seq_offsets,
num_targets,
max_attn_len,
contextual_seq_len,
min_full_seq_len);
});
return 0;
}

View File

@@ -27,8 +27,9 @@ namespace ck_tile {
template <typename InOutDataType,
typename GemmAccDataType,
typename CompDataType,
bool use_causal,
bool use_local>
bool kIsJagged,
bool kUseCausal,
bool kUseLocal>
struct reference_hstu_attention
{
struct hstu_mask
@@ -49,15 +50,15 @@ struct reference_hstu_attention
max_uih_len = max_uih_len_;
};
bool IsPixelInsideMask(int row, int col)
bool IsTokenPairInsideMask(int row, int col)
{
if(row < contextual_seq_len)
return true;
bool result = false;
if constexpr(use_local)
if constexpr(kUseLocal)
{
if constexpr(use_causal)
if constexpr(kUseCausal)
result = (row >= col) && (row - col <= max_attn_len);
else
result = std::abs(row - col) <= max_attn_len;
@@ -67,7 +68,7 @@ struct reference_hstu_attention
}
else
{
if constexpr(use_causal)
if constexpr(kUseCausal)
result = (row >= col);
};
@@ -90,12 +91,10 @@ struct reference_hstu_attention
int min_full_attn_seq_len) // define masking length at the end of query token
// sequence which is included for full attention
{
bool is_jagged = !seq_offsets.empty();
if(is_jagged)
if constexpr(kIsJagged)
{
// check the number of batches
assert(seq_offsets.size() == num_batch + 1);
assert(!seq_offsets.empty() && seq_offsets.size() == num_batch + 1);
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1);
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1);
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1);
@@ -103,6 +102,7 @@ struct reference_hstu_attention
}
else
{
assert(seq_offsets.empty());
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
@@ -140,7 +140,7 @@ struct reference_hstu_attention
assert(num_targets.size() == num_batch);
auto f = [&](auto i_batch, auto i_head) {
int seqlen = is_jagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
: q_batch_seq_nhead_hdim.get_lengths()[1];
int max_uih_len = seqlen;
@@ -161,16 +161,29 @@ struct reference_hstu_attention
// for all cols in the batch
for(int sk = 0; sk < max_uih_len; sk++)
{
if(mask.IsPixelInsideMask(sq, sk))
if(mask.IsTokenPairInsideMask(sq, sk))
{
GemmAccDataType dot_prod = 0.f;
for(int k = 0; k < hdim_qk; k++)
{
InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
if constexpr(kIsJagged)
{
InOutDataType qreg =
q_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k);
InOutDataType kreg =
k_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
}
else
{
InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
};
}
locals.push_back(ck_tile::type_convert<CompDataType>(dot_prod) *
@@ -191,15 +204,31 @@ struct reference_hstu_attention
for(int sk = 0; sk < max_uih_len; sk++)
{
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
if constexpr(kIsJagged)
{
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg =
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
}
else
{
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
};
};
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
if constexpr(kIsJagged)
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
else
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
};
};
};