diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 82cf3a5ab2..57360ea995 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -191,7 +191,9 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m template<> void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if (a.num_splits <= 16) {{ + if (a.num_splits <= 8) {{ + kernel_runner<3>::run(s, a); + }} else if (a.num_splits <= 16) {{ kernel_runner<4>::run(s, a); }} else if (a.num_splits <= 32) {{ kernel_runner<5>::run(s, a); @@ -239,7 +241,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); }} @@ -551,14 +553,14 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 16, 16, 16, -1), + '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 16, 16, 16, -1), + '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 16, 16, 16, -1), + '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) } @@ -568,16 +570,16 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), - '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), - '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), - '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), - '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), - '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + '64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), + '256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), } else: return None diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 1afe0feab3..7c49fce99a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -12,6 +12,16 @@ namespace detail { template struct log2; +template <> +struct log2<4> : std::integral_constant +{ +}; + +template <> +struct log2<8> : std::integral_constant +{ +}; + template <> struct log2<16> : std::integral_constant { @@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline { if constexpr(kHeadDimV <= 32) { - constexpr std::array occupancy{3, 3, 3, 1}; - return occupancy[detail::log2::value - 4]; + constexpr std::array occupancy{3, 3, 3, 3, 3, 1}; + return occupancy[detail::log2::value - 2]; } else if constexpr(kHeadDimV <= 128) { - constexpr std::array occupancy{3, 3, 2, 1}; - return occupancy[detail::log2::value - 4]; + constexpr std::array occupancy{3, 3, 3, 3, 2, 1}; + return occupancy[detail::log2::value - 2]; } else if constexpr(kHeadDimV <= 256) { - constexpr std::array occupancy{2, 2, 2, 1}; - return occupancy[detail::log2::value - 4]; + constexpr std::array occupancy{2, 2, 2, 2, 2, 1}; + return occupancy[detail::log2::value - 2]; } } }(); @@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto lse_accum = make_static_distributed_tensor( Policy::template MakeLSEaccRegTileDistribution()); - // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)]) - // this will extend the distributed tensor width so that each thread in wave have data to - // reduce. + // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits]) + // and fill up -INF values outside the [kM0, num_splits] region. { constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 3327d4af87..ebd69c0cf8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -10,11 +10,26 @@ namespace ck_tile { struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile() + { + constexpr index_t PixelsPerThread = (M * N) / BlockSize; + static_assert(0 < PixelsPerThread); + + constexpr index_t MaxNPerThread = 16 / sizeof(DataType); + constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread); + + return NPerThread; + } + + // alignment for dram lse tile (shape=[kMaxSplits, kM0]) template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() { - using LSEDataType = remove_cvref_t; - return 16 / sizeof(LSEDataType); + return GetVectorSizeForTile(); } template @@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy MakeLSEaccLdsBlockDescriptor().get_element_space_size(); } + // shape=[kMaxSplits, kM0] template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() { using LSEDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNumWarps = Problem::kNumWarps; constexpr index_t kNPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kMaxSplits; - constexpr index_t NPerThread = 16 / sizeof(LSEDataType); - constexpr index_t NThreads = kNPerBlock / NPerThread; + constexpr index_t NPerThread = + GetVectorSizeForTile(); + constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; - constexpr index_t TotalWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp); + constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp); static_assert(NThreads * NPerThread == kNPerBlock); - static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock); + static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock); return make_static_tile_distribution( tile_distribution_encoding, - tuple, + tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, @@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy sequence<0, 1>>{}); } - // 3d + padding, [kMaxSplits, kM0] + // 3d + padding, shape=[kMaxSplits, kM0] template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor() { using LSEDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kNPerBlock = Problem::kM0; - constexpr index_t NPack = 16 / sizeof(LSEDataType); + constexpr index_t NPack = + GetVectorSizeForTile(); constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy return lse_acc_lds_block_desc; } - // 3d + padding, [kM0, kMaxSplits] + // 3d + padding, shape=[kM0, kMaxSplits] template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor() { using LSEDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kNPerBlock = Problem::kM0; - constexpr index_t NPack = 16 / sizeof(LSEDataType); + constexpr index_t NPack = + GetVectorSizeForTile(); constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); + constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kM0; - constexpr index_t NThreads = get_warp_size(); + constexpr index_t NThreads = 4; constexpr index_t NPerThread = kNPerBlock / NThreads; - constexpr index_t MThreads = kBlockSize / NThreads; - constexpr index_t MPerThread = kMPerBlock / MThreads; + constexpr index_t MThreads = kBlockSize / NThreads; + constexpr index_t MPerThread = kMPerBlock / MThreads; + constexpr index_t MWarps = kBlockSize / get_warp_size(); + constexpr index_t MThreadPerWarp = get_warp_size() / NThreads; static_assert(NThreads * NPerThread == kNPerBlock); - static_assert(MThreads * MPerThread == kMPerBlock); + static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock); return make_static_tile_distribution( tile_distribution_encoding< sequence<1>, - tuple, sequence>, - tuple, sequence<2>>, - tuple, sequence<0>>, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 1>>, sequence<1, 2>, - sequence<1, 1>>{}); + sequence<2, 1>>{}); } template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index d254f07e2d..1846664e7d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem using ODataType = remove_cvref_t; using Traits = remove_cvref_t; - static constexpr index_t kBlockSize = 256; + static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4); + static constexpr index_t kBlockSize = kNumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr index_t kHeadDimV = HeadDimV_; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 8fa325241c..a66d2be78b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -88,22 +88,33 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { + constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + static_assert(WarpGemmM == 16 || WarpGemmM == 32); + if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; + else // WarpGemmM == 16 + return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; + else // WarpGemmM == 16 + return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { + static_assert(WarpGemmM == 32); + // TODO: hard coded here. Otherwise, it may incorrect result constexpr index_t swizzle_factor = 4; return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 436d964c37..e708255703 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -23,12 +23,12 @@ #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"