From 3ecd2a86895771a8468cb20d581f887022660132 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 24 Oct 2025 12:16:01 -0700 Subject: [PATCH] [CK-Tile][Async gemm] add missing sync and f8 inputs test cases (#3000) * add missing sync and f8 test cases * reformat test cases * comment failing cases * bump * reintroduce compv4 shapes [ROCm/composable_kernel commit: 86d542f663201d7923c56cd8e31d46e01c4dcfcf] --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 2 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 65 +++++++++++++++---- 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index fa7f9fc788..1d2a3e180b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -472,6 +472,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync >; -using KernelTypesCompV4 = ::testing::Types< - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> ->; +// clang-format on +template +using CompV4Config = std::tuple; + +using KernelTypesCompV4 = ::testing::Types, + CompV4Config, + CompV4Config, + CompV4Config, + CompV4Config, + CompV4Config, + CompV4Config, + CompV4Config>; + +template +using CompAsyncConfig = std::tuple; + +using KernelTypesCompAsync = ::testing::Types, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig>; +// clang-format off using KernelTypesCompV6 = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, @@ -153,12 +200,6 @@ using KernelTypesCompV6 = ::testing::Types< std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6> >; -using KernelTypesCompAsync = ::testing::Types< - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync> ->; using KernelTypesCompV4Wmma = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,