Pass cache_batch_idx to kernels

This commit is contained in:
PoYen, Chen
2024-08-16 15:32:24 +00:00
parent e6239e14f7
commit 9c904b0e4c
5 changed files with 108 additions and 45 deletions

View File

@@ -634,6 +634,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if pipeline.F_pagedkv == 't':
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k = Kernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,

View File

@@ -125,7 +125,8 @@ auto create_args(int argc, char* argv[])
.insert("num_splits",
"1",
"# of splits for key/value. 0 to determine actual number by heuristic")
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe.")
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
@@ -306,6 +307,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false;
}
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
if(0 < page_block_size && use_cache_batch_idx)
{
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch)
{
@@ -589,11 +599,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || 0 < seqlen_knew || 0 < page_block_size
1 < num_splits || 0 < seqlen_knew || use_cache_batch_idx || 0 < page_block_size
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || 0 < seqlen_knew || 0 < page_block_size
1 < num_splits || 0 < seqlen_knew || use_cache_batch_idx || 0 < page_block_size
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
@@ -613,6 +623,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_page_blocks / batch}
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
? std::array<ck_tile::index_t, 1>{batch}
: std::array<ck_tile::index_t, 1>{1});
if(init_method == "ui" || init_method == "0")
{
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
@@ -691,6 +705,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
@@ -711,6 +726,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -727,6 +743,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
rotary_sin_buf.ToDevice(rotary_sin_host.data());
alibi_slope_buf.ToDevice(alibi_slope_host.data());
block_table_buf.ToDevice(block_table_host.data());
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
// clang-format off
auto layout_str = [&](bool permute){
@@ -901,10 +918,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.rotary_dim = rotary_dim;
args.block_table_ptr =
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr;
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size;
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.stride_knew = stride_knew;
args.stride_vnew = stride_vnew;
args.nhead_stride_knew = nhead_stride_knew;
@@ -935,7 +955,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.scale_o = scale_o;
args.stride_bias =
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias;
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
args.stride_o = stride_o;
args.nhead_stride_bias = nhead_stride_bias;
args.nhead_stride_lse = nhead_stride_lse;
@@ -966,10 +986,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
args.block_table_ptr =
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr;
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size;
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.num_splits = num_splits;
args.stride_o_acc = stride_o_acc;
@@ -1001,7 +1024,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || 0 < seqlen_knew || 0 < page_block_size)
if(1 < num_splits || 0 < seqlen_knew || use_cache_batch_idx || 0 < page_block_size)
{
fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits);
@@ -1074,7 +1097,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset =
(mode == mode_enum::batch
@@ -1094,8 +1117,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format off
// permute
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); });
#if CK_TILE_FMHA_FWD_APPENDKV_API
// optionally apply RoPE to the q_host_ref
@@ -1126,8 +1149,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
} else
#endif
{
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
@@ -1197,14 +1220,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
if(is_v_rowmajor) {
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
}
else
{
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); });
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); });
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[0] / nr, i[1], i[2] + key_offset); });
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[1], i[0] / nr, i[2] + key_offset); });
}
}
@@ -1361,7 +1384,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
{nhead, real_seqlen_q, real_seqlen_k});
randval_host_ref.ForEach([&](auto& self, auto idx) {
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]);
});
ck_tile::reference_batched_dropout(
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
@@ -1378,8 +1401,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off
// permute
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); });
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); });
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);

View File

@@ -162,6 +162,8 @@ struct fmha_fwd_splitkv_args
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not
@@ -237,6 +239,8 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
@@ -374,9 +378,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.scale_s,
args.scale_p,
args.stride_q,
@@ -420,6 +421,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s,
args.scale_p,
args.stride_q,
@@ -537,6 +539,7 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.stride_q,
args.stride_k,
args.stride_knew,

View File

@@ -98,10 +98,6 @@ struct FmhaFwdAppendKVKernel
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
const void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
@@ -128,7 +124,21 @@ struct FmhaFwdAppendKVKernel
ck_tile::index_t rotary_dim;
};
struct Kargs : BasicKargs, std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
struct PageBlockTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
};
struct CacheBatchIdxKargs
{
const int32_t* cache_batch_idx;
};
struct Kargs : BasicKargs,
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
{
};
@@ -150,6 +160,7 @@ struct FmhaFwdAppendKVKernel
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
const void* cache_batch_idx,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
@@ -180,9 +191,6 @@ struct FmhaFwdAppendKVKernel
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,
@@ -198,7 +206,8 @@ struct FmhaFwdAppendKVKernel
batch_stride_knew,
batch_stride_v,
batch_stride_vnew}, // args for common karg
{} // placeholder for rope
{}, // placeholder for rope
{} // placeholder for paged-block table or cache_batch_idx
};
if constexpr(kApplyRoPE)
@@ -208,6 +217,17 @@ struct FmhaFwdAppendKVKernel
kargs.rotary_dim = rotary_dim;
}
if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
}
else
{
kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
}
return kargs;
}

View File

@@ -46,6 +46,8 @@ struct FmhaFwdSplitKVKernel
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
"paged-kvcache only supported by batch mode kernels");
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
@@ -120,10 +122,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_ratio_qk;
ck_tile::index_t num_splits;
const void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
float scale_s;
ck_tile::index_t stride_q;
@@ -175,6 +173,18 @@ struct FmhaFwdSplitKVKernel
float scale_p;
};
struct PageBlockTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
};
struct CacheBatchIdxKargs
{
const int32_t* cache_batch_idx;
};
struct BatchModeKargs
: CommonKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
@@ -183,7 +193,8 @@ struct FmhaFwdSplitKVKernel
AlibiKargs,
EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
{
const int32_t* seqlen_k_ptr;
@@ -232,6 +243,7 @@ struct FmhaFwdSplitKVKernel
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
const void* cache_batch_idx,
float scale_s,
float scale_p,
ck_tile::index_t stride_q,
@@ -270,9 +282,6 @@ struct FmhaFwdSplitKVKernel
num_head_q,
nhead_ratio_qk,
num_splits,
block_table_ptr,
batch_stride_block_table,
page_block_size,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
@@ -294,6 +303,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for fp8_static_quant args
{}, // placeholder for paged-block table or cache_batch_idx
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q,
batch_stride_k,
@@ -321,6 +331,16 @@ struct FmhaFwdSplitKVKernel
{
kargs.scale_p = scale_p;
}
if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
}
else
{
kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
}
return kargs;
}
@@ -342,9 +362,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
float scale_s,
float scale_p,
ck_tile::index_t stride_q,
@@ -381,9 +398,6 @@ struct FmhaFwdSplitKVKernel
num_head_q,
nhead_ratio_qk,
num_splits,
block_table_ptr,
batch_stride_block_table,
page_block_size,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else