[CK_TILE] Use Unified Workspace for FMHA BWD
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Motivation
`dq_acc` is the intermediate accumulation buffer used in FMHA backward
pass for deterministic mode. The current implementation allocates it as
a **single rectangular tensor**:
```
shape = [shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q]
```
where `nsplits = launcher.dq_acc_splits` (a single scalar), computed
from `max_seqlen_k` and shared across all batches.
### Problems
1. **Memory waste**: In group mode, each batch may have a different
`seqlen_k`, but `nsplits` is computed from `max_seqlen_k`, causing
batches with shorter `seqlen_k` to over-allocate in the split dimension.
2. **Interface coupling**: `fmha_bwd_args` exposes internal layout
details such as `stride_dq_acc`, `nhead_stride_dq_acc`,
`batch_stride_dq_acc`, and `split_stride_dq_acc`. The caller is
responsible for computing these strides, but this logic belongs inside
the kernel.
### Goals
1. Switch `dq_acc` buffer to a **compact layout**: batches are
concatenated contiguously, with each batch occupying `nhead * nsplits_i
* seqq_i * hdim_q` elements (nhead outermost).
2. **Remove all `*_stride_dq_acc` fields** from `fmha_bwd_args`,
replacing them with a single `workspace_ptr`; the kernel splits this
internally using a fixed layout.
4. `fmha_bwd_launcher` provides a **workspace management interface**:
the caller only needs to allocate GPU memory and call
`prepare_workspace()` — no layout computation required.
5. **Isolate kernel internals from the caller API**: the `dq_acc` layout
(nsplits, strides, buffer size) is determined entirely inside the
launcher/kernel. Future changes to block shape, pipeline type, or
persistent kernel strategy require no modifications to the caller's
`fmha_bwd_args` or workspace allocation logic.
## Technical Details
### Interface Design
#### New fields in `fmha_bwd_traits`
```cpp
struct fmha_bwd_traits
{
int seqlen_q;
int seqlen_k;
int batch;
int max_seqlen_q;
int max_seqlen_k;
int hdim_q;
int hdim_v;
int nhead_q;
int nhead_k;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bias_enum bias_type;
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// New: cumulative physical seqlen pointers for group mode (pass nullptr for batch mode).
// seqstart_qs[i+1] - seqstart_qs[i] = physical seqlen_q of batch i (including padding); length = batch+1
// seqstart_ks[i+1] - seqstart_ks[i] = physical seqlen_k of batch i (including padding); length = batch+1
const int* seqstart_qs = nullptr;
const int* seqstart_ks = nullptr;
};
```
#### `fmha_bwd_launcher` actual structure
```cpp
struct fmha_bwd_launcher
{
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
// Total workspace size in bytes (host_ws_size + device_ws_size), computed by init().
// Zero for kUseQrQtrDorPipeline (writes dq directly, no acc buffer needed).
size_t workspace_size = 0;
fmha_bwd_launcher(const fmha_bwd_traits&);
// Copies auxiliary data (nsplits[], offsets[]) via hipMemcpy to the head of the GPU workspace,
// and zeros the dq_acc buffer portion (tail of workspace) if required.
// The memory pointed to by device_ws must be >= workspace_size bytes.
std::function<void(void* device_ws)> prepare_workspace{};
template <typename... Args>
float operator()(Args&&... args) const { return run(std::forward<Args>(args)...); }
private:
size_t host_ws_size = 0; // CPU workspace size (nsplits[] + offsets[] arrays)
size_t device_ws_size = 0; // GPU-only data size (dq_acc buffer)
std::unique_ptr<char[]> ws_host; // host-side workspace buffer
public:
template <typename T0, typename T1, typename T2, typename Arch>
void init(const fmha_bwd_traits& traits);
};
```
The `init<>()` template method (invoked by codegen dispatch branches as
`this->init<...>(t)`) is responsible for:
1. Setting the `run` lambda
2. Calling `FmhaBwdDQDKDVKernel::GetWorkspaceHostSize(batch)` to obtain
`host_ws_size`
3. Allocating `ws_host` (host memory)
4. Calling `FmhaBwdDQDKDVKernel::PrepareWorkspaceHost(ws_host.get(),
...)` to fill nsplits/offsets; return value is `device_ws_size`
5. `workspace_size = host_ws_size + device_ws_size`
6. Setting the `prepare_workspace` lambda (captures `this`, calls
`PrepareWorkspaceDevice`)
When no kernel matches the given traits, both `run` and
`prepare_workspace` are initialized to default lambdas that print a
warning to `std::cerr` and return gracefully (no exception).
#### Workspace overall layout
The workspace is managed by `FmhaBwdWorkspaceManager` and consists of
two segments:
```
Offset 0 (CPU-prepared segment, host_ws_size bytes; also hipMemcpy'd to the head of GPU workspace):
index_t nsplits[batch or 1] — per-batch nsplits array
group mode: batch elements
batch mode / non-deterministic: 1 element
[group mode only] long_index_t dq_acc_offsets[batch+1]
— per-batch element offset (inclusive prefix sum)
offsets[0]=0, offsets[i+1] = offsets[i] + nhead*nsplits_i*seqq_i*hdim_q
Offset host_ws_size (device data segment, device_ws_size bytes):
AccDataType dq_acc[total_elements] — compact dq_acc buffer (zeroed if required)
total_elements = sum_i(nhead * nsplits_i * seqq_i * hdim_q)
layout within each batch: [nhead, nsplits_i, seqq_i, hdim_q]
note: seqq_i uses the physical length (including padding)
```
Alignment constant (`ALIGNMENT = 16`):
```
nsplits_size = align_up(sizeof(index_t) * N, 16) // N = batch (group) or 1 (batch/non-det)
offsets_size = align_up(sizeof(long_index_t) * (batch+1), 16) // group mode only
host_ws_size = nsplits_size + offsets_size
dq_acc_offset = host_ws_size // GetDqAccDataOffset(batch)
```
**Key benefits**:
- The kernel reads nsplits/offsets directly from the workspace head — no
device-side recomputation.
- `FmhaBwdConvertQGradKernel` is completely decoupled from the pipeline
block shape (`kN0`): nsplits is read from `nsplits_ptr`, `kN0` is no
longer a template parameter, and multiple dq_dk_dv tiles with different
`F_bn0` values now share a single convert_dq kernel instance (under
receipt 1/2, deterministic convert_dq kernel count drops from ~300 to
60).
- nsplits/offsets are computed on the host and transferred in one
`hipMemcpy`; the dq_acc buffer follows immediately, at the offset given
by `GetDqAccDataOffset`.
#### Workspace size by scenario
| Scenario | `workspace_size` | Notes |
|----------|-----------------|-------|
| **kUseQrQtrDorPipeline** (any mode) | `0` | Writes dq directly; no acc
buffer; `PrepareWorkspaceHost` returns 0 |
| **Non-deterministic + batch mode** | `> 0` | nsplits[1]=1; dq_acc used
for atomic add; `workspace_size = host_ws_size +
batch*nhead*seqlen_q*hdim_q*ebytes` |
| **Non-deterministic + group mode** | `> 0` | nsplits[1]=1; dq_acc
contiguous layout; `workspace_size = host_ws_size +
nhead*seqstart_qs[batch]*hdim_q*ebytes` |
| **Deterministic + group mode** | `> 0` | nsplits[batch],
offsets[batch+1], compact dq_acc; nsplits_i computed independently per
batch |
| **Deterministic + batch mode persistent** | `> 0` | nsplits[1]
(uniform across batches); dq_acc `batch*nhead*nsplits*seqlen_q*hdim_q` |
**NeedsZeroDqAcc** (determines whether `PrepareWorkspaceDevice` calls
`hipMemset`):
- Persistent kernel (deterministic batch mode) or non-deterministic:
**must zero** (atomic add requires zero initialization)
- Deterministic group mode + no mask: **no zeroing needed** (every tile
writes its full region)
- Deterministic + with mask: **must zero** (some blocks are skipped,
leaving uninitialized tiles that would contribute to the reduction)
#### Caller usage
```cpp
// 1. Create launcher (traits include seqstart_qs/ks pointers; workspace_size is computed during construction)
fmha_bwd_launcher launcher(fmha_traits);
// 2. Read launcher.workspace_size directly
const auto ws_size = launcher.workspace_size;
// 3. Allocate a single GPU workspace
ck_tile::DeviceMem ws_buf(ws_size);
// 4. Copy nsplits/offsets to GPU head and zero dq_acc if required
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
// 5. Build args with a single workspace pointer; the kernel splits it internally
fmha_bwd_args args{
...,
ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr, // workspace_ptr
};
launcher(args, stream_config);
```
Composable Kernel
Note
The published documentation is available at Composable Kernel in an organized, easy-to-read format, with search and a table of contents. The documentation source files reside in the
docsfolder of this repository. As with all ROCm projects, the documentation is open source. For more information on contributing to the documentation, see Contribute to ROCm documentation.
The Composable Kernel (CK) library provides a programming model for writing performance-critical kernels for machine learning workloads across multiple architectures (GPUs, CPUs, etc.). The CK library uses general purpose kernel languages, such as HIP C++.
CK uses two concepts to achieve performance portability and code maintainability:
- A tile-based programming model
- Algorithm complexity reduction for complex machine learning (ML) operators. This uses an innovative technique called Tensor Coordinate Transformation.
The current CK library is structured into four layers:
- Templated Tile Operators
- Templated Kernel and Invoker
- Instantiated Kernel and Invoker
- Client API
General information
- CK supported operations
- CK Tile supported operations
- CK wrapper
- CK codegen
- CK profiler
- Examples (Custom use of CK supported operations)
- Client examples (Use of CK supported operations with instance factory)
- Terminology
- Contributors
CK is released under the MIT license.
Building CK
We recommend building CK inside Docker containers, which include all necessary packages. Pre-built Docker images are available on DockerHub.
-
To build a new Docker image, use the Dockerfile provided with the source code:
DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . -
Launch the Docker container:
docker run \ -it \ --privileged \ --group-add sudo \ -w /root/workspace \ -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ ck:latest \ /bin/bash -
Clone CK source code from the GitHub repository and start the build:
git clone https://github.com/ROCm/composable_kernel.git && \ cd composable_kernel && \ mkdir build && \ cd buildYou must set the
GPU_TARGETSmacro to specify the GPU target architecture(s) you want to run CK on. You can specify single or multiple architectures. If you specify multiple architectures, use a semicolon between each; for example,gfx908;gfx90a;gfx942.cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx908;gfx90a" \ ..If you don't set
GPU_TARGETSon the cmake command line, CK is built for all GPU targets supported by the current compiler (this may take a long time). Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line.NOTE: If you try setting
GPU_TARGETSto a list of architectures, the build will only work if the architectures are similar, e.g.,gfx908;gfx90a, orgfx1100;gfx1101;gfx11012. Otherwise, if you want to build the library for a list of different architectures, you should use theGPU_ARCHSbuild argument, for exampleGPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942.Convenience script for development builds:
Alternatively, you can use the provided convenience script
script/cmake-ck-dev.shwhich automatically configures CK for development with sensible defaults. In the build directory:../script/cmake-ck-dev.shThis script:
- Cleans CMake cache files before configuring
- Sets
BUILD_DEV=ONfor development mode - Defaults to GPU targets:
gfx908;gfx90a;gfx942 - Enables verbose makefile output
- Sets additional compiler flags for better error messages
By default, it considers the parent directory to be the project source directory.
You can specify the source directory as the first argument. You can specify custom GPU targets (semicolon-separated) as the second argument:
../script/cmake-ck-dev.sh .. gfx1100Or pass additional cmake arguments:
../script/cmake-ck-dev.sh .. gfx90a -DCMAKE_BUILD_TYPE=ReleaseFast iteration builds:
For faster CMake configuration during development (~5s vs ~150s), use the
--minimalflag to disable building device instances, profiler, examples, tutorials, and tests:../script/cmake-ck-dev.sh --minimal .. gfx90aYou can also specify a custom preset:
../script/cmake-ck-dev.sh --preset=dev-minimal .. gfx90a -
Build the entire CK library:
make -j"$(nproc)" -
Install CK:
make -j install
Building for Windows
Install TheRock and run CMake configure as
cmake \
-D CMAKE_PREFIX_PATH="C:/dist/TheRock" \
-D CMAKE_CXX_COMPILER="C:/dist/TheRock/bin/hipcc.exe" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx1151" \
-G Ninja \
..
Use Ninja to build either the whole library or individual targets.
Optional post-install steps
-
Build examples and tests:
make -j examples tests -
Build and run all examples and tests:
make -j checkYou can find instructions for running each individual example in example.
-
Build and run smoke/regression examples and tests:
make -j smoke # tests and examples that run for < 30 seconds eachmake -j regression # tests and examples that run for >= 30 seconds each -
Build ckProfiler:
make -j ckProfilerYou can find instructions for running ckProfiler in profiler.
-
Build our documentation locally:
cd docs pip3 install -r sphinx/requirements.txt python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html
Notes
The -j option for building with multiple threads in parallel, which speeds up the build significantly.
However, -j launches unlimited number of threads, which can cause the build to run out of memory and
crash. On average, you should expect each thread to use ~2Gb of RAM.
Depending on the number of CPU cores and the amount of RAM on your system, you may want to
limit the number of threads. For example, if you have a 128-core CPU and 128 Gb of RAM it's advisable to use -j32.
Additional cmake flags can be used to significantly speed-up the build:
-
DTYPES(default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build instances of select data types only. The main default data types are fp32 and fp16; you can safely skip other data types. -
DISABLE_DL_KERNELS(default is OFF) must be set to ON in order not to build instances, such asgemm_dlorbatched_gemm_multi_d_dl. These instances are useful on architectures like the NAVI2x, as most other platforms have faster instances, such asxdlorwmma, available. -
DISABLE_DPP_KERNELS(default is OFF) must be set to ON in order not to build instances, such asgemm_dpp. These instances offer a slightly better performance of fp16 gemms on NAVI2x. But on other architectures faster alternatives are available. -
CK_USE_FP8_ON_UNSUPPORTED_ARCH(default is OFF) must be set to ON in order to build instances, such asgemm_universal,gemm_universal_streamkandgemm_multiply_multiplyfor fp8 data type for GPU targets which do not have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on architectures like the MI100/MI200 for the functional support only.
Using sccache for building
The default CK Docker images come with a pre-installed version of sccache, which supports clang being used as hip-compiler (" -x hip"). Using sccache can help reduce the time to re-build code from hours to 1-2 minutes. In order to invoke sccache, you need to run:
sccache --start-server
then add the following flags to the cmake command line:
-DCMAKE_HIP_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache
You may need to clean up the build folder and repeat the cmake and make steps in order to take advantage of the sccache during subsequent builds.
Using CK as pre-built kernel library
You can find instructions for using CK as a pre-built kernel library in client_example.
Contributing to CK
When you contribute to CK, make sure you run clang-format on all changed files. We highly
recommend using git hooks that are managed by the pre-commit framework. To install hooks, run:
sudo script/install_precommit.sh
With this approach, pre-commit adds the appropriate hooks to your local repository and
automatically runs clang-format (and possibly additional checks) before any commit is created.
If you need to uninstall hooks from the repository, you can do so by running the following command:
script/uninstall_precommit.sh
If you need to temporarily disable pre-commit hooks, you can add the --no-verify option to the
git commit command.

