Add paged-kv codegen logic for appendkv kernels

This commit is contained in:
PoYen, Chen
2024-08-07 04:19:45 +00:00
parent b98985262d
commit 15d0034a64
10 changed files with 131 additions and 109 deletions

View File

@@ -29,6 +29,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_dpad},
{F_dvpad},
{F_rope},
{F_pagedkv},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
@@ -51,7 +52,7 @@ using fmha_kernel_{F_idx} =
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>;
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
@@ -78,8 +79,9 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@@ -100,11 +102,13 @@ class FmhaFwdAppendKVApiTrait:
dpad : str
dvpad : str
rope : str # key from ROPE_MAP
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}'
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-'+\
f'{self.pagedkv}'
@property
def scheck(self) -> str:
@@ -136,6 +140,7 @@ class FmhaFwdAppendKVPipeline:
F_dpad : str #
F_dvpad : str #
F_rope : str # key from ROPE_MAP
F_pagedkv : str # t/f
@property
def name(self) -> str:
@@ -151,6 +156,7 @@ class FmhaFwdAppendKVPipeline:
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope != 'no': n += f'_{self.F_rope}'
if self.F_pagedkv == 't': n += f'_pagedkv'
return n
class FmhaFwdAppendKVApiPool:
@@ -179,7 +185,7 @@ class FmhaFwdAppendKVApiPool:
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
@@ -227,6 +233,7 @@ class FmhaFwdAppendKVKernel:
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy,
F_mode = MODE_MAP[self.F_mode])
@@ -254,7 +261,8 @@ class FmhaFwdAppendKVKernel:
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope)
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@@ -289,17 +297,18 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 'f', 'f', 'no'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no'))
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'inter'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'half'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
elif dtype in ['fp8', 'bf8']:
# rope is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no'))
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
else:
assert False
return pipelines

View File

@@ -546,9 +546,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_ks[0]
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
#ifdef ENABLE_HOST_DEBUG_MSG
std::cerr << "[HOST] num_blocks: " << max_num_blocks << std::endl;
#endif
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
@@ -798,48 +798,52 @@ bool run(const ck_tile::ArgParser& arg_parser)
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}();
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_block_table = (max_num_blocks / batch);
return fmha_fwd_appendkv_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
knew_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
vnew_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
cache_seqlen_k_buf.GetDeviceBuffer(),
batch,
nhead,
nhead_k,
shape_seqlen_q,
max_seqlen_q,
shape_seqlen_k -
seqlen_knew /* kvcache seqlen_k for batch mode */,
seqlen_knew,
hdim_q,
hdim_v,
rotary_cos_buf.GetDeviceBuffer(),
rotary_sin_buf.GetDeviceBuffer(),
rotary_dim,
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew};
return fmha_fwd_appendkv_args{
q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
knew_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
vnew_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
cache_seqlen_k_buf.GetDeviceBuffer(),
batch,
nhead,
nhead_k,
shape_seqlen_q,
max_seqlen_q,
shape_seqlen_k - seqlen_knew /* kvcache seqlen_k for batch mode */,
seqlen_knew,
hdim_q,
hdim_v,
rotary_cos_buf.GetDeviceBuffer(),
rotary_sin_buf.GetDeviceBuffer(),
rotary_dim,
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr,
batch_stride_block_table, // only used if 'block_table_ptr' is not nullptr
page_block_size, // only used if 'block_table_ptr' is not nullptr
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew};
}();
appendkv_ave_time = fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);

View File

@@ -100,12 +100,15 @@ struct fmha_fwd_args
void* o_acc_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -179,6 +182,10 @@ struct fmha_fwd_appendkv_args
const void* rotary_sin_ptr;
ck_tile::index_t rotary_dim;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
@@ -688,7 +695,8 @@ template <ck_tile::index_t HDim_,
bool kPadSk_,
bool kPadD_,
bool kPadDv_,
ck_tile::RotaryEmbeddingEnum RotaryEnum_>
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_>
struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -704,6 +712,7 @@ struct fmha_fwd_appendkv_traits_
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>

View File

@@ -32,7 +32,6 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"

View File

@@ -32,6 +32,7 @@ struct FmhaFwdAppendKVKernel
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
// clang-format off
template <typename T> struct t2s;
@@ -62,7 +63,8 @@ struct FmhaFwdAppendKVKernel
"b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
_TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name))
+ (kIsPagedKV ? "_pagedkv" : "" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -95,7 +97,11 @@ struct FmhaFwdAppendKVKernel
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
/*
const void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
*/
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;

View File

@@ -124,6 +124,7 @@ struct FmhaFwdSplitKVKernel
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
ck_tile::index_t num_splits;
const void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;

View File

@@ -33,6 +33,7 @@ struct BlockFmhaFwdAppendKVPipeline
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto RotaryEnum = Problem::RotaryEnum;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this

View File

@@ -1,48 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
index_t kM0_,
index_t kN0_,
index_t kK0_,
index_t kN1_,
bool IsVLayoutRowMajor_,
bool kIsGroupMode_,
typename Traits_>
struct BlockFmhaFwdAppendKVPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kK0 = kK0_;
static constexpr index_t kN1 = kN1_;
using VLayout = std::conditional_t<IsVLayoutRowMajor_,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto RotaryEnum = Traits::RotaryEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile

View File

@@ -120,4 +120,43 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
index_t kM0_,
index_t kN0_,
index_t kK0_,
index_t kN1_,
bool IsVLayoutRowMajor_,
bool kIsGroupMode_,
typename Traits_>
struct BlockFmhaFwdAppendKVPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kK0 = kK0_;
static constexpr index_t kN1 = kN1_;
using VLayout = std::conditional_t<IsVLayoutRowMajor_,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto RotaryEnum = Traits::RotaryEnum;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile

View File

@@ -84,6 +84,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
RotaryEmbeddingEnum RotaryEnum_, /* how we apply the rotary embedding */
bool kIsPagedKV_, /* whether use paged-kvcache */
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdAppendKVTraits
{
@@ -92,6 +93,7 @@ struct TileFmhaFwdAppendKVTraits
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};