mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Improve both the with_softmax and no_softmax pipelines
This commit is contained in:
@@ -73,7 +73,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
|
||||
using type = ck_tile::sequence<128, 64, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
@@ -448,6 +448,27 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
});
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
|
||||
|
||||
v_shuffled_tile_type v_shuffled_tile;
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
|
||||
@@ -466,28 +487,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
|
||||
|
||||
statically_indexed_array<v_shuffled_tile_type, k1_loops> v_shuffled_tiles;
|
||||
|
||||
static_for<0, k1_loops, 1>{}(
|
||||
[&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); });
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
@@ -500,6 +501,17 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
};
|
||||
});
|
||||
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
|
||||
@@ -219,9 +219,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
|
||||
constexpr index_t NumPrefetchK = 2;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
|
||||
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
|
||||
|
||||
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
});
|
||||
@@ -391,14 +396,23 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[i_k1],
|
||||
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
|
||||
store_tile(
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1 % NumPrefetchK>{}]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchK)
|
||||
{
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[number<i_k1 - (k1_loops - NumPrefetchK)>{}] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -417,8 +431,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
@@ -477,6 +489,35 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
|
||||
|
||||
v_shuffled_tile_type v_shuffled_tile;
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
@@ -544,31 +585,21 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
|
||||
using v_shuffled_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>()));
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
|
||||
|
||||
statically_indexed_array<v_shuffled_tile_type, k1_loops> v_shuffled_tiles;
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
|
||||
static_for<0, k1_loops, 1>{}(
|
||||
[&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); });
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1]));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
if constexpr(i_k1 < NumPrefetchK)
|
||||
{
|
||||
// load k_tiles used by next iteration
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -578,6 +609,17 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(i_k1 < k1_loops - 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 2>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
};
|
||||
});
|
||||
|
||||
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
|
||||
|
||||
Reference in New Issue
Block a user