fmha: fix coredump issue in small seqlen

PER_TOKEN_HEAD batch prefill reads q_descale_per_token for a full kM0
query tile. For small sequence lengths or chunked-prefill tail chunks,
some tile rows are padding rows beyond the valid query range, but the
kernel still indexed q_descale_per_token_ptr for those rows. In group
mode, q_descale_ptr also missed the per-batch query_start offset, so
later batches could read descales from the wrong position.

Fix the issue by offsetting q_descale_ptr with query_start, exposing
the mask's valid query row count, and guarding q_descale_per_token
loads for rows outside y_total.

Patch credits: Zhen Han <zhen.han@amd.com>
This commit is contained in:
yashagar-amd
2026-06-05 11:04:59 +00:00
parent c8c0eaf982
commit 2867294db3
3 changed files with 31 additions and 5 deletions

View File

@@ -310,6 +310,11 @@ struct GenericAttentionMask
}
}
CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const
{
return y_total;
}
private:
index_t y, x, sink;
index_t y_total, x_total;
@@ -549,6 +554,11 @@ struct SimplifiedGenericAttentionMask
}
}
CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const
{
return y_total;
}
private:
index_t y, x, sink;
index_t y_total, x_total;
@@ -735,6 +745,11 @@ struct SimplifiedRatioAttentionMask
}
}
CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const
{
return y_total;
}
private:
index_t y, x;
index_t y_total, x_total;

View File

@@ -868,10 +868,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
}();
long_index_t query_start = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
@@ -1459,7 +1460,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
assert(kargs.q_descale_ptr != nullptr);
assert(kargs.k_descale_ptr != nullptr);
assert(kargs.v_descale_ptr != nullptr);
const float* q_descale_ptr = reinterpret_cast<const float*>(kargs.q_descale_ptr);
const float* q_descale_ptr =
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
query_start * kargs.stride_q_descale_token;
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);

View File

@@ -1232,11 +1232,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0);
// Q-row descales (kM0 entries).
// Guard against small-seqlen tiles whose tail rows fall past the
// valid query range: rows >= y_total have no per-token descale and
// must not read out of bounds. Stage 0.0f for them (their masked
// s_acc lanes are dropped before softmax anyway).
const index_t q_row_total = mask.GetYTotal();
for(index_t off = tid_in_block; off < kM0; off += threads_per_block)
{
lds_q_descale[off] = q_descale_per_token_ptr[
(q_row_base + off) * stride_q_descale_token +
qo_head * nhead_stride_q_descale];
const index_t q_row = q_row_base + off;
lds_q_descale[off] =
q_row < q_row_total
? q_descale_per_token_ptr[q_row * stride_q_descale_token +
qo_head * nhead_stride_q_descale]
: 0.0f;
}
// K-col descales (kN0 entries).