mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
fix the jagged mode tensor access in reference_hstu_attention
This commit is contained in:
@@ -29,7 +29,8 @@
|
||||
## test/verify
|
||||
|
||||
``` bash
|
||||
#> build/bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -nidx=9 -nhead=4 -hsizeq=64 -hsizev=64 -seqq=13 -seqk=512 -init=u -seed=123 -perf=0 -maskmax=0
|
||||
#> build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6
|
||||
-causal=1 -local_len=5 -context_len=6 -minfull_len=6
|
||||
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
|
||||
```
|
||||
|
||||
@@ -38,16 +39,18 @@
|
||||
``` C++
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
|
||||
.insert("b", "12", "batch size")
|
||||
.insert("nidx", "9", "number of indices for accessing the batches")
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hsizeq", "64", "headdim size of Q/K")
|
||||
.insert("hsizev", "64", "headdim size of V/O")
|
||||
.insert("seqq", "13", "length of the sequence dimension of query tensor")
|
||||
.insert("seqv", "1024", "length of the sequence dimension of key tensor")
|
||||
.insert("init", "u", "init method for input tensor values, u, uniform random float values, n, normalized random float values")
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor")
|
||||
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("perf", "0", "weather measure execution time or not")
|
||||
.insert("maskmax", "0", "used to set mask values to random [0, maskmax), maskmax should in [0, 128], 0 means set all values to 1");
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
```
|
||||
|
||||
|
||||
@@ -216,22 +216,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using GemmAccDataType = typename HSTUAttentionTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
using SMComputeDataType = typename HSTUAttentionTypeConfig<InOutDataType>::SMComputeDataType;
|
||||
|
||||
BOOL_SWITCH_2(use_causal, USE_CAUSAL_, use_local, USE_LOCAL_, [&] {
|
||||
BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] {
|
||||
ck_tile::reference_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
SMComputeDataType,
|
||||
USE_CAUSAL_,
|
||||
USE_LOCAL_>::Run(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
num_batch,
|
||||
1.0f,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
max_attn_len,
|
||||
contextual_seq_len,
|
||||
min_full_seq_len);
|
||||
kIsJagged,
|
||||
kUseCausal,
|
||||
kUseLocal>::Run(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
num_batch,
|
||||
1.0f,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
max_attn_len,
|
||||
contextual_seq_len,
|
||||
min_full_seq_len);
|
||||
});
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -27,8 +27,9 @@ namespace ck_tile {
|
||||
template <typename InOutDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CompDataType,
|
||||
bool use_causal,
|
||||
bool use_local>
|
||||
bool kIsJagged,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal>
|
||||
struct reference_hstu_attention
|
||||
{
|
||||
struct hstu_mask
|
||||
@@ -49,15 +50,15 @@ struct reference_hstu_attention
|
||||
max_uih_len = max_uih_len_;
|
||||
};
|
||||
|
||||
bool IsPixelInsideMask(int row, int col)
|
||||
bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
if(row < contextual_seq_len)
|
||||
return true;
|
||||
|
||||
bool result = false;
|
||||
if constexpr(use_local)
|
||||
if constexpr(kUseLocal)
|
||||
{
|
||||
if constexpr(use_causal)
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
else
|
||||
result = std::abs(row - col) <= max_attn_len;
|
||||
@@ -67,7 +68,7 @@ struct reference_hstu_attention
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(use_causal)
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col);
|
||||
};
|
||||
|
||||
@@ -90,12 +91,10 @@ struct reference_hstu_attention
|
||||
int min_full_attn_seq_len) // define masking length at the end of query token
|
||||
// sequence which is included for full attention
|
||||
{
|
||||
bool is_jagged = !seq_offsets.empty();
|
||||
|
||||
if(is_jagged)
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
// check the number of batches
|
||||
assert(seq_offsets.size() == num_batch + 1);
|
||||
assert(!seq_offsets.empty() && seq_offsets.size() == num_batch + 1);
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
@@ -103,6 +102,7 @@ struct reference_hstu_attention
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(seq_offsets.empty());
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
@@ -140,7 +140,7 @@ struct reference_hstu_attention
|
||||
assert(num_targets.size() == num_batch);
|
||||
|
||||
auto f = [&](auto i_batch, auto i_head) {
|
||||
int seqlen = is_jagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
|
||||
int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
|
||||
: q_batch_seq_nhead_hdim.get_lengths()[1];
|
||||
|
||||
int max_uih_len = seqlen;
|
||||
@@ -161,16 +161,29 @@ struct reference_hstu_attention
|
||||
// for all cols in the batch
|
||||
for(int sk = 0; sk < max_uih_len; sk++)
|
||||
{
|
||||
if(mask.IsPixelInsideMask(sq, sk))
|
||||
if(mask.IsTokenPairInsideMask(sq, sk))
|
||||
{
|
||||
GemmAccDataType dot_prod = 0.f;
|
||||
for(int k = 0; k < hdim_qk; k++)
|
||||
{
|
||||
InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
|
||||
InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
InOutDataType qreg =
|
||||
q_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k);
|
||||
InOutDataType kreg =
|
||||
k_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
}
|
||||
else
|
||||
{
|
||||
InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
|
||||
InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
};
|
||||
}
|
||||
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(dot_prod) *
|
||||
@@ -191,15 +204,31 @@ struct reference_hstu_attention
|
||||
|
||||
for(int sk = 0; sk < max_uih_len; sk++)
|
||||
{
|
||||
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg =
|
||||
v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
}
|
||||
else
|
||||
{
|
||||
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
};
|
||||
};
|
||||
|
||||
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
if constexpr(kIsJagged)
|
||||
o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
else
|
||||
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user