mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
HSTU attention operator
HSTU-attention operator is an operator which takes tensor q: [batches, seqlen, nhead, hdim_qk], k: [batches, seqlen, nhead, hdim_qk,
v: [batches, seqlen, nhead, hdim_v], as well as several parameters that defines the functional masking to do the following:
- Multiply
q: [batches, seqlen, nhead, hdim_qk]withk: [batches, seqlen, nhead, hdim_k]to get tbe intermediate tensors: [batches, nhead, seqlen, seqlen] - Update
sby filtering it with a functional mask that includes a lower-triangular causal mask, a diagonal window causal mask and a sequence mask - Do element-wise SiLu on the
lower seqlendimension ofsto get the intermediat tensorp: [batches, nhead, seqlen, seqlen] - Multiply
p : [batches, nhead, seqlen, seqlen]withv: [batches, seqlen, nhead, hdim_v]to get output tensoro: [batches, seqlen_q, nhead, headsz_v]
build
#> mkdir build
#> cd build
#> ../script/cmake-ck-dev.sh .. gfx942 -G Ninja ; use #> rocminfo |grep "gfx" to check your gpu arch
#> ninja tile_example_hstu_attention ; or using make -j tile_example_hstu_attention ;
test/verify
#> build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6
-causal=1 -local_len=5 -context_len=6 -minfull_len=6
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
Check the example file example_hstu_attention.cpp for more information about the command-line arguments.
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
.insert("b", "12", "batch size")
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlens", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution")
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("attn_scale", "0", "scale factor of SiLu(Q@K), 0 means using 1/max_seqlen for scaling")
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");