mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
modify kv layout to vllm
This commit is contained in:
@@ -1079,14 +1079,22 @@ struct FmhaFwdKernel
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
make_tuple(kargs.seqlen_k / 16, kargs.hdim_q / 8, 16, 8),
|
||||
make_tuple(kargs.hdim_q * 16, 16 * 8, 8, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
const auto k_dram_transposed = transform_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16)),
|
||||
make_merge_transform(ck_tile::make_tuple(kargs.hdim_q / 8, 8))),
|
||||
ck_tile::make_tuple(ck_tile::sequence<0, 2>{}, ck_tile::sequence<1, 3>{}),
|
||||
ck_tile::make_tuple(ck_tile::sequence<0>{}, ck_tile::sequence<1>{}));
|
||||
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
k_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}();
|
||||
@@ -1095,16 +1103,18 @@ struct FmhaFwdKernel
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
make_tuple(kargs.seqlen_k / 16, kargs.hdim_v, 16),
|
||||
make_tuple(kargs.hdim_v * 16, 16, 1),
|
||||
// make_tuple(kargs.seqlen_k, kargs.hdim_v),
|
||||
// make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
@@ -1117,14 +1127,23 @@ struct FmhaFwdKernel
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.hdim_v, kargs.seqlen_k),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
// make_tuple(kargs.hdim_v, kargs.seqlen_k),
|
||||
// make_tuple(kargs.stride_v, 1),
|
||||
make_tuple(kargs.seqlen_k / 16, kargs.hdim_v, 16),
|
||||
make_tuple(kargs.hdim_v * 16, 16, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV_, kPadSeqLenK>{});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user