diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index fa7f9fc788..1d2a3e180b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -472,6 +472,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync >; -using KernelTypesCompV4 = ::testing::Types< - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> ->; +// clang-format on +template +using CompV4Config = std::tuple; + +using KernelTypesCompV4 = ::testing::Types, + CompV4Config, + CompV4Config, + CompV4Config, + CompV4Config, + CompV4Config, + CompV4Config, + CompV4Config>; + +template +using CompAsyncConfig = std::tuple; + +using KernelTypesCompAsync = ::testing::Types, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig, + CompAsyncConfig>; +// clang-format off using KernelTypesCompV6 = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, @@ -153,12 +200,6 @@ using KernelTypesCompV6 = ::testing::Types< std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6> >; -using KernelTypesCompAsync = ::testing::Types< - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync> ->; using KernelTypesCompV4Wmma = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,