From f2dd57b76fa8f94a1e542859a0a1bbe21d7fbd13 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 13 Mar 2025 11:41:39 +0800 Subject: [PATCH] =?UTF-8?q?Reapply=20"[CK=5FTILE]=20support=20hdim=3D192/1?= =?UTF-8?q?28=20pair=20for=20deepseekv3=20(#1961)"=20=E2=80=A6=20(#1971)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Reapply "[CK_TILE] support hdim=192/128 pair for deepseekv3 (#1961)" (#1969) This reverts commit b92caa3d84af3cc4fb5a2e340f1d8bcac44c4f0e. * fix codegen problem * Update config.hpp --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: 3e81279d26ed59d989de8a71703b23477c4c749d] --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 9 +++++++-- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py | 2 +- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 2 +- include/ck_tile/core.hpp | 5 +---- include/ck_tile/core/arch/amd_buffer_addressing.hpp | 4 ++++ .../ck_tile/core/arch/amd_buffer_addressing_builtins.hpp | 4 ++++ include/ck_tile/core/config.hpp | 8 ++++++++ include/ck_tile/ops/fmha.hpp | 4 ++-- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 7 +++++++ include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp | 2 ++ 10 files changed, 37 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f2d9216696..4ff7ede765 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ {F_inner_dispatch} }} """ @@ -288,7 +288,7 @@ class FmhaFwdApiPool: F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -417,6 +417,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -489,6 +490,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't' or (pipeline.F_mask not in ['no', 's_no']): + continue k = FmhaFwdKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 16048e3fb6..f243020dc4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool: F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 75305a1336..b1f9e30178 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -476,7 +476,7 @@ class FmhaFwdSplitKVApiPool: F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 94710e584f..821b3a8e84 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -8,11 +8,8 @@ #include "ck_tile/core/algorithm/indexing_adaptor.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/static_encoding_pattern.hpp" -#if __clang_major__ == 20 -#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" -#else #include "ck_tile/core/arch/amd_buffer_addressing.hpp" -#endif +#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 91c2508ba2..33faa3a18b 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -3,6 +3,8 @@ #pragma once +#if !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/vector_type.hpp" @@ -2553,3 +2555,5 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, } } // namespace ck_tile + +#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 2bbc75509b..0b9956cd01 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -3,6 +3,8 @@ #pragma once +#if CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/vector_type.hpp" @@ -2553,3 +2555,5 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, } } // namespace ck_tile + +#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index aaaf4d4259..eeaf0dca6f 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -252,3 +252,11 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) #else // for GPU code #define CK_TILE_USE_OCP_FP8 0 #endif + +#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN +#if __clang_major__ == 20 +#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1 +#else +#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0 +#endif +#endif diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 2618082e5b..a28b63f813 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -33,12 +33,12 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index d64e5562d0..67354fc72d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -112,6 +112,13 @@ struct BlockFmhaPipelineQRKSVSAsync else return 2; } + else if constexpr(kQKHeaddim <= 192) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } else if constexpr(kQKHeaddim <= 256) { return 1; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 5ce80c2d1f..76ba34115f 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -13,6 +13,8 @@ static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index return 128; if(len == 160) return 256; + if(len == 192) + return 192; // only length of 96, 160 and power-of-two is supported if(!(len & (len - 1)))