Fix wrong fp8 QK/KV block gemm setting

This commit is contained in:
Po Yen Chen
2024-04-08 05:24:41 +00:00
parent 4e005f2457
commit 92d45d1681

View File

@@ -99,13 +99,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
}
else if constexpr(Problem::kIsFp8)
{
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
2>>{};
}
}();
@@ -223,13 +221,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
}
else if constexpr(Problem::kIsFp8)
{
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
typename Problem::KDataType>,
2,
swizzle_factor>>{};
2>>{};
}
}();
@@ -922,14 +918,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
auto warp_gemm = [&]() {
if constexpr(Problem::kIsFp8)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
return WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
// Problem::PDataType, typename Problem::VDataType>>>{};
typename Problem::VDataType>>>{};
}
else
{