mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
enable larger tile size; upgrade xor pattern
This commit is contained in:
@@ -38,8 +38,8 @@ K0_MAX_SUBMAX_MAP = {
|
||||
}
|
||||
|
||||
SEQLENQ_MAP = {
|
||||
"16" : "16",
|
||||
"32" : "32",
|
||||
# "16" : "16",
|
||||
# "32" : "32",
|
||||
# "64" : "64"
|
||||
"128" : "128",
|
||||
}
|
||||
@@ -132,18 +132,18 @@ using trait_{F_idx} = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_
|
||||
namespace {{
|
||||
template <bool kHasUnevenSplits>
|
||||
void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{
|
||||
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
||||
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
||||
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
||||
instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
|
||||
}} else {{
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
// instance<kHasUnevenSplits>::run(s, a);
|
||||
//if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||
// && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
||||
// || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
||||
// if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
||||
// instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
|
||||
// }} else {{
|
||||
// instance<kHasUnevenSplits>::run(s, a);
|
||||
// }}
|
||||
//}} else {{
|
||||
// instance<kHasUnevenSplits>::run(s, a);
|
||||
//}}
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
}} // anonymous namespace
|
||||
|
||||
@@ -152,20 +152,20 @@ void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{
|
||||
template<>
|
||||
void fmha_fwd_decode_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_decode_args a)
|
||||
{{
|
||||
if constexpr({F_mode} == false) {{ // batch mode
|
||||
// we don't check every seqlen_k values for kvcache
|
||||
if (a.seqlen_k_ptr != nullptr) {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
// make sure F_bn0 is divisible by F_bk1
|
||||
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
|
||||
run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
}} else {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
}}
|
||||
// run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
//if constexpr({F_mode} == false) {{ // batch mode
|
||||
// // we don't check every seqlen_k values for kvcache
|
||||
// if (a.seqlen_k_ptr != nullptr) {{
|
||||
// run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
// // make sure F_bn0 is divisible by F_bk1
|
||||
// }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
|
||||
// run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
// }} else {{
|
||||
// run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
// }}
|
||||
//}} else {{
|
||||
// run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
//}}
|
||||
run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
}}
|
||||
|
||||
template<>
|
||||
@@ -658,16 +658,16 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'64': {
|
||||
# Specialize for different SeqQ
|
||||
'16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
},
|
||||
# '64': {
|
||||
# # Specialize for different SeqQ
|
||||
# '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# },
|
||||
'128': {
|
||||
'16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128': FmhaFwdTileSize(128, 32, 128, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128': FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
},
|
||||
}
|
||||
else:
|
||||
|
||||
@@ -831,22 +831,22 @@ struct FmhaFwdDecodeKernel
|
||||
|
||||
// TODO: Add kVHeadDim
|
||||
// TrLoad Performed in 16x4/16x8/16x16 unit, the fast dimension is 16 elements
|
||||
constexpr auto TrLoadFastDimLength = 16;
|
||||
constexpr auto XorGroupSize = FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
const auto v_dram_unmerged = transform_tensor_view(
|
||||
v_dram_pad,
|
||||
make_tuple(make_pass_through_transform(length),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{},
|
||||
number<TrLoadFastDimLength>{}))),
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto v_dram_permuted = transform_tensor_view(
|
||||
v_dram_unmerged,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
length, number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{})),
|
||||
make_pass_through_transform(number<TrLoadFastDimLength>{})),
|
||||
length, number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
@@ -854,8 +854,8 @@ struct FmhaFwdDecodeKernel
|
||||
v_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(length),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{},
|
||||
number<TrLoadFastDimLength>{}))),
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
};
|
||||
|
||||
@@ -865,11 +865,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
auto mainloop = [&](auto lds_write_buf, auto lds_read_buf) {
|
||||
auto k_lds_write_window = k_lds_write_windows.at(lds_write_buf);
|
||||
auto k_lds_read_window = k_lds_read_windows.at(lds_read_buf);
|
||||
auto v_lds_write_window = v_lds_write_windows.at(lds_write_buf);
|
||||
auto v_lds_read_window = v_lds_read_windows.at(lds_read_buf);
|
||||
auto mainloop = [&](index_t cur_loop) {
|
||||
|
||||
auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I1) : k_lds_write_windows.at(I0);
|
||||
auto k_lds_read_window = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1);
|
||||
auto v_lds_write_window = (cur_loop%2 == 0)? v_lds_write_windows.at(I1) : v_lds_write_windows.at(I0);
|
||||
auto v_lds_read_window = (cur_loop%2 == 0)? v_lds_read_windows.at(I0) : v_lds_read_windows.at(I1);
|
||||
|
||||
block_sync_lds();
|
||||
// move K tile windows
|
||||
@@ -1090,14 +1091,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
do
|
||||
{
|
||||
mainloop(I1, I0);
|
||||
i_total_loops++;
|
||||
if(i_total_loops == (num_total_loop))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
mainloop(I0, I1);
|
||||
mainloop(i_total_loops);
|
||||
i_total_loops++;
|
||||
// mainloop(I1, I0);
|
||||
// i_total_loops++;
|
||||
// if(i_total_loops == (num_total_loop))
|
||||
// {
|
||||
// continue;
|
||||
// }
|
||||
// mainloop(I0, I1);
|
||||
// i_total_loops++;
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
|
||||
@@ -290,7 +290,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr auto v_lds_block_desc = [&]() {
|
||||
if constexpr(Xor)
|
||||
{
|
||||
constexpr auto TrLoadFastDimLength = 16;
|
||||
constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
|
||||
@@ -302,8 +302,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
v_lds_block_desc_naive,
|
||||
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kNPerBlock / TrLoadFastDimLength>{},
|
||||
number<TrLoadFastDimLength>{}))),
|
||||
make_tuple(number<kNPerBlock / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
@@ -311,8 +311,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
v_lds_block_desc_unmerged,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kKPerBlock>{},
|
||||
number<kNPerBlock / TrLoadFastDimLength>{})),
|
||||
make_pass_through_transform(number<TrLoadFastDimLength>{})),
|
||||
number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
@@ -320,8 +320,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
v_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kNPerBlock / TrLoadFastDimLength>{},
|
||||
number<TrLoadFastDimLength>{}))),
|
||||
make_tuple(number<kNPerBlock / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user