Files
composable_kernel/example/ck_tile/18_hstu_attention
..

HSTU attention operator

HSTU-attention operator is an operator which takes tensor q: [batches, seqlen, nhead, hdim_qk], k: [batches, seqlen, nhead, hdim_qk, v: [batches, seqlen, nhead, hdim_v] and some parameters for defining the functional masking as inputs, and do the following:

  • Multiply q: [batches, seqlen, nhead, hdim_qk] with k: [batches, seqlen, nhead, hdim_k] to get temporary tensor s: [batches, nhead, seqlen, seqlen]
  • Update s by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask as well assequence mask
  • Do element-wise SiLu on the lower seqlen dimension of s to get temporary tensor p: [batches, nhead, seqlen, seqlen]
  • Multiply p : [batches, nhead, seqlen, seqlen] with v: [batches, seqlen, nhead, hdim_v] to get final output o: [batches, seqlen_q, nhead, headsz_v]
  • Jagged inputs are also supported, where each batch has separate seqlen defined by the sequence_offsets[]

implementation

The operator is implemented using a fused kernel in the example:

  • Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed

build

#> mkdir build
#> cd build
#> ../script/cmake-ck-dev.sh .. gfx942              ; use #> rocminfo |grep "gfx"   to check your gpu arch
#> make -j tile_example_hstu_attention

test/verify

#>  build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6
    -causal=1 -local_len=5 -context_len=6 -minfull_len=6 
#>  . example/ck_tile/07_hstu_attention/test_hstu_attention.sh

Check the example file example_hstu_attention.cpp for an understanding of the command-line arguments. Which is like the following:

  arg_parser.insert("v", "1", "weather do CPU validation or not")
      .insert("prec", "fp16", "data type. fp16/bf16")
      .insert("jagged", "0", "q/k/v batched sequence is jagged or not")
      .insert("b", "12", "batch size")
      .insert("nhead", "4", "number of heads")
      .insert("hdim_qk", "64", "headdim size of Q/K")
      .insert("hdim_v", "64", "headdim size of V/O")
      .insert("seqlens", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
      .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
      .insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
      .insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
      .insert("causal", "1", "enable causal mask or not")
      .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
      .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
      .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
      .insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
      .insert("seed", "13579", "seed by the uniform or normal distribution generator")
      .insert("alpha", "0", "scale factor of P=Q@K. 0 means equal to 1/sqrt(hdim)")
      .insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
      .insert("perf", "0", "weather measure execution time or not");
      .insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");