From 2867294db336689c3cd7bd0f79450f116ab50eda Mon Sep 17 00:00:00 2001 From: yashagar-amd Date: Fri, 5 Jun 2026 11:04:59 +0000 Subject: [PATCH] fmha: fix coredump issue in small seqlen PER_TOKEN_HEAD batch prefill reads q_descale_per_token for a full kM0 query tile. For small sequence lengths or chunked-prefill tail chunks, some tile rows are padding rows beyond the valid query range, but the kernel still indexed q_descale_per_token_ptr for those rows. In group mode, q_descale_ptr also missed the per-batch query_start offset, so later batches could read descales from the wrong position. Fix the issue by offsetting q_descale_ptr with query_start, exposing the mask's valid query row count, and guarding q_descale_per_token loads for rows outside y_total. Patch credits: Zhen Han --- include/ck_tile/ops/fmha/block/block_masking.hpp | 15 +++++++++++++++ .../ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 7 +++++-- ...fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 14 +++++++++++--- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 134cb6acbb..dc5cd9af1c 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -310,6 +310,11 @@ struct GenericAttentionMask } } + CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const + { + return y_total; + } + private: index_t y, x, sink; index_t y_total, x_total; @@ -549,6 +554,11 @@ struct SimplifiedGenericAttentionMask } } + CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const + { + return y_total; + } + private: index_t y, x, sink; index_t y_total, x_total; @@ -735,6 +745,11 @@ struct SimplifiedRatioAttentionMask } } + CK_TILE_HOST_DEVICE constexpr auto GetYTotal() const + { + return y_total; + } + private: index_t y, x; index_t y_total, x_total; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index ce6899e586..3a5b30760d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -868,10 +868,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } }(); + long_index_t query_start = 0; if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + query_start = kargs.seqstart_q_ptr[i_batch]; batch_offset_q = query_start * kargs.stride_q; @@ -1459,7 +1460,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel assert(kargs.q_descale_ptr != nullptr); assert(kargs.k_descale_ptr != nullptr); assert(kargs.v_descale_ptr != nullptr); - const float* q_descale_ptr = reinterpret_cast(kargs.q_descale_ptr); + const float* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + query_start * kargs.stride_q_descale_token; const float* k_descale_ptr = reinterpret_cast(kargs.k_descale_ptr); const float* v_descale_ptr = reinterpret_cast(kargs.v_descale_ptr); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 74f496c6dd..4690a9e361 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -1232,11 +1232,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(0); // Q-row descales (kM0 entries). + // Guard against small-seqlen tiles whose tail rows fall past the + // valid query range: rows >= y_total have no per-token descale and + // must not read out of bounds. Stage 0.0f for them (their masked + // s_acc lanes are dropped before softmax anyway). + const index_t q_row_total = mask.GetYTotal(); for(index_t off = tid_in_block; off < kM0; off += threads_per_block) { - lds_q_descale[off] = q_descale_per_token_ptr[ - (q_row_base + off) * stride_q_descale_token + - qo_head * nhead_stride_q_descale]; + const index_t q_row = q_row_base + off; + lds_q_descale[off] = + q_row < q_row_total + ? q_descale_per_token_ptr[q_row * stride_q_descale_token + + qo_head * nhead_stride_q_descale] + : 0.0f; } // K-col descales (kN0 entries).