From 207a95d5e4081316f3fb18a035b3918c118367c4 Mon Sep 17 00:00:00 2001 From: Yi DING <28386673+DDEle@users.noreply.github.com> Date: Thu, 7 May 2026 02:23:28 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6152 (commit 36b016a) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [CK_TILE] Use Unified Workspace for FMHA BWD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation `dq_acc` is the intermediate accumulation buffer used in FMHA backward pass for deterministic mode. The current implementation allocates it as a **single rectangular tensor**: ``` shape = [shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q] ``` where `nsplits = launcher.dq_acc_splits` (a single scalar), computed from `max_seqlen_k` and shared across all batches. ### Problems 1. **Memory waste**: In group mode, each batch may have a different `seqlen_k`, but `nsplits` is computed from `max_seqlen_k`, causing batches with shorter `seqlen_k` to over-allocate in the split dimension. 2. **Interface coupling**: `fmha_bwd_args` exposes internal layout details such as `stride_dq_acc`, `nhead_stride_dq_acc`, `batch_stride_dq_acc`, and `split_stride_dq_acc`. The caller is responsible for computing these strides, but this logic belongs inside the kernel. ### Goals 1. Switch `dq_acc` buffer to a **compact layout**: batches are concatenated contiguously, with each batch occupying `nhead * nsplits_i * seqq_i * hdim_q` elements (nhead outermost). 2. **Remove all `*_stride_dq_acc` fields** from `fmha_bwd_args`, replacing them with a single `workspace_ptr`; the kernel splits this internally using a fixed layout. 4. `fmha_bwd_launcher` provides a **workspace management interface**: the caller only needs to allocate GPU memory and call `prepare_workspace()` — no layout computation required. 5. **Isolate kernel internals from the caller API**: the `dq_acc` layout (nsplits, strides, buffer size) is determined entirely inside the launcher/kernel. Future changes to block shape, pipeline type, or persistent kernel strategy require no modifications to the caller's `fmha_bwd_args` or workspace allocation logic. ## Technical Details ### Interface Design #### New fields in `fmha_bwd_traits` ```cpp struct fmha_bwd_traits { int seqlen_q; int seqlen_k; int batch; int max_seqlen_q; int max_seqlen_k; int hdim_q; int hdim_v; int nhead_q; int nhead_k; std::string data_type; bool is_group_mode; mask_enum mask_type; bias_enum bias_type; bool has_dbias; bool has_dropout; bool is_store_randval; bool is_deterministic; // New: cumulative physical seqlen pointers for group mode (pass nullptr for batch mode). // seqstart_qs[i+1] - seqstart_qs[i] = physical seqlen_q of batch i (including padding); length = batch+1 // seqstart_ks[i+1] - seqstart_ks[i] = physical seqlen_k of batch i (including padding); length = batch+1 const int* seqstart_qs = nullptr; const int* seqstart_ks = nullptr; }; ``` #### `fmha_bwd_launcher` actual structure ```cpp struct fmha_bwd_launcher { std::function run{}; // Total workspace size in bytes (host_ws_size + device_ws_size), computed by init(). // Zero for kUseQrQtrDorPipeline (writes dq directly, no acc buffer needed). size_t workspace_size = 0; fmha_bwd_launcher(const fmha_bwd_traits&); // Copies auxiliary data (nsplits[], offsets[]) via hipMemcpy to the head of the GPU workspace, // and zeros the dq_acc buffer portion (tail of workspace) if required. // The memory pointed to by device_ws must be >= workspace_size bytes. std::function prepare_workspace{}; template float operator()(Args&&... args) const { return run(std::forward(args)...); } private: size_t host_ws_size = 0; // CPU workspace size (nsplits[] + offsets[] arrays) size_t device_ws_size = 0; // GPU-only data size (dq_acc buffer) std::unique_ptr ws_host; // host-side workspace buffer public: template void init(const fmha_bwd_traits& traits); }; ``` The `init<>()` template method (invoked by codegen dispatch branches as `this->init<...>(t)`) is responsible for: 1. Setting the `run` lambda 2. Calling `FmhaBwdDQDKDVKernel::GetWorkspaceHostSize(batch)` to obtain `host_ws_size` 3. Allocating `ws_host` (host memory) 4. Calling `FmhaBwdDQDKDVKernel::PrepareWorkspaceHost(ws_host.get(), ...)` to fill nsplits/offsets; return value is `device_ws_size` 5. `workspace_size = host_ws_size + device_ws_size` 6. Setting the `prepare_workspace` lambda (captures `this`, calls `PrepareWorkspaceDevice`) When no kernel matches the given traits, both `run` and `prepare_workspace` are initialized to default lambdas that print a warning to `std::cerr` and return gracefully (no exception). #### Workspace overall layout The workspace is managed by `FmhaBwdWorkspaceManager` and consists of two segments: ``` Offset 0 (CPU-prepared segment, host_ws_size bytes; also hipMemcpy'd to the head of GPU workspace): index_t nsplits[batch or 1] — per-batch nsplits array group mode: batch elements batch mode / non-deterministic: 1 element [group mode only] long_index_t dq_acc_offsets[batch+1] — per-batch element offset (inclusive prefix sum) offsets[0]=0, offsets[i+1] = offsets[i] + nhead*nsplits_i*seqq_i*hdim_q Offset host_ws_size (device data segment, device_ws_size bytes): AccDataType dq_acc[total_elements] — compact dq_acc buffer (zeroed if required) total_elements = sum_i(nhead * nsplits_i * seqq_i * hdim_q) layout within each batch: [nhead, nsplits_i, seqq_i, hdim_q] note: seqq_i uses the physical length (including padding) ``` Alignment constant (`ALIGNMENT = 16`): ``` nsplits_size = align_up(sizeof(index_t) * N, 16) // N = batch (group) or 1 (batch/non-det) offsets_size = align_up(sizeof(long_index_t) * (batch+1), 16) // group mode only host_ws_size = nsplits_size + offsets_size dq_acc_offset = host_ws_size // GetDqAccDataOffset(batch) ``` **Key benefits**: - The kernel reads nsplits/offsets directly from the workspace head — no device-side recomputation. - `FmhaBwdConvertQGradKernel` is completely decoupled from the pipeline block shape (`kN0`): nsplits is read from `nsplits_ptr`, `kN0` is no longer a template parameter, and multiple dq_dk_dv tiles with different `F_bn0` values now share a single convert_dq kernel instance (under receipt 1/2, deterministic convert_dq kernel count drops from ~300 to 60). - nsplits/offsets are computed on the host and transferred in one `hipMemcpy`; the dq_acc buffer follows immediately, at the offset given by `GetDqAccDataOffset`. #### Workspace size by scenario | Scenario | `workspace_size` | Notes | |----------|-----------------|-------| | **kUseQrQtrDorPipeline** (any mode) | `0` | Writes dq directly; no acc buffer; `PrepareWorkspaceHost` returns 0 | | **Non-deterministic + batch mode** | `> 0` | nsplits[1]=1; dq_acc used for atomic add; `workspace_size = host_ws_size + batch*nhead*seqlen_q*hdim_q*ebytes` | | **Non-deterministic + group mode** | `> 0` | nsplits[1]=1; dq_acc contiguous layout; `workspace_size = host_ws_size + nhead*seqstart_qs[batch]*hdim_q*ebytes` | | **Deterministic + group mode** | `> 0` | nsplits[batch], offsets[batch+1], compact dq_acc; nsplits_i computed independently per batch | | **Deterministic + batch mode persistent** | `> 0` | nsplits[1] (uniform across batches); dq_acc `batch*nhead*nsplits*seqlen_q*hdim_q` | **NeedsZeroDqAcc** (determines whether `PrepareWorkspaceDevice` calls `hipMemset`): - Persistent kernel (deterministic batch mode) or non-deterministic: **must zero** (atomic add requires zero initialization) - Deterministic group mode + no mask: **no zeroing needed** (every tile writes its full region) - Deterministic + with mask: **must zero** (some blocks are skipped, leaving uninitialized tiles that would contribute to the reduction) #### Caller usage ```cpp // 1. Create launcher (traits include seqstart_qs/ks pointers; workspace_size is computed during construction) fmha_bwd_launcher launcher(fmha_traits); // 2. Read launcher.workspace_size directly const auto ws_size = launcher.workspace_size; // 3. Allocate a single GPU workspace ck_tile::DeviceMem ws_buf(ws_size); // 4. Copy nsplits/offsets to GPU head and zero dq_acc if required launcher.prepare_workspace(ws_buf.GetDeviceBuffer()); // 5. Build args with a single workspace pointer; the kernel splits it internally fmha_bwd_args args{ ..., ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr, // workspace_ptr }; launcher(args, stream_config); ``` --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 40 +- example/ck_tile/01_fmha/fmha_bwd.hpp | 138 ++-- example/ck_tile/01_fmha/fmha_bwd_runner.hpp | 89 +-- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 753 +++++++++++------- .../pipeline/block_fmha_bwd_convert_dq.hpp | 1 - .../block_fmha_bwd_pipeline_problem.hpp | 2 - 6 files changed, 629 insertions(+), 394 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 7105f1aa5c..f89a7d75e4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -169,10 +169,21 @@ int fmha_bwd_dq_dk_dv_maxq_() }} template <> -int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t) +size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size) {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k); + return k_::GetWorkspaceHostSize(batch_size); +}} + +template <> +size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_( + void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q, + ck_tile::index_t nhead_q, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, + const ck_tile::index_t* seqstart_qs, const ck_tile::index_t* seqstart_ks) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::PrepareWorkspaceHost( + cpu_ws, batch_size, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_qs, seqstart_ks); }} template <> @@ -197,9 +208,6 @@ FMHA_BWD_API = """ fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{ [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); {F_launcher} - run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }}; - dq_acc_splits = 1; - needs_zero_dq_acc = false; }} @@ -228,7 +236,7 @@ FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) && ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}>; """ FMHA_BWD_API_INNER_DISPATCH_RUN = """ r = fmha_bwd_, {F_arch.tag}>(s, a); @@ -236,11 +244,7 @@ FMHA_BWD_API_INNER_DISPATCH_RUN = """ }} """ FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """ - run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{ - return fmha_bwd_, {F_arch.tag}>(s, a); - }}; - dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_(t); - needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); + this->init, {F_arch.tag}>(t); return; }} """ @@ -650,7 +654,6 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} = typename FmhaBwdTypeConfig::QGradDataType, /* BlockSize = */ 256, {F_bm0}, - {F_bn0}, {F_hdim}, {F_mode}, {F_deterministic}, @@ -667,8 +670,7 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_mode}, {F_spad}, {F_dpad}, - {F_deterministic}, - {F_bn0}>; + {F_deterministic}>; template <> float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) @@ -712,7 +714,6 @@ class FmhaBwdConvertQGradKernel: F_hdim: int # hdim F_dtype: str # data type F_bm0: int # tile size along q seqlen (block size) - F_bn0: int # tile size along k seqlen F_spad: str # true/false F_dpad: str # F_mode: str # value from MODE_MAP @@ -728,7 +729,6 @@ class FmhaBwdConvertQGradKernel: F_hdim=self.F_hdim, F_dtype=BWD_DTYPE_MAP[self.F_dtype], F_bm0=self.F_bm0, - F_bn0=self.F_bn0, F_spad=BOOL_MAP[self.F_spad], F_dpad=BOOL_MAP[self.F_dpad], F_mode=MODE_MAP[self.F_mode], @@ -749,7 +749,7 @@ class FmhaBwdConvertQGradKernel: return n pn = pad_name() - n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" + n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}_{self.F_mode}_o{self.F_occupancy}" if pn != "": n += f"_{pn}" else: @@ -838,10 +838,6 @@ class FmhaBwdApiTrait: else: return "" - @property - def convert_dq_bn0(self) -> int: - return self.tile.F_bn0 if self.deterministic == "t" else 0 - @property def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy @@ -896,7 +892,6 @@ class FmhaBwdApiTrait: F_hdim=self.hdim, F_dtype=self.dtype, F_bm0=M0_1D, - F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, F_mode=self.mode, @@ -949,7 +944,6 @@ class FmhaBwdApiPool: F_max_seq_q_cond=trait.max_seq_q_cond, F_cond_extra=trait.extra_cond, F_bn0=trait.tile.F_bn0, - F_convert_dq_bn0=trait.convert_dq_bn0, ) inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format( F_arch=trait.arch, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 4496a6c9dd..14f4c210f0 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -11,11 +11,12 @@ #include "mask.hpp" #include "bias.hpp" +#include +#include +#include #include #include #include -#include -#include struct FmhaBwdFp32 { @@ -115,7 +116,7 @@ struct fmha_bwd_args void* dk_ptr; void* dv_ptr; void* dbias_ptr; - void* dq_acc_ptr; + void* workspace_ptr; const void* sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient @@ -128,13 +129,13 @@ struct fmha_bwd_args // With padding: // Group mode: // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence - // lengths. [array size: batch + 1] + // lengths. [array size: batch + 1] // - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each - // sequence. [array size: batch] + // sequence. [array size: batch] // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) - // sequence lengths. [array size: batch + 1] + // sequence lengths. [array size: batch + 1] // - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually - // exclusive. Use one set, not both. + // exclusive. Use one set, not both. // // Batch mode: // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) @@ -181,7 +182,6 @@ struct fmha_bwd_args ck_tile::index_t stride_o; ck_tile::index_t stride_randval; ck_tile::index_t stride_do; - ck_tile::index_t stride_dq_acc; ck_tile::index_t stride_dq; ck_tile::index_t stride_dk; ck_tile::index_t stride_dv; @@ -194,7 +194,6 @@ struct fmha_bwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; - ck_tile::long_index_t nhead_stride_dq_acc; ck_tile::index_t nhead_stride_dq; ck_tile::index_t nhead_stride_dk; ck_tile::index_t nhead_stride_dv; @@ -207,12 +206,10 @@ struct fmha_bwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_lsed; - ck_tile::long_index_t batch_stride_dq_acc; ck_tile::index_t batch_stride_dq; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dbias; - ck_tile::index_t split_stride_dq_acc; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; @@ -227,12 +224,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { - constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0; - const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr; - const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq; - const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq; - const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq; - // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { @@ -244,10 +235,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.do_ptr, args.d_ptr, args.rand_val_ptr, + args.dq_ptr, args.dk_ptr, args.dv_ptr, args.dbias_ptr, - dq_ptr, + args.workspace_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_q_ptr, @@ -266,7 +258,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.stride_bias, args.stride_randval, args.stride_do, - stride_dq, + args.stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, @@ -277,11 +269,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, - nhead_stride_dq, + args.nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, - args.split_stride_dq_acc, args.window_size_left, args.window_size_right, args.mask_type, @@ -298,10 +289,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.do_ptr, args.d_ptr, args.rand_val_ptr, + args.dq_ptr, args.dk_ptr, args.dv_ptr, args.dbias_ptr, - dq_ptr, + args.workspace_ptr, args.seqlen_q, args.seqlen_k, args.batch, @@ -316,7 +308,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.stride_bias, args.stride_randval, args.stride_do, - stride_dq, + args.stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, @@ -327,7 +319,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, - nhead_stride_dq, + args.nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, @@ -338,11 +330,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.batch_stride_randval, args.batch_stride_do, args.batch_stride_lsed, - batch_stride_dq, + args.batch_stride_dq, args.batch_stride_dk, args.batch_stride_dv, args.batch_stride_dbias, - args.split_stride_dq_acc, args.window_size_left, args.window_size_right, args.mask_type, @@ -414,8 +405,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) // create group mode kernel arguments if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) { - return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr, args.dq_ptr, + args.batch, + args.nhead_q, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_q_ptr, @@ -424,27 +417,20 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) args.cu_seqlen_k_ptr, args.hdim_q, args.stride_dq, - args.stride_dq_acc, - args.nhead_stride_dq, - args.nhead_stride_dq_acc, - args.split_stride_dq_acc); + args.nhead_stride_dq); } else { // create batch mode kernel arguments - return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr, args.dq_ptr, + args.batch, + args.nhead_q, args.seqlen_q, args.seqlen_k, args.hdim_q, args.stride_dq, - args.stride_dq_acc, args.nhead_stride_dq, - args.nhead_stride_dq_acc, - args.batch_stride_dq, - args.batch_stride_dq_acc, - args.split_stride_dq_acc, - args.batch, - args.nhead_q); + args.batch_stride_dq); } }(); @@ -482,7 +468,16 @@ template int fmha_bwd_dq_dk_dv_maxq_(); struct fmha_bwd_traits; template -int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t); +size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size); +template +size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws, + ck_tile::index_t batch_size, + ck_tile::index_t hdim_q, + ck_tile::index_t nhead_q, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + const ck_tile::index_t* seqstart_qs, + const ck_tile::index_t* seqstart_ks); template bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); @@ -510,8 +505,7 @@ template + bool kIsDeterministic_> struct fmha_bwd_convert_dq_traits_ { }; @@ -545,6 +539,11 @@ struct fmha_bwd_traits bool has_dropout; bool is_store_randval; bool is_deterministic; + // Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1. + // Only need to remain valid during fmha_bwd_launcher construction (i.e. through + // PrepareWorkspaceHost); they are not retained afterward. + const int* seqstart_qs = nullptr; + const int* seqstart_ks = nullptr; // TODO: padding check is inside this api }; @@ -585,12 +584,61 @@ float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_conf struct fmha_bwd_launcher { - std::function run{}; - ck_tile::index_t dq_acc_splits{0}; - bool needs_zero_dq_acc{true}; + std::function run{ + [](fmha_bwd_args, const ck_tile::stream_config&) { + std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n"; + return -1.0f; + }}; + size_t workspace_size = 0; + std::function prepare_workspace{[](void*) { + std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n"; + }}; fmha_bwd_launcher(const fmha_bwd_traits&); + fmha_bwd_launcher(fmha_bwd_launcher&&) = delete; + fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete; + private: + size_t host_ws_size = 0; + size_t device_ws_size = 0; + std::unique_ptr ws_host; + + template + void init(const fmha_bwd_traits& traits) + { + run = [](fmha_bwd_args a, const ck_tile::stream_config& s) { + return fmha_bwd_(s, a); + }; + host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_(traits.batch); + if(host_ws_size > 0) + { + ws_host = std::make_unique(host_ws_size); // TODO: support host mem allocator + device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_( // + ws_host.get(), + traits.batch, + traits.hdim_q, + traits.nhead_q, + traits.seqlen_q, + traits.seqlen_k, + traits.seqstart_qs, + traits.seqstart_ks); + } + workspace_size = host_ws_size + device_ws_size; + const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); + prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) { + if(host_ws_size > 0) + HIP_CHECK_ERROR( + hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice)); + if(needs_zero_dq_acc) + HIP_CHECK_ERROR( + hipMemset(static_cast(device_ws) + host_ws_size, 0, device_ws_size)); + }; + } + + public: template float operator()(Args&&... args) const { diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 361bda20eb..f81ae34501 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -9,6 +9,7 @@ #include "ck_tile/utility/json_dump.hpp" #include +#include #include #include #include @@ -243,29 +244,6 @@ bwd_result fmha_bwd_run(mode_enum mode, const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back()); - const fmha_bwd_traits fmha_traits{ - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - max_seqlen_k, - hdim_q, - hdim_v, - nhead, - nhead_k, - data_type, - mode == mode_enum::group, - mask.type, - bias.type, - use_dbias, - p_drop > 0.0f, - s_randval, - deterministic, - }; - fmha_bwd_launcher launcher(fmha_traits); - - const ck_tile::index_t nsplits = launcher.dq_acc_splits; - ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor k_host( @@ -318,8 +296,6 @@ bwd_result fmha_bwd_run(mode_enum mode, { d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; }); } - ck_tile::HostTensor dq_acc_host( - std::array{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q}); if(init_method == "ui" || init_method == "0") { @@ -396,7 +372,37 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes()); + const auto t0_launcher = std::chrono::high_resolution_clock::now(); + fmha_bwd_launcher launcher(fmha_bwd_traits{ + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + hdim_q, + hdim_v, + nhead, + nhead_k, + data_type, + mode == mode_enum::group, + mask.type, + bias.type, + use_dbias, + p_drop > 0.0f, + s_randval, + deterministic, + (mode == mode_enum::group) ? seqstart_q_host.data() : nullptr, + (mode == mode_enum::group) ? seqstart_k_host.data() : nullptr, + }); + const auto t1_launcher = std::chrono::high_resolution_clock::now(); + const double launcher_ctor_ms = + std::chrono::duration(t1_launcher - t0_launcher).count(); + const size_t ws_size = launcher.workspace_size; + ck_tile::DeviceMem ws_buf(ws_size); + ck_tile::gpu_timer prepare_ws_timer; + prepare_ws_timer.start(nullptr); + launcher.prepare_workspace(ws_buf.GetDeviceBuffer()); + prepare_ws_timer.stop(nullptr); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -433,7 +439,7 @@ bwd_result fmha_bwd_run(mode_enum mode, // clang-format on const std::size_t workspace_size_in_megabytes = - ck_tile::integer_divide_ceil(dq_acc_host.get_element_space_size_in_bytes(), 1024 * 1024); + ck_tile::integer_divide_ceil(ws_size, 1024 * 1024); std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] @@ -441,11 +447,9 @@ bwd_result fmha_bwd_run(mode_enum mode, << ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop << (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval << ", deterministic:" << deterministic - << (deterministic - ? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) + - "MiB|" + std::to_string(nsplits) + "splits" - : "") - << ", mask:" << mask << std::flush; + << ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB" + << ", mask:" << mask << ", init:" << launcher_ctor_ms << "ms" + << ", prws:" << prepare_ws_timer.duration() << "ms" << std::flush; auto fmha_args = [&]() { /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, @@ -462,7 +466,6 @@ bwd_result fmha_bwd_run(mode_enum mode, const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k); - const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); @@ -474,8 +477,6 @@ bwd_result fmha_bwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q; const ck_tile::index_t nhead_stride_dbias = (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); - const auto nhead_stride_dq_acc = - static_cast(split_stride_dq_acc) * nsplits; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); @@ -488,7 +489,8 @@ bwd_result fmha_bwd_run(mode_enum mode, const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); - const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc; + + void* ws_ptr = ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr; const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) { if(drop_prefs) @@ -518,7 +520,7 @@ bwd_result fmha_bwd_run(mode_enum mode, dk_buf.GetDeviceBuffer(), dv_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(), - dq_acc_buf.GetDeviceBuffer(), + ws_ptr, sink_buf.GetDeviceBuffer(), d_sink_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -545,8 +547,7 @@ bwd_result fmha_bwd_run(mode_enum mode, stride_o, stride_randval, stride_do, - hdim_q, // stride_dq_acc - stride_q, // stride_dq + stride_q, // stride_dq (same layout as q for dq output) stride_dk, stride_dv, stride_dbias, @@ -558,7 +559,6 @@ bwd_result fmha_bwd_run(mode_enum mode, nhead_stride_randval, nhead_stride_do, nhead_stride_lsed, - nhead_stride_dq_acc, nhead_stride_q, // nhead_stride_dq nhead_stride_k, // nhead_stride_dk nhead_stride_v, // nhead_stride_dv @@ -571,12 +571,10 @@ bwd_result fmha_bwd_run(mode_enum mode, batch_stride_randval, batch_stride_do, batch_stride_lsed, - batch_stride_dq_acc, batch_stride_q, // batch_stride_dq batch_stride_dk, batch_stride_dv, batch_stride_dbias, - split_stride_dq_acc, mask.left, mask.right, static_cast(mask.type), @@ -901,11 +899,11 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); - ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_acc_host); dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); - dq_acc_buf.ToDevice(dq_acc_host.data()); + // re-initialize workspace for validation run + launcher.prepare_workspace(ws_buf.GetDeviceBuffer()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); @@ -913,9 +911,6 @@ bwd_result fmha_bwd_run(mode_enum mode, if(sink_grad) d_sink_buf.SetZero(); - if(launcher.needs_zero_dq_acc) - dq_acc_buf.SetZero(); - ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; launcher(fmha_args, stream_config_v); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index e9f0258710..23c73e5f43 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -27,6 +27,147 @@ namespace ck_tile { +template +struct FmhaBwdWorkspaceManager +{ + // CPU workspace (prepared by host, read-only for kernels): + + // index_t nsplits[batch or 1] + // — per-batch nsplits array (batch element in deterministic group mode) + + // [OPTIONAL, only for deterministic group mode] + // long_index_t dq_acc_offsets[batch] + // — per-batch offset array + + // GPU WORKSPACE BELOW (read & written by kernels): + + // [OPTIONAL, only for !kUseQrQtrDorPipeline] + // AccDataType dq_acc[total_elements] + // — dq_acc compact buffer (zeroed if necessary) + // - total_elements = sum_i(nhead * nsplits_i * seqq_i) * hdim_q + // - Layout within each batch: [nhead, nsplits_i, seqq_i, hdim_q] + // - note: use physical (including padding) length for seqq_i for group mode + + static constexpr size_t ALIGNMENT = 16; + + template + CK_TILE_HOST static size_t GetDqAccSplitsSize(const int batch) + { + if constexpr(kUseQrQtrDorPipeline) + return 0; + const auto dqAccSplitsElems = + (kIsGroupMode && kIsDeterministic) ? static_cast(batch) : 1; + return integer_least_multiple(sizeof(index_t) * dqAccSplitsElems, ALIGNMENT); + } + CK_TILE_HOST static size_t GetDqAccOffsetsSize(const int batch) + { + const auto dqAccOffsetsElems = + (kIsGroupMode && kIsDeterministic) ? static_cast(batch) : 0; + return integer_least_multiple(sizeof(long_index_t) * dqAccOffsetsElems, ALIGNMENT); + } + template + CK_TILE_HOST static size_t GetWorkspaceHostSize(const int batch) + { + if constexpr(kUseQrQtrDorPipeline) + return 0; + const size_t raw = + GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch); + // Pad to 4K so dq_acc buffer always starts on a page-aligned boundary. + return integer_least_multiple(raw, static_cast(4096)); + } + + CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; } + template + CK_TILE_HOST static size_t GetDqAccOffsetsOffset(const int batch) + { + return GetDqAccSplitsSize(batch); + } + template + CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch) + { + return GetWorkspaceHostSize(batch); + } + + // Fill CPU prepared workspace and return size of non CPU prepared workspace size + template + CK_TILE_HOST static size_t + PrepareWorkspaceHost(void* cpu_ws, + index_t batch_size, + index_t hdim_q, + index_t nhead_q, + index_t seqlen_q = 0, // only for batch mode + index_t seqlen_k = 0, // only for deterministic batch mode + const index_t* seqstart_qs = nullptr, + const index_t* seqstart_ks = nullptr) + { + if constexpr(kUseQrQtrDorPipeline) + { + // QrQtrDor writes dq directly; no workspace is allocated so cpu_ws is nullptr. + throw std::logic_error( + "PrepareWorkspaceHost: QrQtrDor pipeline does not use workspace"); + } + const auto nsplits = reinterpret_cast(cpu_ws); + const auto offsets = reinterpret_cast(reinterpret_cast(cpu_ws) + + GetDqAccSplitsSize(batch_size)); + if constexpr(kIsGroupMode) + if(!seqstart_qs || !seqstart_ks) + throw std::runtime_error("seqstart_qs and seqstart_ks are required for group mode"); + + if constexpr(!kIsDeterministic) + { + nsplits[0] = 1; + if constexpr(!kIsGroupMode) + return sizeof(AccDataType) * static_cast(batch_size) * nhead_q * + seqlen_q * hdim_q; + else + return sizeof(AccDataType) * static_cast(nhead_q) * + seqstart_qs[batch_size] * hdim_q; + } + else if constexpr(kIsGroupMode) + { // deterministic group mode + offsets[0] = 0; + index_t i = 0; + for(; i < batch_size - 1; ++i) + { + nsplits[i] = integer_divide_ceil(seqstart_ks[i + 1] - seqstart_ks[i], kN0); + offsets[i + 1] = offsets[i] + static_cast(nhead_q) * nsplits[i] * + (seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q; + } + nsplits[i] = integer_divide_ceil(seqstart_ks[i + 1] - seqstart_ks[i], kN0); + return sizeof(AccDataType) * + (offsets[i] + static_cast(nhead_q) * nsplits[i] * + (seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q); + } + else // deterministic non-group mode (kUsePersistent) + { + const index_t dqdqkdv_workers = get_num_cus(); + const index_t jobs_per_head = integer_divide_ceil(seqlen_k, kN0); + const index_t total_jobs = batch_size * nhead_q * jobs_per_head; + const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers); + if(jobs_per_head % jobs_per_worker == 0) + nsplits[0] = jobs_per_head / jobs_per_worker; + else if(jobs_per_worker % jobs_per_head == 0) + nsplits[0] = 1; + else + nsplits[0] = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker); + return sizeof(AccDataType) * static_cast(batch_size) * nhead_q * + nsplits[0] * seqlen_q * hdim_q; + } + } + + template + CK_TILE_HOST static constexpr bool NeedsZeroDqAcc() + { + constexpr bool kUsePersistent = !kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode; + // non-deterministic and persistent kernels use atomic-add to write dq + if constexpr(kUsePersistent || !kIsDeterministic) + return true; + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases we need to zero out it first + return kHasMask; + } +}; + template ; // clang-format off template struct t2s; @@ -126,42 +268,22 @@ struct FmhaBwdDQDKDVKernel #undef _TS_ // clang-format on } - CK_TILE_HOST static index_t - GetDqAccSplits(index_t batch_size_, index_t nhead_, index_t seqlen_k_) + template + CK_TILE_HOST static constexpr auto GetWorkspaceHostSize(Args&&... args) { - // Be consistent with convert_dq kernel, though qrqtrdor pipeline doesn't use persistent - static constexpr bool kUsePersistent__ = kIsDeterministic && !kIsGroupMode; - if constexpr(kUsePersistent__) - { - const index_t dqdqkdv_workers = get_num_cus(); - const index_t jobs_per_head = - integer_divide_ceil(seqlen_k_, FmhaPipeline::BlockFmhaShape::kN0); - const index_t total_jobs = batch_size_ * nhead_ * jobs_per_head; - const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers); - if(jobs_per_head % jobs_per_worker == 0) - return jobs_per_head / jobs_per_worker; - else if(jobs_per_worker % jobs_per_head == 0) - return 1; - else - return 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker); - } - else if constexpr(kIsDeterministic) - return integer_divide_ceil(seqlen_k_, FmhaPipeline::BlockFmhaShape::kN0); - else - return 1; + return WorkspaceManager::template GetWorkspaceHostSize( + std::forward(args)...); + } + template + CK_TILE_HOST static constexpr auto PrepareWorkspaceHost(Args&&... args) + { + return WorkspaceManager::template PrepareWorkspaceHost( + std::forward(args)...); } CK_TILE_HOST static constexpr bool NeedsZeroDqAcc() { - // Be consistent with convert_dq kernel, though qrqtrdor pipeline doesn't use persistent - constexpr bool kUsePersistent__ = kIsDeterministic && !kIsGroupMode; - - // non-deterministic adn persistent kernels use atomic-add to write dq - if constexpr(kUsePersistent__ || !kIsDeterministic) - return true; - - // Some block may be skipped with causal mask and dq are not set to zeros - // In these cases we need to zero out it first - return kHasMask; + return WorkspaceManager::template NeedsZeroDqAcc(); } template // to avoid duplicated base class prblem, introduce an template @@ -192,7 +314,7 @@ struct FmhaBwdDQDKDVKernel // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case - ck_tile::index_t num_head_q; + ck_tile::index_t nhead_q; ck_tile::index_t nhead_ratio_qk; float raw_scale; float scale; @@ -201,7 +323,6 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_do; - ck_tile::index_t stride_dq_acc; ck_tile::index_t stride_dk; ck_tile::index_t stride_dv; @@ -210,11 +331,18 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; - ck_tile::long_index_t nhead_stride_dq_acc; ck_tile::index_t nhead_stride_dk; ck_tile::index_t nhead_stride_dv; }; + // strides for the QrQtrDor pipeline which writes dq directly (no split accumulator) + struct FmhaBwdQrQtrDorKargs + { + ck_tile::index_t stride_dq; + ck_tile::index_t nhead_stride_dq; + std::conditional_t, ck_tile::index_t> batch_stride_dq; + }; + struct FmhaBwdCommonBiasKargs { const void* bias_ptr = nullptr; @@ -313,8 +441,8 @@ struct FmhaBwdDQDKDVKernel struct FmhaBwdDeterministicKargs { - ck_tile::index_t split_stride_dq_acc = 0; - ck_tile::index_t batch; // used for persistent kernel implementation + ck_tile::index_t batch; // used for persistent kernel implementation + const ck_tile::index_t* nsplits_ptr; // points to nsplits[0] in workspace (batch mode) }; struct FmhaBwdBatchModeKargs @@ -323,18 +451,18 @@ struct FmhaBwdDQDKDVKernel FmhaBwdBatchModeBiasKargs, std::conditional_t>>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t> + FmhaBwdEmptyKargs<1>>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_lsed; - ck_tile::long_index_t batch_stride_dq_acc; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; }; @@ -349,7 +477,8 @@ struct FmhaBwdDQDKDVKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -357,6 +486,8 @@ struct FmhaBwdDQDKDVKernel const int32_t* seqlen_k_ptr; // per-batch actual length [batch] const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional + // per-batch element offset into dq_acc buffer (compact layout); used when deterministic + const ck_tile::long_index_t* dq_acc_batch_offset_ptr; }; using Kargs = std::conditional_t; @@ -389,16 +520,17 @@ struct FmhaBwdDQDKDVKernel const void* do_ptr, const void* d_ptr, void* rand_val_ptr, + void* dq_ptr, // only used with qrqtrdor pipeline void* dk_ptr, void* dv_ptr, void* dbias_ptr, - void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline + void* workspace_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t batch, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, + ck_tile::index_t nhead_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, @@ -407,7 +539,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, + ck_tile::index_t stride_dq, // only used for QrQtrDor pipeline ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, @@ -418,7 +550,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, - ck_tile::long_index_t nhead_stride_dq_acc, + ck_tile::index_t nhead_stride_dq, // only used for QrQtrDor pipeline ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, @@ -429,11 +561,10 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, - ck_tile::long_index_t batch_stride_dq_acc, + ck_tile::index_t batch_stride_dq, // only used for QrQtrDor pipeline ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, - ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -441,51 +572,58 @@ struct FmhaBwdDQDKDVKernel std::variant, std::pair> drop_seed_offset) { - Kargs kargs{{q_ptr, - k_ptr, - v_ptr, - lse_ptr, - do_ptr, - d_ptr, - dq_acc_ptr, - dk_ptr, - dv_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - static_cast(scale * ck_tile::log2e_v<>), - stride_q, - stride_k, - stride_v, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for dbias - {}, // placeholder for mask - {}, // placeholder for dropout - {}, // placeholder for deterministic - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, - batch_stride_dk, - batch_stride_dv}; + uint8_t* ws = reinterpret_cast(workspace_ptr); + Kargs kargs{ + {q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + [&]() { + if constexpr(kUseQrQtrDorPipeline) + return dq_ptr; + else + return ws + + WorkspaceManager::template GetDqAccDataOffset( + batch); + }(), + dk_ptr, + dv_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + nhead_q, + nhead_ratio_qk, + scale, + static_cast(scale * ck_tile::log2e_v<>), + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + nhead_stride_dk, + nhead_stride_dv}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + {}, // placeholder for deterministic + {}, // placeholder for QrQtrDor + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_do, + batch_stride_lsed, + batch_stride_dk, + batch_stride_dv}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -540,11 +678,20 @@ struct FmhaBwdDQDKDVKernel } } - if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline) - kargs.split_stride_dq_acc = split_stride_dq_acc; + if constexpr(kUseQrQtrDorPipeline) + { + kargs.stride_dq = stride_dq; + kargs.nhead_stride_dq = nhead_stride_dq; + kargs.batch_stride_dq = batch_stride_dq; + } if constexpr(kUsePersistent) - kargs.batch = batch; + { + kargs.batch = batch; + kargs.nsplits_ptr = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + WorkspaceManager::GetDqAccSplitsOffset(batch)); + } return kargs; } @@ -559,10 +706,11 @@ struct FmhaBwdDQDKDVKernel const void* do_ptr, const void* d_ptr, void* rand_val_ptr, + void* dq_ptr, void* dk_ptr, void* dv_ptr, void* dbias_ptr, - void* dq_acc_ptr, + void* workspace_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_q_ptr, @@ -572,7 +720,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t batch, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, + ck_tile::index_t nhead_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, @@ -581,7 +729,7 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, + ck_tile::index_t stride_dq, // only used for QrQtrDor pipeline ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, @@ -592,11 +740,10 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, - ck_tile::long_index_t nhead_stride_dq_acc, + ck_tile::index_t nhead_stride_dq, // only used for QrQtrDor pipeline ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -604,49 +751,63 @@ struct FmhaBwdDQDKDVKernel std::variant, std::pair> drop_seed_offset) { - Kargs kargs{{q_ptr, - k_ptr, - v_ptr, - lse_ptr, - do_ptr, - d_ptr, - dq_acc_ptr, - dk_ptr, - dv_ptr, - -1, // seqlen will be updated by another pointer - -1, // - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - static_cast(scale * ck_tile::log2e_v<>), - stride_q, - stride_k, - stride_v, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for dbias - {}, // placeholder for mask - {}, // placeholder for dropout - {}, // placeholder for deterministic - reinterpret_cast(seqstart_q_ptr), - reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_q_ptr), - reinterpret_cast(seqlen_k_ptr), - reinterpret_cast(cu_seqlen_q_ptr), - reinterpret_cast(cu_seqlen_k_ptr)}; + const auto ws = reinterpret_cast(workspace_ptr); + Kargs kargs{ + {q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + [&]() { + if constexpr(kUseQrQtrDorPipeline) + return dq_ptr; + else + return ws + + WorkspaceManager::template GetDqAccDataOffset( + batch); + }(), + dk_ptr, + dv_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + nhead_q, + nhead_ratio_qk, + scale, + static_cast(scale * ck_tile::log2e_v<>), + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + nhead_stride_dk, + nhead_stride_dv}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + {}, // placeholder for deterministic + {}, // placeholder for QrQtrDor + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), + reinterpret_cast(seqlen_k_ptr), + reinterpret_cast(cu_seqlen_q_ptr), + reinterpret_cast(cu_seqlen_k_ptr), + nullptr, // dq_acc_batch_offset_ptr (set below for non-QrQtrDor deterministic) + }; + + if constexpr(!kUseQrQtrDorPipeline) + kargs.dq_acc_batch_offset_ptr = reinterpret_cast( + ws + WorkspaceManager::template GetDqAccOffsetsOffset(batch)); if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -694,8 +855,12 @@ struct FmhaBwdDQDKDVKernel kargs.nhead_stride_randval = nhead_stride_randval; } } - if constexpr(kIsDeterministic) - kargs.split_stride_dq_acc = split_stride_dq_acc; + if constexpr(kUseQrQtrDorPipeline) + { + kargs.stride_dq = stride_dq; + kargs.nhead_stride_dq = nhead_stride_dq; + } + if constexpr(kUsePersistent) kargs.batch = batch; @@ -738,7 +903,16 @@ struct FmhaBwdDQDKDVKernel { if constexpr(!kUsePersistent) { - run_(std::move(kargs), blockIdx, blockIdx.x); + if constexpr(kUseQrQtrDorPipeline || kIsGroupMode) + { + run_(std::move(kargs), blockIdx, blockIdx.x, 0); + } + else + { + static_assert(!kIsDeterministic, + "Deterministic Batch Mode should use persistent kernel"); + run_(std::move(kargs), blockIdx, blockIdx.x, 1); + } } else { @@ -749,7 +923,7 @@ struct FmhaBwdDQDKDVKernel const index_t jobs_per_head = integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); - const index_t total_heads = kargs.batch * kargs.num_head_q; + const index_t total_heads = kargs.batch * kargs.nhead_q; const index_t total_jobs = jobs_per_head * total_heads; const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num); @@ -766,25 +940,27 @@ struct FmhaBwdDQDKDVKernel return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2); }; - index_t job_id = begin_job_id; - index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker); + const auto n_splits = kargs.nsplits_ptr[0]; + index_t job_id = begin_job_id; + index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker); do { // loop over jobs assigned to this worker const index_t i_head_flatten = job_id / jobs_per_head; const index_t i_tile_n_ = job_id % jobs_per_head; const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head); - const index_t i_batch = i_head_flatten / kargs.num_head_q; - const index_t i_nhead = i_head_flatten % kargs.num_head_q; + const index_t i_batch = i_head_flatten / kargs.nhead_q; + const index_t i_nhead = i_head_flatten % kargs.nhead_q; if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head i_split = 0; - run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split); + run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split, n_splits); } while(++job_id < end_job_id); } } } - CK_TILE_DEVICE void run_(Kargs kargs, const dim3& tile_index, const index_t i_split) const + CK_TILE_DEVICE void + run_(Kargs kargs, const dim3& tile_index, const index_t i_split, const index_t n_splits) const { // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; @@ -807,6 +983,9 @@ struct FmhaBwdDQDKDVKernel long_index_t batch_offset_dk = 0; long_index_t batch_offset_dv = 0; long_index_t batch_offset_dbias = 0; + // dq_acc per-nhead stride uses padded seqlen_q in group mode; equals kargs.seqlen_q + // in batch mode. See FmhaBwdWorkspaceManager doc. + index_t physical_seqlen_q = kargs.seqlen_q; if constexpr(kIsGroupMode) { @@ -814,14 +993,24 @@ struct FmhaBwdDQDKDVKernel const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - batch_offset_v = key_start * kargs.stride_v; - batch_offset_do = query_start * kargs.stride_do; - batch_offset_lsed = query_start; - batch_offset_dq_acc = query_start * kargs.stride_dq_acc; - batch_offset_dk = key_start * kargs.stride_dk; - batch_offset_dv = key_start * kargs.stride_dv; + physical_seqlen_q = + static_cast(kargs.seqstart_q_ptr[i_batch + 1] - query_start); + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_lsed = query_start; + // All !kUseQrQtrDorPipeline paths use per-batch compact dq_acc layout + // QrQtrDor: direct write to dq_ptr (flat layout with per-nhead strides) + if constexpr(kUseQrQtrDorPipeline) + batch_offset_dq_acc = query_start * kargs.stride_dq; + else if constexpr(!kIsDeterministic) + batch_offset_dq_acc = query_start * kargs.hdim_q * kargs.nhead_q; + else + batch_offset_dq_acc = kargs.dq_acc_batch_offset_ptr[i_batch]; + batch_offset_dk = key_start * kargs.stride_dk; + batch_offset_dv = key_start * kargs.stride_dv; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias; @@ -847,10 +1036,6 @@ struct FmhaBwdDQDKDVKernel } else { - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - const ck_tile::index_t physical_seqlen_q = - adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; kargs.seqlen_q = kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q; } @@ -885,14 +1070,22 @@ struct FmhaBwdDQDKDVKernel } else { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; - batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; - batch_offset_dq_acc = static_cast(i_batch) * kargs.batch_stride_dq_acc; - batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; - batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + + if constexpr(kUseQrQtrDorPipeline) + batch_offset_dq_acc = static_cast(i_batch) * kargs.batch_stride_dq; + else if constexpr(!kIsDeterministic) + batch_offset_dq_acc = static_cast(i_batch) * kargs.nhead_q * + kargs.seqlen_q * kargs.hdim_q; + else + batch_offset_dq_acc = static_cast(i_batch) * kargs.nhead_q * + n_splits * kargs.seqlen_q * kargs.hdim_q; + batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; + batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; @@ -1013,22 +1206,43 @@ struct FmhaBwdDQDKDVKernel using DType = std::conditional_t; auto dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + [&]() { - if constexpr(kUseKSplit) - return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + - static_cast(i_split) * kargs.split_stride_dq_acc + - batch_offset_dq_acc; + if constexpr(kUseQrQtrDorPipeline) + { + return batch_offset_dq_acc + + static_cast(i_nhead_) * kargs.nhead_stride_dq; + } + else if constexpr(!kIsDeterministic) + { + return batch_offset_dq_acc + + static_cast(i_nhead_) * physical_seqlen_q * kargs.hdim_q; + } else - return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + - batch_offset_dq_acc; + { + const long_index_t split_stride = + static_cast(physical_seqlen_q) * kargs.hdim_q; + const auto nsplits = [&]() { + if constexpr(!kIsGroupMode) + return n_splits; + else + return integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + }(); + return batch_offset_dq_acc + (i_nhead_ * nsplits + i_split) * split_stride; + } }(); constexpr auto DstInMemOp = conditional_expr<(kUseKSplit && !kUsePersistent)>( memory_operation_enum::set, memory_operation_enum::atomic_add); + const index_t stride_dq_acc = [&]() { + if constexpr(kUseQrQtrDorPipeline) + return kargs.stride_dq; + else + return kargs.hdim_q; + }(); const auto dq_acc_dram_naive = make_naive_tensor_view( dq_acc_ptr, make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_dq_acc, 1), + make_tuple(stride_dq_acc, 1), number{}, number<1>{}); const auto dq_acc_dram = pad_tensor_view( @@ -1150,7 +1364,7 @@ struct FmhaBwdDQDKDVKernel { return FmhaDropout{i_batch_, i_nhead_, - kargs.num_head_q, + kargs.nhead_q, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val : *kargs.drop_seed.ptr, kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val @@ -1649,7 +1863,6 @@ struct FmhaBwdConvertQGradKernel static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu; static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0; - static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0; static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim; using AccDataType = ck_tile::remove_cvref_t; @@ -1660,6 +1873,7 @@ struct FmhaBwdConvertQGradKernel static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ; static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic; static constexpr bool kUsePersistent = kIsDeterministic && !kIsGroupMode; + using WorkspaceManager = FmhaBwdWorkspaceManager; // clang-format off template struct t2s; @@ -1683,7 +1897,7 @@ struct FmhaBwdConvertQGradKernel return _SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_" + _SS_(t2s::name) + "_" - + "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_" + + "b" + _TS_(kM0) + "_" + (kIsGroupMode ? "group" : "batch") + "_" + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) + (kIsDeterministic ? "_deterministic" : "_ndeterministic") ; @@ -1706,22 +1920,18 @@ struct FmhaBwdConvertQGradKernel const void* dq_acc_ptr; void* dq_ptr; + ck_tile::index_t nhead_q; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t hdim_q; ck_tile::index_t stride_dq; - ck_tile::index_t stride_dq_acc; ck_tile::index_t nhead_stride_dq; - ck_tile::long_index_t nhead_stride_dq_acc; }; struct FmhaBwdConvertQGradDeterministicKargs { - index_t split_stride_dq_acc = 0; - index_t dqdqkdv_workers = 0; // 0 for not using persistent kernel - index_t batch_size = 0; // for nsplits calc of persistent kernel - index_t nhead = 0; // for nsplits calc of persistent kernel + const index_t* nsplits_ptr; }; struct FmhaBwdConvertQGradBatchModeKargs @@ -1730,8 +1940,7 @@ struct FmhaBwdConvertQGradKernel FmhaBwdConvertQGradDeterministicKargs, FmhaBwdConvertQGradEmptyKargs<0>> { - ck_tile::index_t batch_stride_dq; - ck_tile::long_index_t batch_stride_dq_acc; + index_t batch_stride_dq; }; struct FmhaBwdConvertQGradGroupModeKargs @@ -1746,6 +1955,8 @@ struct FmhaBwdConvertQGradKernel const int32_t* seqlen_k_ptr; // per-batch actual length [batch] const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional + // per-batch element offset into compact dq_acc buffer + const long_index_t* dq_acc_batch_offset_ptr; }; using Kargs = std::conditional_t CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* dq_acc_ptr, + MakeKargs(const void* workspace, void* dq_ptr, + ck_tile::index_t batch_size, + ck_tile::index_t nhead_q, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, - ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, - ck_tile::long_index_t nhead_stride_dq_acc, - ck_tile::index_t batch_stride_dq, - ck_tile::long_index_t batch_stride_dq_acc, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t batch_size, - ck_tile::index_t nhead) + ck_tile::index_t batch_stride_dq) { - Kargs kargs{{dq_acc_ptr, - dq_ptr, - seqlen_q, - seqlen_k, - hdim_q, - stride_dq, - stride_dq_acc, - nhead_stride_dq, - nhead_stride_dq_acc}, - {}, - batch_stride_dq, - batch_stride_dq_acc}; - + const uint8_t* ws = reinterpret_cast(workspace); + Kargs kargs{ + {ws + WorkspaceManager::template GetDqAccDataOffset(batch_size), + dq_ptr, + nhead_q, + seqlen_q, + seqlen_k, + hdim_q, + stride_dq, + nhead_stride_dq}, + {}, + batch_stride_dq, + }; if constexpr(kIsDeterministic) { - kargs.split_stride_dq_acc = split_stride_dq_acc; - if constexpr(kUsePersistent) - { - kargs.dqdqkdv_workers = get_num_cus(); - kargs.batch_size = batch_size; - kargs.nhead = nhead; - } + kargs.nsplits_ptr = reinterpret_cast( + ws + WorkspaceManager::GetDqAccSplitsOffset(batch_size)); } return kargs; @@ -1798,8 +2000,10 @@ struct FmhaBwdConvertQGradKernel template CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* dq_acc_ptr, + MakeKargs(const void* workspace, void* dq_ptr, + ck_tile::index_t batch_size, + ck_tile::index_t nhead_q, const void* seqstart_q_ptr, const void* seqstart_k_ptr, const void* seqlen_q_ptr, @@ -1808,31 +2012,31 @@ struct FmhaBwdConvertQGradKernel const void* cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t nhead_stride_dq, - ck_tile::long_index_t nhead_stride_dq_acc, - ck_tile::index_t split_stride_dq_acc) + ck_tile::index_t nhead_stride_dq) { - Kargs kargs{{dq_acc_ptr, + const uint8_t* ws = reinterpret_cast(workspace); + Kargs kargs{{ws + WorkspaceManager::template GetDqAccDataOffset(batch_size), dq_ptr, + nhead_q, -1, // seqlen will be updated by another pointer -1, // hdim_q, stride_dq, - stride_dq_acc, - nhead_stride_dq, - nhead_stride_dq_acc}, + nhead_stride_dq}, {}, reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr), reinterpret_cast(cu_seqlen_q_ptr), - reinterpret_cast(cu_seqlen_k_ptr)}; + reinterpret_cast(cu_seqlen_k_ptr), + reinterpret_cast( + ws + WorkspaceManager::template GetDqAccOffsetsOffset(batch_size))}; if constexpr(kIsDeterministic) { - kargs.split_stride_dq_acc = split_stride_dq_acc; + kargs.nsplits_ptr = reinterpret_cast( + ws + WorkspaceManager::GetDqAccSplitsOffset(batch_size)); } return kargs; @@ -1866,28 +2070,26 @@ struct FmhaBwdConvertQGradKernel long_index_t batch_offset_dq = 0; long_index_t batch_offset_dq_acc = 0; + index_t physical_seqlen_q = 0; if constexpr(kIsGroupMode) { - // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_dq = query_start * kargs.stride_dq; - batch_offset_dq_acc = query_start * kargs.stride_dq_acc; + physical_seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start; + // get starting offset for each batch + batch_offset_dq = query_start * kargs.stride_dq; + if constexpr(!kIsDeterministic) + batch_offset_dq_acc = query_start * kargs.hdim_q * kargs.nhead_q; + else + batch_offset_dq_acc = kargs.dq_acc_batch_offset_ptr[i_batch]; if(kargs.cu_seqlen_q_ptr != nullptr) - { kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; - } - else - { + else if(kargs.seqlen_q_ptr != nullptr) // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - const ck_tile::index_t physical_seqlen_q = - adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - kargs.seqlen_q = kargs.seqlen_q_ptr - ? static_cast(kargs.seqlen_q_ptr[i_batch]) - : physical_seqlen_q; - } + kargs.seqlen_q = static_cast(kargs.seqlen_q_ptr[i_batch]); + else + kargs.seqlen_q = physical_seqlen_q; if constexpr(kIsDeterministic) { @@ -1918,49 +2120,49 @@ struct FmhaBwdConvertQGradKernel } else { - batch_offset_dq = static_cast(i_batch) * kargs.batch_stride_dq; - batch_offset_dq_acc = static_cast(i_batch) * kargs.batch_stride_dq_acc; + batch_offset_dq = static_cast(i_batch) * kargs.batch_stride_dq; + physical_seqlen_q = kargs.seqlen_q; + // batch mode: nsplits was pre-computed by PrepareWorkspaceHost and stored in workspace + index_t nsplits = 1; + if constexpr(kIsDeterministic) + nsplits = kargs.nsplits_ptr[0]; + const long_index_t nhead_stride_dq_acc = + static_cast(nsplits) * kargs.seqlen_q * kargs.hdim_q; + batch_offset_dq_acc = + static_cast(i_batch) * kargs.nhead_q * nhead_stride_dq_acc; } // for simplicity, batch stride we just modify the pointer QGradDataType* dq_ptr = reinterpret_cast(kargs.dq_ptr) + static_cast(i_nhead) * kargs.nhead_stride_dq + batch_offset_dq; - const index_t nsplits = [&]() { - const index_t jobs_per_head = integer_divide_ceil(kargs.seqlen_k, kN0); - if constexpr(!kIsDeterministic) - return 1; - else if constexpr(!kUsePersistent) - return jobs_per_head; - else - { - const index_t total_heads = kargs.batch_size * kargs.nhead; - const index_t total_jobs = jobs_per_head * total_heads; - const index_t jobs_per_worker = - integer_divide_ceil(total_jobs, kargs.dqdqkdv_workers); - const index_t i_head_flatten = i_batch * kargs.nhead + i_nhead; - - const index_t i_job_start = jobs_per_head * i_head_flatten; - const index_t begin_worker_id = i_job_start / jobs_per_worker; - const index_t end_worker_id = // inclusive - (i_job_start + jobs_per_head - 1) / jobs_per_worker; - return end_worker_id - begin_worker_id + 1; - } - }(); // dQAcc/dQ DRAM and DRAM window + // compact layout: stride_dq_acc=hdim_q, split_stride=physical_seqlen_q*hdim_q, + // nhead_stride=nsplits*physical_seqlen_q*hdim_q + const long_index_t split_stride_dq_acc = + static_cast(physical_seqlen_q) * kargs.hdim_q; + const index_t nsplits = [&, i_batch_ = i_batch]() { + if constexpr(!kIsDeterministic) + return 1; + else if constexpr(!kIsGroupMode) + return kargs.nsplits_ptr[0]; + else // deterministic group mode + return kargs.nsplits_ptr[i_batch_]; + }(); + const long_index_t nhead_stride_dq_acc = split_stride_dq_acc * nsplits; + const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() { if constexpr(kIsDeterministic) { const AccDataType* dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + - static_cast(i_nhead_) * (kargs.nhead_stride_dq_acc) + - batch_offset_dq_acc; + static_cast(i_nhead_) * nhead_stride_dq_acc + batch_offset_dq_acc; auto dq_acc_dram_naive = make_naive_tensor_view( dq_acc_ptr, make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1), + make_tuple(split_stride_dq_acc, kargs.hdim_q, 1), number{}, number<1>{}); return pad_tensor_view(dq_acc_dram_naive, @@ -1971,13 +2173,12 @@ struct FmhaBwdConvertQGradKernel { const AccDataType* dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + - static_cast(i_nhead_) * (kargs.nhead_stride_dq_acc) + - batch_offset_dq_acc; + static_cast(i_nhead_) * nhead_stride_dq_acc + batch_offset_dq_acc; auto dq_acc_dram_naive = make_naive_tensor_view( dq_acc_ptr, make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_dq_acc, 1), + make_tuple(kargs.hdim_q, 1), number{}, number<1>{}); return pad_tensor_view(dq_acc_dram_naive, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp index e6d7c622f7..98c40497ec 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp @@ -15,7 +15,6 @@ struct BlockFmhaBwdConvertQGrad using QGradDataType = remove_cvref_t; static constexpr index_t kM0 = Problem::kM0; - static constexpr index_t kN0 = Problem::kN0; static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index d66ce4311e..f553945a37 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -97,7 +97,6 @@ template