mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Update to README.md
This commit is contained in:
@@ -1,21 +1,12 @@
|
||||
# HSTU attention operator
|
||||
|
||||
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
|
||||
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:
|
||||
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
|
||||
`v: [batches, seqlen, nhead, hdim_v]`, as well as several parameters that defines the functional masking to do the following:
|
||||
|
||||
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
|
||||
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
|
||||
as well assequence mask
|
||||
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]`
|
||||
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
|
||||
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`
|
||||
|
||||
|
||||
## implementation
|
||||
|
||||
The operator is implemented using a fused kernel in the example:
|
||||
|
||||
* Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed
|
||||
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get tbe intermediate tensor `s: [batches, nhead, seqlen, seqlen]`
|
||||
* Update `s` by filtering it with a functional mask that includes a lower-triangular causal mask, a diagonal window causal mask and a sequence mask
|
||||
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get the intermediat tensor `p: [batches, nhead, seqlen, seqlen]`
|
||||
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get output tensor `o: [batches, seqlen_q, nhead, headsz_v]`
|
||||
|
||||
## build
|
||||
|
||||
@@ -34,7 +25,7 @@
|
||||
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
|
||||
```
|
||||
|
||||
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following:
|
||||
Check the example file `example_hstu_attention.cpp` for more information about the command-line arguments.
|
||||
|
||||
``` C++
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
@@ -48,7 +39,6 @@
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("softmax", "0", "use softmax or not")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
|
||||
Reference in New Issue
Block a user