diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 849f463afa..ed025dcf5f 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -128,7 +128,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F namespace {{ template void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + if constexpr ({F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ @@ -283,7 +283,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes @@ -364,6 +364,14 @@ class FmhaFwdSplitKVApiTrait: else: assert False + def seqtune(self, max_bm0: int) -> str: + if self.bm0 == max_bm0: + return "true/*fall back to largest tile*/" + else: + if self.mode == "group": + return f"a.max_seqlen_q <= {self.bm0}" + return f"a.seqlen_q <= {self.bm0}" + @property def skcheck(self) -> str: if self.mode == "group": @@ -561,6 +569,7 @@ class FmhaFwdSplitKVApiPool: for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): per_hdim_case = str() for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()): + max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) inners = str() for i_trait, trait in enumerate(pool_by_hdim): inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format( @@ -579,6 +588,7 @@ class FmhaFwdSplitKVApiPool: F_pagedkv=BOOL_MAP[trait.pagedkv], F_sink=BOOL_MAP[trait.sink], F_scheck=trait.scheck, + F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -763,6 +773,7 @@ class KernelComponentFactoryBase: pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr_nwarp_sshuffle", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip elif dtype in ["fp8", "bf8"]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() @@ -846,11 +857,15 @@ class KernelComponentFactoryGfx11(KernelComponentFactoryBase): def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: if dtype in ["fp16", "bf16"]: return { - # bm0, bn0, bk0, bn1, bk1, - "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # bm0, bn0, bk0, bn1, bk1, + "32" : [FmhaFwdTileSize( 16, 64, 16, 32, 32, 32, 1, 2, 1, 1, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "64" : [FmhaFwdTileSize( 16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "128": [FmhaFwdTileSize( 16, 64, 32, 128, 32, 128, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "256": [FmhaFwdTileSize( 16, 64, 32, 256, 32, 256, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip else: return None @@ -863,11 +878,15 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase): def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: if dtype in ["fp16", "bf16"]: return { - # bm0, bn0, bk0, bn1, bk1, - "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # bm0, bn0, bk0, bn1, bk1, + "32" : [FmhaFwdTileSize( 16, 64, 16, 32, 32, 32, 1, 2, 1, 1, 2, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "64" : [FmhaFwdTileSize( 16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "128": [FmhaFwdTileSize( 16, 128, 32, 128, 32, 128, 1, 8, 1, 1, 8, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "256": [FmhaFwdTileSize( 16, 128, 32, 256, 32, 256, 1, 8, 1, 1, 8, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip elif dtype in ["fp8", "bf8"]: return { @@ -930,11 +949,17 @@ def get_fwd_splitkv_blobs( d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue - # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] + tiles = d[hdim_str] + if not isinstance(tiles, list): + tiles = [tiles] hdim = int(hdim_str) - for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, mask_impl) + ): + # Use qr_nwarp_sshuffle with multiple N warps and qr otherwise + if (tile.F_rn0 != 1) != (pipeline.tag == "qr_nwarp_sshuffle"): + continue if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 0b51dffa46..243ff87faa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -165,8 +165,10 @@ int override_num_splits_if_necessary( if(num_splits < 1 && p_drop == 0.0f) { + // props.multiProcessorCount for >=gfx10 is the number of WGPs (each has 2 CUs) + const int num_blocks_per_SM = props.warpSize == 32 ? 4 : 2; return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); + batch * nhead * num_m_blocks, props.multiProcessorCount * num_blocks_per_SM, 128); } return num_splits; @@ -648,8 +650,18 @@ fwd_result fmha_fwd_run(mode_enum mode, // legalize num_splits according to other options if(num_splits < 1) { + int nhead_merged = nhead; + int max_seqlen_q_merged = max_seqlen_q; + // When max_seqlen_q == 1 and multiple head groups are merged (kMergeNumHeadGroupsSeqLenQ) + // then more splits are required + if(bias.type == bias_enum::no_bias && mask.type == mask_enum::no_mask && + max_seqlen_q == 1 && nhead_k < nhead) + { + nhead_merged = nhead_k; + max_seqlen_q_merged = max_seqlen_q * (nhead / nhead_k); + } num_splits = override_num_splits_if_necessary( - batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + batch, nhead_merged, max_seqlen_q_merged, hdim_v, p_drop, num_splits); } if(128 < num_splits) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index adc8ea5a90..bdc598f754 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -257,7 +258,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS clear_tile(o_acc); if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) { - set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E}); + set_tile(m, SMPLComputeDataType{sink_v * static_cast(C_LOG2E)}); set_tile(l, SMPLComputeDataType{1.0f}); } else @@ -698,8 +699,15 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#if defined(__gfx11__) + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif // l{j}, Oacc{j} constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp index c5af751cd5..316720ac22 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp @@ -5,8 +5,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" namespace ck_tile { @@ -163,6 +161,25 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; constexpr index_t kTileK = Problem::BlockFmhaShape::kN0; +#if defined(__gfx11__) + // Keep C distribution and replicate data for NWarp to prevent doubling registers + // PermuteWarpGemmCToA will convert C distribution to A for matrix P later + constexpr index_t K1 = kKPerBlock / WG::kM; + constexpr index_t K0 = kTileK / kKPerBlock; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / WG::kN; + + constexpr auto s2_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 1>>{}; + + constexpr auto s2_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + s2_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); +#else // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane; @@ -179,7 +196,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy tuple, sequence<2, 2>>, sequence<1, 2, 2, 2>, sequence<0, 0, 1, 3>>{}; - +#endif constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding); return s2_block_dstr; diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index 6ae33da30f..bdfc2d17c4 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -735,6 +735,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, Values(3, 4), Values(std::tuple{4, 3, 1, 200, 1024, "0"}, std::tuple{2, 2, -1, 512, 2000, "0"}, + std::tuple{2, 8, 2, 1, 1024, "0"}, std::tuple{3, 2, -1, 230, 899, "t:128,128"}))); TEST_P(SplitKV, DataTypeConfig)