From 92d45d168110d9386b0122194daf4bcb0637a58d Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 8 Apr 2024 05:24:41 +0000 Subject: [PATCH] Fix wrong fp8 QK/KV block gemm setting --- ...ock_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 712c0ca2c9..a64645b123 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -99,13 +99,11 @@ struct BlockFmhaPipelineQXCustomPolicy } else if constexpr(Problem::kIsFp8) { - constexpr index_t swizzle_factor = 4; // TODO: hard coded here return WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base, - 2, - swizzle_factor>>{}; + 2>>{}; } }(); @@ -223,13 +221,11 @@ struct BlockFmhaPipelineQXCustomPolicy } else if constexpr(Problem::kIsFp8) { - constexpr index_t swizzle_factor = 4; // TODO: hard coded here return WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base, - 2, - swizzle_factor>>{}; + 2>>{}; } }(); @@ -922,14 +918,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, - 2>>{}; - // return - // WarpGemmImpl>>{}; + typename Problem::VDataType>>>{}; } else {