diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 712c0ca2c9..a64645b123 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -99,13 +99,11 @@ struct BlockFmhaPipelineQXCustomPolicy } else if constexpr(Problem::kIsFp8) { - constexpr index_t swizzle_factor = 4; // TODO: hard coded here return WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base, - 2, - swizzle_factor>>{}; + 2>>{}; } }(); @@ -223,13 +221,11 @@ struct BlockFmhaPipelineQXCustomPolicy } else if constexpr(Problem::kIsFp8) { - constexpr index_t swizzle_factor = 4; // TODO: hard coded here return WarpGemmImpl< WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base, - 2, - swizzle_factor>>{}; + 2>>{}; } }(); @@ -922,14 +918,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, - 2>>{}; - // return - // WarpGemmImpl>>{}; + typename Problem::VDataType>>>{}; } else {