mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 16:26:10 +00:00
Pass cache_batch_idx to kernels
This commit is contained in:
@@ -634,6 +634,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if pipeline.F_pagedkv == 't':
|
||||
# we only use batch mode kernels to handle (paged-) kvcache problems
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
@@ -125,7 +125,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("num_splits",
|
||||
"1",
|
||||
"# of splits for key/value. 0 to determine actual number by heuristic")
|
||||
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe.")
|
||||
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
|
||||
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -306,6 +307,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
|
||||
if(0 < page_block_size && use_cache_batch_idx)
|
||||
{
|
||||
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
||||
"'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch)
|
||||
{
|
||||
@@ -589,11 +599,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits || 0 < seqlen_knew || 0 < page_block_size
|
||||
1 < num_splits || 0 < seqlen_knew || use_cache_batch_idx || 0 < page_block_size
|
||||
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits || 0 < seqlen_knew || 0 < page_block_size
|
||||
1 < num_splits || 0 < seqlen_knew || use_cache_batch_idx || 0 < page_block_size
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
@@ -613,6 +623,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_page_blocks / batch}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
|
||||
? std::array<ck_tile::index_t, 1>{batch}
|
||||
: std::array<ck_tile::index_t, 1>{1});
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
@@ -691,6 +705,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
|
||||
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
@@ -711,6 +726,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -727,6 +743,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
block_table_buf.ToDevice(block_table_host.data());
|
||||
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
|
||||
|
||||
// clang-format off
|
||||
auto layout_str = [&](bool permute){
|
||||
@@ -901,10 +918,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.rotary_dim = rotary_dim;
|
||||
|
||||
args.block_table_ptr =
|
||||
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr;
|
||||
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
args.page_block_size = page_block_size;
|
||||
|
||||
args.cache_batch_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
||||
|
||||
args.stride_knew = stride_knew;
|
||||
args.stride_vnew = stride_vnew;
|
||||
args.nhead_stride_knew = nhead_stride_knew;
|
||||
@@ -935,7 +955,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.scale_o = scale_o;
|
||||
|
||||
args.stride_bias =
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias;
|
||||
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_bias = nhead_stride_bias;
|
||||
args.nhead_stride_lse = nhead_stride_lse;
|
||||
@@ -966,10 +986,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
|
||||
|
||||
args.block_table_ptr =
|
||||
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr;
|
||||
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
args.page_block_size = page_block_size;
|
||||
|
||||
args.cache_batch_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
||||
|
||||
args.num_splits = num_splits;
|
||||
|
||||
args.stride_o_acc = stride_o_acc;
|
||||
@@ -1001,7 +1024,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
const float fwd_ave_time = [&] {
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(1 < num_splits || 0 < seqlen_knew || 0 < page_block_size)
|
||||
if(1 < num_splits || 0 < seqlen_knew || use_cache_batch_idx || 0 < page_block_size)
|
||||
{
|
||||
fmha_fwd_splitkv_traits fmha_splitkv_traits;
|
||||
init_traits(fmha_splitkv_traits);
|
||||
@@ -1074,7 +1097,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch
|
||||
@@ -1094,8 +1117,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); });
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
// optionally apply RoPE to the q_host_ref
|
||||
@@ -1126,8 +1149,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
@@ -1197,14 +1220,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if(is_v_rowmajor) {
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
}
|
||||
else
|
||||
{
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); });
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); });
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[0] / nr, i[1], i[2] + key_offset); });
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[1], i[0] / nr, i[2] + key_offset); });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1361,7 +1384,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
@@ -1378,8 +1401,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
||||
// clang-format off
|
||||
// permute
|
||||
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
|
||||
@@ -162,6 +162,8 @@ struct fmha_fwd_splitkv_args
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx;
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not
|
||||
@@ -237,6 +239,8 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_knew;
|
||||
@@ -374,9 +378,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
@@ -420,6 +421,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
@@ -537,6 +539,7 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
|
||||
@@ -98,10 +98,6 @@ struct FmhaFwdAppendKVKernel
|
||||
// if this param is larger than 1, indicate MQA/GQA case
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
|
||||
const void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
ck_tile::index_t page_block_size;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_knew;
|
||||
@@ -128,7 +124,21 @@ struct FmhaFwdAppendKVKernel
|
||||
ck_tile::index_t rotary_dim;
|
||||
};
|
||||
|
||||
struct Kargs : BasicKargs, std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
|
||||
struct PageBlockTableKargs
|
||||
{
|
||||
const int32_t* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
ck_tile::index_t page_block_size;
|
||||
};
|
||||
|
||||
struct CacheBatchIdxKargs
|
||||
{
|
||||
const int32_t* cache_batch_idx;
|
||||
};
|
||||
|
||||
struct Kargs : BasicKargs,
|
||||
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>,
|
||||
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -150,6 +160,7 @@ struct FmhaFwdAppendKVKernel
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
const void* cache_batch_idx,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
@@ -180,9 +191,6 @@ struct FmhaFwdAppendKVKernel
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
block_table_ptr,
|
||||
batch_stride_block_table,
|
||||
page_block_size,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_knew,
|
||||
@@ -198,7 +206,8 @@ struct FmhaFwdAppendKVKernel
|
||||
batch_stride_knew,
|
||||
batch_stride_v,
|
||||
batch_stride_vnew}, // args for common karg
|
||||
{} // placeholder for rope
|
||||
{}, // placeholder for rope
|
||||
{} // placeholder for paged-block table or cache_batch_idx
|
||||
};
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
@@ -208,6 +217,17 @@ struct FmhaFwdAppendKVKernel
|
||||
kargs.rotary_dim = rotary_dim;
|
||||
}
|
||||
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
|
||||
kargs.batch_stride_block_table = batch_stride_block_table;
|
||||
kargs.page_block_size = page_block_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,8 @@ struct FmhaFwdSplitKVKernel
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
|
||||
"paged-kvcache only supported by batch mode kernels");
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
@@ -120,10 +122,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
ck_tile::index_t num_splits;
|
||||
|
||||
const void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
ck_tile::index_t page_block_size;
|
||||
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
@@ -175,6 +173,18 @@ struct FmhaFwdSplitKVKernel
|
||||
float scale_p;
|
||||
};
|
||||
|
||||
struct PageBlockTableKargs
|
||||
{
|
||||
const int32_t* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
ck_tile::index_t page_block_size;
|
||||
};
|
||||
|
||||
struct CacheBatchIdxKargs
|
||||
{
|
||||
const int32_t* cache_batch_idx;
|
||||
};
|
||||
|
||||
struct BatchModeKargs
|
||||
: CommonKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
@@ -183,7 +193,8 @@ struct FmhaFwdSplitKVKernel
|
||||
AlibiKargs,
|
||||
EmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
|
||||
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
|
||||
{
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
@@ -232,6 +243,7 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
const void* cache_batch_idx,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
ck_tile::index_t stride_q,
|
||||
@@ -270,9 +282,6 @@ struct FmhaFwdSplitKVKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_splits,
|
||||
block_table_ptr,
|
||||
batch_stride_block_table,
|
||||
page_block_size,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -294,6 +303,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for paged-block table or cache_batch_idx
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -321,6 +331,16 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.scale_p = scale_p;
|
||||
}
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
|
||||
kargs.batch_stride_block_table = batch_stride_block_table;
|
||||
kargs.page_block_size = page_block_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -342,9 +362,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t num_splits,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
ck_tile::index_t stride_q,
|
||||
@@ -381,9 +398,6 @@ struct FmhaFwdSplitKVKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_splits,
|
||||
block_table_ptr,
|
||||
batch_stride_block_table,
|
||||
page_block_size,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user