Fix in hstu-attention pipeline (which makes some testing cases passed)

This commit is contained in:
Qianfeng Zhang
2025-04-08 15:53:08 +00:00
parent dbcf38aae9
commit dc2f72a09f
2 changed files with 22 additions and 3 deletions

View File

@@ -315,8 +315,7 @@ struct HstuAttentionFwdPipelineQRKSVS
if constexpr(i_k0 == 0)
clear_tile(s_acc);
if constexpr(i_k0 < k0_loops - 1)
k_tile = load_tile(k_dram_window);
k_tile = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 2)
move_tile_window(k_dram_window, {0, kK0});
@@ -389,6 +388,18 @@ struct HstuAttentionFwdPipelineQRKSVS
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
}
else
{
if(q_origin.at(number<0>{}) + kM0 > mask.max_uih_len)
{
const auto k_origin = k_dram_block_window.get_window_origin();
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask.IsTokenPairInsideMask(row, col);
});
};
};
auto s = cast_tile<CompDataType>(s_acc); // S{j}
@@ -404,6 +415,8 @@ struct HstuAttentionFwdPipelineQRKSVS
randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window);
}
// ensure gemm_0 has finished access of k-Lds for all warps
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0x7f);
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -426,6 +439,12 @@ struct HstuAttentionFwdPipelineQRKSVS
const auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, s));
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{

View File

@@ -305,7 +305,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
// leave some exclusive space so that the second v_lds buffer will nenver overlap with the first
// leave some exclusive space so that the second v_lds buffer will never overlap with the first
// k_lds bufffer
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()