modify kv layout to vllm

This commit is contained in:
zanzhang
2025-04-29 11:01:04 +08:00
parent edd92fc546
commit 26e950b9d7

View File

@@ -1079,14 +1079,22 @@ struct FmhaFwdKernel
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
make_tuple(kargs.seqlen_k / 16, kargs.hdim_q / 8, 16, 8),
make_tuple(kargs.hdim_q * 16, 16 * 8, 8, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
        const auto k_dram_transposed = transform_tensor_view(
            k_dram_naive,
            make_tuple(make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16)),
                       make_merge_transform(ck_tile::make_tuple(kargs.hdim_q / 8, 8))),
            ck_tile::make_tuple(ck_tile::sequence<0, 2>{}, ck_tile::sequence<1, 3>{}),
            ck_tile::make_tuple(ck_tile::sequence<0>{}, ck_tile::sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
k_dram_naive,
k_dram_transposed,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}();
@@ -1095,16 +1103,18 @@ struct FmhaFwdKernel
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
make_tuple(kargs.seqlen_k / 16, kargs.hdim_v, 16),
make_tuple(kargs.hdim_v * 16, 16, 1),
// make_tuple(kargs.seqlen_k, kargs.hdim_v),
// make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
                     make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
@@ -1117,14 +1127,23 @@ struct FmhaFwdKernel
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_k),
make_tuple(kargs.stride_v, 1),
// make_tuple(kargs.hdim_v, kargs.seqlen_k),
// make_tuple(kargs.stride_v, 1),
make_tuple(kargs.seqlen_k / 16, kargs.hdim_v, 16),
make_tuple(kargs.hdim_v * 16, 16, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
                     make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
return pad_tensor_view(
v_dram_naive,
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV_, kPadSeqLenK>{});
}