Further correction with regard to using n0_loops and k1_loops

This commit is contained in:
Qianfeng Zhang
2025-12-08 15:00:21 +00:00
parent 641dae10e8
commit 8640ffe8eb
3 changed files with 18 additions and 11 deletions

View File

@@ -460,7 +460,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};

View File

@@ -224,13 +224,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
constexpr index_t NumPrefetchK = 2;
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
static_assert(n0_loops >= NumPrefetchK, "Check failed!");
// only prefetch two k tiles to save vgprs consumption
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
static_for<0, NumPrefetchK, 1>{}([&](auto i_n0) {
k_tiles[i_n0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
});
@@ -509,7 +509,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
@@ -519,7 +519,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<min(NumPrefetchK, k1_loops), k1_loops, 1>{}([&](auto i_k1) {
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
@@ -597,6 +597,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
// k1_loops >= 2 required
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
// check whether second V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(
v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index);

View File

@@ -214,9 +214,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
using k_tile_type = decltype(load_tile(k_dram_window));
constexpr index_t NumPrefetchK = (k1_loops <= 3) ? 1 : 2;
constexpr index_t NumPrefetchK = (n0_loops <= 3) ? 1 : 2;
static_assert(k1_loops >= NumPrefetchK, "Check failed!");
static_assert(n0_loops >= NumPrefetchK, "Check failed!");
static_assert(k1_loops >= 2,
"k1_loops >= 2 required due to pre-storing two v_tiles to Lds");
@@ -224,8 +224,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
// only prefetch two k tiles to save vgprs consumption
statically_indexed_array<k_tile_type, NumPrefetchK> k_tiles;
static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
static_for<0, NumPrefetchK, 1>{}([&](auto i_n0) {
k_tiles[i_n0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
});
@@ -485,7 +485,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<min(NumPrefetchK, k1_loops), k1_loops, 1>{}([&](auto i_k1) {
static_for<NumPrefetchK, k1_loops, 1>{}([&](auto i_k1) {
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});