Merge commit '7c6430eca04e62454217630ae2a0bbd70ff50a00' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-29 07:13:01 +00:00
parent 7b759ce7e9
commit e571490afc
11 changed files with 1292 additions and 214 deletions

View File

@@ -269,14 +269,14 @@ class FmhaFwdApiTrait:
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag == "qr_async_trload":
if self.skpad == "t":
return "true"

View File

@@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"padded seqlen_q per batch (group mode only). "
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
.insert("s_k",
"-1",
"seqlen_k, -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_kpad",
"-1",
"padded seqlen_k per batch (group mode only). "
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
@@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
@@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,

View File

@@ -114,9 +114,51 @@ struct fmha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -203,7 +245,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -315,6 +360,8 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.cu_seqlen_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
@@ -356,6 +403,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,

View File

@@ -65,6 +65,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::index_t nhead_k,
std::vector<ck_tile::index_t> seqlen_qs,
std::vector<ck_tile::index_t> seqlen_ks,
std::vector<ck_tile::index_t> seqlen_qpads,
std::vector<ck_tile::index_t> seqlen_kpads,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
bool i_perm,
@@ -119,13 +121,26 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
return bwd_result::invalid_args;
}
std::vector<ck_tile::index_t> seqlen_kpads;
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine);
ck_tile::ignore = seqlen_kpads;
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens(
mode, batch, seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads, 0, false, random_engine);
bool use_qpadding =
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1);
bool use_kpadding =
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] != -1);
#if 0
std::cout << "use_qpadding: " << use_qpadding << std::endl;
std::cout << "use_kpadding: " << use_kpadding << std::endl;
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
if (use_qpadding) {
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
}
if (use_kpadding) {
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
}
#endif
mask_info mask = mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]);
@@ -146,8 +161,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
s_randval = true;
}
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_q_host =
(use_qpadding ? to_seqstarts(seqlen_qpads) : to_seqstarts(seqlen_qs));
const auto seqstart_k_host =
(use_kpadding ? to_seqstarts(seqlen_kpads) : to_seqstarts(seqlen_ks));
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
@@ -176,8 +193,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
{
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// When padding is enabled, use logical lengths for flop/bandwidth calculation
const int32_t real_seqlen_q =
use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]);
const int32_t real_seqlen_k =
use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]);
if(max_seqlen_q < real_seqlen_q)
{
@@ -336,6 +356,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_q_dev(mode == mode_enum::batch ? 0
: seqlen_qs.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_dev(mode == mode_enum::batch ? 0
: seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
@@ -349,6 +373,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
if(mode == mode_enum::group)
{
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
seqlen_q_dev.ToDevice(seqlen_q_host.data());
std::vector<int32_t> seqlen_k_host(seqlen_ks.begin(), seqlen_ks.end());
seqlen_k_dev.ToDevice(seqlen_k_host.data());
}
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
@@ -440,6 +471,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
}
}();
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
@@ -457,6 +491,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
seqlen_q_ptr_dev,
seqlen_k_ptr_dev,
nullptr,
nullptr,
shape_seqlen_q,
shape_seqlen_k,
@@ -551,8 +588,18 @@ bwd_result fmha_bwd_run(mode_enum mode,
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// When padding is enabled, use logical lengths instead of computing from padded
// prefix-sum
const ck_tile::index_t real_seqlen_q =
use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]);
const ck_tile::index_t real_seqlen_k =
use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]);
// Skip forward reference computation for batches with zero length sequences
if(real_seqlen_q == 0 || real_seqlen_k == 0)
{
continue;
}
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
@@ -797,10 +844,23 @@ bwd_result fmha_bwd_run(mode_enum mode,
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
// Track the index into reference vectors (may differ from wb if batches were skipped)
ck_tile::index_t ref_idx = 0;
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// When padding is enabled, use logical lengths instead of computing from padded
// prefix-sum
const ck_tile::index_t real_seqlen_q =
use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]);
const ck_tile::index_t real_seqlen_k =
use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]);
// Skip validation for batches with zero length sequences
if(real_seqlen_q == 0 || real_seqlen_k == 0)
{
continue;
}
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
@@ -833,14 +893,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
// dP = dO@V x Z w/ dropout
// dP = dO@V w/o dropout
auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
auto v_t_host_ref = v_host_refs[ref_idx].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
ck_tile::reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
if(p_drop > 0)
{
ck_tile::reference_batched_dropout(
dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop);
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, rp_undrop);
}
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
@@ -849,11 +909,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
do_dot_o +=
ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o));
}
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
ds_hp_host_ref(i0, i1, i2) =
ck_tile::type_convert<AccDataType>(p_hp_host_refs[ref_idx](i0, i1, i2) *
(dp_hp_host_ref(i0, i1, i2) - do_dot_o));
},
ds_hp_host_ref.mDesc.get_lengths()[0],
ds_hp_host_ref.mDesc.get_lengths()[1],
@@ -869,14 +931,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
// dV = P_drop^T@dO^T
// dV = P^T@dO^T w/o dropout
auto p_t_lp_host_ref =
p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
p_lp_host_refs[ref_idx].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
ck_tile::
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
// dQ = scale * dS@K^T
auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
auto k_t_host_ref = k_host_refs[ref_idx].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
ck_tile::reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
ds_lp_host_ref,
k_t_host_ref,
@@ -886,8 +948,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
// dK = scale * dS^T@Q^T
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
auto q_t_host_ref = q_host_refs[ref_idx].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
ck_tile::reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
ds_t_lp_host_ref,
q_t_host_ref,
@@ -961,6 +1023,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
break;
}
// Increment reference vector index for successfully validated batches
ref_idx++;
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;

View File

@@ -182,19 +182,50 @@ struct fmha_fwd_args
void* lse_ptr;
void* o_ptr;
// Optional cumulative sequence length arrays
// Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD)
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
// Group mode: seqstart_padded_* provide physical starts including PAD (optional)
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -555,6 +586,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
@@ -584,8 +616,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.seqstart_padded_q_ptr,
args.seqstart_padded_k_ptr);
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
else
{ // create batch mode kernel arguments
@@ -633,7 +665,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
args.cu_seqlen_k_ptr);
}
}();

View File

@@ -313,16 +313,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
// Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv)
const bool has_group_padding =
(mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) ||
(mode == mode_enum::group && (seqlen_kpads[0] >= 0));
const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() ||
!kv_eff_lens_per_batch.empty()));
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
const bool using_pagedkv = (0 < page_block_size);
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
const bool has_group_q_padding =
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] > 0);
const bool has_group_k_padding =
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] > 0);
const bool has_group_padding = has_group_q_padding || has_group_k_padding;
const bool has_batch_q_padding = mode == mode_enum::batch && !q_eff_lens_per_batch.empty();
const bool has_batch_k_padding = mode == mode_enum::batch && !kv_eff_lens_per_batch.empty();
const bool has_batch_padding = has_batch_q_padding || has_batch_k_padding;
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
const bool using_pagedkv = (0 < page_block_size);
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
if((using_appendkv || using_pagedkv || using_splitkv) &&
(has_group_padding || has_batch_efflens))
(has_group_padding || has_batch_padding))
{
std::cerr << "Padding (physical or effective lengths) is not supported with "
"appendkv/splitkv/pagedkv pipelines"
@@ -330,11 +333,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
return fwd_result::invalid_args;
}
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) =
generate_missing_seqlens(mode,
batch,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
need_append_kvcache,
@@ -346,7 +350,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl;
return fwd_result::invalid_args;
}
if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb])
{
std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl;
return fwd_result::invalid_args;
}
}
// compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(),
@@ -357,6 +367,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
#if 0
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl;
#endif
@@ -391,23 +402,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_q_with_padding_host = to_seqstarts(seqlen_qpads);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
// Optional padded Q seqstarts (group-mode only)
std::vector<int32_t> seqstart_q_with_padding_host;
if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1)
{
if(seqlen_qpads.size() < static_cast<size_t>(batch))
{
seqlen_qpads.resize(batch, seqlen_qpads.back());
}
if(seqlen_qpads.size() == static_cast<size_t>(batch))
{
seqstart_q_with_padding_host = to_seqstarts(
ck_tile::span<const int32_t>(seqlen_qpads.data(), seqlen_qpads.size()));
}
}
// Optional batch-mode cumulative seqlen overrides
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
if(mode == mode_enum::batch)
@@ -514,19 +511,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
const ck_tile::index_t shape_seqlen_q_lse =
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch
? seqlen_qs[0]
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back()
: seqstart_q_with_padding_host.back()));
(mode == mode_enum::batch ? seqlen_qs[0]
: (has_group_q_padding && !seqstart_q_with_padding_host.empty()
? seqstart_q_with_padding_host.back()
: seqstart_q_host.back()));
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0]
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
: (has_group_k_padding && !seqstart_k_with_padding_host.empty()
? seqstart_k_with_padding_host.back()
: seqstart_k_host.back()));
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
@@ -580,7 +575,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
@@ -684,14 +679,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
sizeof(int32_t));
ck_tile::DeviceMem seqstart_k_padded_buf(
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
// Buffers for query per-sequence logical (unpadded) lengths (used in group mode with padding
// enabled)
ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0);
// Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with
// kvcache or group mode with padding enabled)
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
: cuq_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem cu_seqlen_kv_buf(
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
@@ -787,7 +786,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
: seqstart_k_with_padding_host.data());
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr);
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
? seqlen_ks.data()
: nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
@@ -868,7 +868,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
print_vec("k_padded", seqlen_kpads);
}
}
else if(has_batch_efflens)
else if(has_batch_padding)
{
// derive effective lengths from cumulative arrays if present
if(!cuq_cum.empty())
@@ -970,8 +970,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
@@ -986,8 +986,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
@@ -1051,14 +1051,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.lse_ptr = lse_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
args.max_seqlen_q = max_seqlen_q;
@@ -1102,27 +1094,54 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
// Group-mode: optional physical padded starts for Q/K
// Sequence length and padding parameters (mode-specific)
if(mode == mode_enum::group)
{
args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty()
? nullptr
: seqstart_q_padded_buf.GetDeviceBuffer());
args.seqstart_padded_k_ptr =
(seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer());
}
// Group mode: use physical (padded) cumulative starts + logical per-sequence
// lengths
// Batch-mode: optional cumulative effective seqlen overrides
if(mode == mode_enum::batch)
// Physical cumulative starts (including padding)
args.seqstart_q_ptr =
has_group_q_padding && !seqstart_q_with_padding_host.empty()
? seqstart_q_padded_buf.GetDeviceBuffer()
: seqstart_q.GetDeviceBuffer();
args.seqstart_k_ptr =
has_group_k_padding && !seqstart_k_with_padding_host.empty()
? seqstart_k_padded_buf.GetDeviceBuffer()
: seqstart_k.GetDeviceBuffer();
// Logical (unpadded) per-sequence lengths, used when padding is enabled
args.seqlen_q_ptr =
(has_group_q_padding && !seqstart_q_with_padding_host.empty())
? seqlen_q_buf.GetDeviceBuffer()
: nullptr;
args.seqlen_k_ptr =
(has_group_k_padding && !seqstart_k_with_padding_host.empty())
? seqlen_k_buf.GetDeviceBuffer()
: nullptr;
// Cumulative lengths not used in group mode
args.cu_seqlen_q_ptr = nullptr;
args.cu_seqlen_k_ptr = nullptr;
}
else // mode == mode_enum::batch
{
args.cu_seqlen_q_ptr = cuq_cum.empty()
? nullptr
: reinterpret_cast<const ck_tile::index_t*>(
cu_seqlen_q_buf.GetDeviceBuffer());
args.cu_seqlen_kv_ptr = cukv_cum.empty()
? nullptr
: reinterpret_cast<const ck_tile::index_t*>(
cu_seqlen_kv_buf.GetDeviceBuffer());
// Batch mode: use cumulative logical lengths for tail padding
// seqstart pointers not used in batch mode
args.seqstart_q_ptr = nullptr;
args.seqstart_k_ptr = nullptr;
// seqlen_q_ptr/seqlen_k_ptr not used in batch mode
args.seqlen_q_ptr = nullptr;
args.seqlen_k_ptr = nullptr;
// Cumulative logical lengths for effective length handling
args.cu_seqlen_q_ptr = has_batch_q_padding && !cuq_cum.empty()
? cu_seqlen_q_buf.GetDeviceBuffer()
: nullptr;
args.cu_seqlen_k_ptr = has_batch_k_padding && !cukv_cum.empty()
? cu_seqlen_kv_buf.GetDeviceBuffer()
: nullptr;
}
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
@@ -1148,6 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
}
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_args, std::decay_t<decltype(args)>>)
{
@@ -1159,6 +1187,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
}
}
};
@@ -1360,16 +1397,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t cache_b_idx =
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
// Use physical offset if padding info is valid (not -1) and buffers are available
const ck_tile::index_t query_offset =
(mode == mode_enum::batch
? 0
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb]
: seqstart_q_with_padding_host[wb]));
: ((seqstart_q_with_padding_host.empty() || seqlen_qpads[0] < 0)
? seqstart_q_host[wb]
: seqstart_q_with_padding_host[wb]));
const ck_tile::index_t key_offset =
(mode == mode_enum::batch
? 0
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb]
: seqstart_k_with_padding_host[wb]));
: ((seqstart_k_with_padding_host.empty() || seqlen_kpads[0] < 0)
? seqstart_k_host[wb]
: seqstart_k_with_padding_host[wb]));
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
@@ -1718,8 +1758,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
<< "\tseqstart_q (logical): " << seqstart_q_host << std::endl
<< "\tseqstart_q (physical): " << seqstart_q_with_padding_host
<< std::endl
<< "\tseqstart_k (logical): " << seqstart_k_host << std::endl
<< "\tseqstart_k (physical): " << seqstart_k_with_padding_host
<< std::endl
<< "\tquery_offset used: " << query_offset << std::endl
<< "\tkey_offset used: " << key_offset << std::endl;
break;
}
@@ -1727,10 +1773,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
const ck_tile::index_t query_offset_lse =
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
});
cur_pass = ck_tile::check_err(lse_host_result,

View File

@@ -142,12 +142,14 @@ auto randints(ForwardIterator first,
*/
template <typename RandomEngine>
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
generate_missing_seqlens(mode_enum mode,
ck_tile::index_t batch,
const std::vector<ck_tile::index_t>& q_val,
const std::vector<ck_tile::index_t>& k_val,
const std::vector<ck_tile::index_t>& q_pad_val,
const std::vector<ck_tile::index_t>& k_pad_val,
ck_tile::index_t seqlen_k_min,
bool need_append_kvcache,
@@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode,
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
@@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode,
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
else
{
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
std::vector<ck_tile::index_t> s_qpad;
ck_tile::index_t idx = 0;
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
{
@@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode,
? -1
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
ck_tile::index_t qp =
q_pad_val.empty()
? -1
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
s_qpad.push_back(qp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
@@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode,
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
}