Use explicit partition_index to ensure warp_id is allocated on vpgr when accessing LDS tile_window

This commit is contained in:
Qianfeng Zhang
2025-11-22 16:12:21 +00:00
parent 4f33eb5857
commit f9e8c5539f
7 changed files with 53 additions and 23 deletions

View File

@@ -57,7 +57,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
static_assert(NWarp == 1, "Check failed!");

View File

@@ -57,7 +57,7 @@ struct BlockGemmARegBSmemCRegV2Hack_1
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,

View File

@@ -57,7 +57,7 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,

View File

@@ -212,6 +212,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
__builtin_amdgcn_sched_barrier(0);
// provide partition_index for LDS tile window so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// Q tile in LDS
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
@@ -223,7 +226,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
Policy::template MakeQRegSingleRepMTileDistribution<Problem>(),
partition_index);
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
@@ -333,7 +337,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
{
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index);
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
// by each wavefront is read by itself
@@ -375,7 +379,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1]);
store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1], partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -458,7 +462,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
store_tile(
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -501,7 +506,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
v_shuffled_tile);
v_shuffled_tile,
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
};

View File

@@ -205,6 +205,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_sched_barrier(0);
// provide partition_index for LDS tile window so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// Q tile in LDS
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
@@ -216,7 +219,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
make_tile_window(q_lds,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem>());
Policy::template MakeQRegTileDistribution<Problem>(),
partition_index);
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
@@ -315,7 +319,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
store_tile(q_lds_write_window, q_dram_tile);
store_tile(q_lds_write_window, q_dram_tile, partition_index);
clear_tile(o_acc);
@@ -344,7 +348,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}], k_tiles[i_k1]);
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[i_k1],
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -439,7 +445,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
v_tiles[number<i_k1>{}]);
v_tiles[number<i_k1>{}],
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);

View File

@@ -229,6 +229,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
__builtin_amdgcn_sched_barrier(0);
// provide partition_index for LDS tile window so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// Q tile in LDS
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
@@ -240,7 +243,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
Policy::template MakeQRegSingleRepMTileDistribution<Problem>(),
partition_index);
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
@@ -347,7 +351,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
{
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index);
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
// by each wavefront is read by itself
@@ -393,7 +397,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[number<i_k1 % NumPrefetchK>{}]);
k_tiles[number<i_k1 % NumPrefetchK>{}],
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -500,7 +505,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
store_tile(
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -581,7 +587,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile);
store_tile(
v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -609,7 +616,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 2>{}]);
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
v_shuffled_tile);
v_shuffled_tile,
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
};

View File

@@ -222,6 +222,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_sched_barrier(0);
// provide partition_index for LDS tile window so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// Q tile in LDS
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
@@ -329,7 +332,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
store_tile(q_lds_write_window, q_dram_tile);
store_tile(q_lds_write_window, q_dram_tile, partition_index);
clear_tile(o_acc);
@@ -362,7 +365,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[number<i_k1 % NumPrefetchK>{}]);
k_tiles[number<i_k1 % NumPrefetchK>{}],
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -464,7 +468,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_tiles[number<0>{}]);
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
v_tiles[number<0>{}],
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -550,7 +556,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_tiles[number<1>{}]);
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
v_tiles[number<1>{}],
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -577,7 +585,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_sched_barrier(0x00000001);
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
v_tiles[number<i_k1 + 2>{}]);
v_tiles[number<i_k1 + 2>{}],
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
};