mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Fix rotary cos/sin tensor/tile size
This commit is contained in:
@@ -524,7 +524,7 @@ struct FmhaFwdAppendKVKernel
|
||||
}
|
||||
}();
|
||||
constexpr auto rotary_cos_sin_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{});
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD / 2>{});
|
||||
const auto rotary_cos_dram_window = [&]() {
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
@@ -532,7 +532,7 @@ struct FmhaFwdAppendKVKernel
|
||||
const auto rotary_cos_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
@@ -557,7 +557,7 @@ struct FmhaFwdAppendKVKernel
|
||||
const auto rotary_sin_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
Reference in New Issue
Block a user