mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
[CK_TILE] fmha: Unify sequence length and padding handling
Refactor the handling of sequence lengths and padding in the FMHA forward and backward kernels to provide a more unified and flexible interface. - Replaced `seqstart_padded_*_ptr` with a more robust system that uses `seqstart_*_ptr` for physical sequence lengths and introduces `seqlen_*_ptr` and `cu_seqlen_*_ptr` for logical (unpadded) lengths. - Established a clear order of precedence for determining sequence length: cumulative lengths (`cu_seqlen_*_ptr`) take priority, followed by per-sequence lengths (`seqlen_*_ptr`), and finally physical lengths derived from `seqstart_*_ptr`. - Clarified the distinction between "group mode" and "batch mode" and how sequence lengths are handled in each case. - Renamed `cu_seqlen_kv_ptr` to `cu_seqlen_k_ptr` for consistency. - Updated comments and documentation to reflect the new argument structure and usage.
This commit is contained in:
@@ -269,14 +269,14 @@ class FmhaFwdApiTrait:
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.skpad == "t":
|
||||
return f"(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
|
||||
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
|
||||
else:
|
||||
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
elif self.pipeline_tag == "qr_async_trload":
|
||||
if self.skpad == "t":
|
||||
return "true"
|
||||
|
||||
@@ -114,10 +114,51 @@ struct fmha_bwd_args
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_q_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
// Usage notes for sequence length pointer parameters:
|
||||
//
|
||||
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
|
||||
// MQA/GQA..."]
|
||||
//
|
||||
// With padding:
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
|
||||
// lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
|
||||
// sequence. [array size: batch]
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
|
||||
// exclusive. Use one set, not both.
|
||||
//
|
||||
// Batch mode:
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqstart_* and seqlen_* pointers must be nullptr.
|
||||
//
|
||||
// Without padding:
|
||||
// (Note: Physical length equals logical length)
|
||||
//
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
|
||||
// size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
|
||||
//
|
||||
// Batch mode:
|
||||
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
|
||||
//
|
||||
const void* seqstart_q_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqstart_k_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
@@ -190,54 +231,57 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(
|
||||
args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
static_cast<const ck_tile::index_t*>(args.cu_seqlen_q_ptr),
|
||||
static_cast<const ck_tile::index_t*>(args.cu_seqlen_k_ptr),
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -318,6 +362,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
@@ -361,6 +406,8 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
|
||||
@@ -493,6 +493,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
seqlen_q_ptr_dev,
|
||||
seqlen_k_ptr_dev,
|
||||
nullptr,
|
||||
nullptr,
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
|
||||
@@ -182,19 +182,50 @@ struct fmha_fwd_args
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
// Optional cumulative sequence length arrays
|
||||
// Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD)
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
|
||||
// Group mode: seqstart_padded_* provide physical starts including PAD (optional)
|
||||
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
|
||||
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
|
||||
// Usage notes for sequence length pointer parameters:
|
||||
//
|
||||
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
|
||||
// MQA/GQA..."]
|
||||
//
|
||||
// With padding:
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
|
||||
// lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
|
||||
// sequence. [array size: batch]
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
|
||||
// exclusive. Use one set, not both.
|
||||
//
|
||||
// Batch mode:
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqstart_* and seqlen_* pointers must be nullptr.
|
||||
//
|
||||
// Without padding:
|
||||
// (Note: Physical length equals logical length)
|
||||
//
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
|
||||
// size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
|
||||
//
|
||||
// Batch mode:
|
||||
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
|
||||
//
|
||||
const void* seqstart_q_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqstart_k_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -555,6 +586,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
@@ -584,8 +616,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.seqstart_padded_q_ptr,
|
||||
args.seqstart_padded_k_ptr);
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -633,7 +665,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_kv_ptr);
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -313,16 +313,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
|
||||
|
||||
// Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv)
|
||||
const bool has_group_padding =
|
||||
(mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) ||
|
||||
(mode == mode_enum::group && (seqlen_kpads[0] >= 0));
|
||||
const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() ||
|
||||
!kv_eff_lens_per_batch.empty()));
|
||||
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
|
||||
const bool using_pagedkv = (0 < page_block_size);
|
||||
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
|
||||
const bool has_group_q_padding =
|
||||
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] > 0);
|
||||
const bool has_group_k_padding =
|
||||
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] > 0);
|
||||
const bool has_group_padding = has_group_q_padding || has_group_k_padding;
|
||||
const bool has_batch_q_padding = mode == mode_enum::batch && !q_eff_lens_per_batch.empty();
|
||||
const bool has_batch_k_padding = mode == mode_enum::batch && !kv_eff_lens_per_batch.empty();
|
||||
const bool has_batch_padding = has_batch_q_padding || has_batch_k_padding;
|
||||
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
|
||||
const bool using_pagedkv = (0 < page_block_size);
|
||||
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
|
||||
if((using_appendkv || using_pagedkv || using_splitkv) &&
|
||||
(has_group_padding || has_batch_efflens))
|
||||
(has_group_padding || has_batch_padding))
|
||||
{
|
||||
std::cerr << "Padding (physical or effective lengths) is not supported with "
|
||||
"appendkv/splitkv/pagedkv pipelines"
|
||||
@@ -399,23 +402,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_q_with_padding_host = to_seqstarts(seqlen_qpads);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
|
||||
// Optional padded Q seqstarts (group-mode only)
|
||||
std::vector<int32_t> seqstart_q_with_padding_host;
|
||||
if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1)
|
||||
{
|
||||
if(seqlen_qpads.size() < static_cast<size_t>(batch))
|
||||
{
|
||||
seqlen_qpads.resize(batch, seqlen_qpads.back());
|
||||
}
|
||||
if(seqlen_qpads.size() == static_cast<size_t>(batch))
|
||||
{
|
||||
seqstart_q_with_padding_host = to_seqstarts(
|
||||
ck_tile::span<const int32_t>(seqlen_qpads.data(), seqlen_qpads.size()));
|
||||
}
|
||||
}
|
||||
|
||||
// Optional batch-mode cumulative seqlen overrides
|
||||
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
|
||||
if(mode == mode_enum::batch)
|
||||
@@ -524,14 +513,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch
|
||||
? seqlen_qs[0]
|
||||
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back()
|
||||
: seqstart_q_with_padding_host.back()));
|
||||
(mode == mode_enum::batch ? seqlen_qs[0]
|
||||
: (has_group_q_padding && !seqstart_q_with_padding_host.empty()
|
||||
? seqstart_q_with_padding_host.back()
|
||||
: seqstart_q_host.back()));
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_ks[0]
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
: seqstart_k_with_padding_host.back()));
|
||||
: (has_group_k_padding && !seqstart_k_with_padding_host.empty()
|
||||
? seqstart_k_with_padding_host.back()
|
||||
: seqstart_k_host.back()));
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
@@ -689,14 +679,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k_padded_buf(
|
||||
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
|
||||
// Buffers for query per-sequence logical (unpadded) lengths (used in group mode with padding
|
||||
// enabled)
|
||||
ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0);
|
||||
// Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with
|
||||
// kvcache or group mode with padding enabled)
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
: 0);
|
||||
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
|
||||
: cuq_cum.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem cu_seqlen_kv_buf(
|
||||
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
|
||||
0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
: 0);
|
||||
ck_tile::DeviceMem cache_seqlen_k_buf(
|
||||
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
@@ -792,7 +786,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
: seqstart_k_with_padding_host.data());
|
||||
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
|
||||
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
|
||||
seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr);
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
|
||||
? seqlen_ks.data()
|
||||
: nullptr);
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
@@ -873,7 +868,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
print_vec("k_padded", seqlen_kpads);
|
||||
}
|
||||
}
|
||||
else if(has_batch_efflens)
|
||||
else if(has_batch_padding)
|
||||
{
|
||||
// derive effective lengths from cumulative arrays if present
|
||||
if(!cuq_cum.empty())
|
||||
@@ -1056,14 +1051,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.lse_ptr = lse_buf.GetDeviceBuffer();
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
@@ -1107,27 +1094,54 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
|
||||
// Group-mode: optional physical padded starts for Q/K
|
||||
// Sequence length and padding parameters (mode-specific)
|
||||
if(mode == mode_enum::group)
|
||||
{
|
||||
args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty()
|
||||
? nullptr
|
||||
: seqstart_q_padded_buf.GetDeviceBuffer());
|
||||
args.seqstart_padded_k_ptr =
|
||||
(seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer());
|
||||
}
|
||||
// Group mode: use physical (padded) cumulative starts + logical per-sequence
|
||||
// lengths
|
||||
|
||||
// Batch-mode: optional cumulative effective seqlen overrides
|
||||
if(mode == mode_enum::batch)
|
||||
// Physical cumulative starts (including padding)
|
||||
args.seqstart_q_ptr =
|
||||
has_group_q_padding && !seqstart_q_with_padding_host.empty()
|
||||
? seqstart_q_padded_buf.GetDeviceBuffer()
|
||||
: seqstart_q.GetDeviceBuffer();
|
||||
args.seqstart_k_ptr =
|
||||
has_group_k_padding && !seqstart_k_with_padding_host.empty()
|
||||
? seqstart_k_padded_buf.GetDeviceBuffer()
|
||||
: seqstart_k.GetDeviceBuffer();
|
||||
|
||||
// Logical (unpadded) per-sequence lengths, used when padding is enabled
|
||||
args.seqlen_q_ptr =
|
||||
(has_group_q_padding && !seqstart_q_with_padding_host.empty())
|
||||
? seqlen_q_buf.GetDeviceBuffer()
|
||||
: nullptr;
|
||||
args.seqlen_k_ptr =
|
||||
(has_group_k_padding && !seqstart_k_with_padding_host.empty())
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr;
|
||||
// Cumulative lengths not used in group mode
|
||||
args.cu_seqlen_q_ptr = nullptr;
|
||||
args.cu_seqlen_k_ptr = nullptr;
|
||||
}
|
||||
else // mode == mode_enum::batch
|
||||
{
|
||||
args.cu_seqlen_q_ptr = cuq_cum.empty()
|
||||
? nullptr
|
||||
: reinterpret_cast<const ck_tile::index_t*>(
|
||||
cu_seqlen_q_buf.GetDeviceBuffer());
|
||||
args.cu_seqlen_kv_ptr = cukv_cum.empty()
|
||||
? nullptr
|
||||
: reinterpret_cast<const ck_tile::index_t*>(
|
||||
cu_seqlen_kv_buf.GetDeviceBuffer());
|
||||
// Batch mode: use cumulative logical lengths for tail padding
|
||||
|
||||
// seqstart pointers not used in batch mode
|
||||
args.seqstart_q_ptr = nullptr;
|
||||
args.seqstart_k_ptr = nullptr;
|
||||
|
||||
// seqlen_q_ptr/seqlen_k_ptr not used in batch mode
|
||||
args.seqlen_q_ptr = nullptr;
|
||||
args.seqlen_k_ptr = nullptr;
|
||||
|
||||
// Cumulative logical lengths for effective length handling
|
||||
args.cu_seqlen_q_ptr = has_batch_q_padding && !cuq_cum.empty()
|
||||
? cu_seqlen_q_buf.GetDeviceBuffer()
|
||||
: nullptr;
|
||||
args.cu_seqlen_k_ptr = has_batch_k_padding && !cukv_cum.empty()
|
||||
? cu_seqlen_kv_buf.GetDeviceBuffer()
|
||||
: nullptr;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
@@ -1153,6 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.batch_stride_o_acc = batch_stride_o_acc;
|
||||
args.split_stride_lse_acc = split_stride_lse_acc;
|
||||
args.split_stride_o_acc = split_stride_o_acc;
|
||||
|
||||
args.seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr =
|
||||
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
@@ -1164,6 +1187,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
args.cache_batch_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
||||
|
||||
args.seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr =
|
||||
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1365,16 +1397,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t cache_b_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
|
||||
// Use physical offset if padding info is valid (not -1) and buffers are available
|
||||
const ck_tile::index_t query_offset =
|
||||
(mode == mode_enum::batch
|
||||
? 0
|
||||
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb]
|
||||
: seqstart_q_with_padding_host[wb]));
|
||||
: ((seqstart_q_with_padding_host.empty() || seqlen_qpads[0] < 0)
|
||||
? seqstart_q_host[wb]
|
||||
: seqstart_q_with_padding_host[wb]));
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch
|
||||
? 0
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb]
|
||||
: seqstart_k_with_padding_host[wb]));
|
||||
: ((seqstart_k_with_padding_host.empty() || seqlen_kpads[0] < 0)
|
||||
? seqstart_k_host[wb]
|
||||
: seqstart_k_with_padding_host[wb]));
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
||||
@@ -1723,8 +1758,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
<< "\tseqstart_q (logical): " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_q (physical): " << seqstart_q_with_padding_host
|
||||
<< std::endl
|
||||
<< "\tseqstart_k (logical): " << seqstart_k_host << std::endl
|
||||
<< "\tseqstart_k (physical): " << seqstart_k_with_padding_host
|
||||
<< std::endl
|
||||
<< "\tquery_offset used: " << query_offset << std::endl
|
||||
<< "\tkey_offset used: " << key_offset << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
@@ -1735,8 +1776,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
std::cout << "lse_host_result shape: " << shape_batch << ", " << nhead << ", "
|
||||
<< shape_seqlen_q << std::endl;
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
|
||||
@@ -313,8 +313,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
|
||||
const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
|
||||
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
|
||||
const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
|
||||
@@ -523,6 +525,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
const void* cu_seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -597,7 +601,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -739,13 +745,29 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
|
||||
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q =
|
||||
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
|
||||
}
|
||||
|
||||
if(kargs.seqlen_k_ptr != nullptr)
|
||||
// Priority: cu_seqlen_k_ptr > seqlen_k_ptr > seqstart_k
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
@@ -1258,7 +1280,8 @@ struct FmhaBwdOGradDotOKernel
|
||||
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
|
||||
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
|
||||
};
|
||||
|
||||
using Kargs = std::
|
||||
@@ -1307,6 +1330,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
float p_undrop,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
@@ -1326,7 +1350,8 @@ struct FmhaBwdOGradDotOKernel
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -1370,14 +1395,23 @@ struct FmhaBwdOGradDotOKernel
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_d = query_start;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
const ck_tile::index_t logical_seqlen_q =
|
||||
kargs.seqlen_q_ptr ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
|
||||
: physical_seqlen_q;
|
||||
kargs.seqlen_q = logical_seqlen_q;
|
||||
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr
|
||||
? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
|
||||
: physical_seqlen_q;
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
@@ -1541,8 +1575,10 @@ struct FmhaBwdConvertQGradKernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
|
||||
const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
|
||||
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
|
||||
const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode,
|
||||
@@ -1593,6 +1629,8 @@ struct FmhaBwdConvertQGradKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
const void* cu_seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t stride_dq,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
@@ -1613,7 +1651,9 @@ struct FmhaBwdConvertQGradKernel
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
@@ -1658,22 +1698,41 @@ struct FmhaBwdConvertQGradKernel
|
||||
batch_offset_dq = query_start * kargs.stride_dq;
|
||||
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
const ck_tile::index_t logical_seqlen_q =
|
||||
kargs.seqlen_q_ptr ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
|
||||
: physical_seqlen_q;
|
||||
kargs.seqlen_q = logical_seqlen_q;
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr
|
||||
? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
|
||||
: physical_seqlen_q;
|
||||
}
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_k =
|
||||
adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr
|
||||
? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
|
||||
: physical_seqlen_k;
|
||||
|
||||
// Priority: cu_seqlen_k_ptr > seqlen_k_ptr > physical_seqlen_k
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.seqlen_k_ptr
|
||||
? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
|
||||
: physical_seqlen_k;
|
||||
}
|
||||
}
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
|
||||
@@ -296,8 +296,8 @@ struct FmhaFwdKernel
|
||||
|
||||
// Optional cumulative sequence length pointers for batch mode
|
||||
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
|
||||
const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
|
||||
const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupModeKargs
|
||||
@@ -316,12 +316,12 @@ struct FmhaFwdKernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
// Optional cumulative padded sequence starts (including PAD tokens)
|
||||
// Used solely to compute memory offsets when sequences are physically padded.
|
||||
const int32_t* seqstart_padded_q_ptr = nullptr;
|
||||
const int32_t* seqstart_padded_k_ptr = nullptr;
|
||||
// Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
|
||||
const int32_t* cu_seqlen_q_ptr = nullptr;
|
||||
const int32_t* cu_seqlen_k_ptr = nullptr;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -379,8 +379,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -471,8 +471,8 @@ struct FmhaFwdKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
|
||||
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
|
||||
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
|
||||
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -522,8 +522,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -570,7 +570,7 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_kv_ptr);
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -619,8 +619,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -667,7 +667,7 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_kv_ptr);
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
@@ -681,6 +681,7 @@ struct FmhaFwdKernel
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -711,8 +712,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -746,6 +747,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for min_seqlen_q
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
@@ -804,8 +806,8 @@ struct FmhaFwdKernel
|
||||
kargs.min_seqlen_q = min_seqlen_q;
|
||||
}
|
||||
|
||||
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
|
||||
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
|
||||
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
|
||||
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -821,6 +823,7 @@ struct FmhaFwdKernel
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -850,8 +853,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -863,6 +866,7 @@ struct FmhaFwdKernel
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
@@ -892,8 +896,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
seqstart_padded_q_ptr,
|
||||
seqstart_padded_k_ptr);
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -908,6 +912,7 @@ struct FmhaFwdKernel
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -937,8 +942,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -950,6 +955,7 @@ struct FmhaFwdKernel
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
@@ -979,8 +985,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
seqstart_padded_q_ptr,
|
||||
seqstart_padded_k_ptr);
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
@@ -1109,46 +1115,52 @@ struct FmhaFwdKernel
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// logical and physical (padded) starts
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
// Use seqstart_q_ptr and seqstart_k_ptr for physical starts
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
// DRAM base offsets use physical padded starts
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
// DRAM base offsets use physical starts
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start_padded;
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
// LSE follows the padded layout to stay consistent with other tensors
|
||||
batch_offset_lse = query_start_padded;
|
||||
// LSE follows the physical layout to stay consistent with other tensors
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
batch_offset_randval = query_start_padded * kargs.stride_randval;
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
// Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
|
||||
if(kargs.seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
}
|
||||
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
@@ -1168,6 +1180,11 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
@@ -1201,10 +1218,10 @@ struct FmhaFwdKernel
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1603,39 +1620,46 @@ struct FmhaFwdKernel
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
// get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
|
||||
// physical starts
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
// col-major V: offset along seqlen dimension is scalar index
|
||||
batch_offset_v = key_start_padded;
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
}
|
||||
|
||||
// LSE layout is [nhead, total_seqlen] following the padded layout for Q/O
|
||||
batch_offset_lse = query_start_padded;
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
// LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
|
||||
batch_offset_lse = query_start;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
if(kargs.seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
@@ -1648,6 +1672,11 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
@@ -1677,10 +1706,10 @@ struct FmhaFwdKernel
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user