Using explicit vgpr-saved partition_index with store_tile(lds_window, ...)

This commit is contained in:
Qianfeng Zhang
2025-12-08 04:54:22 +00:00
parent 044f554bf7
commit 2ea8d8313c

View File

@@ -359,6 +359,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
// provide partition_index for LDS tile window with so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
do
{
// STAGE 1, Gemm_0 ( S = Q@K )
@@ -369,9 +372,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
partition_index);
if constexpr(i_k1 < k1_loops - 1)
{
@@ -410,9 +413,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
else // the iteration is also the last iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
partition_index);
if constexpr(i_k1 < k1_loops - 1)
{
@@ -445,9 +448,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
partition_index);
if constexpr(i_k1 < NumPrefetchV)
{
@@ -482,9 +485,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
else // last iteration
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
partition_index);
if constexpr(i_k1 < NumPrefetchV)
{
@@ -511,7 +514,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
{
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[I0]));
tile_elementwise_in(k_element_func, k_tiles[I0]),
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -678,7 +682,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
store_tile(
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -717,7 +722,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
tile_elementwise_in(v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
v_shuffled_tile);
v_shuffled_tile,
partition_index);
};
});