Files
composable_kernel/example/ck_tile/01_fmha
Yi DING 0398b864c3 [CK_TILE] Use Unified Workspace for FMHA BWD (#6152)
## Motivation
`dq_acc` is the intermediate accumulation buffer used in FMHA backward
pass for deterministic mode. The current implementation allocates it as
a **single rectangular tensor**:

```
shape = [shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q]
```

where `nsplits = launcher.dq_acc_splits` (a single scalar), computed
from `max_seqlen_k` and shared across all batches.

### Problems

1. **Memory waste**: In group mode, each batch may have a different
`seqlen_k`, but `nsplits` is computed from `max_seqlen_k`, causing
batches with shorter `seqlen_k` to over-allocate in the split dimension.

2. **Interface coupling**: `fmha_bwd_args` exposes internal layout
details such as `stride_dq_acc`, `nhead_stride_dq_acc`,
`batch_stride_dq_acc`, and `split_stride_dq_acc`. The caller is
responsible for computing these strides, but this logic belongs inside
the kernel.

### Goals

1. Switch `dq_acc` buffer to a **compact layout**: batches are
concatenated contiguously, with each batch occupying `nhead * nsplits_i
* seqq_i * hdim_q` elements (nhead outermost).
2. **Remove all `*_stride_dq_acc` fields** from `fmha_bwd_args`,
replacing them with a single `workspace_ptr`; the kernel splits this
internally using a fixed layout.
4. `fmha_bwd_launcher` provides a **workspace management interface**:
the caller only needs to allocate GPU memory and call
`prepare_workspace()` — no layout computation required.
5. **Isolate kernel internals from the caller API**: the `dq_acc` layout
(nsplits, strides, buffer size) is determined entirely inside the
launcher/kernel. Future changes to block shape, pipeline type, or
persistent kernel strategy require no modifications to the caller's
`fmha_bwd_args` or workspace allocation logic.

## Technical Details

### Interface Design

#### New fields in `fmha_bwd_traits`

```cpp
struct fmha_bwd_traits
{
    int seqlen_q;
    int seqlen_k;
    int batch;
    int max_seqlen_q;
    int max_seqlen_k;
    int hdim_q;
    int hdim_v;
    int nhead_q;
    int nhead_k;
    std::string data_type;
    bool is_group_mode;
    mask_enum mask_type;
    bias_enum bias_type;
    bool has_dbias;
    bool has_dropout;
    bool is_store_randval;
    bool is_deterministic;
    // New: cumulative physical seqlen pointers for group mode (pass nullptr for batch mode).
    // seqstart_qs[i+1] - seqstart_qs[i] = physical seqlen_q of batch i (including padding); length = batch+1
    // seqstart_ks[i+1] - seqstart_ks[i] = physical seqlen_k of batch i (including padding); length = batch+1
    const int* seqstart_qs = nullptr;
    const int* seqstart_ks = nullptr;
};
```

#### `fmha_bwd_launcher` actual structure

```cpp
struct fmha_bwd_launcher
{
    std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};

    // Total workspace size in bytes (host_ws_size + device_ws_size), computed by init().
    // Zero for kUseQrQtrDorPipeline (writes dq directly, no acc buffer needed).
    size_t workspace_size = 0;

    fmha_bwd_launcher(const fmha_bwd_traits&);

    // Copies auxiliary data (nsplits[], offsets[]) via hipMemcpy to the head of the GPU workspace,
    // and zeros the dq_acc buffer portion (tail of workspace) if required.
    // The memory pointed to by device_ws must be >= workspace_size bytes.
    std::function<void(void* device_ws)> prepare_workspace{};

    template <typename... Args>
    float operator()(Args&&... args) const { return run(std::forward<Args>(args)...); }

private:
    size_t host_ws_size   = 0;  // CPU workspace size (nsplits[] + offsets[] arrays)
    size_t device_ws_size = 0;  // GPU-only data size (dq_acc buffer)
    std::unique_ptr<char[]> ws_host;  // host-side workspace buffer

public:
    template <typename T0, typename T1, typename T2, typename Arch>
    void init(const fmha_bwd_traits& traits);
};
```

The `init<>()` template method (invoked by codegen dispatch branches as
`this->init<...>(t)`) is responsible for:
1. Setting the `run` lambda
2. Calling `FmhaBwdDQDKDVKernel::GetWorkspaceHostSize(batch)` to obtain
`host_ws_size`
3. Allocating `ws_host` (host memory)
4. Calling `FmhaBwdDQDKDVKernel::PrepareWorkspaceHost(ws_host.get(),
...)` to fill nsplits/offsets; return value is `device_ws_size`
5. `workspace_size = host_ws_size + device_ws_size`
6. Setting the `prepare_workspace` lambda (captures `this`, calls
`PrepareWorkspaceDevice`)

When no kernel matches the given traits, both `run` and
`prepare_workspace` are initialized to default lambdas that print a
warning to `std::cerr` and return gracefully (no exception).

#### Workspace overall layout

The workspace is managed by `FmhaBwdWorkspaceManager` and consists of
two segments:

```
Offset 0 (CPU-prepared segment, host_ws_size bytes; also hipMemcpy'd to the head of GPU workspace):
  index_t nsplits[batch or 1]       — per-batch nsplits array
                                      group mode: batch elements
                                      batch mode / non-deterministic: 1 element
  [group mode only] long_index_t dq_acc_offsets[batch+1]
                                    — per-batch element offset (inclusive prefix sum)
                                      offsets[0]=0, offsets[i+1] = offsets[i] + nhead*nsplits_i*seqq_i*hdim_q

Offset host_ws_size (device data segment, device_ws_size bytes):
  AccDataType dq_acc[total_elements] — compact dq_acc buffer (zeroed if required)
                                       total_elements = sum_i(nhead * nsplits_i * seqq_i * hdim_q)
                                       layout within each batch: [nhead, nsplits_i, seqq_i, hdim_q]
                                       note: seqq_i uses the physical length (including padding)
```

Alignment constant (`ALIGNMENT = 16`):
```
nsplits_size  = align_up(sizeof(index_t) * N, 16)          // N = batch (group) or 1 (batch/non-det)
offsets_size  = align_up(sizeof(long_index_t) * (batch+1), 16)  // group mode only
host_ws_size  = nsplits_size + offsets_size
dq_acc_offset = host_ws_size  // GetDqAccDataOffset(batch)
```

**Key benefits**:
- The kernel reads nsplits/offsets directly from the workspace head — no
device-side recomputation.
- `FmhaBwdConvertQGradKernel` is completely decoupled from the pipeline
block shape (`kN0`): nsplits is read from `nsplits_ptr`, `kN0` is no
longer a template parameter, and multiple dq_dk_dv tiles with different
`F_bn0` values now share a single convert_dq kernel instance (under
receipt 1/2, deterministic convert_dq kernel count drops from ~300 to
60).
- nsplits/offsets are computed on the host and transferred in one
`hipMemcpy`; the dq_acc buffer follows immediately, at the offset given
by `GetDqAccDataOffset`.

#### Workspace size by scenario

| Scenario | `workspace_size` | Notes |
|----------|-----------------|-------|
| **kUseQrQtrDorPipeline** (any mode) | `0` | Writes dq directly; no acc
buffer; `PrepareWorkspaceHost` returns 0 |
| **Non-deterministic + batch mode** | `> 0` | nsplits[1]=1; dq_acc used
for atomic add; `workspace_size = host_ws_size +
batch*nhead*seqlen_q*hdim_q*ebytes` |
| **Non-deterministic + group mode** | `> 0` | nsplits[1]=1; dq_acc
contiguous layout; `workspace_size = host_ws_size +
nhead*seqstart_qs[batch]*hdim_q*ebytes` |
| **Deterministic + group mode** | `> 0` | nsplits[batch],
offsets[batch+1], compact dq_acc; nsplits_i computed independently per
batch |
| **Deterministic + batch mode persistent** | `> 0` | nsplits[1]
(uniform across batches); dq_acc `batch*nhead*nsplits*seqlen_q*hdim_q` |

**NeedsZeroDqAcc** (determines whether `PrepareWorkspaceDevice` calls
`hipMemset`):
- Persistent kernel (deterministic batch mode) or non-deterministic:
**must zero** (atomic add requires zero initialization)
- Deterministic group mode + no mask: **no zeroing needed** (every tile
writes its full region)
- Deterministic + with mask: **must zero** (some blocks are skipped,
leaving uninitialized tiles that would contribute to the reduction)

#### Caller usage

```cpp
// 1. Create launcher (traits include seqstart_qs/ks pointers; workspace_size is computed during construction)
fmha_bwd_launcher launcher(fmha_traits);

// 2. Read launcher.workspace_size directly
const auto ws_size = launcher.workspace_size;

// 3. Allocate a single GPU workspace
ck_tile::DeviceMem ws_buf(ws_size);

// 4. Copy nsplits/offsets to GPU head and zero dq_acc if required
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());

// 5. Build args with a single workspace pointer; the kernel splits it internally
fmha_bwd_args args{
    ...,
    ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr,  // workspace_ptr
};
launcher(args, stream_config);
```

---

### Key Code Structure

#### FmhaBwdWorkspaceManager (`fmha_bwd_kernel.hpp`, new class)

```cpp
template <typename AccDataType, bool kIsGroupMode, bool kIsDeterministic>
struct FmhaBwdWorkspaceManager
{
    static constexpr size_t ALIGNMENT = 16;

    // CPU workspace (nsplits + offsets) sizes
    static size_t GetDqAccSplitsSize(int batch);   // align_up(sizeof(index_t)*N, 16)
    static size_t GetDqAccOffsetsSize(int batch);  // group mode only: align_up(sizeof(long_index_t)*(batch+1), 16)
    static size_t GetWorkspaceHostSize(int batch);  // = SplitsSize + OffsetsSize

    // Starting offset of dq_acc data within the full workspace (= host_ws_size)
    static size_t GetDqAccDataOffset(int batch);   // = GetWorkspaceHostSize(batch)

    // Fills nsplits/offsets in the CPU workspace; returns device_ws_size (dq_acc buffer bytes)
    template <bool kUseQrQtrDorPipeline, index_t kN0>
    static size_t PrepareWorkspaceHost(void* cpu_ws, index_t batch_size, index_t hdim_q,
                                       index_t nhead_q, index_t seqlen_q, index_t seqlen_k,
                                       const index_t* seqstart_qs, const index_t* seqstart_ks);

    // hipMemcpy's cpu_ws to device_ws head; hipMemset's the dq_acc portion to 0 if required
    template <bool kUseQrQtrDorPipeline, bool kHasMask>
    static void PrepareWorkspaceDevice(void* device_ws, const void* host_ws,
                                       size_t device_ws_size, size_t host_ws_size);
};
```

#### workspace_ptr parsing (inside the kernel)

The kernel parses three address regions from `kargs.workspace_ptr`:

**Group mode (`FmhaBwdDQDKDVKernel::MakeKargs`)**:
```cpp
const uint8_t* ws = reinterpret_cast<uint8_t*>(workspace_ptr);
// dq_acc_ptr (stored in FmhaBwdCommonKargs)
ws + WorkspaceManager::GetDqAccDataOffset(batch)
// dq_acc_batch_offset_ptr (FmhaBwdGroupModeKargs field)
reinterpret_cast<const long_index_t*>(ws + WorkspaceManager::GetDqAccOffsetsOffset(batch))
```

**Batch mode**:
```cpp
ws + WorkspaceManager::GetDqAccDataOffset(batch)  // dq_acc_ptr
// No offsets pointer; batch offset is computed inside run_() from nsplits
```

**`FmhaBwdConvertQGradKernel`** follows the same pattern:
- Group mode: extracts `dq_acc_ptr`, `dq_acc_batch_offset_ptr`, and
`nsplits_ptr` (`GetDqAccSplitsOffset(batch)`) from workspace
- Batch mode: reads nsplits from `nsplits_ptr[0]`; batch offset computed
internally

### Addressing in `run_()` (group mode)

```cpp
// Per-batch processing:
const long_index_t batch_offset_dq_acc = kargs.dq_acc_batch_offset_ptr[i_batch];
// seqq_i (physical length) derived from seqstart_q_ptr
const index_t seqq_i = kargs.seqstart_q_ptr[i_batch+1] - kargs.seqstart_q_ptr[i_batch];
// nsplits_i read from nsplits_ptr (convert_dq kernel) or from GetDqAccSplits
const long_index_t split_stride_i = static_cast<long_index_t>(seqq_i) * kargs.hdim_q;
const long_index_t nhead_stride_i = static_cast<long_index_t>(nsplits_i) * split_stride_i;
// Final address:
dq_acc_base + batch_offset_dq_acc + i_nhead * nhead_stride_i + i_split * split_stride_i
```

#### nsplits computation (`PrepareWorkspaceHost`)

`PrepareWorkspaceHost` is a template method of `FmhaBwdWorkspaceManager`
that still takes `kN0` as a template parameter (from
`BlockFmhaShape::kN0` of the dq_dk_dv pipeline). However, this parameter
is **only used inside this host-side function** to compute nsplits — it
is no longer passed into the convert_dq kernel.

| Mode | nsplits computation |
|------|---------------------|
| kUseQrQtrDorPipeline | Writes dq directly; nsplits[0]=0; returns
device_ws_size=0 |
| Non-deterministic | nsplits[0]=1; dq_acc used for atomic add |
| Deterministic + group mode | `ceil((seqstart_ks[i+1]-seqstart_ks[i]) /
kN0)` computed per batch |
| Deterministic + batch mode persistent | Same logic as the original
`GetDqAccSplits` (`dqdqkdv_workers` based) |

### Removing kN0 dependency from `FmhaBwdConvertQGradKernel`

`FmhaBwdConvertQGradKernel` previously required `kN0` as a template
parameter (via `BlockFmhaBwdConvertQGradPipelineProblem`) for two
purposes:
1. In batch mode `operator()`: self-computing `nsplits = ceil(seqlen_k /
kN0)`
2. The `b{kM0}x{kN0}` component of the kernel name string

Both have been removed in this refactor:
- **Batch mode**: now reads `kargs.nsplits_ptr[0]` directly (guarded by
`if constexpr(kIsDeterministic)` to avoid accessing a non-existent field
in non-deterministic instances)
- **Kernel name**: simplified to `b{kM0}`, no longer includes `kN0`
- **Template parameters**: `BlockFmhaBwdConvertQGradPipelineProblem`
drops the `kN0_` parameter; `fmha_bwd_convert_dq_traits_` drops the
`kN0` parameter; `F_bn0`/`convert_dq_bn0` fields removed from codegen

Effect: all dq_dk_dv tiles sharing the same `(hdim, dtype, mode, pad,
deterministic)` combination — regardless of `F_bn0` value
(16/64/128/192/256) — now share a **single** convert_dq kernel instance.

---

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-07 10:22:28 +08:00
..
2024-04-15 19:27:12 -05:00

fused multi-head attention

This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.

build

# 1. In the root of composable_kernel project, create the build directory.
[~/composable_kernel] mkdir build && cd build
# 2. In the build directory, run the CMake wrapper script to generate the build system files. Replace <arch> with the gfx architectures string.
[~/composable_kernel/build] ../script/cmake-ck-dev.sh .. <arch> -G Ninja
# 3. In the build directory, run the build system recipe.
[~/composable_kernel/build] ninja tile_example_fmha_fwd

Running the build recipe will produce the executable tile_example_fmha_fwd.

The executables reside in bin subdirectory of the build directory.

This example provides recipes for tile_example_fmha_fwd, tile_example_fmha_bwd.

Note

cmake-ck-dev.sh is a CMake wrapper.

The first argument is the path to composable_kernel sources.

The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942").

The remaining arguments are optional and are passed through to CMake. E.g. -G Ninja specifies ninja as the build system.

kernel

The kernel template is fmha_fwd_kernel.hpp, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.

There are 2 template parameters for this kernel template.

  • FmhaPipeline is one of the block_tile_pipeline(under include/ck_tile/tile_program/block_tile_pipeline) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
  • EpiloguePipeline will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.

codegen

To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by generate.py script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in FMHA_FWD_KERNEL_BODY variable.

executable

tile_example_fmha_fwd is the example executable, implemented in fmha_fwd.cpp. You can type ./bin/tile_example_fmha_fwd -? to list all the arguments. Below is an example of the output (may subject to change)

args:
          -v    weather do CPU validation or not (default:1)
       -mode    kernel mode. 0:batch, 1:group (default:0)
          -b    batch size (default:2)
          -h    num of head, for q (default:8)
        -h_k    num of head, for k/v, -1 means equal to h (default:-1)
                if not equal to h, then this is GQA/MQA case
          -s    seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
                total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
                also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
        -s_k    seqlen_k (including new key/value), -1 means equal to s (default:-1)
                also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode)
     -s_qpad    seqlen_q stride between 2 batches (group-mode optional) (default:-1)
                Provide positive strides per-batch to simulate physical padding on Q
     -s_kpad    seqlen_k stride between 2 batches, currently used in group-mode only  (default:-1)
                for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
                along seqlen, instead of packed, same as xformer kv_padding,
                must be greater than or equal to s_k
          -d    head dim for q, k (default:128)
        -d_v    head dim for v, -1 means equal to d (default:-1)
    -scale_s    scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
     -qscale    n or 0, no scaling (default:n)
                pt or 1, per-tensor scale
                bs or 2, block scale
                kvbs or 3, Q per-tensor, K/V per-page block scale, only in batch_prefill
                mx or 4, microscaling (exclusively for mxfp8/mxfp4)
      -iperm    permute input (default:1)
                if true, will be b*h*s*d, else b*s*h*d
      -operm    permute output (default:1)
       -bias    n or 0, no bias (default:n)
                e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
                a(libi) or 2, alibi with 1*h. a:1, b*h
       -prec    data type. fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4 (default:fp16)
       -mask    0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
                't', top-left causal mask, 'b', bottom-r causal mask
                't:l,r', top-left sliding window attn(swa) with FA style left right size
                'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
                'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
                'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
                'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
    -vlayout    r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
        -lse    0 not store lse, 1 store lse (default:0)
      -kname    if set to 1 will print kernel name (default:0)
       -init    init method. ui, uniform random int, ni, normalized random int (default:uf)
                uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
       -seed    random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
  -drop_seed    seed for random number generator (default:1)
-drop_offset    offset for random number generator (default:0)
 -drop_prefs    seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
 -num_splits    number of splits for key/value. 0 to determine actual number by heuristic (default:1)
     -warmup    number of iterations before benchmark the kernel (default:5)
     -repeat    number of iterations to benchmark the kernel (default:20)
       -json    0: No Json, 1: Dump Results in Json format (default:0)
   -jsonfile    json file name to dump results (default:fmha_fwd.json)
 -q_eff_lens    Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
                Comma-separated list of length 'b'. If empty, no override
-kv_eff_lens    Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
                Comma-separated list of length 'b'. If empty, no override

Example 1: ./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128 will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example 2: ./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234 will run a fmha case with batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case

Padding Examples

Example 3 (Group mode with padding): ./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128 will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.

Example 4 (Batch mode with effective lengths): ./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536 will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.

support features

Currently we are still in rapid development stage, so more features/optimizations will be coming soon.

hdim

Currently we support 32/64/128/256 hdim for fp16/bf16, within which 64/128 is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of qr pipeline (we didn't generate this in generate.py by default)

group/batch mode

Currently we support both batch mode and group mode (or varlen, in FA's term), by setting -mode = 0 or 1. In group mode different kind of attention mask is also supported(see below)

MQA/GQA

By setting -h(nhead for q) and -h_k(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that h % h_K == 0 when you set different numbers.

input/output permute, and b*s*3*h*d

If you look at the kernel argument inside fmha_fwd_kernel.hpp, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support b*h*s*d or b*s*h*d input/output permute. The -iperm=0/1, -operm=0/1 is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test b*s*3*h*d layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper stride_q/k/v value as 3*h*d.

attention bias

Attention bias is supported with the layout of 1*1*s*s(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to b*h*s*s) and bias value in float number.

alibi

alibi is supported

lse

For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting -lse=1

vlayout

We support v matrix in both row-major(seqlen*hdim) and col-major(hdim*seqlen). Since the accumulate(reduce) dimension for V is along seqlen, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the -vlayout=r/c here to switch/test between different layouts.

attention mask

we support causal mask and sliding window attention(swa) mask in both batch and group mode, either from top-left or bottom-right. Underneath, we unify the mask expression into generic attention mask coordinate, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.

Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.

mask case cmdline FA style xformer style
no mask -mask=0(default)
causal mask from top-left -mask=1 or -mask=t -mask=t:-1,0 -mask=xt:-1
causal mask from bottom-right -mask=2 or -mask=b -mask=b:-1,0 -mask=xb:-1
swa from top-left -mask=t:3,5 -mask=xt:4
swa from bottom-right -mask=b:10,11 -mask=xb:16

Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.

dropout

TBD

sequence padding and variable length support

We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.

Group Mode Padding: Use -s_qpad and -s_kpad to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (-s, -s_k) but use larger physical strides for memory alignment.

Batch Mode Variable Length: Use -q_eff_lens and -kv_eff_lens to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste.

Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.

FP8 support

FP8 FMHA kernels are supported on gfx942/gfx950 machines with ROCm 6.0+. Three fp8-based precision modes are available via -prec:

-prec value Q/K/V input type Output type Description
fp8 fp8 fp8 Fully fp8: both inputs and output are in fp8
fp8bf16 fp8 bf16 Mixed precision: fp8 inputs, bf16 output — useful when the consumer expects a wider-range output format
fp8fp32 fp8 fp32 Mixed precision: fp8 inputs, fp32 output — highest-precision output, suitable for debugging or further fp32 processing

The following quantization scale modes are available via -qscale:

-qscale value Description
n or 0 No quantization scale (default)
pt or 1 Per-tensor quantization scale — a single scale factor is applied to the entire tensor
bs or 2 Per-block quantization scale — a scale factor is applied per block of elements
kvbs or 3 Q per-tensor + K/V per-page block scale (batch_prefill only)
mx or 4 Microscaling (MX format), exclusively for mxfp8 and mxfp4 data types

Currently only -vlayout=r (seqlen*hdim for V matrix) is supported for fp8 data types.