mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
tempsave, fmha_decode
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
||||
|
||||
1140
include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp
Normal file
1140
include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1071,20 +1071,20 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_lengths,
|
||||
k_page_block_navigator,
|
||||
// k_page_block_navigator,
|
||||
v_dram_window_lengths,
|
||||
v_page_block_navigator,
|
||||
// v_page_block_navigator,
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits,
|
||||
i_split_,
|
||||
// kargs.num_splits,
|
||||
// i_split_,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
// kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -47,11 +47,15 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
// static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
|
||||
// Problem::kPadHeadDimV == true);
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; // support multiple of vector(like 8x)
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; // support multiple of vector(like 8x)
|
||||
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
@@ -65,19 +69,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
return Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentOacc =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
|
||||
static constexpr index_t kAlignmentOacc = Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
@@ -349,7 +350,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
@@ -370,9 +371,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
|
||||
// load the second tile of the first iteration
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
@@ -385,9 +383,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 1
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
@@ -400,22 +399,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kM0, (k0_loops - 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
|
||||
struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
|
||||
@@ -168,6 +168,113 @@ struct BlockFmhaSplitKVCombinePipelineProblem
|
||||
(kM0 * kMaxSplits) % get_warp_size() == 0);
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename LSEDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaFwdDecodePipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
// extract tile size attributes to remove dependency on traits
|
||||
template <typename OaccDataType_, ck_tile::index_t kN1_>
|
||||
struct BlockFmhaDecodeCombinePipelineTileSizes
|
||||
{
|
||||
static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
|
||||
|
||||
static constexpr index_t kN1 = kN1_;
|
||||
static constexpr index_t NThreads = kN1 / MaxVectorSize;
|
||||
static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
|
||||
};
|
||||
|
||||
template <typename LSEDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
index_t HeadDimV_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kN1_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaDecodeCombinePipelineProblem
|
||||
: BlockFmhaDecodeCombinePipelineTileSizes<OaccDataType_, kN1_>
|
||||
{
|
||||
using BaseType = BlockFmhaDecodeCombinePipelineTileSizes<OaccDataType_, kN1_>;
|
||||
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static_assert(std::is_same_v<LSEDataType, OaccDataType>);
|
||||
|
||||
static constexpr index_t kHeadDimV = HeadDimV_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
using BaseType::kM0;
|
||||
using BaseType::kN1;
|
||||
|
||||
static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
|
||||
static_assert(8 <= kMaxSplits);
|
||||
|
||||
static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
|
||||
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
|
||||
|
||||
static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
|
||||
(kM0 * kMaxSplits) % get_warp_size() == 0);
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
|
||||
@@ -86,6 +86,55 @@ struct TileFmhaFwdSplitKVCombineTraits
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kHasLogitsSoftCap_,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kHasUnevenSplits_,
|
||||
bool kMergeNumHeadGroupsSeqLenQ_ = false,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaFwdDecodeTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
// determine if some split (length) is not divisible by tile size
|
||||
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kStoreLSE_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kLogMaxSplits_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaFwdDecodeCombineTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
|
||||
static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_);
|
||||
static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0);
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
|
||||
Reference in New Issue
Block a user