[FMHA FWD] gfx950 Accuracy enhancement & bug fix (#2900)

* disable cast_tile_pk_fp16_fp32 on gfx950

* fix wrong encoding when hdim is not exponentiation of 2

---------

Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
Haocong WANG
2025-09-24 00:59:41 +08:00
committed by GitHub
parent 7b16782d7c
commit 959df2a155
2 changed files with 4 additions and 3 deletions

View File

@@ -813,7 +813,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
@@ -903,7 +904,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,