mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Update to only pre-load one v_tile during Gemm0 loop
This commit is contained in:
@@ -373,14 +373,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
// prefetch all k_tiles for next iteration
|
||||
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
|
||||
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
|
||||
@@ -413,9 +410,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
@@ -441,9 +438,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
if constexpr(i_n0 == 0)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
@@ -470,9 +467,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
if constexpr(i_n0 == 0)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
@@ -507,10 +504,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
static_for<0, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -593,6 +588,31 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto v_shuffled_tile = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>());
|
||||
shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0]));
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(
|
||||
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
@@ -653,22 +673,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto v_shuffled_tile = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>());
|
||||
shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0]));
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(
|
||||
v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -377,14 +377,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
|
||||
// prefetch all k_tiles for next iteration
|
||||
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
|
||||
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
|
||||
@@ -417,9 +414,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
};
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
@@ -445,9 +442,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
if constexpr(i_n0 == 0)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
@@ -474,9 +471,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
if constexpr(i_n0 < NumPrefetchV)
|
||||
if constexpr(i_n0 == 0)
|
||||
{
|
||||
v_tiles[i_n0] = load_tile(v_dram_window);
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
@@ -511,10 +508,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
|
||||
if constexpr(i_n0 == n0_loops - 1)
|
||||
{
|
||||
static_for<0, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
@@ -597,6 +592,28 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
@@ -657,19 +674,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
Reference in New Issue
Block a user