mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
PER_TOKEN_HEAD batch prefill reads q_descale_per_token for a full kM0 query tile. For small sequence lengths or chunked-prefill tail chunks, some tile rows are padding rows beyond the valid query range, but the kernel still indexed q_descale_per_token_ptr for those rows. In group mode, q_descale_ptr also missed the per-batch query_start offset, so later batches could read descales from the wrong position. Fix the issue by offsetting q_descale_ptr with query_start, exposing the mask's valid query row count, and guarding q_descale_per_token loads for rows outside y_total. Patch credits: Zhen Han <zhen.han@amd.com>