mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 23:38:11 +00:00
Shift rotary_cos/rotary_sin by cache_seqlen_k
This commit is contained in:
@@ -1102,8 +1102,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
||||
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
|
||||
|
||||
ck_tile::reference_batched_rotary_position_embedding(
|
||||
q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro);
|
||||
q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro);
|
||||
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
}
|
||||
@@ -1165,10 +1168,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
|
||||
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
|
||||
|
||||
ck_tile::reference_batched_rotary_position_embedding(
|
||||
knew_host_ref,
|
||||
rotary_cos_host,
|
||||
rotary_sin_host,
|
||||
rotary_cos_slice,
|
||||
rotary_sin_slice,
|
||||
is_rotary_interleaved,
|
||||
knew_host_ref_ro.value());
|
||||
|
||||
|
||||
@@ -494,7 +494,7 @@ struct FmhaFwdAppendKVKernel
|
||||
const auto rotary_cos_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const QDataType*>(kargs.rotary_cos_ptr),
|
||||
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_q, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
@@ -505,8 +505,9 @@ struct FmhaFwdAppendKVKernel
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
|
||||
return make_tile_window(rotary_cos_dram,
|
||||
q_rotary_cos_sin_dram_window_lengths,
|
||||
{kargs.seqlen_k + i_m0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -519,7 +520,7 @@ struct FmhaFwdAppendKVKernel
|
||||
const auto rotary_sin_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const QDataType*>(kargs.rotary_sin_ptr),
|
||||
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_q, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
@@ -530,8 +531,9 @@ struct FmhaFwdAppendKVKernel
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
|
||||
return make_tile_window(rotary_sin_dram,
|
||||
q_rotary_cos_sin_dram_window_lengths,
|
||||
{kargs.seqlen_k + i_m0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -558,8 +560,9 @@ struct FmhaFwdAppendKVKernel
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
|
||||
return make_tile_window(rotary_cos_dram,
|
||||
knew_rotary_cos_sin_dram_window_lengths,
|
||||
{kargs.seqlen_k + i_n0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -583,8 +586,9 @@ struct FmhaFwdAppendKVKernel
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
|
||||
return make_tile_window(rotary_sin_dram,
|
||||
knew_rotary_cos_sin_dram_window_lengths,
|
||||
{kargs.seqlen_k + i_n0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user