merge fa_decode pipeline into fmha_fwd api

This commit is contained in:
aska-0096
2025-08-08 05:46:18 +00:00
parent fe63a646a4
commit b4640a9de6
13 changed files with 1276 additions and 531 deletions

View File

@@ -996,7 +996,7 @@ struct FmhaFwdDecodeKernel
v_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenK, false>{});
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);

File diff suppressed because it is too large Load Diff

View File

@@ -17,20 +17,21 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
@@ -63,10 +64,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
Problem::kPadHeadDimV; // support multiple of vector(like 8x)
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kHasUnevenSplits = true;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -120,7 +121,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
}
}();
static constexpr const char* name = "decode_qr";
static constexpr const char* name = "qr_async_trload";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
@@ -140,8 +141,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
@@ -194,8 +193,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
clear_tile(l);
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(I0), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(I0), number<kM0>{}, number<kN0>{});
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
@@ -400,12 +399,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(I0) + tile_idx.at(I1);
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ ||
physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
@@ -649,8 +643,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
@@ -707,8 +699,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
clear_tile(l);
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(I0), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(I0), number<kM0>{}, number<kN0>{});
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
@@ -923,12 +915,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(I0) + tile_idx.at(I1);
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ ||
physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}

View File

@@ -22,6 +22,7 @@ template <typename QDataType_,
bool kIsGroupMode_,
typename AttentionVariant_,
typename FmhaMask_,
bool kUseTrLoad_,
typename Traits_>
struct BlockFmhaPipelineProblem
{
@@ -46,6 +47,7 @@ struct BlockFmhaPipelineProblem
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;