Adjust the codes before the main-loop

This commit is contained in:
Qianfeng Zhang
2025-05-19 11:24:59 +00:00
parent f411d676f2
commit 14ab6f154d

View File

@@ -160,6 +160,19 @@ struct HstuAttentionFwdPipelineQRKSVS
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
SaccBlockTileType sacc_tile;
PcompBlockTileType pcomp_tile;
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
@@ -177,6 +190,8 @@ struct HstuAttentionFwdPipelineQRKSVS
auto q_tile = load_tile(q_dram_window);
clear_tile(o_acc);
auto k_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
@@ -200,11 +215,6 @@ struct HstuAttentionFwdPipelineQRKSVS
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<QKVDataType*>(smem_ptr),
@@ -222,15 +232,11 @@ struct HstuAttentionFwdPipelineQRKSVS
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
SaccBlockTileType sacc_tile;
PcompBlockTileType pcomp_tile;
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// reduction function for softmax
const auto f_silu = [&](CompDataType& x) {
@@ -246,13 +252,6 @@ struct HstuAttentionFwdPipelineQRKSVS
}
};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
clear_tile(o_acc);
const auto num_loops = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do