[CK_TILE] FMHA FWD bug fix (#2888)

* tempsave debug

* fix the bug in fmha fwd_kernel

* Remove unnecessary changes

* Fix the buggy part

* remove fmha fwd known failure cases
This commit is contained in:
Haocong WANG
2025-09-23 15:00:46 +08:00
committed by GitHub
parent ad259eeae2
commit b6e8994386
3 changed files with 24 additions and 20 deletions

View File

@@ -1767,6 +1767,9 @@ struct FmhaFwdKernel
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
constexpr auto kDramTileK =
FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
@@ -1835,32 +1838,36 @@ struct FmhaFwdKernel
{
const auto k_dram_unmerged = transform_tensor_view(
k_dram_pad,
make_tuple(
make_pass_through_transform(height),
make_unmerge_transform(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(make_pass_through_transform(height),
make_unmerge_transform(
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
FmhaPipeline::kAlignmentK>{},
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}));
const auto k_dram_permuted = transform_tensor_view(
k_dram_unmerged,
make_tuple(
make_xor_transform(make_tuple(
height,
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
height, number<kDramTileK / FmhaPipeline::kAlignmentK>{})),
make_pass_through_transform(
number<FmhaPipeline::kQKHeaddim / kDramTileK /
FmhaPipeline::kAlignmentK>{}),
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
return transform_tensor_view(
k_dram_permuted,
make_tuple(
make_pass_through_transform(height),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(make_pass_through_transform(height),
make_merge_transform_v3_division_mod(
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
FmhaPipeline::kAlignmentK>{},
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
};