enable larger tile size; upgrade xor pattern

This commit is contained in:
aska-0096
2025-07-31 05:13:27 +00:00
parent 69890afc98
commit 75cba48682
4 changed files with 66 additions and 63 deletions

View File

@@ -831,22 +831,22 @@ struct FmhaFwdDecodeKernel
// TODO: Add kVHeadDim
// TrLoad Performed in 16x4/16x8/16x16 unit, the fast dimension is 16 elements
constexpr auto TrLoadFastDimLength = 16;
constexpr auto XorGroupSize = FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
const auto v_dram_unmerged = transform_tensor_view(
v_dram_pad,
make_tuple(make_pass_through_transform(length),
make_unmerge_transform(
make_tuple(number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{},
number<TrLoadFastDimLength>{}))),
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto v_dram_permuted = transform_tensor_view(
v_dram_unmerged,
make_tuple(make_xor_transform(make_tuple(
length, number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{})),
make_pass_through_transform(number<TrLoadFastDimLength>{})),
length, number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
@@ -854,8 +854,8 @@ struct FmhaFwdDecodeKernel
v_dram_permuted,
make_tuple(make_pass_through_transform(length),
make_merge_transform_v3_division_mod(
make_tuple(number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{},
number<TrLoadFastDimLength>{}))),
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
};

View File

@@ -865,11 +865,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
auto mainloop = [&](auto lds_write_buf, auto lds_read_buf) {
auto k_lds_write_window = k_lds_write_windows.at(lds_write_buf);
auto k_lds_read_window = k_lds_read_windows.at(lds_read_buf);
auto v_lds_write_window = v_lds_write_windows.at(lds_write_buf);
auto v_lds_read_window = v_lds_read_windows.at(lds_read_buf);
auto mainloop = [&](index_t cur_loop) {
auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I1) : k_lds_write_windows.at(I0);
auto k_lds_read_window = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1);
auto v_lds_write_window = (cur_loop%2 == 0)? v_lds_write_windows.at(I1) : v_lds_write_windows.at(I0);
auto v_lds_read_window = (cur_loop%2 == 0)? v_lds_read_windows.at(I0) : v_lds_read_windows.at(I1);
block_sync_lds();
// move K tile windows
@@ -1090,14 +1091,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
do
{
mainloop(I1, I0);
i_total_loops++;
if(i_total_loops == (num_total_loop))
{
continue;
}
mainloop(I0, I1);
mainloop(i_total_loops);
i_total_loops++;
// mainloop(I1, I0);
// i_total_loops++;
// if(i_total_loops == (num_total_loop))
// {
// continue;
// }
// mainloop(I0, I1);
// i_total_loops++;
} while(i_total_loops < num_total_loop);
if constexpr(kStoreLSE)

View File

@@ -290,7 +290,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr auto v_lds_block_desc = [&]() {
if constexpr(Xor)
{
constexpr auto TrLoadFastDimLength = 16;
constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
@@ -302,8 +302,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
v_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
make_unmerge_transform(
make_tuple(number<kNPerBlock / TrLoadFastDimLength>{},
number<TrLoadFastDimLength>{}))),
make_tuple(number<kNPerBlock / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
@@ -311,8 +311,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
v_lds_block_desc_unmerged,
make_tuple(
make_xor_transform(make_tuple(number<kKPerBlock>{},
number<kNPerBlock / TrLoadFastDimLength>{})),
make_pass_through_transform(number<TrLoadFastDimLength>{})),
number<kNPerBlock / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
@@ -320,8 +320,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
v_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<kNPerBlock / TrLoadFastDimLength>{},
number<TrLoadFastDimLength>{}))),
make_tuple(number<kNPerBlock / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}