[CK-Tile][Async gemm] add missing sync and f8 inputs test cases (#3000)

* add missing sync and f8 test cases

* reformat test cases

* comment failing cases

* bump

* reintroduce compv4 shapes
This commit is contained in:
Max Podkorytov
2025-10-24 12:16:01 -07:00
committed by GitHub
parent 0584399571
commit 86d542f663
2 changed files with 55 additions and 12 deletions

View File

@@ -124,12 +124,59 @@ using KernelTypesCompV3Wmma = ::testing::Types<
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
>;
using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
>;
// clang-format on
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompV4Config = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I256, // MBlockTileSize
I256, // NBlockTileSize
I32, // KBlockTileSize
I32, // MWarpTileSize
I32, // NWarpTileSize
I16, // KWarpTileSize
Intrawave,
CompV4>;
using KernelTypesCompV4 = ::testing::Types<CompV4Config<Row, Row, Row, F16>,
CompV4Config<Row, Col, Row, F16>,
CompV4Config<Col, Row, Row, F16>,
CompV4Config<Col, Col, Row, F16>,
CompV4Config<Row, Row, Row, F8>,
CompV4Config<Row, Col, Row, F8>,
CompV4Config<Col, Row, Row, F8>,
CompV4Config<Col, Col, Row, F8>>;
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompAsyncConfig = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I256, // MBlockTileSize
I256, // NBlockTileSize
I32, // KBlockTileSize
I32, // MWarpTileSize
I32, // NWarpTileSize
I16, // KWarpTileSize
Intrawave,
CompAsync>;
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
CompAsyncConfig<Row, Col, Row, F16>,
CompAsyncConfig<Col, Row, Row, F16>,
CompAsyncConfig<Col, Col, Row, F16>,
CompAsyncConfig<Row, Row, Row, F8>,
CompAsyncConfig<Row, Col, Row, F8>,
CompAsyncConfig<Col, Row, Row, F8>,
CompAsyncConfig<Col, Col, Row, F8>>;
// clang-format off
using KernelTypesCompV6 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
@@ -153,12 +200,6 @@ using KernelTypesCompV6 = ::testing::Types<
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>
>;
using KernelTypesCompAsync = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>
>;
using KernelTypesCompV4Wmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,