diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 607ee70020..73b84594e7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -685,7 +685,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + auto p_cast = + cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA(p, p_cast); +#else + const auto p = p_cast; +#endif __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index e5e9e2333a..bc54f75e06 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -145,10 +145,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() { +#if defined(__gfx11__) + // gfx11 WMMA V loads expect the LDS K-pack to match the warp GEMM K-per-thread; + // clamping to 8 under-reserves LDS padding for K-per-thread 16 variants. + return GetKVWarpGemmKPerThreadSize(); +#else if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) return 8; else return 4; +#endif } template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 9fc3652f51..4eb5eb291a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -512,8 +512,17 @@ struct BlockFmhaPipelineQSKSVS block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif __builtin_amdgcn_sched_barrier(0);