upgrade prefill pipeline; simple iglp; consistent data produce and consume order

This commit is contained in:
aska-0096
2025-07-31 10:25:37 +00:00
parent 75cba48682
commit a28b6e67fe
4 changed files with 111 additions and 77 deletions

View File

@@ -857,47 +857,53 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// block_sync_lds();
async_load_tile(k_lds_write_windows.at(I0), k_dram_window);
async_load_tile(v_lds_write_windows.at(I0), v_dram_window);
move_tile_window(k_dram_window, {kN0, 0});
async_load_tile(k_lds_write_windows.at(I1), k_dram_window);
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
auto k_tile = load_tile(k_lds_read_windows.at(I0));
__builtin_amdgcn_sched_barrier(0);
auto mainloop = [&](index_t cur_loop) {
auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I1) : k_lds_write_windows.at(I0);
auto k_lds_read_window = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1);
auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I0) : k_lds_write_windows.at(I1);
auto k_lds_read_window_cur = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1);
auto k_lds_read_window_next = (cur_loop%2 == 0)? k_lds_read_windows.at(I1) : k_lds_read_windows.at(I0);
auto v_lds_write_window = (cur_loop%2 == 0)? v_lds_write_windows.at(I1) : v_lds_write_windows.at(I0);
auto v_lds_read_window = (cur_loop%2 == 0)? v_lds_read_windows.at(I0) : v_lds_read_windows.at(I1);
// move V tile windows
block_sync_lds();
// move K tile windows
move_tile_window(k_dram_window, {kN0, 0});
async_load_tile(k_lds_write_window, k_dram_window);
move_tile_window(v_dram_window, {kN0, 0});
async_load_tile(v_lds_write_window, v_dram_window);
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
auto k_tile = load_tile(k_lds_read_window);
if constexpr(1 < k0_loops)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
// loop over along the [K]ey head dimension
move_tile_window(k_lds_read_window_cur, {0, kK0});
auto k_tile_switch = load_tile(k_lds_read_window_cur);
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_tile);
// loop over along the [K]ey head dimension
move_tile_window(k_lds_read_window, {0, kK0});
k_tile = load_tile(k_lds_read_window);
k_tile = k_tile_switch;
});
// move back to the origin
move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)});
move_tile_window(k_lds_read_window_cur, {0, -kK0 * (k0_loops - 1)});
}
gemm_0(s_acc,
@@ -905,6 +911,21 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_tile);
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
auto v_tile = load_tile_transpose(v_lds_read_window);
static_for<0, 14, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
});
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
});
if constexpr(kHasUnevenSplits)
{
@@ -1058,25 +1079,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
});
});
move_tile_window(v_dram_window, {kN0, 0});
async_load_tile(v_lds_write_window, v_dram_window);
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
// Will insert unexpected vmcnt(0) here, probably the aliasing issue.
auto v_tile = load_tile_transpose(v_lds_read_window);
block_sync_lds();
move_tile_window(k_dram_window, {kN0, 0});
async_load_tile(k_lds_write_window, k_dram_window);
if constexpr(1 < k1_loops)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
// loop over along the [V]alue Sequence length
move_tile_window(v_lds_read_window, {kK1, 0});
auto v_tile_switch = load_tile_transpose(v_lds_read_window);
gemm_1(o_acc,
get_slice_tile(p_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{}),
v_tile);
// loop over along the [V]alue Sequence length
move_tile_window(v_lds_read_window, {kK1, 0});
v_tile = load_tile_transpose(v_lds_read_window);
v_tile = v_tile_switch;
});
// move back to the origin
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
@@ -1087,20 +1107,26 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
sequence<0, (k1_loops - 1) * kK1>{},
sequence<kM0, k1_loops * kK1>{}),
v_tile);
k_tile = load_tile(k_lds_read_window_next);
static_for<0, 14, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
});
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
});
};
do
{
mainloop(i_total_loops);
i_total_loops++;
// mainloop(I1, I0);
// i_total_loops++;
// if(i_total_loops == (num_total_loop))
// {
// continue;
// }
// mainloop(I0, I1);
// i_total_loops++;
} while(i_total_loops < num_total_loop);
if constexpr(kStoreLSE)

View File

@@ -127,7 +127,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -183,12 +183,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
// Read M first, then K
// This is the same data consume order as BlockGEMM
constexpr auto q_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -428,12 +430,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
// Read N first, then K
// This is the same data consume order as BlockGEMM
constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -489,12 +493,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
// Read M first, then K
// This is the same data consume order as BlockGEMM
constexpr auto p_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -521,12 +527,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
// Read N first, then K
// This is the same data consume order as BlockGEMM
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(

View File

@@ -88,7 +88,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
@@ -120,7 +120,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
@@ -221,7 +221,7 @@ struct BlockGemmARegBRegCRegV1
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
@@ -229,7 +229,7 @@ struct BlockGemmARegBRegCRegV1
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor