mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] fmha: Add query padding support to backward pass (#3097)
* [CK_TILE] fmha: Add query padding support to backward pass Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels. - Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths. - Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences. - Aligning LSE indexing in the forward kernel with the padded layout for consistency. - Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length sequences and deterministic mode. * fix clang format * Adapt fmha_bwd_runner.cpp to new q, kv sequence padding Add backward q/kv sequence padding unit tests. * [CK_TILE] fmha: Unify sequence length and padding handling Refactor the handling of sequence lengths and padding in the FMHA forward and backward kernels to provide a more unified and flexible interface. - Replaced `seqstart_padded_*_ptr` with a more robust system that uses `seqstart_*_ptr` for physical sequence lengths and introduces `seqlen_*_ptr` and `cu_seqlen_*_ptr` for logical (unpadded) lengths. - Established a clear order of precedence for determining sequence length: cumulative lengths (`cu_seqlen_*_ptr`) take priority, followed by per-sequence lengths (`seqlen_*_ptr`), and finally physical lengths derived from `seqstart_*_ptr`. - Clarified the distinction between "group mode" and "batch mode" and how sequence lengths are handled in each case. - Renamed `cu_seqlen_kv_ptr` to `cu_seqlen_k_ptr` for consistency. - Updated comments and documentation to reflect the new argument structure and usage. --------- Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -313,7 +313,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
|
||||
const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
|
||||
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
|
||||
const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
|
||||
@@ -520,7 +523,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
void* dq_acc_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
const void* cu_seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -594,7 +600,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
{}, // placeholder for deterministic
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -736,10 +745,29 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
if(kargs.seqlen_k_ptr != nullptr)
|
||||
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q =
|
||||
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
|
||||
}
|
||||
|
||||
// Priority: cu_seqlen_k_ptr > seqlen_k_ptr > seqstart_k
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
@@ -749,6 +777,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
}
|
||||
|
||||
// skip if logical lengths are zero
|
||||
if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
@@ -1246,6 +1280,8 @@ struct FmhaBwdOGradDotOKernel
|
||||
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
|
||||
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
|
||||
};
|
||||
|
||||
using Kargs = std::
|
||||
@@ -1293,6 +1329,8 @@ struct FmhaBwdOGradDotOKernel
|
||||
void* d_ptr,
|
||||
float p_undrop,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
@@ -1311,7 +1349,9 @@ struct FmhaBwdOGradDotOKernel
|
||||
nhead_stride_do,
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -1355,9 +1395,23 @@ struct FmhaBwdOGradDotOKernel
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_d = query_start;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr
|
||||
? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
|
||||
: physical_seqlen_q;
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
@@ -1521,6 +1575,10 @@ struct FmhaBwdConvertQGradKernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
|
||||
const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
|
||||
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
|
||||
const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode,
|
||||
@@ -1569,6 +1627,10 @@ struct FmhaBwdConvertQGradKernel
|
||||
void* dq_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
const void* cu_seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t stride_dq,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
@@ -1587,7 +1649,11 @@ struct FmhaBwdConvertQGradKernel
|
||||
nhead_stride_dq_acc},
|
||||
{},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
@@ -1632,13 +1698,41 @@ struct FmhaBwdConvertQGradKernel
|
||||
batch_offset_dq = query_start * kargs.stride_dq;
|
||||
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr
|
||||
? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
|
||||
: physical_seqlen_q;
|
||||
}
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
const ck_tile::index_t physical_seqlen_k =
|
||||
adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
|
||||
// Priority: cu_seqlen_k_ptr > seqlen_k_ptr > physical_seqlen_k
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.seqlen_k_ptr
|
||||
? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
|
||||
: physical_seqlen_k;
|
||||
}
|
||||
}
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
|
||||
@@ -296,8 +296,8 @@ struct FmhaFwdKernel
|
||||
|
||||
// Optional cumulative sequence length pointers for batch mode
|
||||
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
|
||||
const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
|
||||
const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupModeKargs
|
||||
@@ -316,12 +316,12 @@ struct FmhaFwdKernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
// Optional cumulative padded sequence starts (including PAD tokens)
|
||||
// Used solely to compute memory offsets when sequences are physically padded.
|
||||
const int32_t* seqstart_padded_q_ptr = nullptr;
|
||||
const int32_t* seqstart_padded_k_ptr = nullptr;
|
||||
// Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
|
||||
const int32_t* cu_seqlen_q_ptr = nullptr;
|
||||
const int32_t* cu_seqlen_k_ptr = nullptr;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -379,8 +379,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -471,8 +471,8 @@ struct FmhaFwdKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
|
||||
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
|
||||
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
|
||||
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -522,8 +522,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -570,7 +570,7 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_kv_ptr);
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -619,8 +619,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -667,7 +667,7 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_kv_ptr);
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
@@ -681,6 +681,7 @@ struct FmhaFwdKernel
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -711,8 +712,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -746,6 +747,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for min_seqlen_q
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
@@ -804,8 +806,8 @@ struct FmhaFwdKernel
|
||||
kargs.min_seqlen_q = min_seqlen_q;
|
||||
}
|
||||
|
||||
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
|
||||
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
|
||||
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
|
||||
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -821,6 +823,7 @@ struct FmhaFwdKernel
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -850,8 +853,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -863,6 +866,7 @@ struct FmhaFwdKernel
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
@@ -892,8 +896,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
seqstart_padded_q_ptr,
|
||||
seqstart_padded_k_ptr);
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -908,6 +912,7 @@ struct FmhaFwdKernel
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -937,8 +942,8 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -950,6 +955,7 @@ struct FmhaFwdKernel
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
@@ -979,8 +985,8 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
seqstart_padded_q_ptr,
|
||||
seqstart_padded_k_ptr);
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
@@ -1109,46 +1115,52 @@ struct FmhaFwdKernel
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// logical and physical (padded) starts
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
// Use seqstart_q_ptr and seqstart_k_ptr for physical starts
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
// DRAM base offsets use physical padded starts
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
// DRAM base offsets use physical starts
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start_padded;
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
// LSE stays indexed by unpadded starts
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
// LSE follows the physical layout to stay consistent with other tensors
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
batch_offset_randval = query_start_padded * kargs.stride_randval;
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
// Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
|
||||
if(kargs.seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
}
|
||||
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
@@ -1168,6 +1180,11 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
@@ -1201,10 +1218,10 @@ struct FmhaFwdKernel
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1603,39 +1620,46 @@ struct FmhaFwdKernel
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
// get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
|
||||
// physical starts
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
// col-major V: offset along seqlen dimension is scalar index
|
||||
batch_offset_v = key_start_padded;
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
}
|
||||
|
||||
// LSE layout is [nhead, total_seqlen], index by unpadded start
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
// LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
|
||||
batch_offset_lse = query_start;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
if(kargs.seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
@@ -1648,6 +1672,11 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
@@ -1677,10 +1706,10 @@ struct FmhaFwdKernel
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user