fix issue

This commit is contained in:
ltqin
2026-01-29 09:46:19 +00:00
parent a9d85dfe16
commit fbae2aba20
2 changed files with 33 additions and 12 deletions

View File

@@ -1237,7 +1237,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.drop_seed_offset,
args.sink_ptr,
args.block_scale_size_kv,
args.block_scale_seqstart_k_ptr);
reinterpret_cast<const int32_t*>(args.block_scale_seqstart_k_ptr));
}
else
{ // create batch mode kernel arguments

View File

@@ -624,11 +624,17 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.batch_stride_k_descale = batch_stride_k_descale;
kargs.batch_stride_v_descale = batch_stride_v_descale;
if constexpr(kIsGroupMode)
{
kargs.block_scale_seqstart_k_ptr = block_scale_seqstart_k_ptr;
}
else // batch mode
{
kargs.batch_stride_k_descale = batch_stride_k_descale;
kargs.batch_stride_v_descale = batch_stride_v_descale;
}
kargs.block_scale_size_kv = block_scale_size_kv;
kargs.block_scale_seqstart_k_ptr = block_scale_seqstart_k_ptr;
kargs.block_scale_size_kv = block_scale_size_kv;
}
if constexpr(kHasDropout)
{
@@ -895,21 +901,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
if constexpr(kIsGroupMode)
{
// For group mode, use seqstart_k_ptr to calculate offset
// For group mode, block_scale_seqstart_k_ptr[i_batch] gives the direct offset
const index_t block_scale_offset = kargs.block_scale_seqstart_k_ptr
? kargs.block_scale_seqstart_k_ptr[i_batch]
: 0;
k_descale_ptr =
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k_descale +
kargs.batch_stride_k_descale * (kargs.block_scale_seqstart_k_ptr
? kargs.block_scale_seqstart_k_ptr[i_batch]
: 0);
block_scale_offset;
v_descale_ptr =
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v_descale +
kargs.batch_stride_v_descale * (kargs.block_scale_seqstart_k_ptr
? kargs.block_scale_seqstart_k_ptr[i_batch]
: 0);
block_scale_offset;
}
else
{
@@ -1359,11 +1364,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
// Q: per-tensor (handled in scale_s), K/V: block scale (handled in pipeline)
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
scales<float>{1.0f}, // s_acc_element_func (will be multiplied by k_descale in pipeline)
identity{}, // p_compute_element_func
identity{}, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
@@ -1386,11 +1399,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
// NO_SCALE
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
identity{}, // p_compute_element_func
identity{}, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,