From 26e950b9d72fef19fc87680c7975ba801cb2a2fd Mon Sep 17 00:00:00 2001 From: zanzhang Date: Tue, 29 Apr 2025 11:01:04 +0800 Subject: [PATCH] modify kv layout to vllm --- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 1202524950..dec4bf7c93 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1079,14 +1079,22 @@ struct FmhaFwdKernel const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), + make_tuple(kargs.seqlen_k / 16, kargs.hdim_q / 8, 16, 8), + make_tuple(kargs.hdim_q * 16, 16 * 8, 8, 1), number{}, number<1>{}); +         const auto k_dram_transposed = transform_tensor_view( +             k_dram_naive, +             make_tuple(make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16)), +                        make_merge_transform(ck_tile::make_tuple(kargs.hdim_q / 8, 8))), +             ck_tile::make_tuple(ck_tile::sequence<0, 2>{}, ck_tile::sequence<1, 3>{}), +             ck_tile::make_tuple(ck_tile::sequence<0>{}, ck_tile::sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; return pad_tensor_view( - k_dram_naive, + k_dram_transposed, make_tuple(number{}, number{}), sequence{}); }(); @@ -1095,16 +1103,18 @@ struct FmhaFwdKernel { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), + make_tuple(kargs.seqlen_k / 16, kargs.hdim_v, 16), + make_tuple(kargs.hdim_v * 16, 16, 1), + // make_tuple(kargs.seqlen_k, kargs.hdim_v), + // make_tuple(kargs.stride_v, 1), number{}, number<1>{}); const auto v_dram_transposed = transform_tensor_view(v_dram_naive, make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), +                      make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; @@ -1117,14 +1127,23 @@ struct FmhaFwdKernel { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k), - make_tuple(kargs.stride_v, 1), + // make_tuple(kargs.hdim_v, kargs.seqlen_k), + // make_tuple(kargs.stride_v, 1), + make_tuple(kargs.seqlen_k / 16, kargs.hdim_v, 16), + make_tuple(kargs.hdim_v * 16, 16, 1), number{}, number<1>{}); + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), +                      make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; return pad_tensor_view( - v_dram_naive, + v_dram_transposed, make_tuple(number{}, number{}), sequence{}); }