mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Adjust the codes before the main-loop
This commit is contained in:
@@ -160,6 +160,19 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
@@ -177,6 +190,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
auto k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kK1, 0});
|
||||
|
||||
@@ -200,11 +215,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<QKVDataType*>(smem_ptr),
|
||||
@@ -222,15 +232,11 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
@@ -246,13 +252,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
}
|
||||
};
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
const auto num_loops = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
|
||||
Reference in New Issue
Block a user