Fix wrong knew/vnew appending positions

This commit is contained in:
PoYen, Chen
2024-07-23 04:46:53 +00:00
parent 56df4d6397
commit bc7c7ee0c5

View File

@@ -514,8 +514,9 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
/// TODO: use tile idx for q
return make_tile_window(
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0});
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
}
else
{
@@ -539,8 +540,9 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
/// TODO: use tile idx for q
return make_tile_window(
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0});
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
}
else
{
@@ -568,7 +570,7 @@ struct FmhaFwdAppendKVKernel
}();
return make_tile_window(
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0});
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
}
else
{
@@ -593,7 +595,7 @@ struct FmhaFwdAppendKVKernel
}();
return make_tile_window(
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0});
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
}
else
{
@@ -601,6 +603,7 @@ struct FmhaFwdAppendKVKernel
}
}();
/// TODO: use tile idx for q
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(number<FmhaPipeline::kTileSizeS>{}, number<FmhaPipeline::kTileSizeD>{}),
@@ -609,7 +612,7 @@ struct FmhaFwdAppendKVKernel
auto k_dram_window = make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
{kargs.seqlen_k, 0});
{kargs.seqlen_k + i_sk, 0});
auto knew_dram_window = make_tile_window(
knew_dram,
@@ -619,7 +622,7 @@ struct FmhaFwdAppendKVKernel
auto v_dram_window = make_tile_window(
v_dram,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kTileSizeSk>{}),
{0, kargs.seqlen_k});
{0, kargs.seqlen_k + i_sk});
auto vnew_dram_window = make_tile_window(
vnew_dram,