mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Refine the interleaving in the loop of Gemm0
This commit is contained in:
@@ -179,12 +179,12 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(k1_loops >= 2,
|
||||
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
|
||||
// usually kN0 is 128, kK1 is 32/16
|
||||
static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline");
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t NumPrefetchV = 2;
|
||||
constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
|
||||
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
|
||||
|
||||
constexpr bool kPreloadWholeNextIterationK =
|
||||
@@ -377,18 +377,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
{
|
||||
k_tiles[number<i_k1 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
if constexpr(i_k1 == k1_loops - 1)
|
||||
{
|
||||
// prefetch all k_tiles for next iteration
|
||||
static_for<0, k1_loops, 1>{}([&](auto ii_k1) {
|
||||
k_tiles[number<ii_k1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
@@ -414,12 +418,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
{
|
||||
k_tiles[number<i_k1 + 1>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
}
|
||||
else
|
||||
};
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
}
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
@@ -444,27 +449,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
|
||||
|
||||
if constexpr(i_k1 == 0)
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
{
|
||||
// prefetch first v_tile
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
// prefetch first two k_tiles for next iteration
|
||||
if constexpr(i_k1 == 1)
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
k_tiles[I1] = load_tile(k_dram_window);
|
||||
// prefetch k_tile for next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
// prefetch other k_tiles for next iteration
|
||||
if constexpr(i_k1 >= 2)
|
||||
if constexpr(i_k1 >= NumPrefetchV)
|
||||
{
|
||||
k_tiles[number<i_k1>{}] = load_tile(k_dram_window);
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
@@ -488,9 +486,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1>{}]));
|
||||
|
||||
if constexpr(i_k1 == 0)
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
@@ -521,10 +519,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
}
|
||||
else
|
||||
};
|
||||
|
||||
if constexpr(i_k1 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[number<i_k1 - (k1_loops - 1)>{}] = load_tile(v_dram_window);
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
@@ -545,11 +544,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x000000001);
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
v_tiles[i_buf] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
|
||||
@@ -23,6 +23,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
return Problem::BlockFmhaShape::kM0 <= 64;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV()
|
||||
{
|
||||
constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1;
|
||||
|
||||
// usually kN0 is 128, kK1 is 32/16
|
||||
static_assert(k1_loops >= 2, "Check failed!");
|
||||
|
||||
return 2;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user