mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Add element function to fmha api
This commit is contained in:
@@ -346,45 +346,53 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
mask.y,
|
||||
mask.x,
|
||||
descale_q * descale_k,
|
||||
descale_v};
|
||||
return fmha_fwd_args<FmhaDefaultElementFunctions>{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
mask.y,
|
||||
mask.x,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
descale_q * descale_k,
|
||||
descale_v};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
@@ -58,6 +58,18 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
using ODataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
struct FmhaDefaultElementFunctions
|
||||
{
|
||||
using QElementFunction = ck_tile::identity;
|
||||
using KElementFunction = ck_tile::identity;
|
||||
using VElementFunction = ck_tile::identity;
|
||||
using BiasElementFunction = ck_tile::identity;
|
||||
using LSEElementFunction = ck_tile::identity;
|
||||
using SAccElementFunction = ck_tile::identity;
|
||||
using PComputeElementFunction = ck_tile::identity;
|
||||
using OAccElementFunction = ck_tile::identity;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
@@ -252,6 +264,7 @@ struct fmha_fwd_args
|
||||
#endif
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
template <typename ElementFunctions>
|
||||
struct fmha_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
@@ -291,12 +304,20 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t mask_y;
|
||||
ck_tile::index_t mask_x;
|
||||
typename ElementFunctions::QElementFunction q_element_func;
|
||||
typename ElementFunctions::KElementFunction k_element_func;
|
||||
typename ElementFunctions::VElementFunction v_element_func;
|
||||
typename ElementFunctions::BiasElementFunction bias_element_func;
|
||||
typename ElementFunctions::LSEElementFunction lse_element_func;
|
||||
typename ElementFunctions::SAccElementFunction s_acc_element_func;
|
||||
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
|
||||
typename ElementFunctions::OAccElementFunction o_acc_element_func;
|
||||
float descale_qk;
|
||||
float descale_sv;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args<FmhaDefaultElementFunctions> args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
@@ -329,6 +350,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_o,
|
||||
args.mask_y,
|
||||
args.mask_x,
|
||||
args.q_element_func,
|
||||
args.k_element_func,
|
||||
args.v_element_func,
|
||||
args.bias_element_func,
|
||||
args.lse_element_func,
|
||||
args.s_acc_element_func,
|
||||
args.p_compute_element_func,
|
||||
args.o_acc_element_func,
|
||||
args.descale_qk,
|
||||
args.descale_sv);
|
||||
}
|
||||
@@ -365,6 +394,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.batch_stride_o,
|
||||
args.mask_y,
|
||||
args.mask_x,
|
||||
args.q_element_func,
|
||||
args.k_element_func,
|
||||
args.v_element_func,
|
||||
args.bias_element_func,
|
||||
args.lse_element_func,
|
||||
args.s_acc_element_func,
|
||||
args.p_compute_element_func,
|
||||
args.o_acc_element_func,
|
||||
args.descale_qk,
|
||||
args.descale_sv);
|
||||
}
|
||||
@@ -414,7 +451,7 @@ struct fmha_fwd_traits_
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args<FmhaDefaultElementFunctions>);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_fwd_traits
|
||||
@@ -429,4 +466,6 @@ struct fmha_fwd_traits
|
||||
bool has_lse;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
float fmha_fwd(fmha_fwd_traits,
|
||||
fmha_fwd_args<FmhaDefaultElementFunctions>,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -84,6 +84,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
|
||||
typename FmhaDefaultElementFunctions::QElementFunction,
|
||||
typename FmhaDefaultElementFunctions::KElementFunction,
|
||||
typename FmhaDefaultElementFunctions::VElementFunction,
|
||||
typename FmhaDefaultElementFunctions::BiasElementFunction,
|
||||
typename FmhaDefaultElementFunctions::LSEElementFunction,
|
||||
typename FmhaDefaultElementFunctions::SAccElementFunction,
|
||||
typename FmhaDefaultElementFunctions::PComputeElementFunction,
|
||||
typename FmhaDefaultElementFunctions::OAccElementFunction>;
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
@@ -95,6 +105,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_element_function_{F_idx},
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_mask_{F_idx},
|
||||
@@ -108,7 +119,7 @@ using fmha_epilogue_{F_idx} =
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
|
||||
fmha_pipeline_{F_idx},
|
||||
fmha_epilogue_{F_idx}>;
|
||||
@@ -118,7 +129,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args<FmhaDefaultElementFunctions> a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -132,7 +143,7 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args<FmhaDefaultElementFunctions> a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -37,6 +37,23 @@ struct FmhaFwdKernel
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
using QElementFunction =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::QElementFunction>;
|
||||
using KElementFunction =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::KElementFunction>;
|
||||
using VElementFunction =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::VElementFunction>;
|
||||
using BiasElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::BiasElementFunction>;
|
||||
using LSEElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::LSEElementFunction>;
|
||||
using SAccElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::SAccElementFunction>;
|
||||
using PComputeElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::PComputeElementFunction>;
|
||||
using OAccElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::OAccElementFunction>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
@@ -77,7 +94,7 @@ struct FmhaFwdKernel
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
@@ -122,6 +139,15 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
QElementFunction q_element_func;
|
||||
KElementFunction k_element_func;
|
||||
VElementFunction v_element_func;
|
||||
BiasElementFunction bias_element_func;
|
||||
LSEElementFunction lse_element_func;
|
||||
SAccElementFunction s_acc_element_func;
|
||||
PComputeElementFunction p_compute_element_func;
|
||||
OAccElementFunction o_acc_element_func;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonBiasKargs
|
||||
@@ -219,6 +245,14 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x,
|
||||
QElementFunction q_element_func,
|
||||
KElementFunction k_element_func,
|
||||
VElementFunction v_element_func,
|
||||
BiasElementFunction bias_element_func,
|
||||
LSEElementFunction lse_element_func,
|
||||
SAccElementFunction s_acc_element_func,
|
||||
PComputeElementFunction p_compute_element_func,
|
||||
OAccElementFunction o_acc_element_func,
|
||||
float descale_qk,
|
||||
float descale_sv)
|
||||
{
|
||||
@@ -243,11 +277,19 @@ struct FmhaFwdKernel
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8 args
|
||||
nhead_stride_o,
|
||||
q_element_func,
|
||||
k_element_func,
|
||||
v_element_func,
|
||||
bias_element_func,
|
||||
lse_element_func,
|
||||
s_acc_element_func,
|
||||
p_compute_element_func,
|
||||
o_acc_element_func}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8 args
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -308,6 +350,14 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x,
|
||||
QElementFunction q_element_func,
|
||||
KElementFunction k_element_func,
|
||||
VElementFunction v_element_func,
|
||||
BiasElementFunction bias_element_func,
|
||||
LSEElementFunction lse_element_func,
|
||||
SAccElementFunction s_acc_element_func,
|
||||
PComputeElementFunction p_compute_element_func,
|
||||
OAccElementFunction o_acc_element_func,
|
||||
float descale_qk,
|
||||
float descale_sv)
|
||||
{
|
||||
@@ -332,11 +382,19 @@ struct FmhaFwdKernel
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8 args
|
||||
nhead_stride_o,
|
||||
q_element_func,
|
||||
k_element_func,
|
||||
v_element_func,
|
||||
bias_element_func,
|
||||
lse_element_func,
|
||||
s_acc_element_func,
|
||||
p_compute_element_func,
|
||||
o_acc_element_func}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8 args
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -661,10 +719,18 @@ struct FmhaFwdKernel
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
kargs.q_element_func,
|
||||
k_dram_window,
|
||||
kargs.k_element_func,
|
||||
v_dram_window,
|
||||
kargs.v_element_func,
|
||||
bias_dram_window,
|
||||
kargs.bias_element_func,
|
||||
lse_dram_window,
|
||||
kargs.lse_element_func,
|
||||
kargs.s_acc_element_func,
|
||||
kargs.p_compute_element_func,
|
||||
kargs.o_acc_element_func,
|
||||
mask,
|
||||
kargs.scale,
|
||||
smem_ptr);
|
||||
|
||||
@@ -17,6 +17,7 @@ template <typename QDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename ElementFunctions_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename FmhaMask_,
|
||||
@@ -33,6 +34,7 @@ struct BlockFmhaPipelineProblem
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using ElementFunctions = remove_cvref_t<ElementFunctions_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
@@ -559,39 +559,6 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
scale,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -654,39 +654,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
scale,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -550,36 +550,6 @@ struct BlockFmhaPipelineQSKSVS
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
mask,
|
||||
scale,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
39
include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp
Normal file
39
include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QElementFunction_,
|
||||
typename KElementFunction_,
|
||||
typename VElementFunction_,
|
||||
typename BiasElementFunction_,
|
||||
typename LSEElementFunction_,
|
||||
typename SAccElementFunction_,
|
||||
typename PComputeElementFunction_,
|
||||
typename OAccElementFunction_>
|
||||
struct FmhaElementFunctions
|
||||
{
|
||||
using QElementFunction = remove_cvref_t<QElementFunction_>;
|
||||
using KElementFunction = remove_cvref_t<KElementFunction_>;
|
||||
using VElementFunction = remove_cvref_t<VElementFunction_>;
|
||||
using BiasElementFunction = remove_cvref_t<BiasElementFunction_>;
|
||||
using LSEElementFunction = remove_cvref_t<LSEElementFunction_>;
|
||||
using SAccElementFunction = remove_cvref_t<SAccElementFunction_>;
|
||||
using PComputeElementFunction = remove_cvref_t<PComputeElementFunction_>;
|
||||
using OAccElementFunction = remove_cvref_t<OAccElementFunction_>;
|
||||
|
||||
QElementFunction q_element_func;
|
||||
KElementFunction k_element_func;
|
||||
VElementFunction v_element_func;
|
||||
BiasElementFunction bias_element_func;
|
||||
LSEElementFunction lse_element_func;
|
||||
SAccElementFunction s_acc_element_func;
|
||||
PComputeElementFunction p_compute_element_func;
|
||||
OAccElementFunction o_acc_element_func;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user