Add SplitKV combine kernel codegen logics

This commit is contained in:
PoYen, Chen
2024-06-02 23:32:55 +00:00
parent cacce74f2c
commit 9ac2654b55
11 changed files with 1284 additions and 42 deletions

View File

@@ -334,7 +334,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
return ck_tile::make_tuple(kargs, grids);
}
#if 0
template <typename FmhaFwdSplitKVCombineKernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
{
@@ -343,31 +342,18 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
// create group mode kernel argumentszs
if constexpr(FmhaFwdSplitKVCombineKernel::kIsGroupMode)
{
return FmhaFwdSplitKVCombineKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
return FmhaFwdSplitKVCombineKernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.nhead,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.scale_o,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o);
}
@@ -395,7 +381,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
dim3 grids = FmhaFwdSplitKVCombineKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
#endif
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
@@ -451,6 +436,12 @@ float fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_();
template <typename Traits_>
float fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_();
// This is the public API, will be generated by script
struct fmha_fwd_traits
{

View File

@@ -249,6 +249,86 @@ std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
}}
"""
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx},
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy},
16>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<fmha_shape_{F_idx}>,
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
{{
using k_ = fmha_kernel_{F_idx};
return k_::GetName();
}}
"""
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
@@ -288,19 +368,18 @@ FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
FMHA_FWD_SPLITKV_API="""
#include <iostream>
template<typename fmha_fwd_splitkv_trait_
// , typename fmha_fwd_splitkv_combine_trait_
>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_trait_>()
// << ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_trait_>()
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_trait_>(s_, a); }}
// , [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<dq_dk_dv_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
);
}}
@@ -313,8 +392,8 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_splitkv_<trait_>(s, a);
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits_>(s, a);
}}
"""
@@ -713,6 +792,85 @@ class FmhaFwdSplitKVKernel:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
@dataclass
class FmhaFwdSplitKVCombineKernel:
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen,
F_rm = self.F_tile.F_rm,
F_rn = self.F_tile.F_rn,
F_rk = self.F_tile.F_rk,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_{self.direction}_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
@@ -812,6 +970,74 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, is_splitkv
return (api_pool, gen)
def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> List[FmhaFwdSplitKVCombineKernel]:
Pipeline = FmhaFwdSplitKVPipeline
Kernel = FmhaFwdSplitKVCombineKernel
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
if receipt == 1:
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
else:
assert False
return pipelines
gen = list()
for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()):
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
k = Kernel(direction=direction,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
gen.append(k)
return gen
BWD_DQDKDV_PIPELINE_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
@@ -1386,7 +1612,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
return gen
def write_single_fwd_kernel(kernel: Union[FmhaFwdKernel, FmhaFwdSplitKVKernel], autogen_dir: Path) -> None:
def write_single_fwd_kernel(kernel: Union[FmhaFwdKernel, FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
@@ -1416,6 +1642,9 @@ def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optio
write_fwd_api(api_pool, output_dir)
# write split-kv blobs
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
@@ -1441,6 +1670,9 @@ def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Opt
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
# get split-kv blobs
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")

View File

@@ -10,6 +10,8 @@
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
@@ -26,7 +28,6 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"

View File

@@ -0,0 +1,439 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdSplitKVCombineKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
{
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_splitkv_combine_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct CommonKargs
{
const void* lse_acc_ptr;
const void* o_acc_ptr;
void* o_ptr;
ck_tile::index_t batch;
ck_tile::index_t nhead;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_splits;
ck_tile::index_t row_stride_o;
ck_tile::index_t nhead_stride_o;
};
struct CommonLSEKargs
{
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
};
struct Fp8StaticQuantKargs
{
float scale_o;
};
struct BatchModeLSEKargs : CommonLSEKargs
{
ck_tile::index_t batch_stride_lse = 0;
};
struct BatchModeKargs
: CommonKargs,
std::conditional_t<kStoreLSE, BatchModeLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
};
struct GroupModeKargs
: CommonKargs,
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* lse_acc_ptr,
const void* o_acc_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o)
{
Kargs kargs{{lse_acc_ptr,
o_acc_ptr,
o_ptr,
batch,
nhead,
max_seqlen_q,
seqlen_q,
hdim_v,
num_splits,
row_stride_o,
nhead_stride_o}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o};
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_o = scale_o;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* lse_acc_ptr,
const void* o_acc_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o)
{
Kargs kargs{{lse_acc_ptr,
o_acc_ptr,
o_ptr,
batch,
nhead,
max_seqlen_q,
-1, // seqlen will be updated by another pointer
hdim_v,
num_splits,
row_stride_o,
nhead_stride_o}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_o = scale_o;
}
return kargs;
}
__host__ static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.hdim_v;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
{
return;
}
}
else
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
// for simplicity, batch stride we just modify the pointer
const LSEDataType* lse_acc_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
static_cast<long_index_t>(i_nhead) * (kargs.max_seqlen_q) +
batch_offset_lse_acc;
const OaccDataType* o_acc_ptr =
reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * (kargs.max_seqlen_q * kargs.hdim_v) +
batch_offset_o_acc;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// LSEacc/Oacc DRAM and DRAM windows
const auto lse_acc_dram = [&]() {
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.num_splits, kargs.seqlen_q),
make_tuple(kargs.batch * kargs.nhead * kargs.max_seqlen_q, 1),
number<8>{},
number<1>{});
return pad_tensor_view(
lse_acc_dram_naive,
make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{}),
sequence<true, kPadSeqLenQ>{});
}();
auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
make_tuple(
kargs.batch * kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v, kargs.hdim_v, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t new_seqlen_q =
integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * FmhaPipeline::kM0;
const index_t new_hdim_v =
integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1) * FmhaPipeline::kN1;
return transform_tensor_view(
o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, new_seqlen_q)),
make_pass_through_transform(new_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}();
auto lse_acc_dram_window = make_tile_window(
lse_acc_dram,
[&]() {
return make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{});
}(),
{0, i_m0});
auto o_acc_dram_window = make_tile_window(
o_acc_dram,
[&]() {
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{});
}(),
{i_m0, 0});
// LSE DRAM window
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
if constexpr(kStoreLSE)
{
LSEDataType* lse_ptr =
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
const auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_ptr,
make_tuple(kargs.seqlen_q),
make_tuple(1),
number<1>{},
number<1>{});
return pad_tensor_view(
lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
}();
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
}
else
{
return make_null_tile_window(lse_dram_window_lengths);
}
}();
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
smem_ptr,
kargs.num_splits,
kargs.max_seqlen_q);
}
else
{
return FmhaPipeline{}(lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
smem_ptr,
kargs.num_splits,
kargs.max_seqlen_q);
}
}();
// O DRAM and DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.row_stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, 0});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaFwdSplitKVCombineTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN0;
// constexpr static ck_tile::index_t kBlockM = kN1 % 128 == 0 ? 4 : (kN1 % 64 == 0 ? 8 : 16);
__host__ static constexpr auto
GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q)
{
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0), nhead, batch_size);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/)
{
const index_t i_tile_m = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_nhead, i_batch);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,433 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
struct BlockFmhaFwdSplitKVCombinePipeline
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr index_t kMaxSplits = Problem::kMaxSplits;
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
{
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
return 2;
}
else if constexpr(kK0BlockLength <= 256)
{
return 1;
}
}
}();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
/// TODO: add padding to avoid bank conflict
return (kM0 * kMaxSplits * sizeof(LSEDataType));
}
#define MARKER(msg) \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("; [POYENC] " msg ::); \
__builtin_amdgcn_sched_barrier(0)
template <typename LSEaccDramBlockWindowTmp,
typename OaccDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename LSEElementFunction,
typename OaccElementFunction>
CK_TILE_HOST_DEVICE auto
operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp,
const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp,
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
void* smem_ptr,
index_t num_splits,
index_t max_seqlen_q) const
{
LSEDataType* lse_acc_lds_ptr =
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution<Problem>();
auto lse_acc_dram_window =
make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(),
lse_acc_dram_block_window_tmp.get_window_lengths(),
lse_acc_dram_block_window_tmp.get_window_origin(),
lse_acc_dist);
auto lse_acc = load_tile(lse_acc_dram_window); // [kMaxSplits, kM0]
#if defined(ENABLE_DEBUG_STMTS)
#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
#else
#define DEBUG_STMTS if(false)
#endif
// copy lse_acc to LDS
{
using DataType = LSEDataType;
using StaticTileDistribution = decltype(lse_acc_dist);
constexpr auto out_spans =
static_distributed_tensor<DataType,
StaticTileDistribution>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
StaticTileDistribution{}, distributed_indices);
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
});
});
}
block_sync_lds();
auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution<Problem>();
auto lse_accum = make_static_distributed_tensor<LSEDataType>(lse_accum_dist);
// copy LDS to lse_accum
{
using DataType = LSEDataType;
using StaticTileDistribution = decltype(lse_accum_dist);
constexpr auto out_spans =
static_distributed_tensor<DataType,
StaticTileDistribution>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
StaticTileDistribution{}, distributed_indices);
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
lse_accum(distributed_indices) = lse_acc_lds_ptr[col + row * kMaxSplits];
}
else
{
lse_accum(distributed_indices) = -numeric<LSEDataType>::infinity();
}
DEBUG_STMTS
{
printf("[POYENC][DEVICE] lse_accum[%2d,%2d]: %11.7f\n",
row,
col,
lse_accum(distributed_indices));
}
});
});
}
// calculate row_max of lse_accum
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
auto lse_max = block_tile_reduce<LSEDataType>(
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
#if defined(PRINT_LSE_MAX)
DEBUG_STMTS
{
constexpr auto out_spans =
static_distributed_tensor<LSEDataType, decltype(lse_max.get_tile_distribution())>::
get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_max.get_tile_distribution(), distributed_indices);
const auto row = x_indices.at(number<0>{});
printf(
"[POYENC][DEVICE] lse_max[%2d]: %11.7f\n", row, lse_max(distributed_indices));
});
}
#endif
static const auto get_validated_m = [](LSEDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
auto p_compute = make_static_distributed_tensor<LSEDataType>(
lse_accum.get_tile_distribution()); // Pcompute{j}
clear_tile(p_compute);
{
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
p_compute.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
#if 0
// from dist tensor
p_compute(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
#else
if (col < num_splits) {
// from shared memory
p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] -
get_validated_m(lse_max(i_idx)));
}
#endif
#if 0
DEBUG_STMTS
{
printf("[POYENC][DEVICE] p_compute[%2d,%2d]: %11.7f\n",
row,
col,
p_compute(i_j_idx));
}
#endif
});
});
}
__syncthreads();
auto lse_sum = block_tile_reduce<LSEDataType>(
p_compute, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
block_tile_reduce_sync(lse_sum, f_sum, bool_constant<false>{});
#if defined(PRINT_LSE_SUM)
DEBUG_STMTS
{
constexpr auto out_spans =
static_distributed_tensor<LSEDataType, decltype(lse_sum.get_tile_distribution())>::
get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_sum.get_tile_distribution(), distributed_indices);
const auto row = x_indices.at(number<0>{});
printf(
"[POYENC][DEVICE] lse_sum[%2d]: %11.7f\n", row, lse_sum(distributed_indices));
});
}
#endif
decltype(lse_max) lse_logsum;
{
constexpr auto out_spans = static_distributed_tensor<
LSEDataType,
decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
if(lse_sum(distributed_indices) == 0.f ||
lse_sum(distributed_indices) != lse_sum(distributed_indices))
{
lse_logsum(distributed_indices) = numeric<LSEDataType>::infinity();
}
else
{
lse_logsum(distributed_indices) =
ck_tile::log(lse_sum(distributed_indices)) + lse_max(distributed_indices);
}
#if defined(PRINT_LSE_LOGSUM)
DEBUG_STMTS
{
const auto x_indices = get_x_indices_from_distributed_indices(
lse_logsum.get_tile_distribution(), distributed_indices);
const auto row = x_indices.at(number<0>{});
printf("[POYENC][DEVICE] lse_logsum[%d]: %11.7f\n",
row,
lse_logsum(distributed_indices));
}
#endif
});
}
#if defined(PRINT_LSE_ACCUM)
DEBUG_STMTS
{
for(index_t row = 0; row < kM0; ++row)
{
printf("[POYENC][DEVICE] lse_accum[%d] = ", row);
for(index_t col = 0; col < num_splits; ++col)
{
printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]);
}
printf("\n");
}
}
#endif
// write lse scales into LDS
{
constexpr auto out_spans =
static_distributed_tensor<LSEDataType, decltype(lse_sum.get_tile_distribution())>::
get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_sum.get_tile_distribution(), distributed_indices);
const auto row = x_indices.at(number<0>{});
for(index_t col = 0; col < num_splits; ++col)
{
lse_acc_lds_ptr[col + row * kMaxSplits] = ck_tile::exp(
lse_acc_lds_ptr[col + row * kMaxSplits] - lse_logsum(distributed_indices));
}
});
}
block_sync_lds();
#if defined(PRINT_LSE_SCALE)
DEBUG_STMTS
{
for(index_t row = 0; row < 32; ++row)
{
printf("[POYENC][DEVICE] lse_scale[%2d] = ", row);
for(index_t col = 0; col < num_splits; ++col)
{
printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]);
}
printf("\n");
}
}
#endif
if constexpr(kStoreLSE)
{
static_assert(kBlockSize == 256);
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
o_acc_dram_block_window_tmp.get_window_lengths(),
o_acc_dram_block_window_tmp.get_window_origin(),
o_acc_dist);
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist); // Pcompute{j}
clear_tile(o_acc);
// [POYENC] added
for(index_t i_split = 0; i_split < num_splits; ++i_split)
{
auto o_tile = load_tile(o_acc_dram_window);
{
using DataType = OaccDataType;
constexpr auto out_spans =
static_distributed_tensor<DataType,
decltype(o_acc_dist)>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices =
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
const auto row = x_indices.at(number<0>{});
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
});
});
}
move_tile_window(o_acc_dram_window, {max_seqlen_q, 0});
}
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename LSEaccDramBlockWindow,
typename OaccDramBlockWindow,
typename LSEDramBlockWindow>
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window,
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
void* smem_ptr,
index_t num_splits,
index_t max_seqlen_q) const
{
return operator()(lse_acc_dram_block_window,
o_acc_dram_block_window,
lse_dram_block_window,
identity{},
identity{},
smem_ptr,
num_splits,
max_seqlen_q);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,18 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
using BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
} // namespace ck_tile

View File

@@ -52,6 +52,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
};
} // namespace ck_tile

View File

@@ -985,6 +985,106 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
WarpGemm>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = 256;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t NumElements = (kMPerBlock * kNPerBlock);
if constexpr(NumElements < kBlockSize)
{
static_assert(false);
}
else
{
static_assert(sizeof(LSEDataType) == 4);
constexpr index_t NPerThread = 16 / sizeof(LSEDataType); // 4
constexpr index_t NThreads = kNPerBlock / NPerThread; // 32
static_assert(NThreads <= kBlockSize);
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; // 2
constexpr index_t TotalWarps = kBlockSize / get_warp_size(); // 4
constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp); // 2
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, TotalWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution()
{
constexpr index_t kBlockSize = 256;
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size());
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t NumElements = (kMPerBlock * kNPerBlock);
static_assert(kBlockSize < NumElements);
if constexpr(NumElements < kBlockSize)
{
static_assert(false);
}
else
{
constexpr index_t NThreads = get_warp_size(); // 64
constexpr index_t NPerThread = kNPerBlock / NThreads; // 1
constexpr index_t MThreads = kBlockSize / NThreads; // 4
constexpr index_t MPerThread = kMPerBlock / MThreads; // 32
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<MThreads, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2>>,
tuple<sequence<0>, sequence<0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t N1 = 16 / sizeof(OaccDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = get_warp_size() / N0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile

View File

@@ -17,7 +17,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */,
index_t kMaxSplits_ = 1>
struct TileFmhaTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -30,16 +31,7 @@ struct TileFmhaTraits
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
struct TileFmhaFwdSplitKVCombineTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr index_t kMaxSplits = kMaxSplits_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,