Add element function to fmha api

This commit is contained in:
rocking
2024-03-29 18:05:36 -04:00
parent 50c36f352a
commit 286c74468d
10 changed files with 222 additions and 152 deletions

View File

@@ -346,45 +346,53 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.y,
mask.x,
descale_q * descale_k,
descale_v};
return fmha_fwd_args<FmhaDefaultElementFunctions>{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.y,
mask.x,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
descale_q * descale_k,
descale_v};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);

View File

@@ -58,6 +58,18 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
using ODataType = ck_tile::fp8_t;
};
struct FmhaDefaultElementFunctions
{
using QElementFunction = ck_tile::identity;
using KElementFunction = ck_tile::identity;
using VElementFunction = ck_tile::identity;
using BiasElementFunction = ck_tile::identity;
using LSEElementFunction = ck_tile::identity;
using SAccElementFunction = ck_tile::identity;
using PComputeElementFunction = ck_tile::identity;
using OAccElementFunction = ck_tile::identity;
};
template <>
struct FmhaFwdTypeConfig<ck_tile::bf8_t>
{
@@ -252,6 +264,7 @@ struct fmha_fwd_args
#endif
// runtime args, some will passed to karg, some will used to compute grids/blocks
template <typename ElementFunctions>
struct fmha_fwd_args
{
const void* q_ptr;
@@ -291,12 +304,20 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_o;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
typename ElementFunctions::QElementFunction q_element_func;
typename ElementFunctions::KElementFunction k_element_func;
typename ElementFunctions::VElementFunction v_element_func;
typename ElementFunctions::BiasElementFunction bias_element_func;
typename ElementFunctions::LSEElementFunction lse_element_func;
typename ElementFunctions::SAccElementFunction s_acc_element_func;
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
typename ElementFunctions::OAccElementFunction o_acc_element_func;
float descale_qk;
float descale_sv;
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args<FmhaDefaultElementFunctions> args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
@@ -329,6 +350,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o,
args.mask_y,
args.mask_x,
args.q_element_func,
args.k_element_func,
args.v_element_func,
args.bias_element_func,
args.lse_element_func,
args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
args.descale_sv);
}
@@ -365,6 +394,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.batch_stride_o,
args.mask_y,
args.mask_x,
args.q_element_func,
args.k_element_func,
args.v_element_func,
args.bias_element_func,
args.lse_element_func,
args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
args.descale_sv);
}
@@ -414,7 +451,7 @@ struct fmha_fwd_traits_
};
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args<FmhaDefaultElementFunctions>);
// This is the public API, will be generated by script
struct fmha_fwd_traits
@@ -429,4 +466,6 @@ struct fmha_fwd_traits
bool has_lse;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
float fmha_fwd(fmha_fwd_traits,
fmha_fwd_args<FmhaDefaultElementFunctions>,
const ck_tile::stream_config&);

View File

@@ -84,6 +84,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
typename FmhaDefaultElementFunctions::QElementFunction,
typename FmhaDefaultElementFunctions::KElementFunction,
typename FmhaDefaultElementFunctions::VElementFunction,
typename FmhaDefaultElementFunctions::BiasElementFunction,
typename FmhaDefaultElementFunctions::LSEElementFunction,
typename FmhaDefaultElementFunctions::SAccElementFunction,
typename FmhaDefaultElementFunctions::PComputeElementFunction,
typename FmhaDefaultElementFunctions::OAccElementFunction>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
@@ -95,6 +105,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_element_function_{F_idx},
fmha_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
@@ -108,7 +119,7 @@ using fmha_epilogue_{F_idx} =
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
@@ -118,7 +129,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F
#include <iostream>
template<>
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args<FmhaDefaultElementFunctions> a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -132,7 +143,7 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args<FmhaDefaultElementFunctions> a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;