mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
merge fa_decode pipeline into fmha_fwd api
This commit is contained in:
@@ -996,7 +996,7 @@ struct FmhaFwdDecodeKernel
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<kPadSeqLenK, false>{});
|
||||
|
||||
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,20 +17,21 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -63,10 +64,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
Problem::kPadHeadDimV; // support multiple of vector(like 8x)
|
||||
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
static constexpr bool kHasUnevenSplits = true;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -120,7 +121,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "decode_qr";
|
||||
static constexpr const char* name = "qr_async_trload";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
@@ -140,8 +141,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
@@ -194,8 +193,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(I0), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(I0), number<kM0>{}, number<kN0>{});
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
@@ -400,12 +399,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(I0) + tile_idx.at(I1);
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < physical_seqlen_k_start_ ||
|
||||
physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
|
||||
{
|
||||
return physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
@@ -649,8 +643,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
@@ -707,8 +699,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(I0), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(I0), number<kM0>{}, number<kN0>{});
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
@@ -923,12 +915,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(I0) + tile_idx.at(I1);
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < physical_seqlen_k_start_ ||
|
||||
physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
|
||||
{
|
||||
return physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ template <typename QDataType_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename FmhaMask_,
|
||||
bool kUseTrLoad_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaPipelineProblem
|
||||
{
|
||||
@@ -46,6 +47,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
|
||||
Reference in New Issue
Block a user