mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Update in whole_k_prefetch_trload pipeline to prefetch two k_tile for next iteration in the non-whole-k-perfetch path
This commit is contained in:
@@ -198,6 +198,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
constexpr bool kPreloadWholeNextIterationK =
|
||||
Policy::template IsPreloadWholeNextIterationK<Problem>();
|
||||
|
||||
// This path prefetches two k_tiles for next iteration, so it has the opportunity to
|
||||
// prefetch two v_tiles during Gemm0
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
static_assert(NumPrefetchV >= 2);
|
||||
};
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
@@ -243,17 +250,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
auto k_tiles = [&]() {
|
||||
if constexpr(kPreloadWholeNextIterationK)
|
||||
return statically_indexed_array<k_tile_type, n0_loops>{};
|
||||
else
|
||||
return statically_indexed_array<k_tile_type, 1>{};
|
||||
return statically_indexed_array<k_tile_type, 2>{};
|
||||
}();
|
||||
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
k_tiles[I1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// provide partition_index for LDS tile window with so that warp_id is in vgpr
|
||||
@@ -495,20 +507,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[I0]),
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0 % 2>{}]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
if constexpr(i_n0 < n0_loops - 1)
|
||||
if constexpr(i_n0 < n0_loops - 2)
|
||||
{
|
||||
k_tiles[I0] = load_tile(k_dram_window);
|
||||
k_tiles[number<i_n0 % 2>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
if constexpr(i_n0 >= n0_loops - 2)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
v_tiles[number<i_n0 - (n0_loops - 2)>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
@@ -607,10 +619,20 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
if constexpr(kPreloadWholeNextIterationK)
|
||||
{
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -698,6 +720,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(i_k1 == k1_loops - NumPrefetchV + 1)
|
||||
{
|
||||
if constexpr(!kPreloadWholeNextIterationK)
|
||||
{
|
||||
if(seqlen_k_curr < seqlen_k_end)
|
||||
{
|
||||
k_tiles[I1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
|
||||
Reference in New Issue
Block a user