mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Add attn sink (#2892)
* enable attn sink Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update attn_sink script Signed-off-by: JL-underdog <Jun.Lin@amd.com> * fix some error Signed-off-by: JL-underdog <Jun.Lin@amd.com> * clang-format Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update fmha_bwd mask Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update fmha_bwd_kernel'mask Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update block_fmha_pipeline_qr_ks_vs.hpp Signed-off-by: JL-underdog <Jun.Lin@amd.com> * fix ci error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * fix format error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_bwd_pipeline_default_policy.hpp * Update fmha_fwd_runner.hpp * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * update splitkv_pipline Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update splitkv&pagedkv pipeline Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * add sink test Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update attn_sink result log Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update smoke_test_fwd_sink.sh Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update test file Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update test script Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp * use constexpr kHasSink for sink in fmha pipeline Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update by pre-commit Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fmha_fwd.py * Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove causal mask setting logic from mask.hpp Removed the mask setting logic for causal masks. * fix ci error that some usage of lamada not support in c++17 Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update remod.py * add smoke sink test Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update fmha_pagedkv_prefill.py * Update FmhaFwdPipeline parameters in fmha_fwd.py * update block_fmha_pipeline_qr_ks_vs_async_trload.hpp Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * fix c++17 unsupprot error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp * Fix formatting of sink_seq_end assignment * Fix indentation for sink_seq_end assignment * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp --------- Signed-off-by: JL-underdog <Jun.Lin@amd.com> Signed-off-by: LJ-underdog <Jun.Lin@amd.com> Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -265,6 +265,7 @@ struct fmha_fwd_args
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t sink_size;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
@@ -351,6 +352,7 @@ struct fmha_fwd_pagedkv_args
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t sink_size;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
};
|
||||
@@ -441,6 +443,7 @@ struct fmha_fwd_splitkv_args
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t sink_size;
|
||||
ck_tile::index_t mask_type;
|
||||
};
|
||||
|
||||
@@ -611,6 +614,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q,
|
||||
args.p_drop,
|
||||
@@ -660,6 +664,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
@@ -727,6 +732,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q);
|
||||
}
|
||||
@@ -772,6 +778,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
}
|
||||
}();
|
||||
@@ -838,6 +845,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
}
|
||||
else
|
||||
@@ -885,6 +893,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
}
|
||||
}();
|
||||
@@ -1131,7 +1140,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
bool kHasSink_ = false>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -1157,6 +1167,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
@@ -1183,7 +1194,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
bool kHasSink_ = false>
|
||||
struct fmha_fwd_pagedkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -1208,6 +1220,7 @@ struct fmha_fwd_pagedkv_traits_
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
@@ -1230,6 +1243,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kHasSink_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
@@ -1257,6 +1271,7 @@ struct fmha_fwd_splitkv_traits_
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
@@ -1343,6 +1358,7 @@ struct fmha_fwd_traits
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
bool skip_min_seqlen_q = false;
|
||||
bool has_sink = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
@@ -1361,6 +1377,7 @@ struct fmha_fwd_pagedkv_traits
|
||||
bool use_pagedkv = true;
|
||||
bool do_fp8_static_quant = false;
|
||||
bool skip_min_seqlen_q = false;
|
||||
bool has_sink = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
@@ -1380,6 +1397,7 @@ struct fmha_fwd_splitkv_traits
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool do_fp8_static_quant;
|
||||
bool has_sink = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
|
||||
|
||||
Reference in New Issue
Block a user