From 5ff7497fa7a7bf7a5c0ec9e03104337c765fbb76 Mon Sep 17 00:00:00 2001 From: JiaLuo-CAN Date: Thu, 21 May 2026 11:36:23 -0400 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7537 (commit 07123f4) [CK Tile] Fix Grouped Gemm quant mixed precision (#7537) test_ck_tile_grouped_gemm_quant_tensor would fail for mixed FP8/BF8 cases: std::tuple, std::tuple GFX1250 would fail with incorrect results, GFX950 would fail when compiling BF8+FP8 and give incorrect results for FP8+BF8. The issue is due to the wrong ComputeDataType selection. The fix is to consider original ADataType and BDataType even when ComputeDataType is not void. For compiling error on gfx950, the bf8, fp8, 16x16x32 warp Gemm is added. --- .../core/utility/mixed_prec_compute_type.hpp | 13 +++++++----- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 3 +++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 3 +++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 1 + .../pipeline/gemm_quant_pipeline_problem.hpp | 20 +++++++++++++++---- .../test_grouped_gemm_quant_tensor.cpp | 4 +++- 6 files changed, 34 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/core/utility/mixed_prec_compute_type.hpp b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp index 021763c108..84ad45d277 100644 --- a/include/ck_tile/core/utility/mixed_prec_compute_type.hpp +++ b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp @@ -42,13 +42,16 @@ template using mixed_prec_compute_type_t = typename detail::mixed_prec_compute_type::type; -// Helper method to determine compute type, defaulting to input data type -// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed, -// ComputeDataType is used. +// Determines the compute type for a given input, preferring ThisDataType when possible. +// - If ThisDataType is packed: use OtherDataType, or ComputeDataType if both are packed. +// - If ThisDataType is smaller than a non-packed OtherDataType: use ComputeDataType. +// - Otherwise: use ThisDataType. template using mixed_prec_compute_type_from_input_t = std::conditional_t< is_packed_type_v, std::conditional_t, ComputeDataType, OtherDataType>, - ThisDataType>; - + std::conditional_t<(sizeof(ThisDataType) < sizeof(OtherDataType)) && + !is_packed_type_v, + ComputeDataType, + ThisDataType>>; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 45c53ffb63..c7b518161c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -390,6 +390,9 @@ using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed = using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; +using WarpGemmMfma_f32_16x16x32_bf8_fp8 = WarpGemmImpl< + WarpGemmAttributeMfma>>; + using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 138fcf230f..e47c868f45 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1680,6 +1680,9 @@ using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8 = template using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; +template +using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 71322cd62a..e93aad350e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -167,6 +167,7 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 84c2273f2d..3916def2d7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -33,8 +33,14 @@ struct GemmQuantPipelineProblemBase CDataType_, BlockGemmShape_, Traits_, - mixed_prec_compute_type_t, - mixed_prec_compute_type_t> + mixed_prec_compute_type_from_input_t< + ADataType_, + BDataType_, + mixed_prec_compute_type_t>, + mixed_prec_compute_type_from_input_t< + BDataType_, + ADataType_, + mixed_prec_compute_type_t>> { using Base = GemmPipelineProblemBase< ADataType_, @@ -42,8 +48,14 @@ struct GemmQuantPipelineProblemBase CDataType_, BlockGemmShape_, Traits_, - mixed_prec_compute_type_t, - mixed_prec_compute_type_t>; + mixed_prec_compute_type_from_input_t< + ADataType_, + BDataType_, + mixed_prec_compute_type_t>, + mixed_prec_compute_type_from_input_t< + BDataType_, + ADataType_, + mixed_prec_compute_type_t>>; using Traits = typename Base::Traits; diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp index e446f7b168..3b90b84e12 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp @@ -27,7 +27,9 @@ using KernelTypes_Tensor = ::testing::Types< std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False> + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, + std::tuple< Row, Col, Row, FP8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False> >; // clang-format on