Shift rotary_cos/rotary_sin by cache_seqlen_k

This commit is contained in:
PoYen, Chen
2024-07-24 05:06:47 +00:00
parent a4da1e7f22
commit 59e1d9b84f
2 changed files with 23 additions and 13 deletions

View File

@@ -1102,8 +1102,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
ck_tile::reference_batched_rotary_position_embedding(
q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro);
q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro);
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
}
@@ -1165,10 +1168,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
ck_tile::reference_batched_rotary_position_embedding(
knew_host_ref,
rotary_cos_host,
rotary_sin_host,
rotary_cos_slice,
rotary_sin_slice,
is_rotary_interleaved,
knew_host_ref_ro.value());

View File

@@ -494,7 +494,7 @@ struct FmhaFwdAppendKVKernel
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const QDataType*>(kargs.rotary_cos_ptr),
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_k + kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
@@ -505,8 +505,9 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
return make_tile_window(rotary_cos_dram,
q_rotary_cos_sin_dram_window_lengths,
{kargs.seqlen_k + i_m0, 0});
}
else
{
@@ -519,7 +520,7 @@ struct FmhaFwdAppendKVKernel
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const QDataType*>(kargs.rotary_sin_ptr),
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_k + kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
@@ -530,8 +531,9 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
return make_tile_window(rotary_sin_dram,
q_rotary_cos_sin_dram_window_lengths,
{kargs.seqlen_k + i_m0, 0});
}
else
{
@@ -558,8 +560,9 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
return make_tile_window(rotary_cos_dram,
knew_rotary_cos_sin_dram_window_lengths,
{kargs.seqlen_k + i_n0, 0});
}
else
{
@@ -583,8 +586,9 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
return make_tile_window(rotary_sin_dram,
knew_rotary_cos_sin_dram_window_lengths,
{kargs.seqlen_k + i_n0, 0});
}
else
{