mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Fix in hstu-attention pipeline (which makes some testing cases passed)
This commit is contained in:
@@ -315,8 +315,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
if constexpr(i_k0 == 0)
|
||||
clear_tile(s_acc);
|
||||
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
k_tile = load_tile(k_dram_window);
|
||||
k_tile = load_tile(k_dram_window);
|
||||
if constexpr(i_k0 < k0_loops - 2)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
@@ -389,6 +388,18 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(q_origin.at(number<0>{}) + kM0 > mask.max_uih_len)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto s = cast_tile<CompDataType>(s_acc); // S{j}
|
||||
@@ -404,6 +415,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window);
|
||||
}
|
||||
|
||||
// ensure gemm_0 has finished access of k-Lds for all warps
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0x7f);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -426,6 +439,12 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
const auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, s));
|
||||
|
||||
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
|
||||
k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
|
||||
@@ -305,7 +305,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// leave some exclusive space so that the second v_lds buffer will nenver overlap with the first
|
||||
// leave some exclusive space so that the second v_lds buffer will never overlap with the first
|
||||
// k_lds bufffer
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
|
||||
|
||||
Reference in New Issue
Block a user