diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 5a6afe36f6..347e74fbd3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -346,45 +346,53 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - return fmha_fwd_args{q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), - lse_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - nullptr, - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_lse, - batch_stride_o, - mask.y, - mask.x, - descale_q * descale_k, - descale_v}; + return fmha_fwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + mask.y, + mask.x, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + descale_q * descale_k, + descale_v}; }(); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9293201cd2..a15dcb790a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -58,6 +58,18 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::fp8_t; }; +struct FmhaDefaultElementFunctions +{ + using QElementFunction = ck_tile::identity; + using KElementFunction = ck_tile::identity; + using VElementFunction = ck_tile::identity; + using BiasElementFunction = ck_tile::identity; + using LSEElementFunction = ck_tile::identity; + using SAccElementFunction = ck_tile::identity; + using PComputeElementFunction = ck_tile::identity; + using OAccElementFunction = ck_tile::identity; +}; + template <> struct FmhaFwdTypeConfig { @@ -252,6 +264,7 @@ struct fmha_fwd_args #endif // runtime args, some will passed to karg, some will used to compute grids/blocks +template struct fmha_fwd_args { const void* q_ptr; @@ -291,12 +304,20 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_o; ck_tile::index_t mask_y; ck_tile::index_t mask_x; + typename ElementFunctions::QElementFunction q_element_func; + typename ElementFunctions::KElementFunction k_element_func; + typename ElementFunctions::VElementFunction v_element_func; + typename ElementFunctions::BiasElementFunction bias_element_func; + typename ElementFunctions::LSEElementFunction lse_element_func; + typename ElementFunctions::SAccElementFunction s_acc_element_func; + typename ElementFunctions::PComputeElementFunction p_compute_element_func; + typename ElementFunctions::OAccElementFunction o_acc_element_func; float descale_qk; float descale_sv; }; template -auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { @@ -329,6 +350,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_o, args.mask_y, args.mask_x, + args.q_element_func, + args.k_element_func, + args.v_element_func, + args.bias_element_func, + args.lse_element_func, + args.s_acc_element_func, + args.p_compute_element_func, + args.o_acc_element_func, args.descale_qk, args.descale_sv); } @@ -365,6 +394,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_o, args.mask_y, args.mask_x, + args.q_element_func, + args.k_element_func, + args.v_element_func, + args.bias_element_func, + args.lse_element_func, + args.s_acc_element_func, + args.p_compute_element_func, + args.o_acc_element_func, args.descale_qk, args.descale_sv); } @@ -414,7 +451,7 @@ struct fmha_fwd_traits_ }; template -float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); // This is the public API, will be generated by script struct fmha_fwd_traits @@ -429,4 +466,6 @@ struct fmha_fwd_traits bool has_lse; // TODO: padding check is inside this api }; -float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); +float fmha_fwd(fmha_fwd_traits, + fmha_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 5c44ad303b..ba08f76683 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -84,6 +84,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; +using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions< + typename FmhaDefaultElementFunctions::QElementFunction, + typename FmhaDefaultElementFunctions::KElementFunction, + typename FmhaDefaultElementFunctions::VElementFunction, + typename FmhaDefaultElementFunctions::BiasElementFunction, + typename FmhaDefaultElementFunctions::LSEElementFunction, + typename FmhaDefaultElementFunctions::SAccElementFunction, + typename FmhaDefaultElementFunctions::PComputeElementFunction, + typename FmhaDefaultElementFunctions::OAccElementFunction>; + using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, @@ -95,6 +105,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, typename FmhaFwdTypeConfig::ODataType, + fmha_element_function_{F_idx}, fmha_shape_{F_idx}, {F_mode}, fmha_mask_{F_idx}, @@ -108,7 +119,7 @@ using fmha_epilogue_{F_idx} = typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx} = +using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel, fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; @@ -118,7 +129,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F #include template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) @@ -132,7 +143,7 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" FMHA_FWD_API=""" -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 1e9acc6d7b..3767b61a7c 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -15,6 +15,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 98866805a0..cfdd21f80e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -37,6 +37,23 @@ struct FmhaFwdKernel using VLayout = ck_tile::remove_cvref_t; + using QElementFunction = + ck_tile::remove_cvref_t; + using KElementFunction = + ck_tile::remove_cvref_t; + using VElementFunction = + ck_tile::remove_cvref_t; + using BiasElementFunction = ck_tile::remove_cvref_t< + typename FmhaPipeline::Problem::ElementFunctions::BiasElementFunction>; + using LSEElementFunction = ck_tile::remove_cvref_t< + typename FmhaPipeline::Problem::ElementFunctions::LSEElementFunction>; + using SAccElementFunction = ck_tile::remove_cvref_t< + typename FmhaPipeline::Problem::ElementFunctions::SAccElementFunction>; + using PComputeElementFunction = ck_tile::remove_cvref_t< + typename FmhaPipeline::Problem::ElementFunctions::PComputeElementFunction>; + using OAccElementFunction = ck_tile::remove_cvref_t< + typename FmhaPipeline::Problem::ElementFunctions::OAccElementFunction>; + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; @@ -77,7 +94,7 @@ struct FmhaFwdKernel "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + - "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + @@ -122,6 +139,15 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; + + QElementFunction q_element_func; + KElementFunction k_element_func; + VElementFunction v_element_func; + BiasElementFunction bias_element_func; + LSEElementFunction lse_element_func; + SAccElementFunction s_acc_element_func; + PComputeElementFunction p_compute_element_func; + OAccElementFunction o_acc_element_func; }; struct FmhaFwdCommonBiasKargs @@ -219,6 +245,14 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t mask_y, ck_tile::index_t mask_x, + QElementFunction q_element_func, + KElementFunction k_element_func, + VElementFunction v_element_func, + BiasElementFunction bias_element_func, + LSEElementFunction lse_element_func, + SAccElementFunction s_acc_element_func, + PComputeElementFunction p_compute_element_func, + OAccElementFunction o_acc_element_func, float descale_qk, float descale_sv) { @@ -243,11 +277,19 @@ struct FmhaFwdKernel nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_o}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for mask - {}, // placeholder for lse - {}, // placeholder for fp8 args + nhead_stride_o, + q_element_func, + k_element_func, + v_element_func, + bias_element_func, + lse_element_func, + s_acc_element_func, + p_compute_element_func, + o_acc_element_func}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8 args batch_stride_q, batch_stride_k, batch_stride_v, @@ -308,6 +350,14 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t mask_y, ck_tile::index_t mask_x, + QElementFunction q_element_func, + KElementFunction k_element_func, + VElementFunction v_element_func, + BiasElementFunction bias_element_func, + LSEElementFunction lse_element_func, + SAccElementFunction s_acc_element_func, + PComputeElementFunction p_compute_element_func, + OAccElementFunction o_acc_element_func, float descale_qk, float descale_sv) { @@ -332,11 +382,19 @@ struct FmhaFwdKernel nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_o}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for mask - {}, // placeholder for lse - {}, // placeholder for fp8 args + nhead_stride_o, + q_element_func, + k_element_func, + v_element_func, + bias_element_func, + lse_element_func, + s_acc_element_func, + p_compute_element_func, + o_acc_element_func}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8 args reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -661,10 +719,18 @@ struct FmhaFwdKernel else { return FmhaPipeline{}(q_dram_window, + kargs.q_element_func, k_dram_window, + kargs.k_element_func, v_dram_window, + kargs.v_element_func, bias_dram_window, + kargs.bias_element_func, lse_dram_window, + kargs.lse_element_func, + kargs.s_acc_element_func, + kargs.p_compute_element_func, + kargs.o_acc_element_func, mask, kargs.scale, smem_ptr); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 90bd69956c..89682540af 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -17,6 +17,7 @@ template ; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using ElementFunctions = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 07930258c4..19033499cc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -559,39 +559,6 @@ struct BlockFmhaPipelineQRKSVS return o_acc; } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - float scale, - void* smem_ptr) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - identity{}, - v_dram_block_window_tmp, - identity{}, - bias_dram_block_window_tmp, - identity{}, - lse_dram_block_window_tmp, - identity{}, - identity{}, - identity{}, - identity{}, - mask, - scale, - smem_ptr); - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index e7d22984cd..1fa33de55c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -654,39 +654,6 @@ struct BlockFmhaPipelineQRKSVSAsync return o_acc; } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - float scale, - void* smem_ptr) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - identity{}, - v_dram_block_window_tmp, - identity{}, - bias_dram_block_window_tmp, - identity{}, - lse_dram_block_window_tmp, - identity{}, - identity{}, - identity{}, - identity{}, - mask, - scale, - smem_ptr); - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 8150326adf..3c56c86e4d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -550,36 +550,6 @@ struct BlockFmhaPipelineQSKSVS return o_acc; } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - float scale, - void* smem_ptr) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - identity{}, - v_dram_block_window_tmp, - identity{}, - bias_dram_block_window_tmp, - identity{}, - lse_dram_block_window_tmp, - identity{}, - mask, - scale, - smem_ptr); - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp b/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp new file mode 100644 index 0000000000..397b15c6d2 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaElementFunctions +{ + using QElementFunction = remove_cvref_t; + using KElementFunction = remove_cvref_t; + using VElementFunction = remove_cvref_t; + using BiasElementFunction = remove_cvref_t; + using LSEElementFunction = remove_cvref_t; + using SAccElementFunction = remove_cvref_t; + using PComputeElementFunction = remove_cvref_t; + using OAccElementFunction = remove_cvref_t; + + QElementFunction q_element_func; + KElementFunction k_element_func; + VElementFunction v_element_func; + BiasElementFunction bias_element_func; + LSEElementFunction lse_element_func; + SAccElementFunction s_acc_element_func; + PComputeElementFunction p_compute_element_func; + OAccElementFunction o_acc_element_func; +}; + +} // namespace ck_tile