Fix rotary cos/sin tensor/tile size

This commit is contained in:
PoYen, Chen
2024-07-16 06:31:17 +00:00
parent b32fd8d3f4
commit 99f863e4cd

View File

@@ -524,7 +524,7 @@ struct FmhaFwdAppendKVKernel
}
}();
constexpr auto rotary_cos_sin_dram_window_lengths =
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{});
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD / 2>{});
const auto rotary_cos_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
@@ -532,7 +532,7 @@ struct FmhaFwdAppendKVKernel
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
@@ -557,7 +557,7 @@ struct FmhaFwdAppendKVKernel
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});