[CK_TILE] fmha: Unify sequence length and padding handling

Refactor the handling of sequence lengths and padding in the
FMHA forward and backward kernels to provide a more unified and flexible
interface.

- Replaced `seqstart_padded_*_ptr` with a more robust system that uses
  `seqstart_*_ptr` for physical sequence lengths and introduces
  `seqlen_*_ptr` and `cu_seqlen_*_ptr` for logical (unpadded) lengths.
- Established a clear order of precedence for determining sequence
  length: cumulative lengths (`cu_seqlen_*_ptr`) take priority,
  followed by per-sequence lengths (`seqlen_*_ptr`), and finally
  physical lengths derived from `seqstart_*_ptr`.
- Clarified the distinction between "group mode" and "batch mode" and
  how sequence lengths are handled in each case.
- Renamed `cu_seqlen_kv_ptr` to `cu_seqlen_k_ptr` for consistency.
- Updated comments and documentation to reflect the new argument
  structure and usage.
This commit is contained in:
Jeff Huang
2025-10-24 15:46:06 +08:00
parent eeffd2717a
commit 1fe3c20ef2
7 changed files with 454 additions and 246 deletions

View File

@@ -269,14 +269,14 @@ class FmhaFwdApiTrait:
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag == "qr_async_trload":
if self.skpad == "t":
return "true"

View File

@@ -114,10 +114,51 @@ struct fmha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_q_ptr;
const void* seqlen_k_ptr;
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -190,54 +231,57 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.drop_seed_offset);
return FmhaBwdDQDKDVKernel::MakeKargsImpl(
args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
static_cast<const ck_tile::index_t*>(args.cu_seqlen_q_ptr),
static_cast<const ck_tile::index_t*>(args.cu_seqlen_k_ptr),
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
@@ -318,6 +362,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.cu_seqlen_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
@@ -361,6 +406,8 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,

View File

@@ -493,6 +493,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
seqstart_k.GetDeviceBuffer(),
seqlen_q_ptr_dev,
seqlen_k_ptr_dev,
nullptr,
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,

View File

@@ -182,19 +182,50 @@ struct fmha_fwd_args
void* lse_ptr;
void* o_ptr;
// Optional cumulative sequence length arrays
// Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD)
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
// Group mode: seqstart_padded_* provide physical starts including PAD (optional)
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -555,6 +586,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
@@ -584,8 +616,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.seqstart_padded_q_ptr,
args.seqstart_padded_k_ptr);
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
else
{ // create batch mode kernel arguments
@@ -633,7 +665,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
args.cu_seqlen_k_ptr);
}
}();

View File

@@ -313,16 +313,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
// Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv)
const bool has_group_padding =
(mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) ||
(mode == mode_enum::group && (seqlen_kpads[0] >= 0));
const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() ||
!kv_eff_lens_per_batch.empty()));
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
const bool using_pagedkv = (0 < page_block_size);
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
const bool has_group_q_padding =
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] > 0);
const bool has_group_k_padding =
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] > 0);
const bool has_group_padding = has_group_q_padding || has_group_k_padding;
const bool has_batch_q_padding = mode == mode_enum::batch && !q_eff_lens_per_batch.empty();
const bool has_batch_k_padding = mode == mode_enum::batch && !kv_eff_lens_per_batch.empty();
const bool has_batch_padding = has_batch_q_padding || has_batch_k_padding;
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
const bool using_pagedkv = (0 < page_block_size);
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
if((using_appendkv || using_pagedkv || using_splitkv) &&
(has_group_padding || has_batch_efflens))
(has_group_padding || has_batch_padding))
{
std::cerr << "Padding (physical or effective lengths) is not supported with "
"appendkv/splitkv/pagedkv pipelines"
@@ -399,23 +402,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_q_with_padding_host = to_seqstarts(seqlen_qpads);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
// Optional padded Q seqstarts (group-mode only)
std::vector<int32_t> seqstart_q_with_padding_host;
if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1)
{
if(seqlen_qpads.size() < static_cast<size_t>(batch))
{
seqlen_qpads.resize(batch, seqlen_qpads.back());
}
if(seqlen_qpads.size() == static_cast<size_t>(batch))
{
seqstart_q_with_padding_host = to_seqstarts(
ck_tile::span<const int32_t>(seqlen_qpads.data(), seqlen_qpads.size()));
}
}
// Optional batch-mode cumulative seqlen overrides
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
if(mode == mode_enum::batch)
@@ -524,14 +513,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch
? seqlen_qs[0]
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back()
: seqstart_q_with_padding_host.back()));
(mode == mode_enum::batch ? seqlen_qs[0]
: (has_group_q_padding && !seqstart_q_with_padding_host.empty()
? seqstart_q_with_padding_host.back()
: seqstart_q_host.back()));
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0]
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
: (has_group_k_padding && !seqstart_k_with_padding_host.empty()
? seqstart_k_with_padding_host.back()
: seqstart_k_host.back()));
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
@@ -689,14 +679,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
sizeof(int32_t));
ck_tile::DeviceMem seqstart_k_padded_buf(
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
// Buffers for query per-sequence logical (unpadded) lengths (used in group mode with padding
// enabled)
ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0);
// Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with
// kvcache or group mode with padding enabled)
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
: cuq_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem cu_seqlen_kv_buf(
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
@@ -792,7 +786,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
: seqstart_k_with_padding_host.data());
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr);
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
? seqlen_ks.data()
: nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
@@ -873,7 +868,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
print_vec("k_padded", seqlen_kpads);
}
}
else if(has_batch_efflens)
else if(has_batch_padding)
{
// derive effective lengths from cumulative arrays if present
if(!cuq_cum.empty())
@@ -1056,14 +1051,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.lse_ptr = lse_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
args.max_seqlen_q = max_seqlen_q;
@@ -1107,27 +1094,54 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
// Group-mode: optional physical padded starts for Q/K
// Sequence length and padding parameters (mode-specific)
if(mode == mode_enum::group)
{
args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty()
? nullptr
: seqstart_q_padded_buf.GetDeviceBuffer());
args.seqstart_padded_k_ptr =
(seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer());
}
// Group mode: use physical (padded) cumulative starts + logical per-sequence
// lengths
// Batch-mode: optional cumulative effective seqlen overrides
if(mode == mode_enum::batch)
// Physical cumulative starts (including padding)
args.seqstart_q_ptr =
has_group_q_padding && !seqstart_q_with_padding_host.empty()
? seqstart_q_padded_buf.GetDeviceBuffer()
: seqstart_q.GetDeviceBuffer();
args.seqstart_k_ptr =
has_group_k_padding && !seqstart_k_with_padding_host.empty()
? seqstart_k_padded_buf.GetDeviceBuffer()
: seqstart_k.GetDeviceBuffer();
// Logical (unpadded) per-sequence lengths, used when padding is enabled
args.seqlen_q_ptr =
(has_group_q_padding && !seqstart_q_with_padding_host.empty())
? seqlen_q_buf.GetDeviceBuffer()
: nullptr;
args.seqlen_k_ptr =
(has_group_k_padding && !seqstart_k_with_padding_host.empty())
? seqlen_k_buf.GetDeviceBuffer()
: nullptr;
// Cumulative lengths not used in group mode
args.cu_seqlen_q_ptr = nullptr;
args.cu_seqlen_k_ptr = nullptr;
}
else // mode == mode_enum::batch
{
args.cu_seqlen_q_ptr = cuq_cum.empty()
? nullptr
: reinterpret_cast<const ck_tile::index_t*>(
cu_seqlen_q_buf.GetDeviceBuffer());
args.cu_seqlen_kv_ptr = cukv_cum.empty()
? nullptr
: reinterpret_cast<const ck_tile::index_t*>(
cu_seqlen_kv_buf.GetDeviceBuffer());
// Batch mode: use cumulative logical lengths for tail padding
// seqstart pointers not used in batch mode
args.seqstart_q_ptr = nullptr;
args.seqstart_k_ptr = nullptr;
// seqlen_q_ptr/seqlen_k_ptr not used in batch mode
args.seqlen_q_ptr = nullptr;
args.seqlen_k_ptr = nullptr;
// Cumulative logical lengths for effective length handling
args.cu_seqlen_q_ptr = has_batch_q_padding && !cuq_cum.empty()
? cu_seqlen_q_buf.GetDeviceBuffer()
: nullptr;
args.cu_seqlen_k_ptr = has_batch_k_padding && !cukv_cum.empty()
? cu_seqlen_kv_buf.GetDeviceBuffer()
: nullptr;
}
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
@@ -1153,6 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
}
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_args, std::decay_t<decltype(args)>>)
{
@@ -1164,6 +1187,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
}
}
};
@@ -1365,16 +1397,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t cache_b_idx =
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
// Use physical offset if padding info is valid (not -1) and buffers are available
const ck_tile::index_t query_offset =
(mode == mode_enum::batch
? 0
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb]
: seqstart_q_with_padding_host[wb]));
: ((seqstart_q_with_padding_host.empty() || seqlen_qpads[0] < 0)
? seqstart_q_host[wb]
: seqstart_q_with_padding_host[wb]));
const ck_tile::index_t key_offset =
(mode == mode_enum::batch
? 0
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb]
: seqstart_k_with_padding_host[wb]));
: ((seqstart_k_with_padding_host.empty() || seqlen_kpads[0] < 0)
? seqstart_k_host[wb]
: seqstart_k_with_padding_host[wb]));
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
@@ -1723,8 +1758,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
<< "\tseqstart_q (logical): " << seqstart_q_host << std::endl
<< "\tseqstart_q (physical): " << seqstart_q_with_padding_host
<< std::endl
<< "\tseqstart_k (logical): " << seqstart_k_host << std::endl
<< "\tseqstart_k (physical): " << seqstart_k_with_padding_host
<< std::endl
<< "\tquery_offset used: " << query_offset << std::endl
<< "\tkey_offset used: " << key_offset << std::endl;
break;
}
@@ -1735,8 +1776,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
});
std::cout << "lse_host_result shape: " << shape_batch << ", " << nhead << ", "
<< shape_seqlen_q << std::endl;
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,