[rocm-libraries] ROCm/rocm-libraries#6152 (commit 36b016a)

[CK_TILE] Use Unified Workspace for FMHA BWD
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation
`dq_acc` is the intermediate accumulation buffer used in FMHA backward
pass for deterministic mode. The current implementation allocates it as
a **single rectangular tensor**:

```
shape = [shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q]
```

where `nsplits = launcher.dq_acc_splits` (a single scalar), computed
from `max_seqlen_k` and shared across all batches.

### Problems

1. **Memory waste**: In group mode, each batch may have a different
`seqlen_k`, but `nsplits` is computed from `max_seqlen_k`, causing
batches with shorter `seqlen_k` to over-allocate in the split dimension.

2. **Interface coupling**: `fmha_bwd_args` exposes internal layout
details such as `stride_dq_acc`, `nhead_stride_dq_acc`,
`batch_stride_dq_acc`, and `split_stride_dq_acc`. The caller is
responsible for computing these strides, but this logic belongs inside
the kernel.

### Goals

1. Switch `dq_acc` buffer to a **compact layout**: batches are
concatenated contiguously, with each batch occupying `nhead * nsplits_i
* seqq_i * hdim_q` elements (nhead outermost).
2. **Remove all `*_stride_dq_acc` fields** from `fmha_bwd_args`,
replacing them with a single `workspace_ptr`; the kernel splits this
internally using a fixed layout.
4. `fmha_bwd_launcher` provides a **workspace management interface**:
the caller only needs to allocate GPU memory and call
`prepare_workspace()` — no layout computation required.
5. **Isolate kernel internals from the caller API**: the `dq_acc` layout
(nsplits, strides, buffer size) is determined entirely inside the
launcher/kernel. Future changes to block shape, pipeline type, or
persistent kernel strategy require no modifications to the caller's
`fmha_bwd_args` or workspace allocation logic.

## Technical Details

### Interface Design

#### New fields in `fmha_bwd_traits`

```cpp
struct fmha_bwd_traits
{
    int seqlen_q;
    int seqlen_k;
    int batch;
    int max_seqlen_q;
    int max_seqlen_k;
    int hdim_q;
    int hdim_v;
    int nhead_q;
    int nhead_k;
    std::string data_type;
    bool is_group_mode;
    mask_enum mask_type;
    bias_enum bias_type;
    bool has_dbias;
    bool has_dropout;
    bool is_store_randval;
    bool is_deterministic;
    // New: cumulative physical seqlen pointers for group mode (pass nullptr for batch mode).
    // seqstart_qs[i+1] - seqstart_qs[i] = physical seqlen_q of batch i (including padding); length = batch+1
    // seqstart_ks[i+1] - seqstart_ks[i] = physical seqlen_k of batch i (including padding); length = batch+1
    const int* seqstart_qs = nullptr;
    const int* seqstart_ks = nullptr;
};
```

#### `fmha_bwd_launcher` actual structure

```cpp
struct fmha_bwd_launcher
{
    std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};

    // Total workspace size in bytes (host_ws_size + device_ws_size), computed by init().
    // Zero for kUseQrQtrDorPipeline (writes dq directly, no acc buffer needed).
    size_t workspace_size = 0;

    fmha_bwd_launcher(const fmha_bwd_traits&);

    // Copies auxiliary data (nsplits[], offsets[]) via hipMemcpy to the head of the GPU workspace,
    // and zeros the dq_acc buffer portion (tail of workspace) if required.
    // The memory pointed to by device_ws must be >= workspace_size bytes.
    std::function<void(void* device_ws)> prepare_workspace{};

    template <typename... Args>
    float operator()(Args&&... args) const { return run(std::forward<Args>(args)...); }

private:
    size_t host_ws_size   = 0;  // CPU workspace size (nsplits[] + offsets[] arrays)
    size_t device_ws_size = 0;  // GPU-only data size (dq_acc buffer)
    std::unique_ptr<char[]> ws_host;  // host-side workspace buffer

public:
    template <typename T0, typename T1, typename T2, typename Arch>
    void init(const fmha_bwd_traits& traits);
};
```

The `init<>()` template method (invoked by codegen dispatch branches as
`this->init<...>(t)`) is responsible for:
1. Setting the `run` lambda
2. Calling `FmhaBwdDQDKDVKernel::GetWorkspaceHostSize(batch)` to obtain
`host_ws_size`
3. Allocating `ws_host` (host memory)
4. Calling `FmhaBwdDQDKDVKernel::PrepareWorkspaceHost(ws_host.get(),
...)` to fill nsplits/offsets; return value is `device_ws_size`
5. `workspace_size = host_ws_size + device_ws_size`
6. Setting the `prepare_workspace` lambda (captures `this`, calls
`PrepareWorkspaceDevice`)

When no kernel matches the given traits, both `run` and
`prepare_workspace` are initialized to default lambdas that print a
warning to `std::cerr` and return gracefully (no exception).

#### Workspace overall layout

The workspace is managed by `FmhaBwdWorkspaceManager` and consists of
two segments:

```
Offset 0 (CPU-prepared segment, host_ws_size bytes; also hipMemcpy'd to the head of GPU workspace):
  index_t nsplits[batch or 1]       — per-batch nsplits array
                                      group mode: batch elements
                                      batch mode / non-deterministic: 1 element
  [group mode only] long_index_t dq_acc_offsets[batch+1]
                                    — per-batch element offset (inclusive prefix sum)
                                      offsets[0]=0, offsets[i+1] = offsets[i] + nhead*nsplits_i*seqq_i*hdim_q

Offset host_ws_size (device data segment, device_ws_size bytes):
  AccDataType dq_acc[total_elements] — compact dq_acc buffer (zeroed if required)
                                       total_elements = sum_i(nhead * nsplits_i * seqq_i * hdim_q)
                                       layout within each batch: [nhead, nsplits_i, seqq_i, hdim_q]
                                       note: seqq_i uses the physical length (including padding)
```

Alignment constant (`ALIGNMENT = 16`):
```
nsplits_size  = align_up(sizeof(index_t) * N, 16)          // N = batch (group) or 1 (batch/non-det)
offsets_size  = align_up(sizeof(long_index_t) * (batch+1), 16)  // group mode only
host_ws_size  = nsplits_size + offsets_size
dq_acc_offset = host_ws_size  // GetDqAccDataOffset(batch)
```

**Key benefits**:
- The kernel reads nsplits/offsets directly from the workspace head — no
device-side recomputation.
- `FmhaBwdConvertQGradKernel` is completely decoupled from the pipeline
block shape (`kN0`): nsplits is read from `nsplits_ptr`, `kN0` is no
longer a template parameter, and multiple dq_dk_dv tiles with different
`F_bn0` values now share a single convert_dq kernel instance (under
receipt 1/2, deterministic convert_dq kernel count drops from ~300 to
60).
- nsplits/offsets are computed on the host and transferred in one
`hipMemcpy`; the dq_acc buffer follows immediately, at the offset given
by `GetDqAccDataOffset`.

#### Workspace size by scenario

| Scenario | `workspace_size` | Notes |
|----------|-----------------|-------|
| **kUseQrQtrDorPipeline** (any mode) | `0` | Writes dq directly; no acc
buffer; `PrepareWorkspaceHost` returns 0 |
| **Non-deterministic + batch mode** | `> 0` | nsplits[1]=1; dq_acc used
for atomic add; `workspace_size = host_ws_size +
batch*nhead*seqlen_q*hdim_q*ebytes` |
| **Non-deterministic + group mode** | `> 0` | nsplits[1]=1; dq_acc
contiguous layout; `workspace_size = host_ws_size +
nhead*seqstart_qs[batch]*hdim_q*ebytes` |
| **Deterministic + group mode** | `> 0` | nsplits[batch],
offsets[batch+1], compact dq_acc; nsplits_i computed independently per
batch |
| **Deterministic + batch mode persistent** | `> 0` | nsplits[1]
(uniform across batches); dq_acc `batch*nhead*nsplits*seqlen_q*hdim_q` |

**NeedsZeroDqAcc** (determines whether `PrepareWorkspaceDevice` calls
`hipMemset`):
- Persistent kernel (deterministic batch mode) or non-deterministic:
**must zero** (atomic add requires zero initialization)
- Deterministic group mode + no mask: **no zeroing needed** (every tile
writes its full region)
- Deterministic + with mask: **must zero** (some blocks are skipped,
leaving uninitialized tiles that would contribute to the reduction)

#### Caller usage

```cpp
// 1. Create launcher (traits include seqstart_qs/ks pointers; workspace_size is computed during construction)
fmha_bwd_launcher launcher(fmha_traits);

// 2. Read launcher.workspace_size directly
const auto ws_size = launcher.workspace_size;

// 3. Allocate a single GPU workspace
ck_tile::DeviceMem ws_buf(ws_size);

// 4. Copy nsplits/offsets to GPU head and zero dq_acc if required
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());

// 5. Build args with a single workspace pointer; the kernel splits it internally
fmha_bwd_args args{
    ...,
    ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr,  // workspace_ptr
};
launcher(args, stream_config);
```
This commit is contained in:
Yi DING
2026-05-07 02:23:28 +00:00
committed by assistant-librarian[bot]
parent 250c29f914
commit 207a95d5e4
6 changed files with 629 additions and 394 deletions

View File

@@ -169,10 +169,21 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
}}
template <>
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const fmha_bwd_traits& t)
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(int batch_size)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k);
return k_::GetWorkspaceHostSize(batch_size);
}}
template <>
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q,
ck_tile::index_t nhead_q, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k,
const ck_tile::index_t* seqstart_qs, const ck_tile::index_t* seqstart_ks)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::PrepareWorkspaceHost(
cpu_ws, batch_size, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_qs, seqstart_ks);
}}
template <>
@@ -197,9 +208,6 @@ FMHA_BWD_API = """
fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_launcher}
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
dq_acc_splits = 1;
needs_zero_dq_acc = false;
}}
@@ -228,7 +236,7 @@ FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}>;
"""
FMHA_BWD_API_INNER_DISPATCH_RUN = """
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
@@ -236,11 +244,7 @@ FMHA_BWD_API_INNER_DISPATCH_RUN = """
}}
"""
FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
}};
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t);
needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_, {F_arch.tag}>();
this->init<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(t);
return;
}}
"""
@@ -650,7 +654,6 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
/* BlockSize = */ 256,
{F_bm0},
{F_bn0},
{F_hdim},
{F_mode},
{F_deterministic},
@@ -667,8 +670,7 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic},
{F_bn0}>;
{F_deterministic}>;
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
@@ -712,7 +714,6 @@ class FmhaBwdConvertQGradKernel:
F_hdim: int # hdim
F_dtype: str # data type
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_spad: str # true/false
F_dpad: str #
F_mode: str # value from MODE_MAP
@@ -728,7 +729,6 @@ class FmhaBwdConvertQGradKernel:
F_hdim=self.F_hdim,
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_bm0,
F_bn0=self.F_bn0,
F_spad=BOOL_MAP[self.F_spad],
F_dpad=BOOL_MAP[self.F_dpad],
F_mode=MODE_MAP[self.F_mode],
@@ -749,7 +749,7 @@ class FmhaBwdConvertQGradKernel:
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}_{self.F_mode}_o{self.F_occupancy}"
if pn != "":
n += f"_{pn}"
else:
@@ -838,10 +838,6 @@ class FmhaBwdApiTrait:
else:
return ""
@property
def convert_dq_bn0(self) -> int:
return self.tile.F_bn0 if self.deterministic == "t" else 0
@property
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
@@ -896,7 +892,6 @@ class FmhaBwdApiTrait:
F_hdim=self.hdim,
F_dtype=self.dtype,
F_bm0=M0_1D,
F_bn0=self.convert_dq_bn0,
F_spad=self.spad1d,
F_dpad=F_dpad,
F_mode=self.mode,
@@ -949,7 +944,6 @@ class FmhaBwdApiPool:
F_max_seq_q_cond=trait.max_seq_q_cond,
F_cond_extra=trait.extra_cond,
F_bn0=trait.tile.F_bn0,
F_convert_dq_bn0=trait.convert_dq_bn0,
)
inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format(
F_arch=trait.arch,

View File

@@ -11,11 +11,12 @@
#include "mask.hpp"
#include "bias.hpp"
#include <functional>
#include <iostream>
#include <memory>
#include <type_traits>
#include <utility>
#include <variant>
#include <iostream>
#include <functional>
struct FmhaBwdFp32
{
@@ -115,7 +116,7 @@ struct fmha_bwd_args
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
void* workspace_ptr;
const void*
sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink
void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient
@@ -128,13 +129,13 @@ struct fmha_bwd_args
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
@@ -181,7 +182,6 @@ struct fmha_bwd_args
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
@@ -194,7 +194,6 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::long_index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
@@ -207,12 +206,10 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::long_index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
@@ -227,12 +224,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
@@ -244,10 +235,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
dq_ptr,
args.workspace_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
@@ -266,7 +258,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
stride_dq,
args.stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -277,11 +269,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
nhead_stride_dq,
args.nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
@@ -298,10 +289,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
dq_ptr,
args.workspace_ptr,
args.seqlen_q,
args.seqlen_k,
args.batch,
@@ -316,7 +308,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
stride_dq,
args.stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -327,7 +319,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
nhead_stride_dq,
args.nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
@@ -338,11 +330,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
batch_stride_dq,
args.batch_stride_dq,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
@@ -414,8 +405,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr,
args.dq_ptr,
args.batch,
args.nhead_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
@@ -424,27 +417,20 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.cu_seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
args.nhead_stride_dq);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr,
args.dq_ptr,
args.batch,
args.nhead_q,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc,
args.batch,
args.nhead_q);
args.batch_stride_dq);
}
}();
@@ -482,7 +468,16 @@ template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_maxq_();
struct fmha_bwd_traits;
template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t);
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
template <typename Traits_, typename Arch = void>
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
ck_tile::index_t batch_size,
ck_tile::index_t hdim_q,
ck_tile::index_t nhead_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
const ck_tile::index_t* seqstart_qs,
const ck_tile::index_t* seqstart_ks);
template <typename Traits_, typename Arch = void>
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
@@ -510,8 +505,7 @@ template <ck_tile::index_t HDim_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_,
ck_tile::index_t kN0>
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
};
@@ -545,6 +539,11 @@ struct fmha_bwd_traits
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1.
// Only need to remain valid during fmha_bwd_launcher construction (i.e. through
// PrepareWorkspaceHost); they are not retained afterward.
const int* seqstart_qs = nullptr;
const int* seqstart_ks = nullptr;
// TODO: padding check is inside this api
};
@@ -585,12 +584,61 @@ float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_conf
struct fmha_bwd_launcher
{
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
ck_tile::index_t dq_acc_splits{0};
bool needs_zero_dq_acc{true};
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{
[](fmha_bwd_args, const ck_tile::stream_config&) {
std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n";
return -1.0f;
}};
size_t workspace_size = 0;
std::function<void(void*)> prepare_workspace{[](void*) {
std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n";
}};
fmha_bwd_launcher(const fmha_bwd_traits&);
fmha_bwd_launcher(fmha_bwd_launcher&&) = delete;
fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete;
private:
size_t host_ws_size = 0;
size_t device_ws_size = 0;
std::unique_ptr<char[]> ws_host;
template <typename T0 /*dot_do_o_trait*/,
typename T1 /*dq_dk_dv_trait*/,
typename T2 /*convert_dq_trait*/,
typename Arch>
void init(const fmha_bwd_traits& traits)
{
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
return fmha_bwd_<T0, T1, T2, Arch>(s, a);
};
host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(traits.batch);
if(host_ws_size > 0)
{
ws_host = std::make_unique<char[]>(host_ws_size); // TODO: support host mem allocator
device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>( //
ws_host.get(),
traits.batch,
traits.hdim_q,
traits.nhead_q,
traits.seqlen_q,
traits.seqlen_k,
traits.seqstart_qs,
traits.seqstart_ks);
}
workspace_size = host_ws_size + device_ws_size;
const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) {
if(host_ws_size > 0)
HIP_CHECK_ERROR(
hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice));
if(needs_zero_dq_acc)
HIP_CHECK_ERROR(
hipMemset(static_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
};
}
public:
template <typename... Args>
float operator()(Args&&... args) const
{

View File

@@ -9,6 +9,7 @@
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <chrono>
#include <cstring>
#include <functional>
#include <numeric>
@@ -243,29 +244,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
const fmha_bwd_traits fmha_traits{
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
hdim_q,
hdim_v,
nhead,
nhead_k,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic,
};
fmha_bwd_launcher launcher(fmha_traits);
const ck_tile::index_t nsplits = launcher.dq_acc_splits;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
@@ -318,8 +296,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
{
d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; });
}
ck_tile::HostTensor<AccDataType> dq_acc_host(
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
if(init_method == "ui" || init_method == "0")
{
@@ -396,7 +372,37 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
const auto t0_launcher = std::chrono::high_resolution_clock::now();
fmha_bwd_launcher launcher(fmha_bwd_traits{
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
hdim_q,
hdim_v,
nhead,
nhead_k,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic,
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
});
const auto t1_launcher = std::chrono::high_resolution_clock::now();
const double launcher_ctor_ms =
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
const size_t ws_size = launcher.workspace_size;
ck_tile::DeviceMem ws_buf(ws_size);
ck_tile::gpu_timer prepare_ws_timer;
prepare_ws_timer.start(nullptr);
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
prepare_ws_timer.stop(nullptr);
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -433,7 +439,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
// clang-format on
const std::size_t workspace_size_in_megabytes =
ck_tile::integer_divide_ceil(dq_acc_host.get_element_space_size_in_bytes(), 1024 * 1024);
ck_tile::integer_divide_ceil(ws_size, 1024 * 1024);
std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm)
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
@@ -441,11 +447,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic
<< (deterministic
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
"MiB|" + std::to_string(nsplits) + "splits"
: "")
<< ", mask:" << mask << std::flush;
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
<< ", mask:" << mask << ", init:" << launcher_ctor_ms << "ms"
<< ", prws:" << prepare_ws_timer.duration() << "ms" << std::flush;
auto fmha_args = [&]() {
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
@@ -462,7 +466,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
@@ -474,8 +477,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
const auto nhead_stride_dq_acc =
static_cast<ck_tile::long_index_t>(split_stride_dq_acc) * nsplits;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
@@ -488,7 +489,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc;
void* ws_ptr = ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr;
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
@@ -518,7 +520,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
ws_ptr,
sink_buf.GetDeviceBuffer(),
d_sink_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
@@ -545,8 +547,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
stride_o,
stride_randval,
stride_do,
hdim_q, // stride_dq_acc
stride_q, // stride_dq
stride_q, // stride_dq (same layout as q for dq output)
stride_dk,
stride_dv,
stride_dbias,
@@ -558,7 +559,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
@@ -571,12 +571,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
@@ -901,11 +899,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
ck_tile::FillConstant<AccDataType>{ck_tile::numeric<AccDataType>::infinity()}(dq_acc_host);
dq_buf.ToDevice(dq_host.data());
dk_buf.ToDevice(dk_host.data());
dv_buf.ToDevice(dv_host.data());
dq_acc_buf.ToDevice(dq_acc_host.data());
// re-initialize workspace for validation run
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
@@ -913,9 +911,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
if(sink_grad)
d_sink_buf.SetZero();
if(launcher.needs_zero_dq_acc)
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
launcher(fmha_args, stream_config_v);