From 295136e48bba10a0ea2a401e84bd57be6920d4b2 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 10 Jun 2026 14:42:49 +0000 Subject: [PATCH] Update the README.md according to the summary by claude code --- example/ck_tile/18_hstu_attention/README.md | 378 +++++++++++++++++--- 1 file changed, 326 insertions(+), 52 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md index 98830e9e30..d7249ad9ca 100644 --- a/example/ck_tile/18_hstu_attention/README.md +++ b/example/ck_tile/18_hstu_attention/README.md @@ -1,61 +1,335 @@ -# HSTU attention operator +# HSTU Attention Forward Operator - HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`, - `v: [batches, seqlen, nhead, hdim_v]`, as well as several parameters that defines the functional masking to do the following: +## Overview - * Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get tbe intermediate tensor `s: [batches, nhead, seqlen, seqlen]` - * Update `s` by filtering it with a functional mask that includes a lower-triangular causal mask, a diagonal window causal mask and a sequence mask - * Do element-wise SiLu on the `lower seqlen` dimension of `s` to get the intermediat tensor `p: [batches, nhead, seqlen, seqlen]` - * Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get output tensor `o: [batches, seqlen_q, nhead, headsz_v]` +HSTU (Hierarchical Sequential Transduction Unit) attention is a custom attention variant designed +for recommendation-system workloads. Unlike standard softmax attention, HSTU uses **SiLU +(Sigmoid Linear Unit)** as the non-linearity instead of softmax, together with a composite +functional masking scheme. -## build +### Forward computation - ``` bash - #> mkdir build - #> cd build - #> ../script/cmake-ck-dev.sh .. gfx942 -G Ninja ; use #> rocminfo |grep "gfx" to check your gpu arch - #> ninja tile_example_hstu_attention ; or using make -j tile_example_hstu_attention ; - ``` +Given inputs: +- `q : [batch, seqlen_q, nhead, hdim_qk]` +- `k : [batch, seqlen_kv, nhead, hdim_qk]` +- `v : [batch, seqlen_kv, nhead, hdim_v]` -## test/verify +the forward pass performs: - ``` bash - #> build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6 - -causal=1 -local_len=5 -context_len=6 -minfull_len=6 - #> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh - ``` +1. **QK projection** — `s[b, h, i, j] = scale_s * q[b, i, h, :] @ k[b, j, h, :]` +2. **Functional masking** — set masked positions of `s` to 0 (see [Masking](#masking) below) +3. **Non-linearity** — `p = SiLU(s)` (element-wise, over the `seqlen_kv` dimension); + alternatively `p = attn_scale * softmax(s)` when `--softmax=1` +4. **Output projection** — `o[b, i, h, :] = p[b, h, i, :] @ v[b, :, h, :]` - Check the example file `example_hstu_attention.cpp` for more information about the command-line arguments. +Output: `o : [batch, seqlen_q, nhead, hdim_v]` - ``` C++ - arg_parser.insert("v", "1", "weather do CPU validation or not") - .insert("g", "1", "num of attention group, bigger than 1 indicating group hstu") - .insert("prec", "fp16", "data type. fp16/bf16") - .insert("jagged", "0", "q/k/v batched sequence is jagged or not") - .insert("b", "12", "number of batches") - .insert("nhead", "4", "number of heads") - .insert("hdim_qk", "64", "headdim size of Q/K") - .insert("hdim_v", "64", "headdim size of V/O") - .insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len") - .insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") - .insert("g_max_seqlens", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens") - .insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention") - .insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets") - .insert("causal", "1", "enable causal mask or not") - .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") - .insert("g_local_lens", "5,", "list of all group's length of the diagonal window for enabling masking, value 0 to disable") - .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") - .insert("g_context_lens", "6,", "list of all group's sequence length at the begin of the query sequence that should be included for attention") - .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") - .insert("g_minfull_lens", "6", "list of all groups's sequence length at the end of the query sequence that should be included for attention") - .insert("seed", "13579", "seed by the uniform or normal distribution generator") - .insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution") - .insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)") - .insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling") - .insert("g_attn_scales", "1.0,", "list of all groups's scale factors of S=@@K. 0 means using 1/max_seqlen of the group for scaling") - .insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data") - .insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes") - .insert("perf", "0", "weather measure execution time or not") - .insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true"); - ``` +--- +## Masking + +The functional mask is a composite of up to four components, all expressed in terms of the +**UIH (User Interaction History)** positions, i.e. positions that are not targets and not +contextual: + +| Component | Parameter | Description | +|-----------|-----------|-------------| +| **Causal** | `--causal=1` | Standard lower-triangular causal mask | +| **Local (diagonal window)** | `--local_len=N` | Allow each query token to attend only to the `N` most recent keys (sliding-window causal) | +| **Contextual** | `--context_len=N` | The first `N` positions of the sequence are always visible to all queries | +| **Min-full attention** | `--minfull_len=N` | The last `N` UIH positions attend to the full UIH history (no local restriction) | +| **Targets** | `--targets=T[,T,...]` | Per-batch token counts at the *end* of the Q sequence that are excluded from attending | + +Physical sequence length is computed as: +``` +phy_seqlen_q = uih_seqlen + num_targets + context_len +phy_seqlen_kv = uih_seqlen + num_targets + context_len (self-attention) +phy_seqlen_kv = uih_seqlen_kv + context_len (cross-attention) +``` + +--- + +## Modes of operation + +### Batch modes + +| Mode | Flag | Description | +|------|------|-------------| +| **Batched** | `-jagged=0` | All sequences in the batch share the same `seqlen`; tensors are `[B, S, H, D]` | +| **Jagged** | `-jagged=1` | Each sequence has its own length given by `seq_offsets`; tensors are flattened `[1, total_tokens, H, D]` | + +### Attention variants + +| Variant | Flag | Description | +|---------|------|-------------| +| **No-group HSTU** | `-g=1` (default) | All batches share the same masking parameters (`local_len`, `context_len`, `minfull_len`, `attn_scale`) | +| **Group HSTU** | `-g=G` (G > 1) | `num_batch` must be a multiple of `G`; each group has its own per-group masking parameters passed via `--g_*` flags | +| **Self-attention** | (default) | `seqlen_kv` == `seqlen_q` | +| **Cross-attention** | `--seqlens_kv=...` | Enabled implicitly when KV sequence lengths differ from Q | + +### Non-linearity + +| Mode | Flag | Description | +|------|------|-------------| +| **SiLU** | `-softmax=0` (default) | Element-wise SiLU; `attn_scale` applied afterwards | +| **Softmax** | `-softmax=1` | Standard scaled softmax; LSE saved when `-training=1` | + +### Split-KV + +A split-KV kernel path (`hstu_attention_fwd_splitkv_kernel.hpp` / +`hstu_attention_batched_forward_splitkv_dispatch.hpp` etc.) is also available, which splits the +KV dimension across multiple work-groups and uses a separate combine pass +(`hstu_attention_fwd_splitkv_combine_kernel.hpp`) to merge partial results. + +--- + +## Data types + +- **fp16** (`half_t`) and **bf16** (`bfloat16_t`) for Q/K/V/O +- Accumulation and computation use `float32` +- Supported head dimensions (`hdim_qk` / `hdim_v`): **64, 96, 128, 256** + +--- + +## File structure + +``` +18_hstu_attention/ +├── CMakeLists.txt # Build configuration +├── example_hstu_attention_fwd.cpp # Driver / benchmark / validation harness +├── generate_instances.py # Python script to regenerate pre-compiled instances +│ +│-- Core kernel headers -- +├── hstu_attention_fwd_kernel.hpp # Main forward kernel +├── hstu_attention_fwd_splitkv_kernel.hpp # Split-KV forward kernel +├── hstu_attention_fwd_splitkv_combine_kernel.hpp # Split-KV combine kernel +├── hstu_attention_no_softmax_fwd_pipeline.hpp # SiLU pipeline (no softmax, batched/jagged) +├── hstu_attention_with_softmax_fwd_pipeline.hpp # Softmax pipeline +├── hstu_attention_no_softmax_fwd_trload_pipeline.hpp # SiLU pipeline (transposed load) +├── hstu_attention_with_softmax_fwd_trload_pipeline.hpp # Softmax pipeline (transposed load) +├── hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp +├── hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp +│ +│-- Policy / settings -- +├── hstu_attention_fwd_pipeline_policy.hpp # Pipeline policy selection +├── hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp +├── hstu_attention_fwd_setting.hpp # Tile/warp sizes and dispatch settings +├── hstu_attention_fwd_splitkv_combine_setting.hpp +├── hstu_attention_fwd_type_config.hpp # Type aliases (CompDataType, GemmAccDataType) +├── hstu_attention_pipeline_problem.hpp # Problem descriptor passed to pipelines +├── hstu_attention_traits.hpp # Trait structs for padding/occupancy +│ +│-- Masking -- +├── hstu_block_masking.hpp # Block-level masking logic +│ # (HstuSelfAttentionBlockMaskWithLocal, +│ # HstuCrossAttentionBlockMaskWithLocal) +├── hstu_attention_epilogue.hpp # Epilogue (output write-back + scaling) +│ +│-- Dispatch layer -- +├── hstu_attention_params.hpp # HstuAttentionNoGroupFwdParams / +│ # HstuAttentionGroupFwdParams structs +├── hstu_attention_api.hpp # Public C API declarations +├── hstu_attention_batched_forward_dispatch.hpp +├── hstu_attention_batched_forward_splitkv_dispatch.hpp +├── hstu_attention_jagged_forward_dispatch.hpp +├── hstu_attention_jagged_forward_splitkv_dispatch.hpp +├── hstu_attention_group_forward_dispatch.hpp +├── hstu_attention_group_forward_splitkv_dispatch.hpp +├── hstu_attention_no_group_forward_fp16.cpp # fp16 entry points (no-group) +├── hstu_attention_no_group_forward_bf16.cpp # bf16 entry points (no-group) +├── hstu_attention_group_forward_fp16.cpp # fp16 entry points (group) +├── hstu_attention_group_forward_bf16.cpp # bf16 entry points (group) +│ +│-- Switch helpers -- +├── hstu_attention_bool_switch.hpp # BOOL_SWITCH macros +├── hstu_attention_hdim_switch.hpp # Head-dim dispatch switch +├── hstu_attention_max_splits_switch.hpp # Max-splits dispatch switch +├── hstu_attention_splitkv_helper.hpp # SplitKV helper utilities +├── hstu_attention_tile_setting_define.hpp # Tile-size macro definitions +│ +│-- Utility / host -- +├── hstu_attention_host_util.hpp # Host-side helpers (offset computation, etc.) +├── hstu_attention_kernel_util.hpp # Kernel-side helpers +├── reference_hstu_attention_fwd.hpp # CPU reference implementation for validation +│ +│-- Custom GEMM hacks -- +├── block_gemm_areg_bsmem_creg_v2_hack_0.hpp +├── block_gemm_areg_bsmem_creg_v2_hack_1.hpp +├── block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp +│ +├── instances/ # Pre-compiled kernel instances (auto-generated) +│ ├── hstu_attention_{batched,group,jagged}_forward_{fp16,bf16}_*.cpp +│ └── hstu_attention_{batched,group,jagged}_forward_{fp16,bf16}_instances_ref.hpp +│ +├── scripts/ # Test and benchmark shell scripts +│ ├── test_hstu_attention.sh +│ ├── test_hstu_softmax_attention.sh +│ ├── test_group_hstu_attention.sh +│ ├── test_group_hstu_softmax_attention.sh +│ ├── test_hstu_cross_attention.sh +│ ├── test_hstu_attention_hdim96_hdim64.sh +│ ├── test_hstu_softmax_attention_hdim96_hdim64.sh +│ ├── test_ck_hstu_mask.sh +│ ├── test_cross_attention_with_sparsity.sh +│ ├── test_jagged_causal_mattn0_full0.sh +│ ├── test_jagged_causal_mattn256_full0.sh +│ ├── test_jagged_causal_mattn256_full256.sh +│ ├── bench_batched_causal.sh +│ ├── bench_jagged_causal.sh +│ ├── bench_jagged_causal_local.sh +│ ├── bench_jagged_causal_mattn0_full0.sh +│ ├── bench_jagged_causal_mattn256_full256.sh +│ ├── bench_jagged_causal_mattn256_full256_sparsity_90.sh +│ ├── bench_cross_attention_with_sparsity.sh +│ └── benchmark_hstu_attention.sh +│ +├── test_pytorch_hstu_mask.py # PyTorch mask validation script +└── test_pytorch_hstu_mask_v2.py +``` + +--- + +## Build + +```bash +mkdir build +cd build +../script/cmake-ck-dev.sh .. gfx942 -G Ninja # use: rocminfo | grep "gfx" to check GPU arch +ninja tile_example_hstu_attention # or: make -j tile_example_hstu_attention +``` + +The build target is `tile_example_hstu_attention` (excluded from `make all` by default). + +**Optional compile-time flags:** + +| Environment variable | Effect | +|---------------------|--------| +| `ASSUME_HIGHLY_VARIED_SEQLEN=1` | Schedules batch dimension as a non-leading grid dimension (trades occupancy for better load balance when sequence lengths vary widely) | + +On `gfx950`-only builds (`-DBUILD_HSTU_FOR_GFX95_ONLY`), SLP vectorization is disabled to +improve pipeline performance. + +--- + +## Test / Verify + +### No-group HSTU (single set of masking parameters) + +```bash +# Jagged batches, fp16/bf16, causal + local + context + targets +build/bin/tile_example_hstu_attention \ + -v=1 -prec=bf16 -b=10 -jagged=1 \ + -nhead=4 -hdim_qk=128 -hdim_v=128 \ + -seqlens=750,730,733,860,870,788,760,821,833,779 \ + -targets=5,5,6,6,5,6,5,6,4,6 \ + -causal=1 -local_len=5 -context_len=6 -minfull_len=6 + +# Run the full standard test suite +. example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh + +# Softmax variant +. example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh + +# Cross-attention +. example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention.sh + +# Asymmetric head dims (hdim_qk=96, hdim_v=64) +. example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_hdim96_hdim64.sh +``` + +### Group HSTU (per-group masking parameters) + +```bash +# 3 groups, 18 batches (6 per group), each group has distinct local/context/minfull/scale params +build/bin/tile_example_hstu_attention \ + -v=1 -prec=bf16 -b=18 -g=3 \ + -nhead=4 -hdim_qk=128 -hdim_v=128 \ + -seqlens=300,300,290,280,310,308,312 \ + -causal=1 -targets=8 \ + -g_max_seqlens=310,312,312 \ + -g_local_lens=5,5,5 \ + -g_context_lens=8,8,8 \ + -g_minfull_lens=7,7,7 \ + -g_attn_scales=0.0,0.1,0.0 + +# Run the full group test suite +. example/ck_tile/18_hstu_attention/scripts/test_group_hstu_attention.sh +``` + +--- + +## Command-line arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `-v` | `1` | Run CPU validation (0 = disabled) | +| `-prec` | `fp16` | Data type: `fp16` or `bf16` | +| `-g` | `1` | Number of groups (>1 enables group-HSTU mode) | +| `-b` | `12` | Number of batches | +| `-jagged` | `0` | Jagged sequence mode (1 = enabled) | +| `-nhead` | `4` | Number of attention heads | +| `-hdim_qk` | `64` | Head dimension for Q and K | +| `-hdim_v` | `64` | Head dimension for V and O | +| `-seqlens` | `400` | UIH sequence length(s); comma-separated for jagged mode | +| `-seqlens_kv` | (same as Q) | KV UIH lengths; enables cross-attention when set | +| `-max_seqlen` | `0` | Override max UIH seqlen for Q (0 = auto) | +| `-targets` | (empty) | Per-batch target token counts appended to Q (and K/V for self-attn) | +| `-softmax` | `0` | Use softmax instead of SiLU (1 = enabled) | +| `-training` | `0` | Training mode; saves LSE when softmax is also enabled | +| `-causal` | `1` | Enable lower-triangular causal mask | +| `-local_len` | `5` | Diagonal window size (0 = no local mask) | +| `-context_len` | `6` | Contextual prefix length always visible to all queries | +| `-minfull_len` | `6` | Tail UIH length that receives full (non-local) attention | +| `-alpha` | `0` | QK scale factor (`0` = `1/sqrt(hdim_qk)`) | +| `-attn_scale` | `0` | Post-SiLU scale (`0` = `1/max_seqlen`) | +| `-seed` | `13579` | RNG seed for random initialization | +| `-norm_dist` | `0` | Use normal distribution for QKV init (0 = uniform) | +| `-init_qkv` | `0` | Load Q/K/V from binary files `q.dat`, `k.dat`, `v.dat` | +| `-perf` | `0` | Measure and report average execution time and TFLOPS | +| `-dump_output` | `0` | Dump device and reference outputs to binary files | +| `-save_mask` | `0` | Save the attention mask tensor to `ck_hstu_mask.dat` | + +### Group-HSTU-specific arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `-g_max_seqlens` | `0` | Per-group max UIH seqlens (comma-separated) | +| `-g_local_lens` | `5,` | Per-group local window sizes | +| `-g_context_lens` | `6,` | Per-group contextual prefix lengths | +| `-g_minfull_lens` | `6` | Per-group min-full-attention tail lengths | +| `-g_attn_scales` | `1.0,` | Per-group post-SiLU scale factors | + +--- + +## Benchmark + +```bash +# Batched causal +. example/ck_tile/18_hstu_attention/scripts/bench_batched_causal.sh + +# Jagged causal +. example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal.sh + +# Jagged causal + local window +. example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_local.sh + +# With -perf=1 flag directly: +build/bin/tile_example_hstu_attention -v=0 -perf=1 -prec=bf16 \ + -b=32 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 \ + -seqlens=512 -causal=1 -local_len=5 -context_len=8 -minfull_len=8 +``` + +Performance output reports average kernel execution time (ms) and estimated TFLOPS, counting +only the two GEMMs (QK and PV), ignoring masking, scaling, and SiLU overhead. + +--- + +## Regenerating kernel instances + +The `instances/` directory is auto-generated by `generate_instances.py`. To regenerate after +changing template parameters (dtypes, head dims, causal/softmax/bias/dropout combinations): + +```bash +cd example/ck_tile/18_hstu_attention +python3 generate_instances.py +```