Add element function to fmha api

This commit is contained in:
rocking
2024-03-29 18:05:36 -04:00
parent 50c36f352a
commit 286c74468d
10 changed files with 222 additions and 152 deletions

View File

@@ -15,6 +15,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -37,6 +37,23 @@ struct FmhaFwdKernel
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
using QElementFunction =
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::QElementFunction>;
using KElementFunction =
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::KElementFunction>;
using VElementFunction =
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::VElementFunction>;
using BiasElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::BiasElementFunction>;
using LSEElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::LSEElementFunction>;
using SAccElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::SAccElementFunction>;
using PComputeElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::PComputeElementFunction>;
using OAccElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::OAccElementFunction>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
@@ -77,7 +94,7 @@ struct FmhaFwdKernel
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
@@ -122,6 +139,15 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
QElementFunction q_element_func;
KElementFunction k_element_func;
VElementFunction v_element_func;
BiasElementFunction bias_element_func;
LSEElementFunction lse_element_func;
SAccElementFunction s_acc_element_func;
PComputeElementFunction p_compute_element_func;
OAccElementFunction o_acc_element_func;
};
struct FmhaFwdCommonBiasKargs
@@ -219,6 +245,14 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x,
QElementFunction q_element_func,
KElementFunction k_element_func,
VElementFunction v_element_func,
BiasElementFunction bias_element_func,
LSEElementFunction lse_element_func,
SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func,
float descale_qk,
float descale_sv)
{
@@ -243,11 +277,19 @@ struct FmhaFwdKernel
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8 args
nhead_stride_o,
q_element_func,
k_element_func,
v_element_func,
bias_element_func,
lse_element_func,
s_acc_element_func,
p_compute_element_func,
o_acc_element_func}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8 args
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -308,6 +350,14 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x,
QElementFunction q_element_func,
KElementFunction k_element_func,
VElementFunction v_element_func,
BiasElementFunction bias_element_func,
LSEElementFunction lse_element_func,
SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func,
float descale_qk,
float descale_sv)
{
@@ -332,11 +382,19 @@ struct FmhaFwdKernel
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8 args
nhead_stride_o,
q_element_func,
k_element_func,
v_element_func,
bias_element_func,
lse_element_func,
s_acc_element_func,
p_compute_element_func,
o_acc_element_func}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8 args
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -661,10 +719,18 @@ struct FmhaFwdKernel
else
{
return FmhaPipeline{}(q_dram_window,
kargs.q_element_func,
k_dram_window,
kargs.k_element_func,
v_dram_window,
kargs.v_element_func,
bias_dram_window,
kargs.bias_element_func,
lse_dram_window,
kargs.lse_element_func,
kargs.s_acc_element_func,
kargs.p_compute_element_func,
kargs.o_acc_element_func,
mask,
kargs.scale,
smem_ptr);

View File

@@ -17,6 +17,7 @@ template <typename QDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename ElementFunctions_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
@@ -33,6 +34,7 @@ struct BlockFmhaPipelineProblem
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using ElementFunctions = remove_cvref_t<ElementFunctions_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;

View File

@@ -559,39 +559,6 @@ struct BlockFmhaPipelineQRKSVS
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale,
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
scale,
smem_ptr);
}
};
} // namespace ck_tile

View File

@@ -654,39 +654,6 @@ struct BlockFmhaPipelineQRKSVSAsync
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale,
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
scale,
smem_ptr);
}
};
} // namespace ck_tile

View File

@@ -550,36 +550,6 @@ struct BlockFmhaPipelineQSKSVS
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale,
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
lse_dram_block_window_tmp,
identity{},
mask,
scale,
smem_ptr);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,39 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename QElementFunction_,
typename KElementFunction_,
typename VElementFunction_,
typename BiasElementFunction_,
typename LSEElementFunction_,
typename SAccElementFunction_,
typename PComputeElementFunction_,
typename OAccElementFunction_>
struct FmhaElementFunctions
{
using QElementFunction = remove_cvref_t<QElementFunction_>;
using KElementFunction = remove_cvref_t<KElementFunction_>;
using VElementFunction = remove_cvref_t<VElementFunction_>;
using BiasElementFunction = remove_cvref_t<BiasElementFunction_>;
using LSEElementFunction = remove_cvref_t<LSEElementFunction_>;
using SAccElementFunction = remove_cvref_t<SAccElementFunction_>;
using PComputeElementFunction = remove_cvref_t<PComputeElementFunction_>;
using OAccElementFunction = remove_cvref_t<OAccElementFunction_>;
QElementFunction q_element_func;
KElementFunction k_element_func;
VElementFunction v_element_func;
BiasElementFunction bias_element_func;
LSEElementFunction lse_element_func;
SAccElementFunction s_acc_element_func;
PComputeElementFunction p_compute_element_func;
OAccElementFunction o_acc_element_func;
};
} // namespace ck_tile