From 3a792017fbfd4e11468e85f12e98c286159ee3b0 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 12 Nov 2025 15:09:01 +0000 Subject: [PATCH] Add functionality and tests for fp16 x fp8 and fp8 x fp16 --- .../ops/common/determine_warp_prec_type.hpp | 14 ++++++++++++++ .../gemm/test_gemm_pipeline_kernel_types.hpp | 8 ++++++++ 2 files changed, 22 insertions(+) diff --git a/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/include/ck_tile/ops/common/determine_warp_prec_type.hpp index 8d30f60bdb..13094f591c 100644 --- a/include/ck_tile/ops/common/determine_warp_prec_type.hpp +++ b/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -52,4 +52,18 @@ struct DetermineWarpPrecType { using prec_type = float; }; + +// For fp8 x fp16 or fp16 x fp8, convert fp8 to float +template <> +struct DetermineWarpPrecType +{ + using prec_type = float; +}; + +// For fp8 x fp16 or fp16 x fp8, convert fp16 to float +template <> +struct DetermineWarpPrecType +{ + using prec_type = float; +}; }; // namespace ck_tile diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 3bfcd9e2c0..70449b5adf 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -79,48 +79,56 @@ using KernelTypesMemWmma = ::testing::Types< using KernelTypesCompV3 = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,