diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 5093945095..64643f157e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -53,6 +53,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_lse}, {F_dropout}, {F_squant}, + {F_pagedkv}, kHasUnevenSplits, {F_occupancy}>; @@ -97,8 +98,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a) }}; }} -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_dvpad}>; #include @@ -225,14 +227,22 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); }} """ +@dataclass +class FmhaFwdSplitKVApiTrait(FmhaFwdApiTrait): + pagedkv : str + + @property + def name(self) -> str: + return FmhaFwdApiTrait.name + f'-{self.pagedkv}' + @dataclass class FmhaFwdSplitKVPipeline: tag : str @@ -246,6 +256,7 @@ class FmhaFwdSplitKVPipeline: F_lse : str # F_dropout : str # F_squant : str # + F_pagedkv : str # t/f F_mask : str # value from MASK_MAP @property @@ -269,6 +280,7 @@ class FmhaFwdSplitKVPipeline: if self.F_lse == 't' : n += '_lse' if self.F_dropout == 't' : n += '_dropout' if self.F_squant == 't' : n += '_squant' + if self.F_pagedkv == 't' : n += '_pagedkv' return n @dataclass @@ -300,7 +312,7 @@ class FmhaFwdSplitKVApiPool: self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdApiTrait) -> None: + def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -322,8 +334,8 @@ class FmhaFwdSplitKVApiPool: inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) @@ -385,6 +397,7 @@ class FmhaFwdSplitKVKernel: F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -401,8 +414,8 @@ class FmhaFwdSplitKVKernel: def filename(self) -> str: return self.name + ".cpp" - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( + def api_trait(self) -> FmhaFwdSplitKVApiTrait: + return FmhaFwdSplitKVApiTrait( pipeline_tag=self.F_pipeline.tag, hdim=str(self.F_hdim), dtype=self.F_dtype, @@ -419,6 +432,7 @@ class FmhaFwdSplitKVKernel: lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, squant=self.F_pipeline.F_squant, + pagedkv=self.F_pipeline.F_pagedkv, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, @@ -460,29 +474,6 @@ class FmhaFwdSplitKVCombineKernel: def filename(self) -> str: return self.name + ".cpp" - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) - # TODO: design a more practical way to do it # this is current supported tile size per hdim def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: @@ -534,26 +525,26 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> pipelines = [] if dtype in ['fp16', 'bf16']: # splitkv kernel donot support dropout - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]): + for mask, bias, lse, dropout, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - # no need lse/dropout kernels + # no need lse/dropout/paged-kv kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, 'f', mask)) else: assert False return pipelines diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8b02e27322..984081b9c9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -125,6 +125,7 @@ auto create_args(int argc, char* argv[]) .insert( "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") + .insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe.") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -257,7 +258,7 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits, const ck_tile::stream_config& config) { #if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < args.num_splits) + if(1 < args.num_splits || args.block_table_ptr != nullptr) { return fmha_fwd_splitkv(traits, args, config); } @@ -415,9 +416,11 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } +#if CK_TILE_FMHA_FWD_APPENDKV_API const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); +#endif - int num_splits = arg_parser.get_int("num_splits"); + ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); #if !CK_TILE_FMHA_FWD_SPLITKV_API if(num_splits != 1) { @@ -425,6 +428,23 @@ bool run(const ck_tile::ArgParser& arg_parser) num_splits = 1; } #endif + ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); +#if !CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) + { + std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" + << std::endl; + page_block_size = 0; + } +#endif + if(!(page_block_size % 256 == 0)) + { + std::cerr << "only paged-kvcache block size divisible by 256 are currently supported" + << std::endl; + return false; + } + +#define ENABLE_PAGED_KVCACHE 0 int stream_warmup = arg_parser.get_int("warmup"); int stream_repeat = arg_parser.get_int("repeat"); @@ -486,6 +506,11 @@ bool run(const ck_tile::ArgParser& arg_parser) } } + const ck_tile::index_t max_num_blocks = + (ENABLE_PAGED_KVCACHE && 0 < page_block_size + ? batch * std::min(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size)) + : 0); + // legalize num_splits according to other options if(num_splits < 1) { @@ -520,18 +545,26 @@ bool run(const ck_tile::ArgParser& arg_parser) : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() : seqstart_k_with_padding_host.back())); + std::cerr << "[POYENC] num_blocks: " << max_num_blocks << std::endl; + ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor k_host( - get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ENABLE_PAGED_KVCACHE && 0 < page_block_size + ? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_q) + : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode ck_tile::HostTensor knew_host( 0 < seqlen_knew ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor v_host( - is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) - : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); + ENABLE_PAGED_KVCACHE && 0 < page_block_size + ? (is_v_rowmajor + ? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_v) + : get_lengths(i_perm, max_num_blocks, nhead_k, hdim_v, page_block_size)) + : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); ck_tile::HostTensor vnew_host( 0 < seqlen_knew ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) @@ -571,6 +604,11 @@ bool run(const ck_tile::ArgParser& arg_parser) p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1}); + ck_tile::HostTensor block_table_host( + ENABLE_PAGED_KVCACHE && 0 < page_block_size + ? std::array{batch, max_num_blocks / batch} + : std::array{1, 1}); + if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); @@ -648,6 +686,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } } } + iota_shuffle(block_table_host.begin(), block_table_host.end(), 0); ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); @@ -667,6 +706,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -682,6 +722,7 @@ bool run(const ck_tile::ArgParser& arg_parser) rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data()); + block_table_buf.ToDevice(block_table_host.data()); // clang-format off auto layout_str = [&](bool permute){ @@ -842,7 +883,9 @@ bool run(const ck_tile::ArgParser& arg_parser) if(is_v_rowmajor) return i_perm ? hdim_v : nhead_k * hdim_v; else - return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; + return ENABLE_PAGED_KVCACHE && 0 < page_block_size + ? (i_perm ? page_block_size : nhead_k * page_block_size) + : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); @@ -850,12 +893,18 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (ENABLE_PAGED_KVCACHE && 0 < page_block_size + ? (i_perm ? page_block_size * hdim_q : hdim_q) + : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); const ck_tile::index_t nhead_stride_v = [&]() { if(is_v_rowmajor) - return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + return ENABLE_PAGED_KVCACHE && 0 < page_block_size + ? (i_perm ? page_block_size * hdim_v : hdim_v) + : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); else - return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + return ENABLE_PAGED_KVCACHE && 0 < page_block_size + ? (i_perm ? hdim_v * page_block_size : page_block_size) + : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); @@ -865,78 +914,86 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); - const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = + (ENABLE_PAGED_KVCACHE && 0 < page_block_size ? (nhead_k * page_block_size * hdim_q) + : (nhead_k * shape_seqlen_k * hdim_q)); + const ck_tile::index_t batch_stride_v = + (ENABLE_PAGED_KVCACHE && 0 < page_block_size ? (nhead_k * hdim_v * page_block_size) + : (nhead_k * hdim_v * shape_seqlen_k)); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_block_table = (max_num_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); - return fmha_fwd_args{q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() - : bias_buf.GetDeviceBuffer(), - randval_buf.GetDeviceBuffer(), - lse_acc_buf.GetDeviceBuffer(), - o_acc_buf.GetDeviceBuffer(), - lse_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - num_splits, - scale_s, - scale_p, - scale_o, - stride_q, - stride_k, - stride_v, - bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) - : stride_bias, - stride_randval, - stride_o_acc, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_lse_acc, - nhead_stride_o_acc, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_lse_acc, - batch_stride_o_acc, - batch_stride_o, - split_stride_lse_acc, - split_stride_o_acc, - mask.left, - mask.right, - static_cast(mask.type), - p_drop, - s_randval, - {drop_seed, drop_offset}}; + return fmha_fwd_args{ + q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), + 1 < num_splits ? lse_acc_buf.GetDeviceBuffer() : nullptr, + 1 < num_splits ? o_acc_buf.GetDeviceBuffer() : nullptr, + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), + 0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr, + batch_stride_block_table, // only used if 'block_table_ptr' is not nullptr + page_block_size, // only used if 'block_table_ptr' is not nullptr + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + hdim_q, + hdim_v, + nhead, + nhead_k, + num_splits, // only used in splitkv kernel + scale_s, + scale_p, + scale_o, + stride_q, + stride_k, + stride_v, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias, + stride_randval, + stride_o_acc, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_lse_acc, + batch_stride_o_acc, + batch_stride_o, + split_stride_lse_acc, + split_stride_o_acc, + mask.left, + mask.right, + static_cast(mask.type), + p_drop, + s_randval, + {drop_seed, drop_offset}}; }(); const float fwd_ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); @@ -1018,8 +1075,20 @@ bool run(const ck_tile::ArgParser& arg_parser) } #endif - if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); - else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + if (ENABLE_PAGED_KVCACHE && 0 < page_block_size) { + if(i_perm) { + k_host_ref.ForEach([&](auto& self, auto i) { + self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); + }); + } else { + k_host_ref.ForEach([&](auto& self, auto i) { + self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); + }); + } + } else { + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + } #if CK_TILE_FMHA_FWD_APPENDKV_API // copy Knew to the end of K @@ -1058,16 +1127,40 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } #endif - - if (is_v_rowmajor) { - // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); - // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); - } - else { - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); + if (ENABLE_PAGED_KVCACHE && 0 < page_block_size) { + if (is_v_rowmajor) { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); + }); + } else { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); + }); + } + } + else { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); + }); + } else { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); + }); + } + } + } else { + if (is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); + } } #if CK_TILE_FMHA_FWD_APPENDKV_API diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 6c4edf3827..546b70a4b0 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -103,6 +103,9 @@ struct fmha_fwd_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + ck_tile::index_t page_block_size; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -315,6 +318,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, args.scale_s, args.scale_p, args.stride_q, @@ -359,6 +365,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, args.scale_s, args.scale_p, args.stride_q, @@ -585,6 +594,51 @@ struct fmha_fwd_traits_ template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +struct fmha_fwd_splitkv_traits_ : fmha_fwd_traits_ +{ + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + template void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index 4af4e6959e..22bac8cca1 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -3,16 +3,17 @@ #pragma once +#include #include #include +#include #include #include #include +#include #include #include #include -#include -#include #include "ck_tile/core/container/span.hpp" @@ -209,3 +210,15 @@ auto randint(Int low, Int high, std::optional seed = std::nullopt) std::uniform_int_distribution dist(low, high); return dist(engine); } + +template +std::enable_if_t> iota_shuffle(RandomAccessIterator first, + RandomAccessIterator last, + Int value, + std::optional seed = std::nullopt) +{ + std::iota(first, last, value); + + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::shuffle(first, last, engine); +} diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 45ed185ada..6bf68d2c6d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -48,6 +48,7 @@ struct FmhaFwdSplitKVKernel static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -122,6 +123,9 @@ struct FmhaFwdSplitKVKernel // if this param is larger than 1, indicate MQA/GQA case ck_tile::index_t nhead_ratio_qk; ck_tile::index_t num_splits; + const void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + ck_tile::index_t page_block_size; float scale_s; @@ -254,6 +258,9 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, float scale_s, float scale_p, ck_tile::index_t stride_q, @@ -299,6 +306,9 @@ struct FmhaFwdSplitKVKernel num_head_q, nhead_ratio_qk, num_splits, + block_table_ptr, + batch_stride_block_table, + page_block_size, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), #else @@ -379,6 +389,9 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, float scale_s, float scale_p, ck_tile::index_t stride_q, @@ -419,6 +432,9 @@ struct FmhaFwdSplitKVKernel num_head_q, nhead_ratio_qk, num_splits, + block_table_ptr, + batch_stride_block_table, + page_block_size, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), #else @@ -565,8 +581,24 @@ struct FmhaFwdSplitKVKernel else { batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + + if(true || kargs.block_table_ptr == nullptr) + { + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + } + else + { + const auto* block_table = reinterpret_cast(kargs.block_table_ptr) + + i_batch * kargs.batch_stride_block_table; + const auto i_block = + static_cast(block_table[i_n1 / kargs.page_block_size]); + + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + // batch_offset_k = i_block * kargs.batch_stride_k; + batch_offset_v = i_block * kargs.batch_stride_v; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index a6d74b3885..ed14fa8bbe 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -51,6 +51,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; // last dimension vector length used to create tensor view(and decide buffer_load vector length) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp index ae363a4978..da9a973bba 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp @@ -56,6 +56,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; // last dimension vector length used to create tensor view(and decide buffer_load vector length) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 23b75f16ac..2ab21c56bd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -85,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem { + static constexpr bool kIsPagedKV = Traits::kIsPagedKV; static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; }; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index e11b404b99..3f1fcf72de 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -42,8 +42,9 @@ template + bool kIsPagedKV_, + bool kHasUnevenSplits_, + index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */> struct TileFmhaFwdSplitKVTraits : TileFmhaTraits { + static constexpr bool kIsPagedKV = kIsPagedKV_; // determine if some split (length) is not divisible by tile size static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; };