mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
upgrade prefill pipeline; simple iglp; consistent data produce and consume order
This commit is contained in:
@@ -857,47 +857,53 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
|
||||
// block_sync_lds();
|
||||
async_load_tile(k_lds_write_windows.at(I0), k_dram_window);
|
||||
async_load_tile(v_lds_write_windows.at(I0), v_dram_window);
|
||||
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
async_load_tile(k_lds_write_windows.at(I1), k_dram_window);
|
||||
|
||||
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_windows.at(I0));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto mainloop = [&](index_t cur_loop) {
|
||||
|
||||
auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I1) : k_lds_write_windows.at(I0);
|
||||
auto k_lds_read_window = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1);
|
||||
auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I0) : k_lds_write_windows.at(I1);
|
||||
auto k_lds_read_window_cur = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1);
|
||||
auto k_lds_read_window_next = (cur_loop%2 == 0)? k_lds_read_windows.at(I1) : k_lds_read_windows.at(I0);
|
||||
auto v_lds_write_window = (cur_loop%2 == 0)? v_lds_write_windows.at(I1) : v_lds_write_windows.at(I0);
|
||||
auto v_lds_read_window = (cur_loop%2 == 0)? v_lds_read_windows.at(I0) : v_lds_read_windows.at(I1);
|
||||
|
||||
// move V tile windows
|
||||
block_sync_lds();
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
if constexpr(1 < k0_loops)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
// loop over along the [K]ey head dimension
|
||||
move_tile_window(k_lds_read_window_cur, {0, kK0});
|
||||
auto k_tile_switch = load_tile(k_lds_read_window_cur);
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
// loop over along the [K]ey head dimension
|
||||
move_tile_window(k_lds_read_window, {0, kK0});
|
||||
k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
k_tile = k_tile_switch;
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)});
|
||||
move_tile_window(k_lds_read_window_cur, {0, -kK0 * (k0_loops - 1)});
|
||||
}
|
||||
|
||||
gemm_0(s_acc,
|
||||
@@ -905,6 +911,21 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
static_for<0, 14, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
|
||||
});
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
|
||||
});
|
||||
|
||||
if constexpr(kHasUnevenSplits)
|
||||
{
|
||||
@@ -1058,25 +1079,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
// Will insert unexpected vmcnt(0) here, probably the aliasing issue.
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
|
||||
if constexpr(1 < k1_loops)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
// loop over along the [V]alue Sequence length
|
||||
move_tile_window(v_lds_read_window, {kK1, 0});
|
||||
auto v_tile_switch = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_tile);
|
||||
|
||||
// loop over along the [V]alue Sequence length
|
||||
move_tile_window(v_lds_read_window, {kK1, 0});
|
||||
v_tile = load_tile_transpose(v_lds_read_window);
|
||||
v_tile = v_tile_switch;
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
|
||||
@@ -1087,20 +1107,26 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
sequence<0, (k1_loops - 1) * kK1>{},
|
||||
sequence<kM0, k1_loops * kK1>{}),
|
||||
v_tile);
|
||||
|
||||
k_tile = load_tile(k_lds_read_window_next);
|
||||
|
||||
static_for<0, 14, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
|
||||
});
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
|
||||
});
|
||||
};
|
||||
|
||||
do
|
||||
{
|
||||
mainloop(i_total_loops);
|
||||
i_total_loops++;
|
||||
// mainloop(I1, I0);
|
||||
// i_total_loops++;
|
||||
// if(i_total_loops == (num_total_loop))
|
||||
// {
|
||||
// continue;
|
||||
// }
|
||||
// mainloop(I0, I1);
|
||||
// i_total_loops++;
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
|
||||
@@ -127,7 +127,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -183,12 +183,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
// Read M first, then K
|
||||
// This is the same data consume order as BlockGEMM
|
||||
constexpr auto q_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -428,12 +430,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
// Read N first, then K
|
||||
// This is the same data consume order as BlockGEMM
|
||||
constexpr auto k_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -489,12 +493,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
// Read M first, then K
|
||||
// This is the same data consume order as BlockGEMM
|
||||
constexpr auto p_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -521,12 +527,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
// Read N first, then K
|
||||
// This is the same data consume order as BlockGEMM
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
|
||||
@@ -88,7 +88,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
@@ -120,7 +120,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
@@ -221,7 +221,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -229,7 +229,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
|
||||
Reference in New Issue
Block a user