From e01c295551bc77852e92be8b3cc4afde66c251c6 Mon Sep 17 00:00:00 2001 From: SamiAario-AMD Date: Mon, 26 Jan 2026 20:23:26 +0200 Subject: [PATCH] Re enable f8 x bf8 tests on compv3 and compv4 (#3605) * Re-enable f8 x bf8 tests on CompV3 as they now pass * On CompV4, fp8 x bf8 tests now pass with K_BlockSize I32 * Add a changelog entry --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: 834642202c0cb39df1b96dacc24d5c3b3d97e62c] --- CHANGELOG.md | 1 + test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp | 9 ++------- test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp | 8 ++++---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f17a4d768..c99fc1d065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 * Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index ebe17aadd6..016f7be60d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -13,13 +13,8 @@ class TestCkTileGemmPipelineCompV3 static constexpr bool check_data_type() { using Base = TestCkTileGemmPipeline>; - if constexpr(std::is_same_v && - std::is_same_v) - { - return false; - } - else if constexpr(std::is_same_v && - std::is_same_v) + if constexpr(std::is_same_v && + std::is_same_v) { return false; } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 334e360eb5..4bef581254 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -170,7 +170,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -180,7 +180,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -190,7 +190,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -200,7 +200,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>