mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
fmha: fix coredump issue in small seqlen
PER_TOKEN_HEAD batch prefill reads q_descale_per_token for a full kM0 query tile. For small sequence lengths or chunked-prefill tail chunks, some tile rows are padding rows beyond the valid query range, but the kernel still indexed q_descale_per_token_ptr for those rows. In group mode, q_descale_ptr also missed the per-batch query_start offset, so later batches could read descales from the wrong position. Fix the issue by offsetting q_descale_ptr with query_start, exposing the mask's valid query row count, and guarding q_descale_per_token loads for rows outside y_total. Patch credits: Zhen Han <zhen.han@amd.com>
This commit is contained in:
@@ -310,6 +310,11 @@ struct GenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const
|
||||
{
|
||||
return y_total;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x, sink;
|
||||
index_t y_total, x_total;
|
||||
@@ -549,6 +554,11 @@ struct SimplifiedGenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const
|
||||
{
|
||||
return y_total;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x, sink;
|
||||
index_t y_total, x_total;
|
||||
@@ -735,6 +745,11 @@ struct SimplifiedRatioAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const
|
||||
{
|
||||
return y_total;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y_total, x_total;
|
||||
|
||||
@@ -868,10 +868,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
}
|
||||
}();
|
||||
|
||||
long_index_t query_start = 0;
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
|
||||
@@ -1459,7 +1460,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
assert(kargs.q_descale_ptr != nullptr);
|
||||
assert(kargs.k_descale_ptr != nullptr);
|
||||
assert(kargs.v_descale_ptr != nullptr);
|
||||
const float* q_descale_ptr = reinterpret_cast<const float*>(kargs.q_descale_ptr);
|
||||
const float* q_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
|
||||
query_start * kargs.stride_q_descale_token;
|
||||
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
|
||||
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
|
||||
|
||||
|
||||
@@ -1232,11 +1232,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Q-row descales (kM0 entries).
|
||||
// Guard against small-seqlen tiles whose tail rows fall past the
|
||||
// valid query range: rows >= y_total have no per-token descale and
|
||||
// must not read out of bounds. Stage 0.0f for them (their masked
|
||||
// s_acc lanes are dropped before softmax anyway).
|
||||
const index_t q_row_total = mask.GetYTotal();
|
||||
for(index_t off = tid_in_block; off < kM0; off += threads_per_block)
|
||||
{
|
||||
lds_q_descale[off] = q_descale_per_token_ptr[
|
||||
(q_row_base + off) * stride_q_descale_token +
|
||||
qo_head * nhead_stride_q_descale];
|
||||
const index_t q_row = q_row_base + off;
|
||||
lds_q_descale[off] =
|
||||
q_row < q_row_total
|
||||
? q_descale_per_token_ptr[q_row * stride_q_descale_token +
|
||||
qo_head * nhead_stride_q_descale]
|
||||
: 0.0f;
|
||||
}
|
||||
|
||||
// K-col descales (kN0 entries).
|
||||
|
||||
Reference in New Issue
Block a user