diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index adc24943e6..02f1948e17 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -1088,7 +1088,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc); } - const auto p = [&]() { + const auto p_cast = [&]() { const auto bias_tile = load_tile(bias_dram_window); // load bias tile // STAGE 2, scale_s, add bias, mask, softmax @@ -1423,6 +1423,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync tile_elementwise_in(p_compute_element_func, p_compute)); #endif }(); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA(p, p_cast); +#else + const auto p = p_cast; +#endif // STAGE 3, KV gemm // KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7b97d01fa4..277cb5b3e0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -762,7 +762,7 @@ struct BlockFmhaPipelineQRKSVSAsync randval_ptr, seq_offset, p_compute, randval_dram_window); } - const auto p = [&]() { + const auto p_cast = [&]() { #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN // For fp32 to fp16, // impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue, @@ -777,6 +777,15 @@ struct BlockFmhaPipelineQRKSVSAsync tile_elementwise_in(p_compute_element_func, p_compute)); #endif }(); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA(p, p_cast); +#else + const auto p = p_cast; +#endif float v_descale = 1.0f; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 607ee70020..73b84594e7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -685,7 +685,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + auto p_cast = + cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA(p, p_cast); +#else + const auto p = p_cast; +#endif __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index e5e9e2333a..bc54f75e06 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -145,10 +145,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() { +#if defined(__gfx11__) + // gfx11 WMMA V loads expect the LDS K-pack to match the warp GEMM K-per-thread; + // clamping to 8 under-reserves LDS padding for K-per-thread 16 variants. + return GetKVWarpGemmKPerThreadSize(); +#else if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) return 8; else return 4; +#endif } template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp index 95f68623fa..cb6dd09230 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp @@ -712,7 +712,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad __builtin_amdgcn_sched_barrier(0x00000001); - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + auto p_cast = + cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA(p, p_cast); +#else + const auto p = p_cast; +#endif __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 9fc3652f51..4eb5eb291a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -512,8 +512,17 @@ struct BlockFmhaPipelineQSKSVS block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif __builtin_amdgcn_sched_barrier(0);