[CK_TILE] FMHA FWD bug fix (#2888)

* tempsave debug

* fix the bug in fmha fwd_kernel

* Remove unnecessary changes

* Fix the buggy part

* remove fmha fwd known failure cases

[ROCm/composable_kernel commit: b6e8994386]
This commit is contained in:
Haocong WANG
2025-09-23 15:00:46 +08:00
committed by GitHub
parent bfa145c418
commit d85ca87d97
3 changed files with 24 additions and 20 deletions

View File

@@ -1,4 +0,0 @@
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1

View File

@@ -1767,6 +1767,9 @@ struct FmhaFwdKernel
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
constexpr auto kDramTileK =
FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
@@ -1835,32 +1838,36 @@ struct FmhaFwdKernel
{
const auto k_dram_unmerged = transform_tensor_view(
k_dram_pad,
make_tuple(
make_pass_through_transform(height),
make_unmerge_transform(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(make_pass_through_transform(height),
make_unmerge_transform(
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
FmhaPipeline::kAlignmentK>{},
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}));
const auto k_dram_permuted = transform_tensor_view(
k_dram_unmerged,
make_tuple(
make_xor_transform(make_tuple(
height,
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
height, number<kDramTileK / FmhaPipeline::kAlignmentK>{})),
make_pass_through_transform(
number<FmhaPipeline::kQKHeaddim / kDramTileK /
FmhaPipeline::kAlignmentK>{}),
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
return transform_tensor_view(
k_dram_permuted,
make_tuple(
make_pass_through_transform(height),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(make_pass_through_transform(height),
make_merge_transform_v3_division_mod(
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
FmhaPipeline::kAlignmentK>{},
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
};

View File

@@ -37,6 +37,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 >= 64;
static constexpr index_t kBlockSize = Problem::kBlockSize;