mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Separate splitkv/non-splitkv args/traits
This commit is contained in:
@@ -87,7 +87,7 @@ using fmha_kernel =
|
||||
fmha_pipeline,
|
||||
fmha_epilogue>;
|
||||
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel;
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
|
||||
@@ -105,7 +105,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if constexpr({F_mode} == false) {{ // batch mode
|
||||
// make sure F_bn0 is divisible by F_bk1
|
||||
@@ -163,7 +163,7 @@ using fmha_kernel =
|
||||
fmha_pipeline,
|
||||
fmha_epilogue>;
|
||||
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel;
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
|
||||
@@ -180,7 +180,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if (a.num_splits <= 16) {{
|
||||
kernel_runner<4>::run(s, a);
|
||||
@@ -206,7 +206,7 @@ FMHA_FWD_SPLITKV_API="""
|
||||
#include <iostream>
|
||||
|
||||
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
|
||||
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout
|
||||
@@ -220,7 +220,7 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
);
|
||||
}}
|
||||
|
||||
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
|
||||
@@ -156,6 +156,8 @@ struct fmha_fwd_args
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
};
|
||||
|
||||
using fmha_fwd_splitkv_args = fmha_fwd_args;
|
||||
|
||||
struct fmha_fwd_appendkv_args
|
||||
{
|
||||
void* q_ptr;
|
||||
@@ -301,7 +303,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
|
||||
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
@@ -417,7 +419,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
@@ -611,53 +613,54 @@ struct fmha_fwd_traits_
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <ck_tile::index_t HDim,
|
||||
typename DataType,
|
||||
bool kIsGroupMode,
|
||||
ck_tile::index_t kM0,
|
||||
ck_tile::index_t kN0,
|
||||
ck_tile::index_t kK0,
|
||||
ck_tile::index_t kN1,
|
||||
ck_tile::index_t kK1,
|
||||
ck_tile::index_t kK0BlockLength,
|
||||
bool kIsVLayoutRowMajor,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum,
|
||||
typename FmhaMask,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum,
|
||||
bool kStoreLse,
|
||||
bool kHasDropout,
|
||||
bool kDoFp8StaticQuant,
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kPadS,
|
||||
bool kPadSK,
|
||||
bool kPadD,
|
||||
bool kPadDv>
|
||||
struct fmha_fwd_splitkv_traits_ : fmha_fwd_traits_<HDim,
|
||||
DataType,
|
||||
kIsGroupMode,
|
||||
kM0,
|
||||
kN0,
|
||||
kK0,
|
||||
kN1,
|
||||
kK1,
|
||||
kK0BlockLength,
|
||||
kIsVLayoutRowMajor,
|
||||
FmhaPipelineEnum,
|
||||
FmhaMask,
|
||||
BiasEnum,
|
||||
kStoreLse,
|
||||
kHasDropout,
|
||||
kDoFp8StaticQuant,
|
||||
kPadS,
|
||||
kPadSK,
|
||||
kPadD,
|
||||
kPadDv>
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_splitkv_traits_
|
||||
{
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_get_name_();
|
||||
@@ -685,7 +688,7 @@ struct fmha_fwd_splitkv_combine_traits_
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_();
|
||||
|
||||
Reference in New Issue
Block a user