diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index fb8a4389f3..37745dd382 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -48,8 +48,8 @@ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_mask_{F_idx} = {F_mask}; namespace {{ -template -struct kernel_runner {{ +template +struct instance {{ using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_shape = ck_tile::TileFmhaShape; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< @@ -115,28 +116,50 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wtautological-compare" + +namespace {{ +template +void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ + if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> + || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ + if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ + instance::run(s, a); + }} else {{ + instance::run(s, a); + }} + }} else {{ + instance::run(s, a); + }} +}} +}} // anonymous namespace + +#pragma clang diagnostic pop + template<> void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if constexpr({F_mode} == false) {{ // batch mode // we don't check every seqlen_k values for kvcache if (a.seqlen_k_ptr != nullptr) {{ - kernel_runner::run(s, a); + run_instance(s, a); // make sure F_bn0 is divisible by F_bk1 }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - kernel_runner::run(s, a); + run_instance(s, a); }} else {{ - kernel_runner::run(s, a); + run_instance(s, a); }} }} else {{ - kernel_runner::run(s, a); + run_instance(s, a); }} }} template<> std::string fmha_fwd_splitkv_get_name_() {{ - using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + using k_ = instance::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ @@ -146,7 +169,7 @@ using fmha_dtype_{F_idx} = {F_dtype}; namespace {{ template -struct kernel_runner {{ +struct instance {{ using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, {F_dvpad}, {F_lse}, @@ -196,22 +219,22 @@ template<> void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if (a.num_splits <= 8) {{ - kernel_runner<3>::run(s, a); + instance<3>::run(s, a); }} else if (a.num_splits <= 16) {{ - kernel_runner<4>::run(s, a); + instance<4>::run(s, a); }} else if (a.num_splits <= 32) {{ - kernel_runner<5>::run(s, a); + instance<5>::run(s, a); }} else if (a.num_splits <= 64) {{ - kernel_runner<6>::run(s, a); + instance<6>::run(s, a); }} else if (a.num_splits <= 128) {{ - kernel_runner<7>::run(s, a); + instance<7>::run(s, a); }} }} template<> std::string fmha_fwd_splitkv_combine_get_name_() {{ - using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type + using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 0368de352f..765c221a7b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -510,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) } }(); - dim3 grids = - Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + dim3 grids = Kernel::GridSize( + args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits); return ck_tile::make_tuple(kargs, grids); } diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 41f3383c7f..02ce449912 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 2f3a302eea..440b306705 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index d06d8529ac..8b5302257c 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 1510f18a30..9b9bf30ad3 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index cd1e43fb8c..15fa269740 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index c24744bdbc..95ead2645e 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index ba76e3070d..616db2fa5b 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index d5920f4837..4cbb59e95b 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 10ab25119b..92dc2bac3f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -47,10 +47,16 @@ struct FmhaFwdSplitKVKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + static constexpr bool kMergeNumHeadGroupsSeqLenQ = + FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; + static_assert(!kMergeNumHeadGroupsSeqLenQ || + (kMergeNumHeadGroupsSeqLenQ && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && + !kHasMask)); + // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; @@ -476,15 +482,20 @@ struct FmhaFwdSplitKVKernel } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, - ck_tile::index_t nhead, + ck_tile::index_t nhead_q, + ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) { + ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q; + ck_tile::index_t max_seqlen_q_ = + max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1); + // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) * ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, - nhead, + nhead_, batch_size); } @@ -562,7 +573,7 @@ struct FmhaFwdSplitKVKernel // # of required blocks is different in each groups, terminate unnecessary blocks // earlier - if(kargs.seqlen_q <= i_m0) + if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0) { return; } @@ -617,30 +628,60 @@ struct FmhaFwdSplitKVKernel } // for simplicity, batch stride we just modify the pointer + const index_t i_nhead_k = + (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk); + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + + static_cast(i_nhead) * + (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) * + kargs.nhead_stride_q + batch_offset_q; - const KDataType* k_ptr = - reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = - reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - batch_offset_v; + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead_k) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead_k) * kargs.nhead_stride_v + + batch_offset_v; ODataType* o_acc_ptr = reinterpret_cast(kargs.o_acc_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_o_acc + + static_cast(i_nhead) * + (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) * + kargs.nhead_stride_o_acc + batch_offset_o_acc + i_split * kargs.split_stride_o_acc; // Q/K/V DRAM and DRAM window - const auto q_dram = [&]() { - const auto q_dram_naive = make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); + const auto q_dram = [&] { + const auto q_dram_naive = [&] { + if constexpr(kMergeNumHeadGroupsSeqLenQ) + { + // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q, + // hdim_q) + const auto view = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1), + number{}, + number<1>{}); + + return transform_tensor_view( + view, + make_tuple( + make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)), + make_pass_through_transform(kargs.hdim_q)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + return make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + } + }(); + if constexpr(FmhaPipeline::kQLoadOnce) { return pad_tensor_view( @@ -729,7 +770,7 @@ struct FmhaFwdSplitKVKernel } }(); - auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + auto k_page_block_navigator = [&, i_batch_ = i_batch]() { if constexpr(kIsPagedKV) { const auto* block_indices = @@ -739,8 +780,7 @@ struct FmhaFwdSplitKVKernel integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); const long_index_t fixed_offset = - static_cast(i_nhead_ / kargs.nhead_ratio_qk) * - kargs.nhead_stride_k; + static_cast(i_nhead_k) * kargs.nhead_stride_k; return make_page_block_navigator( kargs.k_ptr, @@ -760,7 +800,7 @@ struct FmhaFwdSplitKVKernel } }(); - auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + auto v_page_block_navigator = [&, i_batch_ = i_batch]() { if constexpr(kIsPagedKV) { const auto* block_indices = @@ -770,8 +810,7 @@ struct FmhaFwdSplitKVKernel integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); const long_index_t fixed_offset = - static_cast(i_nhead_ / kargs.nhead_ratio_qk) * - kargs.nhead_stride_v; + static_cast(i_nhead_k) * kargs.nhead_stride_v; return make_page_block_navigator( kargs.v_ptr, @@ -842,19 +881,40 @@ struct FmhaFwdSplitKVKernel // lse acc auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() { constexpr auto lse_acc_dram_window_lengths = make_tuple(number{}); - LSEDataType* lse_acc_ptr = - reinterpret_cast(kargs.lse_acc_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse_acc + - batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc; + LSEDataType* lse_acc_ptr = reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead_) * + (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) * + kargs.nhead_stride_lse_acc + + batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc; - const auto lse_acc_dram = [&]() { - const auto lse_acc_dram_naive = - make_naive_tensor_view(lse_acc_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); + const auto lse_acc_dram = [&] { + const auto lse_acc_dram_naive = [&] { + if constexpr(kMergeNumHeadGroupsSeqLenQ) + { + // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q) + const auto view = make_naive_tensor_view( + lse_acc_ptr, + make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q), + make_tuple(kargs.nhead_stride_lse_acc, 1), + number<1>{}, + number<1>{}); + return transform_tensor_view(view, + make_tuple(make_merge_transform(make_tuple( + kargs.nhead_ratio_qk, kargs.seqlen_q))), + make_tuple(sequence<0, 1>{}), + make_tuple(sequence<0>{})); + } + else + { + return make_naive_tensor_view( + lse_acc_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + } + }(); return pad_tensor_view( lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence{}); }(); @@ -953,13 +1013,37 @@ struct FmhaFwdSplitKVKernel }(); // Oacc DRAM and Oacc DRAM window - auto o_acc_dram = [&]() { - const auto o_acc_dram_naive = make_naive_tensor_view( - o_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o_acc, 1), - number{}, - number<1>{}); + auto o_acc_dram = [&] { + const auto o_acc_dram_naive = [&] { + if constexpr(kMergeNumHeadGroupsSeqLenQ) + { + // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q, + // hdim_v) + const auto view = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1), + number{}, + number<1>{}); + + return transform_tensor_view( + view, + make_tuple( + make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)), + make_pass_through_transform(kargs.hdim_v)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + return make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o_acc, 1), + number{}, + number<1>{}); + } + }(); return pad_tensor_view( o_acc_dram_naive, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 1fe19faaf9..9a5208c025 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -94,16 +94,17 @@ struct BlockFmhaFwdSplitKVPipelineProblem static constexpr bool kIsGroupMode = kIsGroupMode_; // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr auto BiasEnum = Traits::BiasEnum; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; - static constexpr bool kIsPagedKV = Traits::kIsPagedKV; - static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr bool kIsPagedKV = Traits::kIsPagedKV; + static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; + static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; // extract tile size attributes to remove dependency on traits diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index d7bf8ea7e7..8d2d848558 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -43,7 +43,8 @@ template + bool kMergeNumHeadGroupsSeqLenQ_ = false, + index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> struct TileFmhaFwdSplitKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kIsPagedKV = kIsPagedKV_; // determine if some split (length) is not divisible by tile size - static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; + static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; + static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; }; template