[CK_TILE] fmha: Add query padding support to backward pass (#3097)

* [CK_TILE] fmha: Add query padding support to backward pass

Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels.
- Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths.
- Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences.
- Aligning LSE indexing in the forward kernel with the padded layout for consistency.
- Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length
  sequences and deterministic mode.

* fix clang format

* Adapt fmha_bwd_runner.cpp to new q, kv sequence padding
Add backward q/kv sequence padding unit tests.

* [CK_TILE] fmha: Unify sequence length and padding handling

Refactor the handling of sequence lengths and padding in the
FMHA forward and backward kernels to provide a more unified and flexible
interface.

- Replaced `seqstart_padded_*_ptr` with a more robust system that uses
  `seqstart_*_ptr` for physical sequence lengths and introduces
  `seqlen_*_ptr` and `cu_seqlen_*_ptr` for logical (unpadded) lengths.
- Established a clear order of precedence for determining sequence
  length: cumulative lengths (`cu_seqlen_*_ptr`) take priority,
  followed by per-sequence lengths (`seqlen_*_ptr`), and finally
  physical lengths derived from `seqstart_*_ptr`.
- Clarified the distinction between "group mode" and "batch mode" and
  how sequence lengths are handled in each case.
- Renamed `cu_seqlen_kv_ptr` to `cu_seqlen_k_ptr` for consistency.
- Updated comments and documentation to reflect the new argument
  structure and usage.

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
Jeff Huang
2025-10-29 13:56:11 +08:00
committed by GitHub
parent 13e13ce359
commit 7c6430eca0
11 changed files with 1292 additions and 214 deletions

View File

@@ -313,7 +313,10 @@ struct FmhaBwdDQDKDVKernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
@@ -520,7 +523,10 @@ struct FmhaBwdDQDKDVKernel
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* cu_seqlen_q_ptr,
const void* cu_seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -594,7 +600,10 @@ struct FmhaBwdDQDKDVKernel
{}, // placeholder for deterministic
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -736,10 +745,29 @@ struct FmhaBwdDQDKDVKernel
batch_offset_randval = query_start * kargs.stride_randval;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if(kargs.seqlen_k_ptr != nullptr)
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q =
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
}
// Priority: cu_seqlen_k_ptr > seqlen_k_ptr > seqstart_k
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else if(kargs.seqlen_k_ptr != nullptr)
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
@@ -749,6 +777,12 @@ struct FmhaBwdDQDKDVKernel
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
// skip if logical lengths are zero
if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
{
return;
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if constexpr(!kUseQrQtrDorPipeline)
@@ -1246,6 +1280,8 @@ struct FmhaBwdOGradDotOKernel
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
{
const int32_t* seqstart_q_ptr;
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
};
using Kargs = std::
@@ -1293,6 +1329,8 @@ struct FmhaBwdOGradDotOKernel
void* d_ptr,
float p_undrop,
const void* seqstart_q_ptr,
const void* seqlen_q_ptr,
const void* cu_seqlen_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
@@ -1311,7 +1349,9 @@ struct FmhaBwdOGradDotOKernel
nhead_stride_do,
nhead_stride_o,
nhead_stride_d},
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
return kargs;
}
@@ -1355,9 +1395,23 @@ struct FmhaBwdOGradDotOKernel
batch_offset_do = query_start * kargs.stride_do;
batch_offset_d = query_start;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q = kargs.seqlen_q_ptr
? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
: physical_seqlen_q;
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
@@ -1521,6 +1575,10 @@ struct FmhaBwdConvertQGradKernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
};
using Kargs = std::conditional_t<kIsGroupMode,
@@ -1569,6 +1627,10 @@ struct FmhaBwdConvertQGradKernel
void* dq_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* cu_seqlen_q_ptr,
const void* cu_seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
@@ -1587,7 +1649,11 @@ struct FmhaBwdConvertQGradKernel
nhead_stride_dq_acc},
{},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
if constexpr(kIsDeterministic)
{
@@ -1632,13 +1698,41 @@ struct FmhaBwdConvertQGradKernel
batch_offset_dq = query_start * kargs.stride_dq;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q = kargs.seqlen_q_ptr
? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
: physical_seqlen_q;
}
if constexpr(kIsDeterministic)
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
const ck_tile::index_t physical_seqlen_k =
adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
// Priority: cu_seqlen_k_ptr > seqlen_k_ptr > physical_seqlen_k
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else
{
kargs.seqlen_k =
kargs.seqlen_k_ptr
? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
: physical_seqlen_k;
}
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier

View File

@@ -296,8 +296,8 @@ struct FmhaFwdKernel
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
};
struct FmhaFwdGroupModeKargs
@@ -316,12 +316,12 @@ struct FmhaFwdKernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_q_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr;
const int32_t* seqstart_padded_k_ptr = nullptr;
// Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
const int32_t* cu_seqlen_q_ptr = nullptr;
const int32_t* cu_seqlen_k_ptr = nullptr;
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -379,8 +379,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -471,8 +471,8 @@ struct FmhaFwdKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
return kargs;
}
@@ -522,8 +522,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -570,7 +570,7 @@ struct FmhaFwdKernel
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
cu_seqlen_k_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -619,8 +619,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -667,7 +667,7 @@ struct FmhaFwdKernel
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
cu_seqlen_k_ptr);
}
template <bool Cond = kIsGroupMode>
@@ -681,6 +681,7 @@ struct FmhaFwdKernel
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -711,8 +712,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -746,6 +747,7 @@ struct FmhaFwdKernel
{}, // placeholder for min_seqlen_q
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
@@ -804,8 +806,8 @@ struct FmhaFwdKernel
kargs.min_seqlen_q = min_seqlen_q;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
return kargs;
}
@@ -821,6 +823,7 @@ struct FmhaFwdKernel
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -850,8 +853,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -863,6 +866,7 @@ struct FmhaFwdKernel
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_q_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
@@ -892,8 +896,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -908,6 +912,7 @@ struct FmhaFwdKernel
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -937,8 +942,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -950,6 +955,7 @@ struct FmhaFwdKernel
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_q_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
@@ -979,8 +985,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
@@ -1109,46 +1115,52 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// logical and physical (padded) starts
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
// Use seqstart_q_ptr and seqstart_k_ptr for physical starts
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
// DRAM base offsets use physical padded starts
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
// DRAM base offsets use physical starts
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start_padded * kargs.stride_v;
batch_offset_v = key_start * kargs.stride_v;
}
else
{
batch_offset_v = key_start_padded;
batch_offset_v = key_start;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start_padded * kargs.stride_bias;
batch_offset_bias = query_start * kargs.stride_bias;
}
if constexpr(kStoreLSE)
{
// LSE stays indexed by unpadded starts
batch_offset_lse = query_start_unpadded;
// LSE follows the physical layout to stay consistent with other tensors
batch_offset_lse = query_start;
}
if constexpr(kHasDropout)
{
batch_offset_randval = query_start_padded * kargs.stride_randval;
batch_offset_randval = query_start * kargs.stride_randval;
}
batch_offset_o = query_start_padded * kargs.stride_o;
batch_offset_o = query_start * kargs.stride_o;
// real logical lengths (exclude PAD)
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
if(kargs.seqlen_q_ptr != nullptr)
{
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
}
else if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
}
if constexpr(kSkipMinSeqlenQ)
{
@@ -1168,6 +1180,11 @@ struct FmhaFwdKernel
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
@@ -1201,10 +1218,10 @@ struct FmhaFwdKernel
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
}
@@ -1603,39 +1620,46 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
// get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
// physical starts
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start_padded * kargs.stride_v;
batch_offset_v = key_start * kargs.stride_v;
}
else
{
// col-major V: offset along seqlen dimension is scalar index
batch_offset_v = key_start_padded;
batch_offset_v = key_start;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start_padded * kargs.stride_bias;
batch_offset_bias = query_start * kargs.stride_bias;
}
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
batch_offset_o = query_start_padded * kargs.stride_o;
// LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
batch_offset_lse = query_start;
batch_offset_o = query_start * kargs.stride_o;
// get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
if(kargs.seqlen_q_ptr != nullptr)
{
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
}
else if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
kargs.seqlen_q =
kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
@@ -1648,6 +1672,11 @@ struct FmhaFwdKernel
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else
{
kargs.seqlen_k =
@@ -1677,10 +1706,10 @@ struct FmhaFwdKernel
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
}