mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[CK_TILE] Use Unified Workspace for FMHA BWD
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Motivation
`dq_acc` is the intermediate accumulation buffer used in FMHA backward
pass for deterministic mode. The current implementation allocates it as
a **single rectangular tensor**:
```
shape = [shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q]
```
where `nsplits = launcher.dq_acc_splits` (a single scalar), computed
from `max_seqlen_k` and shared across all batches.
### Problems
1. **Memory waste**: In group mode, each batch may have a different
`seqlen_k`, but `nsplits` is computed from `max_seqlen_k`, causing
batches with shorter `seqlen_k` to over-allocate in the split dimension.
2. **Interface coupling**: `fmha_bwd_args` exposes internal layout
details such as `stride_dq_acc`, `nhead_stride_dq_acc`,
`batch_stride_dq_acc`, and `split_stride_dq_acc`. The caller is
responsible for computing these strides, but this logic belongs inside
the kernel.
### Goals
1. Switch `dq_acc` buffer to a **compact layout**: batches are
concatenated contiguously, with each batch occupying `nhead * nsplits_i
* seqq_i * hdim_q` elements (nhead outermost).
2. **Remove all `*_stride_dq_acc` fields** from `fmha_bwd_args`,
replacing them with a single `workspace_ptr`; the kernel splits this
internally using a fixed layout.
4. `fmha_bwd_launcher` provides a **workspace management interface**:
the caller only needs to allocate GPU memory and call
`prepare_workspace()` — no layout computation required.
5. **Isolate kernel internals from the caller API**: the `dq_acc` layout
(nsplits, strides, buffer size) is determined entirely inside the
launcher/kernel. Future changes to block shape, pipeline type, or
persistent kernel strategy require no modifications to the caller's
`fmha_bwd_args` or workspace allocation logic.
## Technical Details
### Interface Design
#### New fields in `fmha_bwd_traits`
```cpp
struct fmha_bwd_traits
{
int seqlen_q;
int seqlen_k;
int batch;
int max_seqlen_q;
int max_seqlen_k;
int hdim_q;
int hdim_v;
int nhead_q;
int nhead_k;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bias_enum bias_type;
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// New: cumulative physical seqlen pointers for group mode (pass nullptr for batch mode).
// seqstart_qs[i+1] - seqstart_qs[i] = physical seqlen_q of batch i (including padding); length = batch+1
// seqstart_ks[i+1] - seqstart_ks[i] = physical seqlen_k of batch i (including padding); length = batch+1
const int* seqstart_qs = nullptr;
const int* seqstart_ks = nullptr;
};
```
#### `fmha_bwd_launcher` actual structure
```cpp
struct fmha_bwd_launcher
{
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
// Total workspace size in bytes (host_ws_size + device_ws_size), computed by init().
// Zero for kUseQrQtrDorPipeline (writes dq directly, no acc buffer needed).
size_t workspace_size = 0;
fmha_bwd_launcher(const fmha_bwd_traits&);
// Copies auxiliary data (nsplits[], offsets[]) via hipMemcpy to the head of the GPU workspace,
// and zeros the dq_acc buffer portion (tail of workspace) if required.
// The memory pointed to by device_ws must be >= workspace_size bytes.
std::function<void(void* device_ws)> prepare_workspace{};
template <typename... Args>
float operator()(Args&&... args) const { return run(std::forward<Args>(args)...); }
private:
size_t host_ws_size = 0; // CPU workspace size (nsplits[] + offsets[] arrays)
size_t device_ws_size = 0; // GPU-only data size (dq_acc buffer)
std::unique_ptr<char[]> ws_host; // host-side workspace buffer
public:
template <typename T0, typename T1, typename T2, typename Arch>
void init(const fmha_bwd_traits& traits);
};
```
The `init<>()` template method (invoked by codegen dispatch branches as
`this->init<...>(t)`) is responsible for:
1. Setting the `run` lambda
2. Calling `FmhaBwdDQDKDVKernel::GetWorkspaceHostSize(batch)` to obtain
`host_ws_size`
3. Allocating `ws_host` (host memory)
4. Calling `FmhaBwdDQDKDVKernel::PrepareWorkspaceHost(ws_host.get(),
...)` to fill nsplits/offsets; return value is `device_ws_size`
5. `workspace_size = host_ws_size + device_ws_size`
6. Setting the `prepare_workspace` lambda (captures `this`, calls
`PrepareWorkspaceDevice`)
When no kernel matches the given traits, both `run` and
`prepare_workspace` are initialized to default lambdas that print a
warning to `std::cerr` and return gracefully (no exception).
#### Workspace overall layout
The workspace is managed by `FmhaBwdWorkspaceManager` and consists of
two segments:
```
Offset 0 (CPU-prepared segment, host_ws_size bytes; also hipMemcpy'd to the head of GPU workspace):
index_t nsplits[batch or 1] — per-batch nsplits array
group mode: batch elements
batch mode / non-deterministic: 1 element
[group mode only] long_index_t dq_acc_offsets[batch+1]
— per-batch element offset (inclusive prefix sum)
offsets[0]=0, offsets[i+1] = offsets[i] + nhead*nsplits_i*seqq_i*hdim_q
Offset host_ws_size (device data segment, device_ws_size bytes):
AccDataType dq_acc[total_elements] — compact dq_acc buffer (zeroed if required)
total_elements = sum_i(nhead * nsplits_i * seqq_i * hdim_q)
layout within each batch: [nhead, nsplits_i, seqq_i, hdim_q]
note: seqq_i uses the physical length (including padding)
```
Alignment constant (`ALIGNMENT = 16`):
```
nsplits_size = align_up(sizeof(index_t) * N, 16) // N = batch (group) or 1 (batch/non-det)
offsets_size = align_up(sizeof(long_index_t) * (batch+1), 16) // group mode only
host_ws_size = nsplits_size + offsets_size
dq_acc_offset = host_ws_size // GetDqAccDataOffset(batch)
```
**Key benefits**:
- The kernel reads nsplits/offsets directly from the workspace head — no
device-side recomputation.
- `FmhaBwdConvertQGradKernel` is completely decoupled from the pipeline
block shape (`kN0`): nsplits is read from `nsplits_ptr`, `kN0` is no
longer a template parameter, and multiple dq_dk_dv tiles with different
`F_bn0` values now share a single convert_dq kernel instance (under
receipt 1/2, deterministic convert_dq kernel count drops from ~300 to
60).
- nsplits/offsets are computed on the host and transferred in one
`hipMemcpy`; the dq_acc buffer follows immediately, at the offset given
by `GetDqAccDataOffset`.
#### Workspace size by scenario
| Scenario | `workspace_size` | Notes |
|----------|-----------------|-------|
| **kUseQrQtrDorPipeline** (any mode) | `0` | Writes dq directly; no acc
buffer; `PrepareWorkspaceHost` returns 0 |
| **Non-deterministic + batch mode** | `> 0` | nsplits[1]=1; dq_acc used
for atomic add; `workspace_size = host_ws_size +
batch*nhead*seqlen_q*hdim_q*ebytes` |
| **Non-deterministic + group mode** | `> 0` | nsplits[1]=1; dq_acc
contiguous layout; `workspace_size = host_ws_size +
nhead*seqstart_qs[batch]*hdim_q*ebytes` |
| **Deterministic + group mode** | `> 0` | nsplits[batch],
offsets[batch+1], compact dq_acc; nsplits_i computed independently per
batch |
| **Deterministic + batch mode persistent** | `> 0` | nsplits[1]
(uniform across batches); dq_acc `batch*nhead*nsplits*seqlen_q*hdim_q` |
**NeedsZeroDqAcc** (determines whether `PrepareWorkspaceDevice` calls
`hipMemset`):
- Persistent kernel (deterministic batch mode) or non-deterministic:
**must zero** (atomic add requires zero initialization)
- Deterministic group mode + no mask: **no zeroing needed** (every tile
writes its full region)
- Deterministic + with mask: **must zero** (some blocks are skipped,
leaving uninitialized tiles that would contribute to the reduction)
#### Caller usage
```cpp
// 1. Create launcher (traits include seqstart_qs/ks pointers; workspace_size is computed during construction)
fmha_bwd_launcher launcher(fmha_traits);
// 2. Read launcher.workspace_size directly
const auto ws_size = launcher.workspace_size;
// 3. Allocate a single GPU workspace
ck_tile::DeviceMem ws_buf(ws_size);
// 4. Copy nsplits/offsets to GPU head and zero dq_acc if required
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
// 5. Build args with a single workspace pointer; the kernel splits it internally
fmha_bwd_args args{
...,
ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr, // workspace_ptr
};
launcher(args, stream_config);
```