mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Using explicit vgpr-saved partition_index with store_tile(lds_window, ...)
This commit is contained in:
@@ -359,6 +359,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
|
||||
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
@@ -369,9 +372,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
@@ -410,9 +413,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
else // the iteration is also the last iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
@@ -445,9 +448,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
{
|
||||
@@ -482,9 +485,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
else // last iteration
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
{
|
||||
@@ -511,7 +514,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
{
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[I0]));
|
||||
tile_elementwise_in(k_element_func, k_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -678,7 +682,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
|
||||
store_tile(
|
||||
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -717,7 +722,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
|
||||
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
|
||||
v_shuffled_tile);
|
||||
v_shuffled_tile,
|
||||
partition_index);
|
||||
};
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user