mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Fix wrong fp8 QK/KV block gemm setting
This commit is contained in:
@@ -99,13 +99,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
}
|
||||
else if constexpr(Problem::kIsFp8)
|
||||
{
|
||||
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
2,
|
||||
swizzle_factor>>{};
|
||||
2>>{};
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -223,13 +221,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
}
|
||||
else if constexpr(Problem::kIsFp8)
|
||||
{
|
||||
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
2,
|
||||
swizzle_factor>>{};
|
||||
2>>{};
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -922,14 +918,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(Problem::kIsFp8)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
|
||||
typename Problem::VDataType>,
|
||||
2>>{};
|
||||
// return
|
||||
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
// Problem::PDataType, typename Problem::VDataType>>>{};
|
||||
typename Problem::VDataType>>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user