mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
fix issue
This commit is contained in:
@@ -1237,7 +1237,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr,
|
||||
args.block_scale_size_kv,
|
||||
args.block_scale_seqstart_k_ptr);
|
||||
reinterpret_cast<const int32_t*>(args.block_scale_seqstart_k_ptr));
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
|
||||
@@ -624,11 +624,17 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
kargs.block_scale_seqstart_k_ptr = block_scale_seqstart_k_ptr;
|
||||
}
|
||||
else // batch mode
|
||||
{
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
}
|
||||
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
kargs.block_scale_seqstart_k_ptr = block_scale_seqstart_k_ptr;
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -895,21 +901,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// For group mode, use seqstart_k_ptr to calculate offset
|
||||
// For group mode, block_scale_seqstart_k_ptr[i_batch] gives the direct offset
|
||||
const index_t block_scale_offset = kargs.block_scale_seqstart_k_ptr
|
||||
? kargs.block_scale_seqstart_k_ptr[i_batch]
|
||||
: 0;
|
||||
k_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k_descale +
|
||||
kargs.batch_stride_k_descale * (kargs.block_scale_seqstart_k_ptr
|
||||
? kargs.block_scale_seqstart_k_ptr[i_batch]
|
||||
: 0);
|
||||
block_scale_offset;
|
||||
v_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v_descale +
|
||||
kargs.batch_stride_v_descale * (kargs.block_scale_seqstart_k_ptr
|
||||
? kargs.block_scale_seqstart_k_ptr[i_batch]
|
||||
: 0);
|
||||
block_scale_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1359,11 +1364,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{
|
||||
// Q: per-tensor (handled in scale_s), K/V: block scale (handled in pipeline)
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
scales<float>{1.0f}, // s_acc_element_func (will be multiplied by k_descale in pipeline)
|
||||
identity{}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
@@ -1386,11 +1399,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{
|
||||
// NO_SCALE
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
identity{}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
|
||||
Reference in New Issue
Block a user