mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] FMHA FWD bug fix (#2888)
* tempsave debug * fix the bug in fmha fwd_kernel * Remove unnecessary changes * Fix the buggy part * remove fmha fwd known failure cases
This commit is contained in:
@@ -1767,6 +1767,9 @@ struct FmhaFwdKernel
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
constexpr auto kDramTileK =
|
||||
FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
|
||||
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
@@ -1835,32 +1838,36 @@ struct FmhaFwdKernel
|
||||
{
|
||||
const auto k_dram_unmerged = transform_tensor_view(
|
||||
k_dram_pad,
|
||||
make_tuple(
|
||||
make_pass_through_transform(height),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{},
|
||||
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}));
|
||||
|
||||
const auto k_dram_permuted = transform_tensor_view(
|
||||
k_dram_unmerged,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(
|
||||
height,
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
|
||||
height, number<kDramTileK / FmhaPipeline::kAlignmentK>{})),
|
||||
make_pass_through_transform(
|
||||
number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{}),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
k_dram_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(height),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{},
|
||||
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user