From fbae2aba205738c0f67d5c00d4bbf732012f3e01 Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 29 Jan 2026 09:46:19 +0000 Subject: [PATCH] fix issue --- example/ck_tile/01_fmha/fmha_fwd.hpp | 2 +- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 43 ++++++++++++++----- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 5a450eb172..87cd4b0b1a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1237,7 +1237,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.drop_seed_offset, args.sink_ptr, args.block_scale_size_kv, - args.block_scale_seqstart_k_ptr); + reinterpret_cast(args.block_scale_seqstart_k_ptr)); } else { // create batch mode kernel arguments diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index cb420a938f..d9759057fc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -624,11 +624,17 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.nhead_stride_k_descale = nhead_stride_k_descale; kargs.nhead_stride_v_descale = nhead_stride_v_descale; - kargs.batch_stride_k_descale = batch_stride_k_descale; - kargs.batch_stride_v_descale = batch_stride_v_descale; + if constexpr(kIsGroupMode) + { + kargs.block_scale_seqstart_k_ptr = block_scale_seqstart_k_ptr; + } + else // batch mode + { + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + } - kargs.block_scale_size_kv = block_scale_size_kv; - kargs.block_scale_seqstart_k_ptr = block_scale_seqstart_k_ptr; + kargs.block_scale_size_kv = block_scale_size_kv; } if constexpr(kHasDropout) { @@ -895,21 +901,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { if constexpr(kIsGroupMode) { - // For group mode, use seqstart_k_ptr to calculate offset + // For group mode, block_scale_seqstart_k_ptr[i_batch] gives the direct offset + const index_t block_scale_offset = kargs.block_scale_seqstart_k_ptr + ? kargs.block_scale_seqstart_k_ptr[i_batch] + : 0; k_descale_ptr = reinterpret_cast(kargs.k_descale_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k_descale + - kargs.batch_stride_k_descale * (kargs.block_scale_seqstart_k_ptr - ? kargs.block_scale_seqstart_k_ptr[i_batch] - : 0); + block_scale_offset; v_descale_ptr = reinterpret_cast(kargs.v_descale_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v_descale + - kargs.batch_stride_v_descale * (kargs.block_scale_seqstart_k_ptr - ? kargs.block_scale_seqstart_k_ptr[i_batch] - : 0); + block_scale_offset; } else { @@ -1359,11 +1364,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { // Q: per-tensor (handled in scale_s), K/V: block scale (handled in pipeline) return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func k_dram_window, + identity{}, // k_element_func v_dram_window, + identity{}, // v_element_func bias_dram_window, + identity{}, // bias_element_func randval_dram_window, lse_dram_window, + identity{}, // lse_element_func + scales{1.0f}, // s_acc_element_func (will be multiplied by k_descale in pipeline) + identity{}, // p_compute_element_func + identity{}, // o_acc_element_func mask, position_encoding, variant_params.sm_scale, @@ -1386,11 +1399,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { // NO_SCALE return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func k_dram_window, + identity{}, // k_element_func v_dram_window, + identity{}, // v_element_func bias_dram_window, + identity{}, // bias_element_func randval_dram_window, lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + identity{}, // p_compute_element_func + identity{}, // o_acc_element_func mask, position_encoding, variant_params.sm_scale,