mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#6152 (commit 36b016a)
[CK_TILE] Use Unified Workspace for FMHA BWD
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Motivation
`dq_acc` is the intermediate accumulation buffer used in FMHA backward
pass for deterministic mode. The current implementation allocates it as
a **single rectangular tensor**:
```
shape = [shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q]
```
where `nsplits = launcher.dq_acc_splits` (a single scalar), computed
from `max_seqlen_k` and shared across all batches.
### Problems
1. **Memory waste**: In group mode, each batch may have a different
`seqlen_k`, but `nsplits` is computed from `max_seqlen_k`, causing
batches with shorter `seqlen_k` to over-allocate in the split dimension.
2. **Interface coupling**: `fmha_bwd_args` exposes internal layout
details such as `stride_dq_acc`, `nhead_stride_dq_acc`,
`batch_stride_dq_acc`, and `split_stride_dq_acc`. The caller is
responsible for computing these strides, but this logic belongs inside
the kernel.
### Goals
1. Switch `dq_acc` buffer to a **compact layout**: batches are
concatenated contiguously, with each batch occupying `nhead * nsplits_i
* seqq_i * hdim_q` elements (nhead outermost).
2. **Remove all `*_stride_dq_acc` fields** from `fmha_bwd_args`,
replacing them with a single `workspace_ptr`; the kernel splits this
internally using a fixed layout.
4. `fmha_bwd_launcher` provides a **workspace management interface**:
the caller only needs to allocate GPU memory and call
`prepare_workspace()` — no layout computation required.
5. **Isolate kernel internals from the caller API**: the `dq_acc` layout
(nsplits, strides, buffer size) is determined entirely inside the
launcher/kernel. Future changes to block shape, pipeline type, or
persistent kernel strategy require no modifications to the caller's
`fmha_bwd_args` or workspace allocation logic.
## Technical Details
### Interface Design
#### New fields in `fmha_bwd_traits`
```cpp
struct fmha_bwd_traits
{
int seqlen_q;
int seqlen_k;
int batch;
int max_seqlen_q;
int max_seqlen_k;
int hdim_q;
int hdim_v;
int nhead_q;
int nhead_k;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bias_enum bias_type;
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// New: cumulative physical seqlen pointers for group mode (pass nullptr for batch mode).
// seqstart_qs[i+1] - seqstart_qs[i] = physical seqlen_q of batch i (including padding); length = batch+1
// seqstart_ks[i+1] - seqstart_ks[i] = physical seqlen_k of batch i (including padding); length = batch+1
const int* seqstart_qs = nullptr;
const int* seqstart_ks = nullptr;
};
```
#### `fmha_bwd_launcher` actual structure
```cpp
struct fmha_bwd_launcher
{
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
// Total workspace size in bytes (host_ws_size + device_ws_size), computed by init().
// Zero for kUseQrQtrDorPipeline (writes dq directly, no acc buffer needed).
size_t workspace_size = 0;
fmha_bwd_launcher(const fmha_bwd_traits&);
// Copies auxiliary data (nsplits[], offsets[]) via hipMemcpy to the head of the GPU workspace,
// and zeros the dq_acc buffer portion (tail of workspace) if required.
// The memory pointed to by device_ws must be >= workspace_size bytes.
std::function<void(void* device_ws)> prepare_workspace{};
template <typename... Args>
float operator()(Args&&... args) const { return run(std::forward<Args>(args)...); }
private:
size_t host_ws_size = 0; // CPU workspace size (nsplits[] + offsets[] arrays)
size_t device_ws_size = 0; // GPU-only data size (dq_acc buffer)
std::unique_ptr<char[]> ws_host; // host-side workspace buffer
public:
template <typename T0, typename T1, typename T2, typename Arch>
void init(const fmha_bwd_traits& traits);
};
```
The `init<>()` template method (invoked by codegen dispatch branches as
`this->init<...>(t)`) is responsible for:
1. Setting the `run` lambda
2. Calling `FmhaBwdDQDKDVKernel::GetWorkspaceHostSize(batch)` to obtain
`host_ws_size`
3. Allocating `ws_host` (host memory)
4. Calling `FmhaBwdDQDKDVKernel::PrepareWorkspaceHost(ws_host.get(),
...)` to fill nsplits/offsets; return value is `device_ws_size`
5. `workspace_size = host_ws_size + device_ws_size`
6. Setting the `prepare_workspace` lambda (captures `this`, calls
`PrepareWorkspaceDevice`)
When no kernel matches the given traits, both `run` and
`prepare_workspace` are initialized to default lambdas that print a
warning to `std::cerr` and return gracefully (no exception).
#### Workspace overall layout
The workspace is managed by `FmhaBwdWorkspaceManager` and consists of
two segments:
```
Offset 0 (CPU-prepared segment, host_ws_size bytes; also hipMemcpy'd to the head of GPU workspace):
index_t nsplits[batch or 1] — per-batch nsplits array
group mode: batch elements
batch mode / non-deterministic: 1 element
[group mode only] long_index_t dq_acc_offsets[batch+1]
— per-batch element offset (inclusive prefix sum)
offsets[0]=0, offsets[i+1] = offsets[i] + nhead*nsplits_i*seqq_i*hdim_q
Offset host_ws_size (device data segment, device_ws_size bytes):
AccDataType dq_acc[total_elements] — compact dq_acc buffer (zeroed if required)
total_elements = sum_i(nhead * nsplits_i * seqq_i * hdim_q)
layout within each batch: [nhead, nsplits_i, seqq_i, hdim_q]
note: seqq_i uses the physical length (including padding)
```
Alignment constant (`ALIGNMENT = 16`):
```
nsplits_size = align_up(sizeof(index_t) * N, 16) // N = batch (group) or 1 (batch/non-det)
offsets_size = align_up(sizeof(long_index_t) * (batch+1), 16) // group mode only
host_ws_size = nsplits_size + offsets_size
dq_acc_offset = host_ws_size // GetDqAccDataOffset(batch)
```
**Key benefits**:
- The kernel reads nsplits/offsets directly from the workspace head — no
device-side recomputation.
- `FmhaBwdConvertQGradKernel` is completely decoupled from the pipeline
block shape (`kN0`): nsplits is read from `nsplits_ptr`, `kN0` is no
longer a template parameter, and multiple dq_dk_dv tiles with different
`F_bn0` values now share a single convert_dq kernel instance (under
receipt 1/2, deterministic convert_dq kernel count drops from ~300 to
60).
- nsplits/offsets are computed on the host and transferred in one
`hipMemcpy`; the dq_acc buffer follows immediately, at the offset given
by `GetDqAccDataOffset`.
#### Workspace size by scenario
| Scenario | `workspace_size` | Notes |
|----------|-----------------|-------|
| **kUseQrQtrDorPipeline** (any mode) | `0` | Writes dq directly; no acc
buffer; `PrepareWorkspaceHost` returns 0 |
| **Non-deterministic + batch mode** | `> 0` | nsplits[1]=1; dq_acc used
for atomic add; `workspace_size = host_ws_size +
batch*nhead*seqlen_q*hdim_q*ebytes` |
| **Non-deterministic + group mode** | `> 0` | nsplits[1]=1; dq_acc
contiguous layout; `workspace_size = host_ws_size +
nhead*seqstart_qs[batch]*hdim_q*ebytes` |
| **Deterministic + group mode** | `> 0` | nsplits[batch],
offsets[batch+1], compact dq_acc; nsplits_i computed independently per
batch |
| **Deterministic + batch mode persistent** | `> 0` | nsplits[1]
(uniform across batches); dq_acc `batch*nhead*nsplits*seqlen_q*hdim_q` |
**NeedsZeroDqAcc** (determines whether `PrepareWorkspaceDevice` calls
`hipMemset`):
- Persistent kernel (deterministic batch mode) or non-deterministic:
**must zero** (atomic add requires zero initialization)
- Deterministic group mode + no mask: **no zeroing needed** (every tile
writes its full region)
- Deterministic + with mask: **must zero** (some blocks are skipped,
leaving uninitialized tiles that would contribute to the reduction)
#### Caller usage
```cpp
// 1. Create launcher (traits include seqstart_qs/ks pointers; workspace_size is computed during construction)
fmha_bwd_launcher launcher(fmha_traits);
// 2. Read launcher.workspace_size directly
const auto ws_size = launcher.workspace_size;
// 3. Allocate a single GPU workspace
ck_tile::DeviceMem ws_buf(ws_size);
// 4. Copy nsplits/offsets to GPU head and zero dq_acc if required
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
// 5. Build args with a single workspace pointer; the kernel splits it internally
fmha_bwd_args args{
...,
ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr, // workspace_ptr
};
launcher(args, stream_config);
```
This commit is contained in:
committed by
assistant-librarian[bot]
parent
250c29f914
commit
207a95d5e4
@@ -169,10 +169,21 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const fmha_bwd_traits& t)
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(int batch_size)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k);
|
||||
return k_::GetWorkspaceHostSize(batch_size);
|
||||
}}
|
||||
|
||||
template <>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t nhead_q, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k,
|
||||
const ck_tile::index_t* seqstart_qs, const ck_tile::index_t* seqstart_ks)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::PrepareWorkspaceHost(
|
||||
cpu_ws, batch_size, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_qs, seqstart_ks);
|
||||
}}
|
||||
|
||||
template <>
|
||||
@@ -197,9 +208,6 @@ FMHA_BWD_API = """
|
||||
fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
{F_launcher}
|
||||
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
|
||||
dq_acc_splits = 1;
|
||||
needs_zero_dq_acc = false;
|
||||
}}
|
||||
|
||||
|
||||
@@ -228,7 +236,7 @@ FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}>;
|
||||
"""
|
||||
FMHA_BWD_API_INNER_DISPATCH_RUN = """
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
@@ -236,11 +244,7 @@ FMHA_BWD_API_INNER_DISPATCH_RUN = """
|
||||
}}
|
||||
"""
|
||||
FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
|
||||
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
}};
|
||||
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t);
|
||||
needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_, {F_arch.tag}>();
|
||||
this->init<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(t);
|
||||
return;
|
||||
}}
|
||||
"""
|
||||
@@ -650,7 +654,6 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
{F_bm0},
|
||||
{F_bn0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
@@ -667,8 +670,7 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_mode},
|
||||
{F_spad},
|
||||
{F_dpad},
|
||||
{F_deterministic},
|
||||
{F_bn0}>;
|
||||
{F_deterministic}>;
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
@@ -712,7 +714,6 @@ class FmhaBwdConvertQGradKernel:
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_spad: str # true/false
|
||||
F_dpad: str #
|
||||
F_mode: str # value from MODE_MAP
|
||||
@@ -728,7 +729,6 @@ class FmhaBwdConvertQGradKernel:
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_bm0,
|
||||
F_bn0=self.F_bn0,
|
||||
F_spad=BOOL_MAP[self.F_spad],
|
||||
F_dpad=BOOL_MAP[self.F_dpad],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
@@ -749,7 +749,7 @@ class FmhaBwdConvertQGradKernel:
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
@@ -838,10 +838,6 @@ class FmhaBwdApiTrait:
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def convert_dq_bn0(self) -> int:
|
||||
return self.tile.F_bn0 if self.deterministic == "t" else 0
|
||||
|
||||
@property
|
||||
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
@@ -896,7 +892,6 @@ class FmhaBwdApiTrait:
|
||||
F_hdim=self.hdim,
|
||||
F_dtype=self.dtype,
|
||||
F_bm0=M0_1D,
|
||||
F_bn0=self.convert_dq_bn0,
|
||||
F_spad=self.spad1d,
|
||||
F_dpad=F_dpad,
|
||||
F_mode=self.mode,
|
||||
@@ -949,7 +944,6 @@ class FmhaBwdApiPool:
|
||||
F_max_seq_q_cond=trait.max_seq_q_cond,
|
||||
F_cond_extra=trait.extra_cond,
|
||||
F_bn0=trait.tile.F_bn0,
|
||||
F_convert_dq_bn0=trait.convert_dq_bn0,
|
||||
)
|
||||
inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format(
|
||||
F_arch=trait.arch,
|
||||
|
||||
@@ -11,11 +11,12 @@
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
|
||||
struct FmhaBwdFp32
|
||||
{
|
||||
@@ -115,7 +116,7 @@ struct fmha_bwd_args
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
void* workspace_ptr;
|
||||
const void*
|
||||
sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink
|
||||
void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient
|
||||
@@ -128,13 +129,13 @@ struct fmha_bwd_args
|
||||
// With padding:
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
|
||||
// lengths. [array size: batch + 1]
|
||||
// lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
|
||||
// sequence. [array size: batch]
|
||||
// sequence. [array size: batch]
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
|
||||
// exclusive. Use one set, not both.
|
||||
// exclusive. Use one set, not both.
|
||||
//
|
||||
// Batch mode:
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
@@ -181,7 +182,6 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
@@ -194,7 +194,6 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::long_index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
@@ -207,12 +206,10 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::long_index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t split_stride_dq_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
@@ -227,12 +224,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
|
||||
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
|
||||
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
|
||||
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
|
||||
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
|
||||
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
@@ -244,10 +235,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
dq_ptr,
|
||||
args.workspace_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
@@ -266,7 +258,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
stride_dq,
|
||||
args.stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -277,11 +269,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
@@ -298,10 +289,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
dq_ptr,
|
||||
args.workspace_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.batch,
|
||||
@@ -316,7 +308,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
stride_dq,
|
||||
args.stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -327,7 +319,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
@@ -338,11 +330,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
batch_stride_dq,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
@@ -414,8 +405,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr,
|
||||
args.dq_ptr,
|
||||
args.batch,
|
||||
args.nhead_q,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
@@ -424,27 +417,20 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
args.nhead_stride_dq);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr,
|
||||
args.dq_ptr,
|
||||
args.batch,
|
||||
args.nhead_q,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc,
|
||||
args.batch,
|
||||
args.nhead_q);
|
||||
args.batch_stride_dq);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -482,7 +468,16 @@ template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_maxq_();
|
||||
struct fmha_bwd_traits;
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t);
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
|
||||
ck_tile::index_t batch_size,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
const ck_tile::index_t* seqstart_qs,
|
||||
const ck_tile::index_t* seqstart_ks);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
|
||||
|
||||
@@ -510,8 +505,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_,
|
||||
ck_tile::index_t kN0>
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
};
|
||||
@@ -545,6 +539,11 @@ struct fmha_bwd_traits
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1.
|
||||
// Only need to remain valid during fmha_bwd_launcher construction (i.e. through
|
||||
// PrepareWorkspaceHost); they are not retained afterward.
|
||||
const int* seqstart_qs = nullptr;
|
||||
const int* seqstart_ks = nullptr;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
@@ -585,12 +584,61 @@ float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_conf
|
||||
|
||||
struct fmha_bwd_launcher
|
||||
{
|
||||
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
|
||||
ck_tile::index_t dq_acc_splits{0};
|
||||
bool needs_zero_dq_acc{true};
|
||||
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{
|
||||
[](fmha_bwd_args, const ck_tile::stream_config&) {
|
||||
std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n";
|
||||
return -1.0f;
|
||||
}};
|
||||
size_t workspace_size = 0;
|
||||
std::function<void(void*)> prepare_workspace{[](void*) {
|
||||
std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n";
|
||||
}};
|
||||
|
||||
fmha_bwd_launcher(const fmha_bwd_traits&);
|
||||
fmha_bwd_launcher(fmha_bwd_launcher&&) = delete;
|
||||
fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete;
|
||||
|
||||
private:
|
||||
size_t host_ws_size = 0;
|
||||
size_t device_ws_size = 0;
|
||||
std::unique_ptr<char[]> ws_host;
|
||||
|
||||
template <typename T0 /*dot_do_o_trait*/,
|
||||
typename T1 /*dq_dk_dv_trait*/,
|
||||
typename T2 /*convert_dq_trait*/,
|
||||
typename Arch>
|
||||
void init(const fmha_bwd_traits& traits)
|
||||
{
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
|
||||
return fmha_bwd_<T0, T1, T2, Arch>(s, a);
|
||||
};
|
||||
host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(traits.batch);
|
||||
if(host_ws_size > 0)
|
||||
{
|
||||
ws_host = std::make_unique<char[]>(host_ws_size); // TODO: support host mem allocator
|
||||
device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>( //
|
||||
ws_host.get(),
|
||||
traits.batch,
|
||||
traits.hdim_q,
|
||||
traits.nhead_q,
|
||||
traits.seqlen_q,
|
||||
traits.seqlen_k,
|
||||
traits.seqstart_qs,
|
||||
traits.seqstart_ks);
|
||||
}
|
||||
workspace_size = host_ws_size + device_ws_size;
|
||||
const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
|
||||
prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) {
|
||||
if(host_ws_size > 0)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice));
|
||||
if(needs_zero_dq_acc)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemset(static_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
float operator()(Args&&... args) const
|
||||
{
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
@@ -243,29 +244,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
|
||||
|
||||
const fmha_bwd_traits fmha_traits{
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
};
|
||||
fmha_bwd_launcher launcher(fmha_traits);
|
||||
|
||||
const ck_tile::index_t nsplits = launcher.dq_acc_splits;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
@@ -318,8 +296,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
{
|
||||
d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; });
|
||||
}
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
@@ -396,7 +372,37 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
|
||||
const auto t0_launcher = std::chrono::high_resolution_clock::now();
|
||||
fmha_bwd_launcher launcher(fmha_bwd_traits{
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
|
||||
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
|
||||
});
|
||||
const auto t1_launcher = std::chrono::high_resolution_clock::now();
|
||||
const double launcher_ctor_ms =
|
||||
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
|
||||
const size_t ws_size = launcher.workspace_size;
|
||||
ck_tile::DeviceMem ws_buf(ws_size);
|
||||
ck_tile::gpu_timer prepare_ws_timer;
|
||||
prepare_ws_timer.start(nullptr);
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
prepare_ws_timer.stop(nullptr);
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -433,7 +439,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
// clang-format on
|
||||
|
||||
const std::size_t workspace_size_in_megabytes =
|
||||
ck_tile::integer_divide_ceil(dq_acc_host.get_element_space_size_in_bytes(), 1024 * 1024);
|
||||
ck_tile::integer_divide_ceil(ws_size, 1024 * 1024);
|
||||
|
||||
std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm)
|
||||
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
|
||||
@@ -441,11 +447,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
|
||||
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic
|
||||
<< (deterministic
|
||||
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
|
||||
"MiB|" + std::to_string(nsplits) + "splits"
|
||||
: "")
|
||||
<< ", mask:" << mask << std::flush;
|
||||
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
|
||||
<< ", mask:" << mask << ", init:" << launcher_ctor_ms << "ms"
|
||||
<< ", prws:" << prepare_ws_timer.duration() << "ms" << std::flush;
|
||||
|
||||
auto fmha_args = [&]() {
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
@@ -462,7 +466,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
|
||||
const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
@@ -474,8 +477,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
|
||||
const auto nhead_stride_dq_acc =
|
||||
static_cast<ck_tile::long_index_t>(split_stride_dq_acc) * nsplits;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
@@ -488,7 +489,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc;
|
||||
|
||||
void* ws_ptr = ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr;
|
||||
|
||||
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
|
||||
if(drop_prefs)
|
||||
@@ -518,7 +520,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dk_buf.GetDeviceBuffer(),
|
||||
dv_buf.GetDeviceBuffer(),
|
||||
dbias_buf.GetDeviceBuffer(),
|
||||
dq_acc_buf.GetDeviceBuffer(),
|
||||
ws_ptr,
|
||||
sink_buf.GetDeviceBuffer(),
|
||||
d_sink_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
@@ -545,8 +547,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
hdim_q, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_q, // stride_dq (same layout as q for dq output)
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
@@ -558,7 +559,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
@@ -571,12 +571,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
@@ -901,11 +899,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
|
||||
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
|
||||
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
|
||||
ck_tile::FillConstant<AccDataType>{ck_tile::numeric<AccDataType>::infinity()}(dq_acc_host);
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
dq_acc_buf.ToDevice(dq_acc_host.data());
|
||||
// re-initialize workspace for validation run
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
@@ -913,9 +911,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
if(sink_grad)
|
||||
d_sink_buf.SetZero();
|
||||
|
||||
if(launcher.needs_zero_dq_acc)
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
launcher(fmha_args, stream_config_v);
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,7 +15,6 @@ struct BlockFmhaBwdConvertQGrad
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
|
||||
static constexpr index_t kM0 = Problem::kM0;
|
||||
static constexpr index_t kN0 = Problem::kN0;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
@@ -97,7 +97,6 @@ template <typename AccDataType_,
|
||||
typename QGradDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM0_,
|
||||
index_t kN0_,
|
||||
index_t kQKHeaddim_,
|
||||
bool kIsGroupMode_,
|
||||
bool kIsDeterministic_,
|
||||
@@ -113,7 +112,6 @@ struct BlockFmhaBwdConvertQGradPipelineProblem
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kM0 = kM0_;
|
||||
static constexpr index_t kN0 = kN0_;
|
||||
static constexpr index_t kQKHeaddim = kQKHeaddim_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
Reference in New Issue
Block a user