mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Revert "[CK_TILE] Add sequence padding and variable length support in fmha (a…" (#2883)
This reverts commit 86dd59cd01.
This commit is contained in:
@@ -291,11 +291,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
// Optional cumulative sequence length pointers for batch mode
|
||||
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupModeKargs
|
||||
@@ -315,11 +310,6 @@ struct FmhaFwdKernel
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
// Optional cumulative padded sequence starts (including PAD tokens)
|
||||
// Used solely to compute memory offsets when sequences are physically padded.
|
||||
const int32_t* seqstart_padded_q_ptr = nullptr;
|
||||
const int32_t* seqstart_padded_k_ptr = nullptr;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -470,105 +460,6 @@ struct FmhaFwdKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// Overload: Batch mode with optional cu_seqlen pointers (unpadded cumulative lengths)
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr)
|
||||
{
|
||||
auto kargs = MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
drop_seed_offset);
|
||||
|
||||
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
|
||||
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
@@ -890,95 +781,6 @@ struct FmhaFwdKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// Overload: Group mode with optional padded seqstarts for memory offsets
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr,
|
||||
const void* seqstart_padded_k_ptr)
|
||||
{
|
||||
auto kargs = MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
drop_seed_offset);
|
||||
|
||||
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
|
||||
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
@@ -1271,44 +1073,35 @@ struct FmhaFwdKernel
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// logical and physical (padded) starts
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
// DRAM base offsets use physical padded starts
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start_padded;
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
// LSE stays indexed by unpadded starts
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
batch_offset_randval = query_start_padded * kargs.stride_randval;
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
@@ -1320,7 +1113,8 @@ struct FmhaFwdKernel
|
||||
}
|
||||
}
|
||||
|
||||
// terminate unnecessary blocks earlier
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
{
|
||||
return;
|
||||
@@ -1356,18 +1150,6 @@ struct FmhaFwdKernel
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -1766,35 +1548,26 @@ struct FmhaFwdKernel
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
// col-major V: offset along seqlen dimension is scalar index
|
||||
batch_offset_v = key_start_padded;
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
}
|
||||
|
||||
// LSE layout is [nhead, total_seqlen], index by unpadded start
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
batch_offset_lse = query_start;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
@@ -1832,18 +1605,6 @@ struct FmhaFwdKernel
|
||||
batch_offset_bias =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
@@ -100,11 +100,6 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
// Optional cumulative sequence length pointers for batch mode
|
||||
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupModeKargs
|
||||
@@ -115,11 +110,6 @@ struct FmhaFwdV3Kernel
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
// Optional cumulative padded sequence starts (including PAD tokens)
|
||||
// Used solely to compute memory offsets when sequences are physically padded.
|
||||
const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
|
||||
const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -200,78 +190,6 @@ struct FmhaFwdV3Kernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// Overload: Batch mode with optional cu_seqlen pointers
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr)
|
||||
{
|
||||
auto kargs = MakeKargs(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
remap_opt);
|
||||
|
||||
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
|
||||
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
@@ -342,70 +260,6 @@ struct FmhaFwdV3Kernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// Overload: Group mode with optional padded seqstarts for memory offsets
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt,
|
||||
const void* seqstart_padded_q_ptr,
|
||||
const void* seqstart_padded_k_ptr)
|
||||
{
|
||||
auto kargs = MakeKargs(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
remap_opt);
|
||||
|
||||
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
|
||||
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
@@ -519,26 +373,18 @@ struct FmhaFwdV3Kernel
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
// LSE layout is [nhead, total_seqlen], index by unpadded start
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
@@ -571,18 +417,6 @@ struct FmhaFwdV3Kernel
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
Reference in New Issue
Block a user