mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 08:48:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#7537 (commit 07123f4)
[CK Tile] Fix Grouped Gemm quant mixed precision (#7537) <Migrate from Internal repo PR> test_ck_tile_grouped_gemm_quant_tensor would fail for mixed FP8/BF8 cases: std::tuple<Row, Col, Row, FP8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, std::tuple<Row, Col, Row, BF8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False> GFX1250 would fail with incorrect results, GFX950 would fail when compiling BF8+FP8 and give incorrect results for FP8+BF8. The issue is due to the wrong ComputeDataType selection. The fix is to consider original ADataType and BDataType even when ComputeDataType is not void. For compiling error on gfx950, the bf8, fp8, 16x16x32 warp Gemm is added.
This commit is contained in:
@@ -42,13 +42,16 @@ template <typename ComputeDataType, typename ADataType, typename BDataType>
|
||||
using mixed_prec_compute_type_t =
|
||||
typename detail::mixed_prec_compute_type<ComputeDataType, ADataType, BDataType>::type;
|
||||
|
||||
// Helper method to determine compute type, defaulting to input data type
|
||||
// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed,
|
||||
// ComputeDataType is used.
|
||||
// Determines the compute type for a given input, preferring ThisDataType when possible.
|
||||
// - If ThisDataType is packed: use OtherDataType, or ComputeDataType if both are packed.
|
||||
// - If ThisDataType is smaller than a non-packed OtherDataType: use ComputeDataType.
|
||||
// - Otherwise: use ThisDataType.
|
||||
template <typename ThisDataType, typename OtherDataType, typename ComputeDataType>
|
||||
using mixed_prec_compute_type_from_input_t = std::conditional_t<
|
||||
is_packed_type_v<ThisDataType>,
|
||||
std::conditional_t<is_packed_type_v<OtherDataType>, ComputeDataType, OtherDataType>,
|
||||
ThisDataType>;
|
||||
|
||||
std::conditional_t<(sizeof(ThisDataType) < sizeof(OtherDataType)) &&
|
||||
!is_packed_type_v<OtherDataType>,
|
||||
ComputeDataType,
|
||||
ThisDataType>>;
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -390,6 +390,9 @@ using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed =
|
||||
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
@@ -1680,6 +1680,9 @@ using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8 =
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
|
||||
|
||||
@@ -167,6 +167,7 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Ty
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
|
||||
@@ -33,8 +33,14 @@ struct GemmQuantPipelineProblemBase
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>,
|
||||
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>
|
||||
mixed_prec_compute_type_from_input_t<
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>,
|
||||
mixed_prec_compute_type_from_input_t<
|
||||
BDataType_,
|
||||
ADataType_,
|
||||
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>>
|
||||
{
|
||||
using Base = GemmPipelineProblemBase<
|
||||
ADataType_,
|
||||
@@ -42,8 +48,14 @@ struct GemmQuantPipelineProblemBase
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>,
|
||||
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>;
|
||||
mixed_prec_compute_type_from_input_t<
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>,
|
||||
mixed_prec_compute_type_from_input_t<
|
||||
BDataType_,
|
||||
ADataType_,
|
||||
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>>;
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
|
||||
@@ -27,7 +27,9 @@ using KernelTypes_Tensor = ::testing::Types<
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user