Files
composable_kernel/example/ck_tile/18_hstu_attention
..

HSTU attention operator

HSTU-attention operator is an operator which takes tensor q: [batches, seqlen, nhead, hdim_qk], k: [batches, seqlen, nhead, hdim_qk, v: [batches, seqlen, nhead, hdim_v], as well as several parameters that defines the functional masking to do the following:

  • Multiply q: [batches, seqlen, nhead, hdim_qk] with k: [batches, seqlen, nhead, hdim_k] to get tbe intermediate tensor s: [batches, nhead, seqlen, seqlen]
  • Update s by filtering it with a functional mask that includes a lower-triangular causal mask, a diagonal window causal mask and a sequence mask
  • Do element-wise SiLu on the lower seqlen dimension of s to get the intermediat tensor p: [batches, nhead, seqlen, seqlen]
  • Multiply p : [batches, nhead, seqlen, seqlen] with v: [batches, seqlen, nhead, hdim_v] to get output tensor o: [batches, seqlen_q, nhead, headsz_v]

build

#> mkdir build
#> cd build
#> ../script/cmake-ck-dev.sh .. gfx942 -G Ninja              ; use #> rocminfo |grep "gfx"   to check your gpu arch
#> ninja tile_example_hstu_attention                         ; or using make -j tile_example_hstu_attention                       ; 

test/verify

#>  build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6
    -causal=1 -local_len=5 -context_len=6 -minfull_len=6 
#>  . example/ck_tile/07_hstu_attention/test_hstu_attention.sh

Check the example file example_hstu_attention.cpp for more information about the command-line arguments.

  arg_parser.insert("v", "1", "weather do CPU validation or not")
      .insert("prec", "fp16", "data type. fp16/bf16")
      .insert("jagged", "0", "q/k/v batched sequence is jagged or not")
      .insert("b", "12", "batch size")
      .insert("nhead", "4", "number of heads")
      .insert("hdim_qk", "64", "headdim size of Q/K")
      .insert("hdim_v", "64", "headdim size of V/O")
      .insert("seqlens", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
      .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
      .insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
      .insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
      .insert("causal", "1", "enable causal mask or not")
      .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
      .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
      .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
      .insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
      .insert("seed", "13579", "seed by the uniform or normal distribution generator")
      .insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution")
      .insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
      .insert("attn_scale", "0", "scale factor of SiLu(Q@K), 0 means using 1/max_seqlen for scaling")
      .insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
      .insert("perf", "0", "weather measure execution time or not");
      .insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");