[rocm-libraries] ROCm/rocm-libraries#7537 (commit 07123f4)

[CK Tile] Fix Grouped Gemm quant mixed precision (#7537)

<Migrate from Internal repo PR>
test_ck_tile_grouped_gemm_quant_tensor would fail for mixed FP8/BF8
cases:
std::tuple<Row, Col, Row, FP8, F32, BF8, F32, F32, F16, TensorQuant,
False, True, False>,
std::tuple<Row, Col, Row, BF8, F32, FP8, F32, F32, F16, TensorQuant,
False, True, False>

GFX1250 would fail with incorrect results, GFX950 would fail when
compiling BF8+FP8 and give incorrect results for FP8+BF8.
The issue is due to the wrong ComputeDataType selection.
The fix is to consider original ADataType and BDataType even when
ComputeDataType is not void. For compiling error on gfx950, the bf8,
fp8, 16x16x32 warp Gemm is added.
This commit is contained in:
JiaLuo-CAN
2026-05-21 11:36:23 -04:00
committed by GitHub
parent 309d823056
commit 5ff7497fa7
6 changed files with 34 additions and 10 deletions

View File

@@ -42,13 +42,16 @@ template <typename ComputeDataType, typename ADataType, typename BDataType>
using mixed_prec_compute_type_t =
typename detail::mixed_prec_compute_type<ComputeDataType, ADataType, BDataType>::type;
// Helper method to determine compute type, defaulting to input data type
// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed,
// ComputeDataType is used.
// Determines the compute type for a given input, preferring ThisDataType when possible.
// - If ThisDataType is packed: use OtherDataType, or ComputeDataType if both are packed.
// - If ThisDataType is smaller than a non-packed OtherDataType: use ComputeDataType.
// - Otherwise: use ThisDataType.
template <typename ThisDataType, typename OtherDataType, typename ComputeDataType>
using mixed_prec_compute_type_from_input_t = std::conditional_t<
is_packed_type_v<ThisDataType>,
std::conditional_t<is_packed_type_v<OtherDataType>, ComputeDataType, OtherDataType>,
ThisDataType>;
std::conditional_t<(sizeof(ThisDataType) < sizeof(OtherDataType)) &&
!is_packed_type_v<OtherDataType>,
ComputeDataType,
ThisDataType>>;
} // namespace ck_tile

View File

@@ -390,6 +390,9 @@ using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed =
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;

View File

@@ -1680,6 +1680,9 @@ using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8 =
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, bf8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_fp8 =
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base<bf8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =

View File

@@ -167,6 +167,7 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Ty
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_fp8; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };

View File

@@ -33,8 +33,14 @@ struct GemmQuantPipelineProblemBase
CDataType_,
BlockGemmShape_,
Traits_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>
mixed_prec_compute_type_from_input_t<
ADataType_,
BDataType_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>,
mixed_prec_compute_type_from_input_t<
BDataType_,
ADataType_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>>
{
using Base = GemmPipelineProblemBase<
ADataType_,
@@ -42,8 +48,14 @@ struct GemmQuantPipelineProblemBase
CDataType_,
BlockGemmShape_,
Traits_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>;
mixed_prec_compute_type_from_input_t<
ADataType_,
BDataType_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>,
mixed_prec_compute_type_from_input_t<
BDataType_,
ADataType_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>>;
using Traits = typename Base::Traits;

View File

@@ -27,7 +27,9 @@ using KernelTypes_Tensor = ::testing::Types<
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>,
std::tuple< Row, Col, Row, FP8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>
>;
// clang-format on