mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Further correction with regard to using n0_loops and k1_loops
This commit is contained in:
@@ -460,7 +460,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
@@ -224,13 +224,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
constexpr index_t NumPrefetchK = 2;
|
||||
|
||||
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
|
||||
static_assert(n0_loops >= NumPrefetchK, "Check failed!");
|
||||
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
|
||||
|
||||
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
static_for<0, NumPrefetchK, 1>{}([&](auto i_n0) {
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
|
||||
@@ -509,7 +509,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
@@ -519,7 +519,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<min(NumPrefetchK, k1_loops), k1_loops, 1>{}([&](auto i_k1) {
|
||||
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
@@ -597,6 +597,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
// k1_loops >= 2 required
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
|
||||
|
||||
// check whether second V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(
|
||||
v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);
|
||||
|
||||
|
||||
@@ -214,9 +214,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
constexpr index_t NumPrefetchK = (k1_loops <= 3) ? 1 : 2;
|
||||
constexpr index_t NumPrefetchK = (n0_loops <= 3) ? 1 : 2;
|
||||
|
||||
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
|
||||
static_assert(n0_loops >= NumPrefetchK, "Check failed!");
|
||||
|
||||
static_assert(k1_loops >= 2,
|
||||
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
|
||||
@@ -224,8 +224,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// only prefetch two k tiles to save vgprs consumption
|
||||
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
|
||||
|
||||
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
|
||||
k_tiles[i_k1] = load_tile(k_dram_window);
|
||||
static_for<0, NumPrefetchK, 1>{}([&](auto i_n0) {
|
||||
k_tiles[i_n0] = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
});
|
||||
|
||||
@@ -485,7 +485,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
static_for<min(NumPrefetchK, k1_loops), k1_loops, 1>{}([&](auto i_k1) {
|
||||
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
|
||||
// load v_tiles used in current iteration
|
||||
v_tiles[i_k1] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
|
||||
Reference in New Issue
Block a user