Update to README.md

This commit is contained in:
Qianfeng Zhang
2025-11-01 13:16:50 +00:00
parent 8408ec0a02
commit 10133e5d51

View File

@@ -1,21 +1,12 @@
# HSTU attention operator
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
`v: [batches, seqlen, nhead, hdim_v]`, as well as several parameters that defines the functional masking to do the following:
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
as well assequence mask
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]`
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`
## implementation
The operator is implemented using a fused kernel in the example:
* Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get tbe intermediate tensor `s: [batches, nhead, seqlen, seqlen]`
* Update `s` by filtering it with a functional mask that includes a lower-triangular causal mask, a diagonal window causal mask and a sequence mask
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get the intermediat tensor `p: [batches, nhead, seqlen, seqlen]`
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get output tensor `o: [batches, seqlen_q, nhead, headsz_v]`
## build
@@ -34,7 +25,7 @@
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
```
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following:
Check the example file `example_hstu_attention.cpp` for more information about the command-line arguments.
``` C++
arg_parser.insert("v", "1", "weather do CPU validation or not")
@@ -48,7 +39,6 @@
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("softmax", "0", "use softmax or not")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")