mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
Update the README.md according to the summary by claude code
This commit is contained in:
@@ -1,61 +1,335 @@
|
||||
# HSTU attention operator
|
||||
# HSTU Attention Forward Operator
|
||||
|
||||
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
|
||||
`v: [batches, seqlen, nhead, hdim_v]`, as well as several parameters that defines the functional masking to do the following:
|
||||
## Overview
|
||||
|
||||
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get tbe intermediate tensor `s: [batches, nhead, seqlen, seqlen]`
|
||||
* Update `s` by filtering it with a functional mask that includes a lower-triangular causal mask, a diagonal window causal mask and a sequence mask
|
||||
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get the intermediat tensor `p: [batches, nhead, seqlen, seqlen]`
|
||||
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get output tensor `o: [batches, seqlen_q, nhead, headsz_v]`
|
||||
HSTU (Hierarchical Sequential Transduction Unit) attention is a custom attention variant designed
|
||||
for recommendation-system workloads. Unlike standard softmax attention, HSTU uses **SiLU
|
||||
(Sigmoid Linear Unit)** as the non-linearity instead of softmax, together with a composite
|
||||
functional masking scheme.
|
||||
|
||||
## build
|
||||
### Forward computation
|
||||
|
||||
``` bash
|
||||
#> mkdir build
|
||||
#> cd build
|
||||
#> ../script/cmake-ck-dev.sh .. gfx942 -G Ninja ; use #> rocminfo |grep "gfx" to check your gpu arch
|
||||
#> ninja tile_example_hstu_attention ; or using make -j tile_example_hstu_attention ;
|
||||
```
|
||||
Given inputs:
|
||||
- `q : [batch, seqlen_q, nhead, hdim_qk]`
|
||||
- `k : [batch, seqlen_kv, nhead, hdim_qk]`
|
||||
- `v : [batch, seqlen_kv, nhead, hdim_v]`
|
||||
|
||||
## test/verify
|
||||
the forward pass performs:
|
||||
|
||||
``` bash
|
||||
#> build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6
|
||||
-causal=1 -local_len=5 -context_len=6 -minfull_len=6
|
||||
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
|
||||
```
|
||||
1. **QK projection** — `s[b, h, i, j] = scale_s * q[b, i, h, :] @ k[b, j, h, :]`
|
||||
2. **Functional masking** — set masked positions of `s` to 0 (see [Masking](#masking) below)
|
||||
3. **Non-linearity** — `p = SiLU(s)` (element-wise, over the `seqlen_kv` dimension);
|
||||
alternatively `p = attn_scale * softmax(s)` when `--softmax=1`
|
||||
4. **Output projection** — `o[b, i, h, :] = p[b, h, i, :] @ v[b, :, h, :]`
|
||||
|
||||
Check the example file `example_hstu_attention.cpp` for more information about the command-line arguments.
|
||||
Output: `o : [batch, seqlen_q, nhead, hdim_v]`
|
||||
|
||||
``` C++
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("g", "1", "num of attention group, bigger than 1 indicating group hstu")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
|
||||
.insert("b", "12", "number of batches")
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
.insert("seqlens", "400", "uih seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("g_max_seqlens", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("g_local_lens", "5,", "list of all group's length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("g_context_lens", "6,", "list of all group's sequence length at the begin of the query sequence that should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("g_minfull_lens", "6", "list of all groups's sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution")
|
||||
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("attn_scale", "0", "scale factor of SiLU(Q@K). 0 means using 1/max_seqlen for scaling")
|
||||
.insert("g_attn_scales", "1.0,", "list of all groups's scale factors of S=@@K. 0 means using 1/max_seqlen of the group for scaling")
|
||||
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
|
||||
.insert("save_mask", "0", "save the mask tensor to disk by the CPU validation codes")
|
||||
.insert("perf", "0", "weather measure execution time or not")
|
||||
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
|
||||
```
|
||||
---
|
||||
|
||||
## Masking
|
||||
|
||||
The functional mask is a composite of up to four components, all expressed in terms of the
|
||||
**UIH (User Interaction History)** positions, i.e. positions that are not targets and not
|
||||
contextual:
|
||||
|
||||
| Component | Parameter | Description |
|
||||
|-----------|-----------|-------------|
|
||||
| **Causal** | `--causal=1` | Standard lower-triangular causal mask |
|
||||
| **Local (diagonal window)** | `--local_len=N` | Allow each query token to attend only to the `N` most recent keys (sliding-window causal) |
|
||||
| **Contextual** | `--context_len=N` | The first `N` positions of the sequence are always visible to all queries |
|
||||
| **Min-full attention** | `--minfull_len=N` | The last `N` UIH positions attend to the full UIH history (no local restriction) |
|
||||
| **Targets** | `--targets=T[,T,...]` | Per-batch token counts at the *end* of the Q sequence that are excluded from attending |
|
||||
|
||||
Physical sequence length is computed as:
|
||||
```
|
||||
phy_seqlen_q = uih_seqlen + num_targets + context_len
|
||||
phy_seqlen_kv = uih_seqlen + num_targets + context_len (self-attention)
|
||||
phy_seqlen_kv = uih_seqlen_kv + context_len (cross-attention)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Modes of operation
|
||||
|
||||
### Batch modes
|
||||
|
||||
| Mode | Flag | Description |
|
||||
|------|------|-------------|
|
||||
| **Batched** | `-jagged=0` | All sequences in the batch share the same `seqlen`; tensors are `[B, S, H, D]` |
|
||||
| **Jagged** | `-jagged=1` | Each sequence has its own length given by `seq_offsets`; tensors are flattened `[1, total_tokens, H, D]` |
|
||||
|
||||
### Attention variants
|
||||
|
||||
| Variant | Flag | Description |
|
||||
|---------|------|-------------|
|
||||
| **No-group HSTU** | `-g=1` (default) | All batches share the same masking parameters (`local_len`, `context_len`, `minfull_len`, `attn_scale`) |
|
||||
| **Group HSTU** | `-g=G` (G > 1) | `num_batch` must be a multiple of `G`; each group has its own per-group masking parameters passed via `--g_*` flags |
|
||||
| **Self-attention** | (default) | `seqlen_kv` == `seqlen_q` |
|
||||
| **Cross-attention** | `--seqlens_kv=...` | Enabled implicitly when KV sequence lengths differ from Q |
|
||||
|
||||
### Non-linearity
|
||||
|
||||
| Mode | Flag | Description |
|
||||
|------|------|-------------|
|
||||
| **SiLU** | `-softmax=0` (default) | Element-wise SiLU; `attn_scale` applied afterwards |
|
||||
| **Softmax** | `-softmax=1` | Standard scaled softmax; LSE saved when `-training=1` |
|
||||
|
||||
### Split-KV
|
||||
|
||||
A split-KV kernel path (`hstu_attention_fwd_splitkv_kernel.hpp` /
|
||||
`hstu_attention_batched_forward_splitkv_dispatch.hpp` etc.) is also available, which splits the
|
||||
KV dimension across multiple work-groups and uses a separate combine pass
|
||||
(`hstu_attention_fwd_splitkv_combine_kernel.hpp`) to merge partial results.
|
||||
|
||||
---
|
||||
|
||||
## Data types
|
||||
|
||||
- **fp16** (`half_t`) and **bf16** (`bfloat16_t`) for Q/K/V/O
|
||||
- Accumulation and computation use `float32`
|
||||
- Supported head dimensions (`hdim_qk` / `hdim_v`): **64, 96, 128, 256**
|
||||
|
||||
---
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
18_hstu_attention/
|
||||
├── CMakeLists.txt # Build configuration
|
||||
├── example_hstu_attention_fwd.cpp # Driver / benchmark / validation harness
|
||||
├── generate_instances.py # Python script to regenerate pre-compiled instances
|
||||
│
|
||||
│-- Core kernel headers --
|
||||
├── hstu_attention_fwd_kernel.hpp # Main forward kernel
|
||||
├── hstu_attention_fwd_splitkv_kernel.hpp # Split-KV forward kernel
|
||||
├── hstu_attention_fwd_splitkv_combine_kernel.hpp # Split-KV combine kernel
|
||||
├── hstu_attention_no_softmax_fwd_pipeline.hpp # SiLU pipeline (no softmax, batched/jagged)
|
||||
├── hstu_attention_with_softmax_fwd_pipeline.hpp # Softmax pipeline
|
||||
├── hstu_attention_no_softmax_fwd_trload_pipeline.hpp # SiLU pipeline (transposed load)
|
||||
├── hstu_attention_with_softmax_fwd_trload_pipeline.hpp # Softmax pipeline (transposed load)
|
||||
├── hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp
|
||||
├── hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp
|
||||
│
|
||||
│-- Policy / settings --
|
||||
├── hstu_attention_fwd_pipeline_policy.hpp # Pipeline policy selection
|
||||
├── hstu_attention_fwd_splitkv_combine_pipeline_policy.hpp
|
||||
├── hstu_attention_fwd_setting.hpp # Tile/warp sizes and dispatch settings
|
||||
├── hstu_attention_fwd_splitkv_combine_setting.hpp
|
||||
├── hstu_attention_fwd_type_config.hpp # Type aliases (CompDataType, GemmAccDataType)
|
||||
├── hstu_attention_pipeline_problem.hpp # Problem descriptor passed to pipelines
|
||||
├── hstu_attention_traits.hpp # Trait structs for padding/occupancy
|
||||
│
|
||||
│-- Masking --
|
||||
├── hstu_block_masking.hpp # Block-level masking logic
|
||||
│ # (HstuSelfAttentionBlockMaskWithLocal,
|
||||
│ # HstuCrossAttentionBlockMaskWithLocal)
|
||||
├── hstu_attention_epilogue.hpp # Epilogue (output write-back + scaling)
|
||||
│
|
||||
│-- Dispatch layer --
|
||||
├── hstu_attention_params.hpp # HstuAttentionNoGroupFwdParams /
|
||||
│ # HstuAttentionGroupFwdParams structs
|
||||
├── hstu_attention_api.hpp # Public C API declarations
|
||||
├── hstu_attention_batched_forward_dispatch.hpp
|
||||
├── hstu_attention_batched_forward_splitkv_dispatch.hpp
|
||||
├── hstu_attention_jagged_forward_dispatch.hpp
|
||||
├── hstu_attention_jagged_forward_splitkv_dispatch.hpp
|
||||
├── hstu_attention_group_forward_dispatch.hpp
|
||||
├── hstu_attention_group_forward_splitkv_dispatch.hpp
|
||||
├── hstu_attention_no_group_forward_fp16.cpp # fp16 entry points (no-group)
|
||||
├── hstu_attention_no_group_forward_bf16.cpp # bf16 entry points (no-group)
|
||||
├── hstu_attention_group_forward_fp16.cpp # fp16 entry points (group)
|
||||
├── hstu_attention_group_forward_bf16.cpp # bf16 entry points (group)
|
||||
│
|
||||
│-- Switch helpers --
|
||||
├── hstu_attention_bool_switch.hpp # BOOL_SWITCH macros
|
||||
├── hstu_attention_hdim_switch.hpp # Head-dim dispatch switch
|
||||
├── hstu_attention_max_splits_switch.hpp # Max-splits dispatch switch
|
||||
├── hstu_attention_splitkv_helper.hpp # SplitKV helper utilities
|
||||
├── hstu_attention_tile_setting_define.hpp # Tile-size macro definitions
|
||||
│
|
||||
│-- Utility / host --
|
||||
├── hstu_attention_host_util.hpp # Host-side helpers (offset computation, etc.)
|
||||
├── hstu_attention_kernel_util.hpp # Kernel-side helpers
|
||||
├── reference_hstu_attention_fwd.hpp # CPU reference implementation for validation
|
||||
│
|
||||
│-- Custom GEMM hacks --
|
||||
├── block_gemm_areg_bsmem_creg_v2_hack_0.hpp
|
||||
├── block_gemm_areg_bsmem_creg_v2_hack_1.hpp
|
||||
├── block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp
|
||||
│
|
||||
├── instances/ # Pre-compiled kernel instances (auto-generated)
|
||||
│ ├── hstu_attention_{batched,group,jagged}_forward_{fp16,bf16}_*.cpp
|
||||
│ └── hstu_attention_{batched,group,jagged}_forward_{fp16,bf16}_instances_ref.hpp
|
||||
│
|
||||
├── scripts/ # Test and benchmark shell scripts
|
||||
│ ├── test_hstu_attention.sh
|
||||
│ ├── test_hstu_softmax_attention.sh
|
||||
│ ├── test_group_hstu_attention.sh
|
||||
│ ├── test_group_hstu_softmax_attention.sh
|
||||
│ ├── test_hstu_cross_attention.sh
|
||||
│ ├── test_hstu_attention_hdim96_hdim64.sh
|
||||
│ ├── test_hstu_softmax_attention_hdim96_hdim64.sh
|
||||
│ ├── test_ck_hstu_mask.sh
|
||||
│ ├── test_cross_attention_with_sparsity.sh
|
||||
│ ├── test_jagged_causal_mattn0_full0.sh
|
||||
│ ├── test_jagged_causal_mattn256_full0.sh
|
||||
│ ├── test_jagged_causal_mattn256_full256.sh
|
||||
│ ├── bench_batched_causal.sh
|
||||
│ ├── bench_jagged_causal.sh
|
||||
│ ├── bench_jagged_causal_local.sh
|
||||
│ ├── bench_jagged_causal_mattn0_full0.sh
|
||||
│ ├── bench_jagged_causal_mattn256_full256.sh
|
||||
│ ├── bench_jagged_causal_mattn256_full256_sparsity_90.sh
|
||||
│ ├── bench_cross_attention_with_sparsity.sh
|
||||
│ └── benchmark_hstu_attention.sh
|
||||
│
|
||||
├── test_pytorch_hstu_mask.py # PyTorch mask validation script
|
||||
└── test_pytorch_hstu_mask_v2.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
mkdir build
|
||||
cd build
|
||||
../script/cmake-ck-dev.sh .. gfx942 -G Ninja # use: rocminfo | grep "gfx" to check GPU arch
|
||||
ninja tile_example_hstu_attention # or: make -j tile_example_hstu_attention
|
||||
```
|
||||
|
||||
The build target is `tile_example_hstu_attention` (excluded from `make all` by default).
|
||||
|
||||
**Optional compile-time flags:**
|
||||
|
||||
| Environment variable | Effect |
|
||||
|---------------------|--------|
|
||||
| `ASSUME_HIGHLY_VARIED_SEQLEN=1` | Schedules batch dimension as a non-leading grid dimension (trades occupancy for better load balance when sequence lengths vary widely) |
|
||||
|
||||
On `gfx950`-only builds (`-DBUILD_HSTU_FOR_GFX95_ONLY`), SLP vectorization is disabled to
|
||||
improve pipeline performance.
|
||||
|
||||
---
|
||||
|
||||
## Test / Verify
|
||||
|
||||
### No-group HSTU (single set of masking parameters)
|
||||
|
||||
```bash
|
||||
# Jagged batches, fp16/bf16, causal + local + context + targets
|
||||
build/bin/tile_example_hstu_attention \
|
||||
-v=1 -prec=bf16 -b=10 -jagged=1 \
|
||||
-nhead=4 -hdim_qk=128 -hdim_v=128 \
|
||||
-seqlens=750,730,733,860,870,788,760,821,833,779 \
|
||||
-targets=5,5,6,6,5,6,5,6,4,6 \
|
||||
-causal=1 -local_len=5 -context_len=6 -minfull_len=6
|
||||
|
||||
# Run the full standard test suite
|
||||
. example/ck_tile/18_hstu_attention/scripts/test_hstu_attention.sh
|
||||
|
||||
# Softmax variant
|
||||
. example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh
|
||||
|
||||
# Cross-attention
|
||||
. example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention.sh
|
||||
|
||||
# Asymmetric head dims (hdim_qk=96, hdim_v=64)
|
||||
. example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_hdim96_hdim64.sh
|
||||
```
|
||||
|
||||
### Group HSTU (per-group masking parameters)
|
||||
|
||||
```bash
|
||||
# 3 groups, 18 batches (6 per group), each group has distinct local/context/minfull/scale params
|
||||
build/bin/tile_example_hstu_attention \
|
||||
-v=1 -prec=bf16 -b=18 -g=3 \
|
||||
-nhead=4 -hdim_qk=128 -hdim_v=128 \
|
||||
-seqlens=300,300,290,280,310,308,312 \
|
||||
-causal=1 -targets=8 \
|
||||
-g_max_seqlens=310,312,312 \
|
||||
-g_local_lens=5,5,5 \
|
||||
-g_context_lens=8,8,8 \
|
||||
-g_minfull_lens=7,7,7 \
|
||||
-g_attn_scales=0.0,0.1,0.0
|
||||
|
||||
# Run the full group test suite
|
||||
. example/ck_tile/18_hstu_attention/scripts/test_group_hstu_attention.sh
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Command-line arguments
|
||||
|
||||
| Argument | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `-v` | `1` | Run CPU validation (0 = disabled) |
|
||||
| `-prec` | `fp16` | Data type: `fp16` or `bf16` |
|
||||
| `-g` | `1` | Number of groups (>1 enables group-HSTU mode) |
|
||||
| `-b` | `12` | Number of batches |
|
||||
| `-jagged` | `0` | Jagged sequence mode (1 = enabled) |
|
||||
| `-nhead` | `4` | Number of attention heads |
|
||||
| `-hdim_qk` | `64` | Head dimension for Q and K |
|
||||
| `-hdim_v` | `64` | Head dimension for V and O |
|
||||
| `-seqlens` | `400` | UIH sequence length(s); comma-separated for jagged mode |
|
||||
| `-seqlens_kv` | (same as Q) | KV UIH lengths; enables cross-attention when set |
|
||||
| `-max_seqlen` | `0` | Override max UIH seqlen for Q (0 = auto) |
|
||||
| `-targets` | (empty) | Per-batch target token counts appended to Q (and K/V for self-attn) |
|
||||
| `-softmax` | `0` | Use softmax instead of SiLU (1 = enabled) |
|
||||
| `-training` | `0` | Training mode; saves LSE when softmax is also enabled |
|
||||
| `-causal` | `1` | Enable lower-triangular causal mask |
|
||||
| `-local_len` | `5` | Diagonal window size (0 = no local mask) |
|
||||
| `-context_len` | `6` | Contextual prefix length always visible to all queries |
|
||||
| `-minfull_len` | `6` | Tail UIH length that receives full (non-local) attention |
|
||||
| `-alpha` | `0` | QK scale factor (`0` = `1/sqrt(hdim_qk)`) |
|
||||
| `-attn_scale` | `0` | Post-SiLU scale (`0` = `1/max_seqlen`) |
|
||||
| `-seed` | `13579` | RNG seed for random initialization |
|
||||
| `-norm_dist` | `0` | Use normal distribution for QKV init (0 = uniform) |
|
||||
| `-init_qkv` | `0` | Load Q/K/V from binary files `q.dat`, `k.dat`, `v.dat` |
|
||||
| `-perf` | `0` | Measure and report average execution time and TFLOPS |
|
||||
| `-dump_output` | `0` | Dump device and reference outputs to binary files |
|
||||
| `-save_mask` | `0` | Save the attention mask tensor to `ck_hstu_mask.dat` |
|
||||
|
||||
### Group-HSTU-specific arguments
|
||||
|
||||
| Argument | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `-g_max_seqlens` | `0` | Per-group max UIH seqlens (comma-separated) |
|
||||
| `-g_local_lens` | `5,` | Per-group local window sizes |
|
||||
| `-g_context_lens` | `6,` | Per-group contextual prefix lengths |
|
||||
| `-g_minfull_lens` | `6` | Per-group min-full-attention tail lengths |
|
||||
| `-g_attn_scales` | `1.0,` | Per-group post-SiLU scale factors |
|
||||
|
||||
---
|
||||
|
||||
## Benchmark
|
||||
|
||||
```bash
|
||||
# Batched causal
|
||||
. example/ck_tile/18_hstu_attention/scripts/bench_batched_causal.sh
|
||||
|
||||
# Jagged causal
|
||||
. example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal.sh
|
||||
|
||||
# Jagged causal + local window
|
||||
. example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_local.sh
|
||||
|
||||
# With -perf=1 flag directly:
|
||||
build/bin/tile_example_hstu_attention -v=0 -perf=1 -prec=bf16 \
|
||||
-b=32 -jagged=1 -nhead=8 -hdim_qk=128 -hdim_v=128 \
|
||||
-seqlens=512 -causal=1 -local_len=5 -context_len=8 -minfull_len=8
|
||||
```
|
||||
|
||||
Performance output reports average kernel execution time (ms) and estimated TFLOPS, counting
|
||||
only the two GEMMs (QK and PV), ignoring masking, scaling, and SiLU overhead.
|
||||
|
||||
---
|
||||
|
||||
## Regenerating kernel instances
|
||||
|
||||
The `instances/` directory is auto-generated by `generate_instances.py`. To regenerate after
|
||||
changing template parameters (dtypes, head dims, causal/softmax/bias/dropout combinations):
|
||||
|
||||
```bash
|
||||
cd example/ck_tile/18_hstu_attention
|
||||
python3 generate_instances.py
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user