diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 919a7aa8c0..2cec9c713a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -269,14 +269,14 @@ class FmhaFwdApiTrait: return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true if self.pipeline_tag == "qr_async": if self.skpad == "t": - return f"(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" + return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" else: - return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" + return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" elif self.pipeline_tag in ["qr", "qs"]: if self.skpad == "t": return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: - return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" + return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" elif self.pipeline_tag == "qr_async_trload": if self.skpad == "t": return "true" diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index eac1840a19..bf7092c6b8 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -114,10 +114,51 @@ struct fmha_bwd_args void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* seqlen_q_ptr; - const void* seqlen_k_ptr; + + // Usage notes for sequence length pointer parameters: + // + // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles + // MQA/GQA..."] + // + // With padding: + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence + // lengths. [array size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each + // sequence. [array size: batch] + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually + // exclusive. Use one set, not both. + // + // Batch mode: + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqstart_* and seqlen_* pointers must be nullptr. + // + // Without padding: + // (Note: Physical length equals logical length) + // + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array + // size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr. + // + // Batch mode: + // - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr. + // + const void* seqstart_q_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqstart_k_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) + const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -190,54 +231,57 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { - return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - dq_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_q_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - stride_dq, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - nhead_stride_dq, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); + return FmhaBwdDQDKDVKernel::MakeKargsImpl( + args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + dq_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, + static_cast(args.cu_seqlen_q_ptr), + static_cast(args.cu_seqlen_k_ptr), + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + stride_dq, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + nhead_stride_dq, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); } else { // create batch mode kernel arguments @@ -318,6 +362,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) args.p_undrop, args.seqstart_q_ptr, args.seqlen_q_ptr, + args.cu_seqlen_q_ptr, args.hdim_v, args.stride_do, args.stride_o, @@ -361,6 +406,8 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr, args.hdim_q, args.stride_dq, args.stride_dq_acc, diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 0f34ef851f..52adcdc21d 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -493,6 +493,8 @@ bwd_result fmha_bwd_run(mode_enum mode, seqstart_k.GetDeviceBuffer(), seqlen_q_ptr_dev, seqlen_k_ptr_dev, + nullptr, + nullptr, shape_seqlen_q, shape_seqlen_k, batch, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 761def6d6a..383be6e099 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -182,19 +182,50 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; - // Optional cumulative sequence length arrays - // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] - - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - - // Group mode: seqstart_padded_* provide physical starts including PAD (optional) - const void* seqstart_padded_q_ptr = nullptr; // [batch+1] - const void* seqstart_padded_k_ptr = nullptr; // [batch+1] + // Usage notes for sequence length pointer parameters: + // + // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles + // MQA/GQA..."] + // + // With padding: + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence + // lengths. [array size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each + // sequence. [array size: batch] + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually + // exclusive. Use one set, not both. + // + // Batch mode: + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqstart_* and seqlen_* pointers must be nullptr. + // + // Without padding: + // (Note: Physical length equals logical length) + // + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array + // size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr. + // + // Batch mode: + // - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr. + // + const void* seqstart_q_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqstart_k_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) + const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -555,6 +586,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.o_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, + args.seqlen_q_ptr, args.seqlen_k_ptr, args.hdim_q, args.hdim_v, @@ -584,8 +616,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.seqstart_padded_q_ptr, - args.seqstart_padded_k_ptr); + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); } else { // create batch mode kernel arguments @@ -633,7 +665,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.s_randval, args.drop_seed_offset, args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); + args.cu_seqlen_k_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 69a5fcbeda..ca3cd51c57 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -313,16 +313,19 @@ fwd_result fmha_fwd_run(mode_enum mode, const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); // Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv) - const bool has_group_padding = - (mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) || - (mode == mode_enum::group && (seqlen_kpads[0] >= 0)); - const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() || - !kv_eff_lens_per_batch.empty())); - const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); - const bool using_pagedkv = (0 < page_block_size); - const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; + const bool has_group_q_padding = + mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] > 0); + const bool has_group_k_padding = + mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] > 0); + const bool has_group_padding = has_group_q_padding || has_group_k_padding; + const bool has_batch_q_padding = mode == mode_enum::batch && !q_eff_lens_per_batch.empty(); + const bool has_batch_k_padding = mode == mode_enum::batch && !kv_eff_lens_per_batch.empty(); + const bool has_batch_padding = has_batch_q_padding || has_batch_k_padding; + const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); + const bool using_pagedkv = (0 < page_block_size); + const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; if((using_appendkv || using_pagedkv || using_splitkv) && - (has_group_padding || has_batch_efflens)) + (has_group_padding || has_batch_padding)) { std::cerr << "Padding (physical or effective lengths) is not supported with " "appendkv/splitkv/pagedkv pipelines" @@ -399,23 +402,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_q_host = to_seqstarts(seqlen_qs); const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_q_with_padding_host = to_seqstarts(seqlen_qpads); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); - // Optional padded Q seqstarts (group-mode only) - std::vector seqstart_q_with_padding_host; - if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) - { - if(seqlen_qpads.size() < static_cast(batch)) - { - seqlen_qpads.resize(batch, seqlen_qpads.back()); - } - if(seqlen_qpads.size() == static_cast(batch)) - { - seqstart_q_with_padding_host = to_seqstarts( - ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); - } - } - // Optional batch-mode cumulative seqlen overrides std::vector cuq_cum, cukv_cum; if(mode == mode_enum::batch) @@ -524,14 +513,15 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch - ? seqlen_qs[0] - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() - : seqstart_q_with_padding_host.back())); + (mode == mode_enum::batch ? seqlen_qs[0] + : (has_group_q_padding && !seqstart_q_with_padding_host.empty() + ? seqstart_q_with_padding_host.back() + : seqstart_q_host.back())); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] - : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() - : seqstart_k_with_padding_host.back())); + : (has_group_k_padding && !seqstart_k_with_padding_host.empty() + ? seqstart_k_with_padding_host.back() + : seqstart_k_host.back())); ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -689,14 +679,18 @@ fwd_result fmha_fwd_run(mode_enum mode, sizeof(int32_t)); ck_tile::DeviceMem seqstart_k_padded_buf( seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + // Buffers for query per-sequence logical (unpadded) lengths (used in group mode with padding + // enabled) + ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0); + // Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with + // kvcache or group mode with padding enabled) + ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding + ? seqlen_ks.size() * sizeof(int32_t) + : 0); ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 : cuq_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem cu_seqlen_kv_buf( cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); - ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || - 0 <= seqlen_kpads[0] - ? seqlen_ks.size() * sizeof(int32_t) - : 0); ck_tile::DeviceMem cache_seqlen_k_buf( need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); @@ -792,7 +786,8 @@ fwd_result fmha_fwd_run(mode_enum mode, : seqstart_k_with_padding_host.data()); cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); - seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] + seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr); + seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding ? seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); @@ -873,7 +868,7 @@ fwd_result fmha_fwd_run(mode_enum mode, print_vec("k_padded", seqlen_kpads); } } - else if(has_batch_efflens) + else if(has_batch_padding) { // derive effective lengths from cumulative arrays if present if(!cuq_cum.empty()) @@ -1056,14 +1051,6 @@ fwd_result fmha_fwd_run(mode_enum mode, args.lse_ptr = lse_buf.GetDeviceBuffer(); args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = - (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); - args.seqstart_k_ptr = - (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] - ? seqlen_k_buf.GetDeviceBuffer() - : nullptr); - args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.max_seqlen_q = max_seqlen_q; @@ -1107,27 +1094,54 @@ fwd_result fmha_fwd_run(mode_enum mode, args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } - // Group-mode: optional physical padded starts for Q/K + // Sequence length and padding parameters (mode-specific) if(mode == mode_enum::group) { - args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() - ? nullptr - : seqstart_q_padded_buf.GetDeviceBuffer()); - args.seqstart_padded_k_ptr = - (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); - } + // Group mode: use physical (padded) cumulative starts + logical per-sequence + // lengths - // Batch-mode: optional cumulative effective seqlen overrides - if(mode == mode_enum::batch) + // Physical cumulative starts (including padding) + args.seqstart_q_ptr = + has_group_q_padding && !seqstart_q_with_padding_host.empty() + ? seqstart_q_padded_buf.GetDeviceBuffer() + : seqstart_q.GetDeviceBuffer(); + args.seqstart_k_ptr = + has_group_k_padding && !seqstart_k_with_padding_host.empty() + ? seqstart_k_padded_buf.GetDeviceBuffer() + : seqstart_k.GetDeviceBuffer(); + + // Logical (unpadded) per-sequence lengths, used when padding is enabled + args.seqlen_q_ptr = + (has_group_q_padding && !seqstart_q_with_padding_host.empty()) + ? seqlen_q_buf.GetDeviceBuffer() + : nullptr; + args.seqlen_k_ptr = + (has_group_k_padding && !seqstart_k_with_padding_host.empty()) + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr; + // Cumulative lengths not used in group mode + args.cu_seqlen_q_ptr = nullptr; + args.cu_seqlen_k_ptr = nullptr; + } + else // mode == mode_enum::batch { - args.cu_seqlen_q_ptr = cuq_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_q_buf.GetDeviceBuffer()); - args.cu_seqlen_kv_ptr = cukv_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_kv_buf.GetDeviceBuffer()); + // Batch mode: use cumulative logical lengths for tail padding + + // seqstart pointers not used in batch mode + args.seqstart_q_ptr = nullptr; + args.seqstart_k_ptr = nullptr; + + // seqlen_q_ptr/seqlen_k_ptr not used in batch mode + args.seqlen_q_ptr = nullptr; + args.seqlen_k_ptr = nullptr; + + // Cumulative logical lengths for effective length handling + args.cu_seqlen_q_ptr = has_batch_q_padding && !cuq_cum.empty() + ? cu_seqlen_q_buf.GetDeviceBuffer() + : nullptr; + args.cu_seqlen_k_ptr = has_batch_k_padding && !cukv_cum.empty() + ? cu_seqlen_kv_buf.GetDeviceBuffer() + : nullptr; } } else if constexpr(std::is_same_v>) @@ -1153,6 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode, args.batch_stride_o_acc = batch_stride_o_acc; args.split_stride_lse_acc = split_stride_lse_acc; args.split_stride_o_acc = split_stride_o_acc; + + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = + ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); } else if constexpr(std::is_same_v>) { @@ -1164,6 +1187,15 @@ fwd_result fmha_fwd_run(mode_enum mode, args.cache_batch_idx = (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = + ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); } } }; @@ -1365,16 +1397,19 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); + // Use physical offset if padding info is valid (not -1) and buffers are available const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] - : seqstart_q_with_padding_host[wb])); + : ((seqstart_q_with_padding_host.empty() || seqlen_qpads[0] < 0) + ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 - : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] - : seqstart_k_with_padding_host[wb])); + : ((seqstart_k_with_padding_host.empty() || seqlen_kpads[0] < 0) + ? seqstart_k_host[wb] + : seqstart_k_with_padding_host[wb])); ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); @@ -1723,8 +1758,14 @@ fwd_result fmha_fwd_run(mode_enum mode, std::cerr << "OUT mismatch found at batch: " << wb << std::endl << "\tseqlen_q: " << real_seqlen_q << std::endl << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; + << "\tseqstart_q (logical): " << seqstart_q_host << std::endl + << "\tseqstart_q (physical): " << seqstart_q_with_padding_host + << std::endl + << "\tseqstart_k (logical): " << seqstart_k_host << std::endl + << "\tseqstart_k (physical): " << seqstart_k_with_padding_host + << std::endl + << "\tquery_offset used: " << query_offset << std::endl + << "\tkey_offset used: " << key_offset << std::endl; break; } @@ -1735,8 +1776,6 @@ fwd_result fmha_fwd_run(mode_enum mode, lse_host_result.ForEach([&](auto& self, auto idx) { self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); }); - std::cout << "lse_host_result shape: " << shape_batch << ", " << nhead << ", " - << shape_seqlen_q << std::endl; cur_pass = ck_tile::check_err(lse_host_result, lse_host_ref, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index e6cd9c0b7b..668fab3fd3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -313,8 +313,10 @@ struct FmhaBwdDQDKDVKernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; - const int32_t* seqlen_q_ptr; - const int32_t* seqlen_k_ptr; + const int32_t* seqlen_q_ptr; // per-batch actual length [batch] + const int32_t* seqlen_k_ptr; // per-batch actual length [batch] + const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional + const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional }; using Kargs = std::conditional_t; @@ -523,6 +525,8 @@ struct FmhaBwdDQDKDVKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* cu_seqlen_q_ptr, + const void* cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -597,7 +601,9 @@ struct FmhaBwdDQDKDVKernel reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), - reinterpret_cast(seqlen_k_ptr)}; + reinterpret_cast(seqlen_k_ptr), + reinterpret_cast(cu_seqlen_q_ptr), + reinterpret_cast(cu_seqlen_k_ptr)}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -739,13 +745,29 @@ struct FmhaBwdDQDKDVKernel batch_offset_randval = query_start * kargs.stride_randval; } - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - const ck_tile::index_t physical_seqlen_q = - adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - kargs.seqlen_q = kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q; + // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = + kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q; + } - if(kargs.seqlen_k_ptr != nullptr) + // Priority: cu_seqlen_k_ptr > seqlen_k_ptr > seqstart_k + if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } + else if(kargs.seqlen_k_ptr != nullptr) { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } @@ -1258,7 +1280,8 @@ struct FmhaBwdOGradDotOKernel struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs { const int32_t* seqstart_q_ptr; - const int32_t* seqlen_q_ptr; + const int32_t* seqlen_q_ptr; // per-batch actual length [batch] + const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional }; using Kargs = std:: @@ -1307,6 +1330,7 @@ struct FmhaBwdOGradDotOKernel float p_undrop, const void* seqstart_q_ptr, const void* seqlen_q_ptr, + const void* cu_seqlen_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, @@ -1326,7 +1350,8 @@ struct FmhaBwdOGradDotOKernel nhead_stride_o, nhead_stride_d}, reinterpret_cast(seqstart_q_ptr), - reinterpret_cast(seqlen_q_ptr)}; + reinterpret_cast(seqlen_q_ptr), + reinterpret_cast(cu_seqlen_q_ptr)}; return kargs; } @@ -1370,14 +1395,23 @@ struct FmhaBwdOGradDotOKernel batch_offset_do = query_start * kargs.stride_do; batch_offset_d = query_start; - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - const ck_tile::index_t physical_seqlen_q = - adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - const ck_tile::index_t logical_seqlen_q = - kargs.seqlen_q_ptr ? static_cast(kargs.seqlen_q_ptr[i_batch]) - : physical_seqlen_q; - kargs.seqlen_q = logical_seqlen_q; + // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = kargs.seqlen_q_ptr + ? static_cast(kargs.seqlen_q_ptr[i_batch]) + : physical_seqlen_q; + } + // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) @@ -1541,8 +1575,10 @@ struct FmhaBwdConvertQGradKernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; - const int32_t* seqlen_q_ptr; - const int32_t* seqlen_k_ptr; + const int32_t* seqlen_q_ptr; // per-batch actual length [batch] + const int32_t* seqlen_k_ptr; // per-batch actual length [batch] + const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional + const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional }; using Kargs = std::conditional_t(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), - reinterpret_cast(seqlen_k_ptr)}; + reinterpret_cast(seqlen_k_ptr), + reinterpret_cast(cu_seqlen_q_ptr), + reinterpret_cast(cu_seqlen_k_ptr)}; if constexpr(kIsDeterministic) { @@ -1658,22 +1698,41 @@ struct FmhaBwdConvertQGradKernel batch_offset_dq = query_start * kargs.stride_dq; batch_offset_dq_acc = query_start * kargs.stride_dq_acc; - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - const ck_tile::index_t physical_seqlen_q = - adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - const ck_tile::index_t logical_seqlen_q = - kargs.seqlen_q_ptr ? static_cast(kargs.seqlen_q_ptr[i_batch]) - : physical_seqlen_q; - kargs.seqlen_q = logical_seqlen_q; + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = kargs.seqlen_q_ptr + ? static_cast(kargs.seqlen_q_ptr[i_batch]) + : physical_seqlen_q; + } + if constexpr(kIsDeterministic) { const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; const ck_tile::index_t physical_seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; - kargs.seqlen_k = kargs.seqlen_k_ptr - ? static_cast(kargs.seqlen_k_ptr[i_batch]) - : physical_seqlen_k; + + // Priority: cu_seqlen_k_ptr > seqlen_k_ptr > physical_seqlen_k + if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } + else + { + kargs.seqlen_k = + kargs.seqlen_k_ptr + ? static_cast(kargs.seqlen_k_ptr[i_batch]) + : physical_seqlen_k; + } } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 1b2554d0a2..f539c9d7e9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -296,8 +296,8 @@ struct FmhaFwdKernel // Optional cumulative sequence length pointers for batch mode // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD + const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD + const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD }; struct FmhaFwdGroupModeKargs @@ -316,12 +316,12 @@ struct FmhaFwdKernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; const int32_t* seqlen_k_ptr; - // Optional cumulative padded sequence starts (including PAD tokens) - // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; - const int32_t* seqstart_padded_k_ptr = nullptr; + // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays + const int32_t* cu_seqlen_q_ptr = nullptr; + const int32_t* cu_seqlen_k_ptr = nullptr; }; using Kargs = std::conditional_t; @@ -379,8 +379,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -471,8 +471,8 @@ struct FmhaFwdKernel kargs.init_logits_soft_cap(logits_soft_cap); } - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -522,8 +522,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -570,7 +570,7 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_kv_ptr); + cu_seqlen_k_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -619,8 +619,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -667,7 +667,7 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_kv_ptr); + cu_seqlen_k_ptr); } template @@ -681,6 +681,7 @@ struct FmhaFwdKernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -711,8 +712,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -746,6 +747,7 @@ struct FmhaFwdKernel {}, // placeholder for min_seqlen_q reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -804,8 +806,8 @@ struct FmhaFwdKernel kargs.min_seqlen_q = min_seqlen_q; } - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -821,6 +823,7 @@ struct FmhaFwdKernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -850,8 +853,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -863,6 +866,7 @@ struct FmhaFwdKernel o_ptr, seqstart_q_ptr, seqstart_k_ptr, + seqlen_q_ptr, seqlen_k_ptr, hdim_q, hdim_v, @@ -892,8 +896,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - seqstart_padded_q_ptr, - seqstart_padded_k_ptr); + cu_seqlen_q_ptr, + cu_seqlen_k_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -908,6 +912,7 @@ struct FmhaFwdKernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -937,8 +942,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -950,6 +955,7 @@ struct FmhaFwdKernel o_ptr, seqstart_q_ptr, seqstart_k_ptr, + seqlen_q_ptr, seqlen_k_ptr, hdim_q, hdim_v, @@ -979,8 +985,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - seqstart_padded_q_ptr, - seqstart_padded_k_ptr); + cu_seqlen_q_ptr, + cu_seqlen_k_ptr); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1109,46 +1115,52 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { - // logical and physical (padded) starts - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // Use seqstart_q_ptr and seqstart_k_ptr for physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - // DRAM base offsets use physical padded starts - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; + // DRAM base offsets use physical starts + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { - batch_offset_v = key_start_padded; + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start_padded * kargs.stride_bias; + batch_offset_bias = query_start * kargs.stride_bias; } if constexpr(kStoreLSE) { - // LSE follows the padded layout to stay consistent with other tensors - batch_offset_lse = query_start_padded; + // LSE follows the physical layout to stay consistent with other tensors + batch_offset_lse = query_start; } if constexpr(kHasDropout) { - batch_offset_randval = query_start_padded * kargs.stride_randval; + batch_offset_randval = query_start * kargs.stride_randval; } - batch_offset_o = query_start_padded * kargs.stride_o; + batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + } if constexpr(kSkipMinSeqlenQ) { @@ -1168,6 +1180,11 @@ struct FmhaFwdKernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; @@ -1201,10 +1218,10 @@ struct FmhaFwdKernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } } @@ -1603,39 +1620,46 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for + // physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { // col-major V: offset along seqlen dimension is scalar index - batch_offset_v = key_start_padded; + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start_padded * kargs.stride_bias; + batch_offset_bias = query_start * kargs.stride_bias; } - // LSE layout is [nhead, total_seqlen] following the padded layout for Q/O - batch_offset_lse = query_start_padded; - batch_offset_o = query_start_padded * kargs.stride_o; + // LSE layout is [nhead, total_seqlen] following the physical layout for Q/O + batch_offset_lse = query_start; + batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode - kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + kargs.seqlen_q = + kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier @@ -1648,6 +1672,11 @@ struct FmhaFwdKernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { kargs.seqlen_k = @@ -1677,10 +1706,10 @@ struct FmhaFwdKernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } }