mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Add SplitKV combine kernel codegen logics
This commit is contained in:
@@ -334,7 +334,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <typename FmhaFwdSplitKVCombineKernel>
|
||||
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
@@ -343,31 +342,18 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
// create group mode kernel argumentszs
|
||||
if constexpr(FmhaFwdSplitKVCombineKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaFwdSplitKVCombineKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
return FmhaFwdSplitKVCombineKernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.nhead,
|
||||
args.max_seqlen_q,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o);
|
||||
}
|
||||
@@ -395,7 +381,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
dim3 grids = FmhaFwdSplitKVCombineKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
#endif
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
@@ -451,6 +436,12 @@ float fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_get_name_();
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_();
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_fwd_traits
|
||||
{
|
||||
|
||||
@@ -249,6 +249,86 @@ std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
|
||||
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
|
||||
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
fmha_block_warps_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy},
|
||||
16>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<fmha_shape_{F_idx}>,
|
||||
fmha_pipeline_{F_idx},
|
||||
fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
||||
@@ -288,19 +368,18 @@ FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
|
||||
FMHA_FWD_SPLITKV_API="""
|
||||
#include <iostream>
|
||||
|
||||
template<typename fmha_fwd_splitkv_trait_
|
||||
// , typename fmha_fwd_splitkv_combine_trait_
|
||||
>
|
||||
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
|
||||
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_trait_>()
|
||||
// << ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_trait_>()
|
||||
std::cout
|
||||
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
|
||||
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_trait_>(s_, a); }}
|
||||
// , [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<dq_dk_dv_trait_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
|
||||
@@ -313,8 +392,8 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_fwd_splitkv_<trait_>(s, a);
|
||||
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_fwd_splitkv_<traits_, traits_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -713,6 +792,85 @@ class FmhaFwdSplitKVKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVCombineKernel:
|
||||
direction : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0blen = self.F_tile.F_bk0blen,
|
||||
F_rm = self.F_tile.F_rm,
|
||||
F_rn = self.F_tile.F_rn,
|
||||
F_rk = self.F_tile.F_rk,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_{self.direction}_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0blen=self.F_tile.F_bk0blen,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
|
||||
@@ -812,6 +970,74 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, is_splitkv
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> List[FmhaFwdSplitKVCombineKernel]:
|
||||
Pipeline = FmhaFwdSplitKVPipeline
|
||||
Kernel = FmhaFwdSplitKVCombineKernel
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
|
||||
for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()):
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
k = Kernel(direction=direction,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
BWD_DQDKDV_PIPELINE_MAP = {
|
||||
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
|
||||
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
|
||||
@@ -1386,7 +1612,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
|
||||
|
||||
return gen
|
||||
|
||||
def write_single_fwd_kernel(kernel: Union[FmhaFwdKernel, FmhaFwdSplitKVKernel], autogen_dir: Path) -> None:
|
||||
def write_single_fwd_kernel(kernel: Union[FmhaFwdKernel, FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
@@ -1416,6 +1642,9 @@ def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optio
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
# write split-kv blobs
|
||||
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
@@ -1441,6 +1670,9 @@ def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Opt
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
|
||||
# get split-kv blobs
|
||||
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
@@ -26,7 +28,6 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
|
||||
@@ -0,0 +1,439 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdSplitKVCombineKernel
|
||||
{
|
||||
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
__host__ static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
using bfs = typename FmhaPipeline::BlockFmhaShape;
|
||||
using gbr = typename bfs::Gemm0BlockWarps;
|
||||
using gwt = typename bfs::Gemm0WarpTile;
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
std::string n;
|
||||
if (kPadSeqLenQ) n += "s";
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_fwd_splitkv_combine_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
struct EmptyKargs
|
||||
{
|
||||
};
|
||||
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct CommonKargs
|
||||
{
|
||||
const void* lse_acc_ptr;
|
||||
const void* o_acc_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t nhead;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t num_splits;
|
||||
|
||||
ck_tile::index_t row_stride_o;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
};
|
||||
|
||||
struct CommonLSEKargs
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
ck_tile::index_t nhead_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct Fp8StaticQuantKargs
|
||||
{
|
||||
float scale_o;
|
||||
};
|
||||
|
||||
struct BatchModeLSEKargs : CommonLSEKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct BatchModeKargs
|
||||
: CommonKargs,
|
||||
std::conditional_t<kStoreLSE, BatchModeLSEKargs, EmptyKargs<0>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_o;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
: CommonKargs,
|
||||
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<3>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* lse_acc_ptr,
|
||||
const void* o_acc_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
float scale_o,
|
||||
ck_tile::index_t row_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o)
|
||||
{
|
||||
Kargs kargs{{lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
o_ptr,
|
||||
batch,
|
||||
nhead,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
hdim_v,
|
||||
num_splits,
|
||||
row_stride_o,
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
batch_stride_o};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
kargs.scale_o = scale_o;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* lse_acc_ptr,
|
||||
const void* o_acc_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
float scale_o,
|
||||
ck_tile::index_t row_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o)
|
||||
{
|
||||
Kargs kargs{{lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
o_ptr,
|
||||
batch,
|
||||
nhead,
|
||||
max_seqlen_q,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_v,
|
||||
num_splits,
|
||||
row_stride_o,
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
kargs.scale_o = scale_o;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.hdim_v;
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
|
||||
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const LSEDataType* lse_acc_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * (kargs.max_seqlen_q) +
|
||||
batch_offset_lse_acc;
|
||||
const OaccDataType* o_acc_ptr =
|
||||
reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * (kargs.max_seqlen_q * kargs.hdim_v) +
|
||||
batch_offset_o_acc;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
// LSEacc/Oacc DRAM and DRAM windows
|
||||
const auto lse_acc_dram = [&]() {
|
||||
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q),
|
||||
make_tuple(kargs.batch * kargs.nhead * kargs.max_seqlen_q, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
lse_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{}),
|
||||
sequence<true, kPadSeqLenQ>{});
|
||||
}();
|
||||
|
||||
auto o_acc_dram = [&]() {
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
|
||||
make_tuple(
|
||||
kargs.batch * kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v, kargs.hdim_v, 1),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
auto o_acc_dram_view = pad_tensor_view(
|
||||
o_acc_dram_naive,
|
||||
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
|
||||
|
||||
const index_t new_seqlen_q =
|
||||
integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * FmhaPipeline::kM0;
|
||||
const index_t new_hdim_v =
|
||||
integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1) * FmhaPipeline::kN1;
|
||||
|
||||
return transform_tensor_view(
|
||||
o_acc_dram_view,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, new_seqlen_q)),
|
||||
make_pass_through_transform(new_hdim_v)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}();
|
||||
|
||||
auto lse_acc_dram_window = make_tile_window(
|
||||
lse_acc_dram,
|
||||
[&]() {
|
||||
return make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{});
|
||||
}(),
|
||||
{0, i_m0});
|
||||
|
||||
auto o_acc_dram_window = make_tile_window(
|
||||
o_acc_dram,
|
||||
[&]() {
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{});
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
|
||||
// LSE DRAM window
|
||||
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
LSEDataType* lse_ptr =
|
||||
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
|
||||
|
||||
const auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_ptr,
|
||||
make_tuple(kargs.seqlen_q),
|
||||
make_tuple(1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(lse_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
return FmhaPipeline{}(
|
||||
lse_acc_dram_window,
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
smem_ptr,
|
||||
kargs.num_splits,
|
||||
kargs.max_seqlen_q);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(lse_acc_dram_window,
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
smem_ptr,
|
||||
kargs.num_splits,
|
||||
kargs.max_seqlen_q);
|
||||
}
|
||||
}();
|
||||
|
||||
// O DRAM and DRAM window
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.row_stride_o, 1),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
o_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
{i_m0, 0});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaFwdSplitKVCombineTilePartitioner
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN0;
|
||||
// constexpr static ck_tile::index_t kBlockM = kN1 % 128 == 0 ? 4 : (kN1 % 64 == 0 ? 8 : 16);
|
||||
|
||||
__host__ static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q)
|
||||
{
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0), nhead, batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
{
|
||||
const index_t i_tile_m = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,433 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
||||
struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr index_t kMaxSplits = Problem::kMaxSplits;
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kK0BlockLength <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 128)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
/// TODO: add padding to avoid bank conflict
|
||||
return (kM0 * kMaxSplits * sizeof(LSEDataType));
|
||||
}
|
||||
|
||||
#define MARKER(msg) \
|
||||
__builtin_amdgcn_sched_barrier(0); \
|
||||
asm volatile("; [POYENC] " msg ::); \
|
||||
__builtin_amdgcn_sched_barrier(0)
|
||||
|
||||
template <typename LSEaccDramBlockWindowTmp,
|
||||
typename OaccDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename LSEElementFunction,
|
||||
typename OaccElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp,
|
||||
const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp,
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const OaccElementFunction& o_acc_element_func,
|
||||
void* smem_ptr,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q) const
|
||||
{
|
||||
LSEDataType* lse_acc_lds_ptr =
|
||||
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
|
||||
auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution<Problem>();
|
||||
auto lse_acc_dram_window =
|
||||
make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
lse_acc_dram_block_window_tmp.get_window_origin(),
|
||||
lse_acc_dist);
|
||||
|
||||
auto lse_acc = load_tile(lse_acc_dram_window); // [kMaxSplits, kM0]
|
||||
|
||||
#if defined(ENABLE_DEBUG_STMTS)
|
||||
#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
|
||||
#else
|
||||
#define DEBUG_STMTS if(false)
|
||||
#endif
|
||||
// copy lse_acc to LDS
|
||||
{
|
||||
using DataType = LSEDataType;
|
||||
using StaticTileDistribution = decltype(lse_acc_dist);
|
||||
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType,
|
||||
StaticTileDistribution>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
StaticTileDistribution{}, distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
|
||||
});
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution<Problem>();
|
||||
auto lse_accum = make_static_distributed_tensor<LSEDataType>(lse_accum_dist);
|
||||
|
||||
// copy LDS to lse_accum
|
||||
{
|
||||
using DataType = LSEDataType;
|
||||
using StaticTileDistribution = decltype(lse_accum_dist);
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType,
|
||||
StaticTileDistribution>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
StaticTileDistribution{}, distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
if(col < num_splits)
|
||||
{
|
||||
lse_accum(distributed_indices) = lse_acc_lds_ptr[col + row * kMaxSplits];
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_accum(distributed_indices) = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
|
||||
DEBUG_STMTS
|
||||
{
|
||||
printf("[POYENC][DEVICE] lse_accum[%2d,%2d]: %11.7f\n",
|
||||
row,
|
||||
col,
|
||||
lse_accum(distributed_indices));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// calculate row_max of lse_accum
|
||||
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
auto lse_max = block_tile_reduce<LSEDataType>(
|
||||
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
|
||||
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
|
||||
|
||||
#if defined(PRINT_LSE_MAX)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<LSEDataType, decltype(lse_max.get_tile_distribution())>::
|
||||
get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_max.get_tile_distribution(), distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
printf(
|
||||
"[POYENC][DEVICE] lse_max[%2d]: %11.7f\n", row, lse_max(distributed_indices));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
static const auto get_validated_m = [](LSEDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<LSEDataType>(
|
||||
lse_accum.get_tile_distribution()); // Pcompute{j}
|
||||
clear_tile(p_compute);
|
||||
{
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
p_compute.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
#if 0
|
||||
// from dist tensor
|
||||
p_compute(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
|
||||
#else
|
||||
if (col < num_splits) {
|
||||
// from shared memory
|
||||
p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] -
|
||||
get_validated_m(lse_max(i_idx)));
|
||||
}
|
||||
#endif
|
||||
#if 0
|
||||
DEBUG_STMTS
|
||||
{
|
||||
printf("[POYENC][DEVICE] p_compute[%2d,%2d]: %11.7f\n",
|
||||
row,
|
||||
col,
|
||||
p_compute(i_j_idx));
|
||||
}
|
||||
#endif
|
||||
});
|
||||
});
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto lse_sum = block_tile_reduce<LSEDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
|
||||
block_tile_reduce_sync(lse_sum, f_sum, bool_constant<false>{});
|
||||
|
||||
#if defined(PRINT_LSE_SUM)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<LSEDataType, decltype(lse_sum.get_tile_distribution())>::
|
||||
get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_sum.get_tile_distribution(), distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
printf(
|
||||
"[POYENC][DEVICE] lse_sum[%2d]: %11.7f\n", row, lse_sum(distributed_indices));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
decltype(lse_max) lse_logsum;
|
||||
{
|
||||
constexpr auto out_spans = static_distributed_tensor<
|
||||
LSEDataType,
|
||||
decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
|
||||
if(lse_sum(distributed_indices) == 0.f ||
|
||||
lse_sum(distributed_indices) != lse_sum(distributed_indices))
|
||||
{
|
||||
lse_logsum(distributed_indices) = numeric<LSEDataType>::infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_logsum(distributed_indices) =
|
||||
ck_tile::log(lse_sum(distributed_indices)) + lse_max(distributed_indices);
|
||||
}
|
||||
|
||||
#if defined(PRINT_LSE_LOGSUM)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_logsum.get_tile_distribution(), distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
printf("[POYENC][DEVICE] lse_logsum[%d]: %11.7f\n",
|
||||
row,
|
||||
lse_logsum(distributed_indices));
|
||||
}
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(PRINT_LSE_ACCUM)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
for(index_t row = 0; row < kM0; ++row)
|
||||
{
|
||||
printf("[POYENC][DEVICE] lse_accum[%d] = ", row);
|
||||
for(index_t col = 0; col < num_splits; ++col)
|
||||
{
|
||||
printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// write lse scales into LDS
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<LSEDataType, decltype(lse_sum.get_tile_distribution())>::
|
||||
get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_sum.get_tile_distribution(), distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
for(index_t col = 0; col < num_splits; ++col)
|
||||
{
|
||||
lse_acc_lds_ptr[col + row * kMaxSplits] = ck_tile::exp(
|
||||
lse_acc_lds_ptr[col + row * kMaxSplits] - lse_logsum(distributed_indices));
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
#if defined(PRINT_LSE_SCALE)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
for(index_t row = 0; row < 32; ++row)
|
||||
{
|
||||
printf("[POYENC][DEVICE] lse_scale[%2d] = ", row);
|
||||
for(index_t col = 0; col < num_splits; ++col)
|
||||
{
|
||||
printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
static_assert(kBlockSize == 256);
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
}
|
||||
|
||||
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
|
||||
auto o_acc_dram_window =
|
||||
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
o_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
o_acc_dram_block_window_tmp.get_window_origin(),
|
||||
o_acc_dist);
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist); // Pcompute{j}
|
||||
clear_tile(o_acc);
|
||||
|
||||
// [POYENC] added
|
||||
for(index_t i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
auto o_tile = load_tile(o_acc_dram_window);
|
||||
{
|
||||
using DataType = OaccDataType;
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType,
|
||||
decltype(o_acc_dist)>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices =
|
||||
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
|
||||
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_dram_window, {max_seqlen_q, 0});
|
||||
}
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename LSEaccDramBlockWindow,
|
||||
typename OaccDramBlockWindow,
|
||||
typename LSEDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window,
|
||||
const OaccDramBlockWindow& o_acc_dram_block_window,
|
||||
LSEDramBlockWindow& lse_dram_block_window,
|
||||
void* smem_ptr,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window,
|
||||
o_acc_dram_block_window,
|
||||
lse_dram_block_window,
|
||||
identity{},
|
||||
identity{},
|
||||
smem_ptr,
|
||||
num_splits,
|
||||
max_seqlen_q);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,18 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy =
|
||||
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopyK = */ false,
|
||||
/* AsyncCopyV = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -52,6 +52,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -985,6 +985,106 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
WarpGemm>;
|
||||
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = 256;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kMPerBlock = Problem::kMaxSplits;
|
||||
|
||||
constexpr index_t NumElements = (kMPerBlock * kNPerBlock);
|
||||
|
||||
if constexpr(NumElements < kBlockSize)
|
||||
{
|
||||
static_assert(false);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(LSEDataType) == 4);
|
||||
|
||||
constexpr index_t NPerThread = 16 / sizeof(LSEDataType); // 4
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread; // 32
|
||||
static_assert(NThreads <= kBlockSize);
|
||||
|
||||
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; // 2
|
||||
constexpr index_t TotalWarps = kBlockSize / get_warp_size(); // 4
|
||||
constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp); // 2
|
||||
|
||||
static_assert(NThreads * NPerThread == kNPerBlock);
|
||||
static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, TotalWarps, MThreadsPerWarp>,
|
||||
sequence<NThreads, NPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = 256;
|
||||
|
||||
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size());
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t NumElements = (kMPerBlock * kNPerBlock);
|
||||
|
||||
static_assert(kBlockSize < NumElements);
|
||||
if constexpr(NumElements < kBlockSize)
|
||||
{
|
||||
static_assert(false);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t NThreads = get_warp_size(); // 64
|
||||
constexpr index_t NPerThread = kNPerBlock / NThreads; // 1
|
||||
|
||||
constexpr index_t MThreads = kBlockSize / NThreads; // 4
|
||||
constexpr index_t MPerThread = kMPerBlock / MThreads; // 32
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<MThreads, MPerThread>, sequence<NThreads, NPerThread>>,
|
||||
tuple<sequence<1>, sequence<2>>,
|
||||
tuple<sequence<0>, sequence<0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
|
||||
constexpr index_t N1 = 16 / sizeof(OaccDataType);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t M2 = get_warp_size() / N0;
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -17,7 +17,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */,
|
||||
index_t kMaxSplits_ = 1>
|
||||
struct TileFmhaTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -30,16 +31,7 @@ struct TileFmhaTraits
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
|
||||
struct TileFmhaFwdSplitKVCombineTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr index_t kMaxSplits = kMaxSplits_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
|
||||
Reference in New Issue
Block a user