Separate splitkv/non-splitkv args/traits

This commit is contained in:
PoYen, Chen
2024-08-08 08:07:03 +00:00
parent 655b13b059
commit 291e9b4bbb
2 changed files with 54 additions and 51 deletions

View File

@@ -87,7 +87,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
@@ -105,7 +105,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
#include <iostream>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
// make sure F_bn0 is divisible by F_bk1
@@ -163,7 +163,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
@@ -180,7 +180,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
#include <iostream>
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a);
@@ -206,7 +206,7 @@ FMHA_FWD_SPLITKV_API="""
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if(s.log_level_ > 0)
std::cout
@@ -220,7 +220,7 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
);
}}
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;

View File

@@ -156,6 +156,8 @@ struct fmha_fwd_args
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
using fmha_fwd_splitkv_args = fmha_fwd_args;
struct fmha_fwd_appendkv_args
{
void* q_ptr;
@@ -301,7 +303,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
@@ -417,7 +419,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
}
template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
@@ -611,53 +613,54 @@ struct fmha_fwd_traits_
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim,
typename DataType,
bool kIsGroupMode,
ck_tile::index_t kM0,
ck_tile::index_t kN0,
ck_tile::index_t kK0,
ck_tile::index_t kN1,
ck_tile::index_t kK1,
ck_tile::index_t kK0BlockLength,
bool kIsVLayoutRowMajor,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum,
typename FmhaMask,
ck_tile::BlockAttentionBiasEnum BiasEnum,
bool kStoreLse,
bool kHasDropout,
bool kDoFp8StaticQuant,
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kPadS,
bool kPadSK,
bool kPadD,
bool kPadDv>
struct fmha_fwd_splitkv_traits_ : fmha_fwd_traits_<HDim,
DataType,
kIsGroupMode,
kM0,
kN0,
kK0,
kN1,
kK1,
kK0BlockLength,
kIsVLayoutRowMajor,
FmhaPipelineEnum,
FmhaMask,
BiasEnum,
kStoreLse,
kHasDropout,
kDoFp8StaticQuant,
kPadS,
kPadSK,
kPadD,
kPadDv>
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_splitkv_traits_
{
static constexpr bool kIsPagedKV = kIsPagedKV_;
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_();
@@ -685,7 +688,7 @@ struct fmha_fwd_splitkv_combine_traits_
};
template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_();