mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Revert "[CK_TILE] Add sequence padding and variable length support in fmha (a…" (#2883)
This reverts commit 86dd59cd01.
This commit is contained in:
@@ -259,11 +259,11 @@ class FmhaFwdApiTrait:
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)'
|
||||
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag == 'qr_async_trload':
|
||||
if self.skpad == 't' : return 'true'
|
||||
else: return 'true'
|
||||
|
||||
@@ -33,10 +33,6 @@ auto create_args(int argc, char* argv[])
|
||||
"0",
|
||||
"seqlen_k for new key/value, 0 means not to use this at all; "
|
||||
"-1 to choose s_knew in [1, s] randomly.")
|
||||
.insert("s_qpad",
|
||||
"-1",
|
||||
"seqlen_q stride between 2 batches (group-mode optional).\n"
|
||||
"Provide positive strides per-batch to simulate physical padding on Q.")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
|
||||
@@ -111,15 +107,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results")
|
||||
.insert("q_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -139,9 +127,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
|
||||
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
|
||||
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
|
||||
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
@@ -189,10 +174,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_knew,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
q_eff_lens_per_batch,
|
||||
kv_eff_lens_per_batch,
|
||||
rotary_dim,
|
||||
i_perm,
|
||||
o_perm,
|
||||
|
||||
@@ -52,16 +52,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "30", "number of iterations to benchmark the kernel")
|
||||
// Optional effective seqlen override (exclude PAD) for batch mode
|
||||
.insert("q_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
.insert("repeat", "30", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_pair(result, arg_parser);
|
||||
@@ -120,8 +111,6 @@ struct Problem
|
||||
|
||||
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
q_eff_lens = args.get_int_vec("q_eff_lens");
|
||||
kv_eff_lens = args.get_int_vec("kv_eff_lens");
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_query_shape() const
|
||||
@@ -183,8 +172,6 @@ struct Problem
|
||||
mask_info mask;
|
||||
TensorLayout input_layout;
|
||||
TensorLayout output_layout;
|
||||
std::vector<int> q_eff_lens;
|
||||
std::vector<int> kv_eff_lens;
|
||||
};
|
||||
|
||||
struct RunConfig
|
||||
@@ -339,10 +326,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
q_buf.ToDevice(q.data());
|
||||
k_buf.ToDevice(k.data());
|
||||
v_buf.ToDevice(v.data());
|
||||
// Ensure output buffer is zero-initialized so padded regions compare cleanly
|
||||
o_buf.SetZero();
|
||||
|
||||
ck_tile::fmha_fwd_v3_args args{};
|
||||
ck_tile::fmha_fwd_v3_args args;
|
||||
|
||||
args.data_type = problem.data_type;
|
||||
args.batch = problem.batch;
|
||||
@@ -395,60 +380,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
: problem.seqlen_q * problem.hdim;
|
||||
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
|
||||
|
||||
// Optional cumulative seqlen overrides (exclude PAD)
|
||||
const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1;
|
||||
const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1;
|
||||
|
||||
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
|
||||
std::vector<ck_tile::index_t> eff;
|
||||
if(!opt_vec.empty() && opt_vec[0] != -1)
|
||||
{
|
||||
eff.assign(opt_vec.begin(), opt_vec.end());
|
||||
if(eff.size() < static_cast<size_t>(problem.batch))
|
||||
{
|
||||
eff.resize(problem.batch, eff.back());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
eff.assign(problem.batch, fallback);
|
||||
}
|
||||
return eff;
|
||||
};
|
||||
|
||||
const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q);
|
||||
const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k);
|
||||
|
||||
// Calculate cumulative sums for kernel arguments if varlen is used
|
||||
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
|
||||
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
|
||||
std::vector<ck_tile::index_t>& cum_vec) {
|
||||
cum_vec.resize(per_batch_vec.size() + 1);
|
||||
cum_vec[0] = 0;
|
||||
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
|
||||
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
|
||||
};
|
||||
|
||||
if(has_varlen_q)
|
||||
{
|
||||
calculate_cumulative(eff_q_vec, cuq_cum);
|
||||
}
|
||||
if(has_varlen_k)
|
||||
{
|
||||
calculate_cumulative(eff_kv_vec, cukv_cum);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0);
|
||||
ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0);
|
||||
cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr);
|
||||
cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr);
|
||||
args.cu_seqlen_q_ptr =
|
||||
!cuq_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cuq_buf.GetDeviceBuffer())
|
||||
: nullptr;
|
||||
args.cu_seqlen_kv_ptr =
|
||||
!cukv_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cukv_buf.GetDeviceBuffer())
|
||||
: nullptr;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/*log_level=*/0,
|
||||
@@ -511,72 +442,15 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
o_ref = o_ref.transpose({0, 2, 1, 3});
|
||||
}
|
||||
|
||||
// If variable lengths are provided, compute per-batch references
|
||||
// with the effective lengths; else compute a single full reference.
|
||||
if(has_varlen_q || has_varlen_k)
|
||||
{
|
||||
// Variable-length aware verification: zero-fill padded region and only compute valid part.
|
||||
o_ref.SetZero();
|
||||
|
||||
for(int b = 0; b < problem.batch; ++b)
|
||||
{
|
||||
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
|
||||
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
|
||||
|
||||
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
|
||||
continue;
|
||||
|
||||
// Slice current batch from inputs (bshd) and build single-batch tensors
|
||||
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
|
||||
// Copy effective region
|
||||
q_b.ForEach([&](auto& self, auto idx) {
|
||||
// idx: [0, s, h, d]
|
||||
self(idx) = q(b, idx[1], idx[2], idx[3]);
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
|
||||
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
|
||||
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
|
||||
host::fmha_fwd<float, DataType>(q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
problem.mask,
|
||||
o_b,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.softmax_scale});
|
||||
|
||||
// Scatter into o_ref's bshd descriptor memory
|
||||
for(int s = 0; s < seqlen_q_eff; ++s)
|
||||
{
|
||||
for(int h = 0; h < problem.nhead_q; ++h)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
o_ref(b, s, h, d) = o_b(0, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// No varlen override: compute the full reference once
|
||||
host::fmha_fwd<float, DataType>(q,
|
||||
k,
|
||||
v,
|
||||
problem.mask,
|
||||
o_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.softmax_scale});
|
||||
}
|
||||
host::fmha_fwd<float, DataType>(q,
|
||||
k,
|
||||
v,
|
||||
problem.mask,
|
||||
o_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.softmax_scale});
|
||||
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
@@ -162,20 +162,11 @@ struct fmha_fwd_args
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
// Optional cumulative sequence length arrays
|
||||
// Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD)
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
|
||||
// Group mode: seqstart_padded_* provide physical starts including PAD (optional)
|
||||
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
|
||||
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
@@ -563,9 +554,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.min_seqlen_q,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.seqstart_padded_q_ptr,
|
||||
args.seqstart_padded_k_ptr);
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -611,9 +600,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_kv_ptr);
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -151,10 +151,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
std::vector<ck_tile::index_t> seqlen_qpads,
|
||||
std::vector<ck_tile::index_t> seqlen_kpads,
|
||||
std::vector<ck_tile::index_t> q_eff_lens_per_batch,
|
||||
std::vector<ck_tile::index_t> kv_eff_lens_per_batch,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
@@ -365,44 +362,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
|
||||
// Optional padded Q seqstarts (group-mode only)
|
||||
std::vector<int32_t> seqstart_q_with_padding_host;
|
||||
if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1)
|
||||
{
|
||||
if(seqlen_qpads.size() < static_cast<size_t>(batch))
|
||||
{
|
||||
seqlen_qpads.resize(batch, seqlen_qpads.back());
|
||||
}
|
||||
if(seqlen_qpads.size() == static_cast<size_t>(batch))
|
||||
{
|
||||
seqstart_q_with_padding_host = to_seqstarts(
|
||||
ck_tile::span<const int32_t>(seqlen_qpads.data(), seqlen_qpads.size()));
|
||||
}
|
||||
}
|
||||
|
||||
// Optional batch-mode cumulative seqlen overrides
|
||||
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
auto calculate_cumulative = [&](std::vector<ck_tile::index_t>& per_batch_vec,
|
||||
std::vector<ck_tile::index_t>& cum_vec) {
|
||||
if(!per_batch_vec.empty() && per_batch_vec[0] != -1)
|
||||
{
|
||||
if(per_batch_vec.size() < static_cast<size_t>(batch))
|
||||
{
|
||||
per_batch_vec.resize(batch, per_batch_vec.back());
|
||||
}
|
||||
cum_vec.resize(batch + 1);
|
||||
cum_vec[0] = 0;
|
||||
for(int i = 0; i < batch; ++i)
|
||||
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
|
||||
}
|
||||
};
|
||||
|
||||
calculate_cumulative(q_eff_lens_per_batch, cuq_cum);
|
||||
calculate_cumulative(kv_eff_lens_per_batch, cukv_cum);
|
||||
}
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
@@ -486,15 +445,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
|
||||
const ck_tile::index_t shape_seqlen_q_lse =
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch
|
||||
? seqlen_qs[0]
|
||||
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back()
|
||||
: seqstart_q_with_padding_host.back()));
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_ks[0]
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
@@ -552,7 +504,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
@@ -650,16 +602,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty()
|
||||
? 0
|
||||
: seqstart_q_with_padding_host.size() *
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k_padded_buf(
|
||||
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
|
||||
: cuq_cum.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem cu_seqlen_kv_buf(
|
||||
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
|
||||
0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
@@ -751,14 +693,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
// Keep logical starts in seqstart_k; pass padded K via separate pointer
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
seqstart_q_padded_buf.ToDevice(
|
||||
seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data());
|
||||
seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr
|
||||
: seqstart_k_with_padding_host.data());
|
||||
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
|
||||
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
: seqstart_k_with_padding_host.data());
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.data()
|
||||
: nullptr);
|
||||
@@ -894,8 +830,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
@@ -910,8 +846,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
@@ -1025,29 +961,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
|
||||
// Group-mode: optional physical padded starts for Q/K
|
||||
if(mode == mode_enum::group)
|
||||
{
|
||||
args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty()
|
||||
? nullptr
|
||||
: seqstart_q_padded_buf.GetDeviceBuffer());
|
||||
args.seqstart_padded_k_ptr =
|
||||
(seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// Batch-mode: optional cumulative effective seqlen overrides
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
args.cu_seqlen_q_ptr = cuq_cum.empty()
|
||||
? nullptr
|
||||
: reinterpret_cast<const ck_tile::index_t*>(
|
||||
cu_seqlen_q_buf.GetDeviceBuffer());
|
||||
args.cu_seqlen_kv_ptr = cukv_cum.empty()
|
||||
? nullptr
|
||||
: reinterpret_cast<const ck_tile::index_t*>(
|
||||
cu_seqlen_kv_buf.GetDeviceBuffer());
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
@@ -1254,29 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
if(!cuq_cum.empty())
|
||||
{
|
||||
real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb];
|
||||
}
|
||||
if(!cukv_cum.empty())
|
||||
{
|
||||
real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb];
|
||||
}
|
||||
}
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t cache_b_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
|
||||
const ck_tile::index_t query_offset =
|
||||
(mode == mode_enum::batch
|
||||
? 0
|
||||
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb]
|
||||
: seqstart_q_with_padding_host[wb]));
|
||||
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch
|
||||
? 0
|
||||
@@ -1639,10 +1538,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
const ck_tile::index_t query_offset_lse =
|
||||
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
|
||||
@@ -56,11 +56,6 @@ struct fmha_fwd_v3_args
|
||||
index_t stride_o;
|
||||
index_t nhead_stride_o;
|
||||
index_t batch_stride_o;
|
||||
|
||||
// Optional batch-mode cumulative seqlen overrides (exclude PAD)
|
||||
// If provided, they override per-batch effective lengths to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
|
||||
|
||||
@@ -158,9 +158,7 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
remap_opt,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_kv_ptr);
|
||||
remap_opt);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
@@ -18,36 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
#Padding Benchmarks: batch mode (baseline vs low/med/high pad)
|
||||
prec="fp16"
|
||||
base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_batch_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
# Padding Benchmarks: group mode (baseline vs low/med/high physical pad)
|
||||
seqlens_q="1024,768,512,256"
|
||||
seqlens_k="1024,768,512,256"
|
||||
base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
|
||||
|
||||
# baseline (no physical pad)
|
||||
$EXE $base_group_args
|
||||
|
||||
# low physical pad
|
||||
$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
|
||||
# medium physical pad
|
||||
$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
|
||||
|
||||
# high physical pad
|
||||
$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
|
||||
|
||||
@@ -23,20 +23,3 @@ done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# Padding benchmark comparisons for v3 (batch mode only)
|
||||
# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ====
|
||||
prec="fp16"
|
||||
base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_v3_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
@@ -137,118 +137,9 @@ run_fp16_appendkv_tests() {
|
||||
done ; done ; done
|
||||
}
|
||||
|
||||
run_padding_smoke_tests() {
|
||||
# Padding-only smoke tests for batch/group mode using COMMON_ARGS
|
||||
local prec="fp16"
|
||||
|
||||
# Batch mode: padding via effective lengths (exclude PAD)
|
||||
# Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches
|
||||
local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
# Group mode: padding via physical stride along seqlen
|
||||
local seqlens_q="1024,768,512,256"
|
||||
local seqlens_k="1024,768,512,256"
|
||||
local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
|
||||
# low physical pad
|
||||
$EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
# medium physical pad
|
||||
$EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
|
||||
# high physical pad
|
||||
$EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
|
||||
}
|
||||
|
||||
run_padding_basic_boundary_tests() {
|
||||
# Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh)
|
||||
local prec
|
||||
local perm
|
||||
|
||||
# Group mode: Q&K padded with per-batch different strides
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \
|
||||
-s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \
|
||||
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# slightly larger, uneven padding strides
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \
|
||||
-s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# only K padded; Q unpadded
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \
|
||||
-s=55 -s_k=256 -s_kpad=272,260 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# use cu_seqlen overrides to skip tail PAD
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \
|
||||
-q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \
|
||||
-q_eff_lens=55,60 -kv_eff_lens=200,256 \
|
||||
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# no padding (equal), mixed Q/KV, all len=1
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# highly variable logical lengths
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \
|
||||
-s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \
|
||||
-s=256,129 -s_k=256,129 -s_kpad=256 \
|
||||
-bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
run_padding_smoke_tests
|
||||
run_padding_basic_boundary_tests
|
||||
run_fp8_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
|
||||
Reference in New Issue
Block a user