diff --git a/include/ck_tile/core/utility/mixed_prec_compute_type.hpp b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp index 021763c108..84ad45d277 100644 --- a/include/ck_tile/core/utility/mixed_prec_compute_type.hpp +++ b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp @@ -42,13 +42,16 @@ template using mixed_prec_compute_type_t = typename detail::mixed_prec_compute_type::type; -// Helper method to determine compute type, defaulting to input data type -// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed, -// ComputeDataType is used. +// Determines the compute type for a given input, preferring ThisDataType when possible. +// - If ThisDataType is packed: use OtherDataType, or ComputeDataType if both are packed. +// - If ThisDataType is smaller than a non-packed OtherDataType: use ComputeDataType. +// - Otherwise: use ThisDataType. template using mixed_prec_compute_type_from_input_t = std::conditional_t< is_packed_type_v, std::conditional_t, ComputeDataType, OtherDataType>, - ThisDataType>; - + std::conditional_t<(sizeof(ThisDataType) < sizeof(OtherDataType)) && + !is_packed_type_v, + ComputeDataType, + ThisDataType>>; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 45c53ffb63..c7b518161c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -390,6 +390,9 @@ using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed = using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; +using WarpGemmMfma_f32_16x16x32_bf8_fp8 = WarpGemmImpl< + WarpGemmAttributeMfma>>; + using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 138fcf230f..e47c868f45 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1680,6 +1680,9 @@ using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8 = template using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; +template +using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 71322cd62a..e93aad350e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -167,6 +167,7 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 84c2273f2d..3916def2d7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -33,8 +33,14 @@ struct GemmQuantPipelineProblemBase CDataType_, BlockGemmShape_, Traits_, - mixed_prec_compute_type_t, - mixed_prec_compute_type_t> + mixed_prec_compute_type_from_input_t< + ADataType_, + BDataType_, + mixed_prec_compute_type_t>, + mixed_prec_compute_type_from_input_t< + BDataType_, + ADataType_, + mixed_prec_compute_type_t>> { using Base = GemmPipelineProblemBase< ADataType_, @@ -42,8 +48,14 @@ struct GemmQuantPipelineProblemBase CDataType_, BlockGemmShape_, Traits_, - mixed_prec_compute_type_t, - mixed_prec_compute_type_t>; + mixed_prec_compute_type_from_input_t< + ADataType_, + BDataType_, + mixed_prec_compute_type_t>, + mixed_prec_compute_type_from_input_t< + BDataType_, + ADataType_, + mixed_prec_compute_type_t>>; using Traits = typename Base::Traits; diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp index e446f7b168..3b90b84e12 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp @@ -27,7 +27,9 @@ using KernelTypes_Tensor = ::testing::Types< std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False> + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, + std::tuple< Row, Col, Row, FP8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False> >; // clang-format on