mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Use explicit partition_index to ensure warp_id is allocated on vpgr when accessing LDS tile_window
This commit is contained in:
@@ -57,7 +57,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
static_assert(NWarp == 1, "Check failed!");
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ struct BlockGemmARegBSmemCRegV2Hack_1
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
|
||||
@@ -57,7 +57,7 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
const index_t iNWarp = get_warp_id<false>() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
|
||||
@@ -212,6 +212,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// provide partition_index for LDS tile window so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// Q tile in LDS
|
||||
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -223,7 +226,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>(),
|
||||
partition_index);
|
||||
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
@@ -333,7 +337,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
{
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index);
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
|
||||
// by each wavefront is read by itself
|
||||
@@ -375,7 +379,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1]);
|
||||
store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1], partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -458,7 +462,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
|
||||
store_tile(
|
||||
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -501,7 +506,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
|
||||
v_shuffled_tile);
|
||||
v_shuffled_tile,
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
};
|
||||
|
||||
@@ -205,6 +205,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// provide partition_index for LDS tile window so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// Q tile in LDS
|
||||
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -216,7 +219,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
Policy::template MakeQRegTileDistribution<Problem>(),
|
||||
partition_index);
|
||||
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
@@ -315,7 +319,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
store_tile(q_lds_write_window, q_dram_tile);
|
||||
store_tile(q_lds_write_window, q_dram_tile, partition_index);
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
@@ -344,7 +348,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}], k_tiles[i_k1]);
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
k_tiles[i_k1],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -439,7 +445,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
v_tiles[number<i_k1>{}]);
|
||||
v_tiles[number<i_k1>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
|
||||
@@ -229,6 +229,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// provide partition_index for LDS tile window so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// Q tile in LDS
|
||||
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -240,7 +243,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>(),
|
||||
partition_index);
|
||||
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
@@ -347,7 +351,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
{
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index);
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
|
||||
// by each wavefront is read by itself
|
||||
@@ -393,7 +397,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}]);
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -500,7 +505,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
|
||||
store_tile(
|
||||
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -581,7 +587,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
|
||||
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile);
|
||||
store_tile(
|
||||
v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -609,7 +616,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 2>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
|
||||
v_shuffled_tile);
|
||||
v_shuffled_tile,
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
};
|
||||
|
||||
@@ -222,6 +222,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// provide partition_index for LDS tile window so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// Q tile in LDS
|
||||
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -329,7 +332,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
store_tile(q_lds_write_window, q_dram_tile);
|
||||
store_tile(q_lds_write_window, q_dram_tile, partition_index);
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
@@ -362,7 +365,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}]);
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -464,7 +468,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_tiles[number<0>{}]);
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
v_tiles[number<0>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -550,7 +556,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_tiles[number<1>{}]);
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
|
||||
v_tiles[number<1>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -577,7 +585,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
|
||||
v_tiles[number<i_k1 + 2>{}]);
|
||||
v_tiles[number<i_k1 + 2>{}],
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user