small refactor

This commit is contained in:
aska-0096
2025-08-01 10:44:54 +00:00
parent a28b6e67fe
commit 2d4e73d2b4
5 changed files with 126 additions and 154 deletions

View File

@@ -831,7 +831,8 @@ struct FmhaFwdDecodeKernel
// TODO: Add kVHeadDim
// TrLoad Performed in 16x4/16x8/16x16 unit, the fast dimension is 16 elements
constexpr auto XorGroupSize = FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
constexpr auto XorGroupSize =
FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
const auto v_dram_unmerged = transform_tensor_view(
v_dram_pad,

View File

@@ -758,41 +758,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution<Problem, true>());
auto k_lds_write_view =
make_tuple(make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk0),
Policy::template MakeKLdsBlockDescriptor<Problem, true>()),
make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk1),
Policy::template MakeKLdsBlockDescriptor<Problem, true>()));
auto k_lds_write_view = make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk0),
Policy::template MakeKLdsBlockDescriptor<Problem, true>());
auto k_lds_read_view =
make_tuple(make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk0),
Policy::template MakeKLdsBlockDescriptor<Problem, true, true>()),
make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk1),
Policy::template MakeKLdsBlockDescriptor<Problem, true, true>()));
auto k_lds_read_view = make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk0),
Policy::template MakeKLdsBlockDescriptor<Problem, true, true>());
auto k_lds_write_windows =
make_tuple(make_tile_window(
k_lds_write_view.at(I0),
Policy::template MakeKLdsBlockDescriptor<Problem, true>().get_lengths(),
{0, 0}),
make_tile_window(
k_lds_write_view.at(I1),
Policy::template MakeKLdsBlockDescriptor<Problem, true>().get_lengths(),
{0, 0}));
auto k_lds_write_window =
make_tile_window(k_lds_write_view,
Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0});
auto k_lds_read_windows =
make_tuple(make_tile_window(k_lds_read_view.at(I0),
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKRegTileDistribution<Problem>()),
make_tile_window(k_lds_read_view.at(I1),
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKRegTileDistribution<Problem>()));
auto k_lds_read_window =
make_tile_window(k_lds_read_view,
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKRegTileDistribution<Problem>());
// S tile in LDS
auto s_lds = make_tensor_view<address_space_enum::lds>(
@@ -811,39 +794,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution<Problem>());
auto v_lds_write_view = make_tuple(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
Policy::template MakeVLdsBlockDescriptor<Problem>()),
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv1)),
Policy::template MakeVLdsBlockDescriptor<Problem>()));
auto v_lds_write_view = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_read_view = make_tuple(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
Policy::template MakeVLdsBlockDescriptor<Problem, true>()),
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv1)),
Policy::template MakeVLdsBlockDescriptor<Problem, true>()));
auto v_lds_read_view = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
Policy::template MakeVLdsBlockDescriptor<Problem, true>());
auto v_lds_write_windows = make_tuple(
make_tile_window(v_lds_write_view.at(I0),
auto v_lds_write_window =
make_tile_window(v_lds_write_view,
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0}),
make_tile_window(v_lds_write_view.at(I1),
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0}));
{0, 0});
auto v_lds_read_windows =
make_tuple(make_tile_window(v_lds_read_view.at(I0),
make_tuple(number<kK1>{}, number<kN1>{}),
{0, 0},
Policy::template MakeVRegTileDistribution<Problem>()),
make_tile_window(v_lds_read_view.at(I1),
make_tuple(number<kK1>{}, number<kN1>{}),
{0, 0},
Policy::template MakeVRegTileDistribution<Problem>()));
auto v_lds_read_window =
make_tile_window(v_lds_read_view,
make_tuple(number<kK1>{}, number<kN1>{}),
{0, 0},
Policy::template MakeVRegTileDistribution<Problem>());
// block_sync_lds_direct_load<0>();
// auto q_tile = load_tile(q_lds_read_window);
@@ -857,31 +825,41 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
async_load_tile(k_lds_write_windows.at(I0), k_dram_window);
async_load_tile(v_lds_write_windows.at(I0), v_dram_window);
async_load_tile(k_lds_write_window, k_dram_window);
async_load_tile(v_lds_write_window, v_dram_window);
move_tile_window(k_dram_window, {kN0, 0});
async_load_tile(k_lds_write_windows.at(I1), k_dram_window);
k_lds_write_window.set_bottom_tensor_view_data_ptr(
static_cast<KDataType* __restrict__>(smem_ptrk1));
async_load_tile(k_lds_write_window, k_dram_window);
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access();
constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access();
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
auto k_tile = load_tile(k_lds_read_windows.at(I0));
auto k_tile = load_tile(k_lds_read_window);
__builtin_amdgcn_sched_barrier(0);
auto mainloop = [&](index_t cur_loop) {
auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I0) : k_lds_write_windows.at(I1);
auto k_lds_read_window_cur = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1);
auto k_lds_read_window_next = (cur_loop%2 == 0)? k_lds_read_windows.at(I1) : k_lds_read_windows.at(I0);
auto v_lds_write_window = (cur_loop%2 == 0)? v_lds_write_windows.at(I1) : v_lds_write_windows.at(I0);
auto v_lds_read_window = (cur_loop%2 == 0)? v_lds_read_windows.at(I0) : v_lds_read_windows.at(I1);
const bool is_even_loop = (cur_loop % 2 == 0);
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
: static_cast<KDataType* __restrict__>(smem_ptrk0);
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
// move V tile windows
block_sync_lds();
block_sync_lds<k_lds_insts>();
move_tile_window(v_dram_window, {kN0, 0});
v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr);
async_load_tile(v_lds_write_window, v_dram_window);
// STAGE 1, QK gemm
@@ -891,19 +869,19 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
// loop over along the [K]ey head dimension
move_tile_window(k_lds_read_window_cur, {0, kK0});
auto k_tile_switch = load_tile(k_lds_read_window_cur);
move_tile_window(k_lds_read_window, {0, kK0});
auto k_tile_switch = load_tile(k_lds_read_window);
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_tile);
k_tile = k_tile_switch;
});
// move back to the origin
move_tile_window(k_lds_read_window_cur, {0, -kK0 * (k0_loops - 1)});
move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)});
}
gemm_0(s_acc,
@@ -911,21 +889,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_tile);
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
auto v_tile = load_tile_transpose(v_lds_read_window);
static_for<0, 14, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
});
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
});
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr);
auto v_tile = load_tile_transpose(v_lds_read_window);
if constexpr(kHasUnevenSplits)
{
@@ -991,6 +958,18 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
static_for<0, 12, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
});
static_for<0, 4, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
@@ -1079,8 +1058,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
});
});
block_sync_lds();
block_sync_lds<v_lds_insts>();
move_tile_window(k_dram_window, {kN0, 0});
k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr);
async_load_tile(k_lds_write_window, k_dram_window);
if constexpr(1 < k1_loops)
@@ -1108,15 +1088,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
sequence<kM0, k1_loops * kK1>{}),
v_tile);
k_tile = load_tile(k_lds_read_window_next);
static_for<0, 14, 1>{}([&](auto i) {
k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr);
k_tile = load_tile(k_lds_read_window);
static_for<0, 12, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
});
static_for<0, 2, 1>{}([&](auto i) {
static_for<0, 4, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ

View File

@@ -292,7 +292,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr auto v_lds_block_desc = [&]() {
if constexpr(Xor)
{
constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
constexpr auto XorGroupSize =
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
@@ -303,27 +304,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
const auto v_lds_block_desc_unmerged = transform_tensor_descriptor(
v_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
make_unmerge_transform(
make_tuple(number<kNPerBlock / XorGroupSize>{},
number<XorGroupSize>{}))),
make_unmerge_transform(make_tuple(
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto v_lds_block_desc_permuted = transform_tensor_descriptor(
v_lds_block_desc_unmerged,
make_tuple(
make_xor_transform(make_tuple(number<kKPerBlock>{},
number<kNPerBlock / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(make_xor_transform(make_tuple(number<kKPerBlock>{},
number<kNPerBlock / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
v_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<kNPerBlock / XorGroupSize>{},
number<XorGroupSize>{}))),
make_merge_transform_v3_division_mod(make_tuple(
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}