Merge commit '86d542f663201d7923c56cd8e31d46e01c4dcfcf' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-24 20:13:08 +00:00
parent e7707c32d1
commit 4494721174
2 changed files with 55 additions and 12 deletions

View File

@@ -124,12 +124,59 @@ using KernelTypesCompV3Wmma = ::testing::Types<
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
>;
using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
>;
// clang-format on
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompV4Config = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I256, // MBlockTileSize
I256, // NBlockTileSize
I32, // KBlockTileSize
I32, // MWarpTileSize
I32, // NWarpTileSize
I16, // KWarpTileSize
Intrawave,
CompV4>;
using KernelTypesCompV4 = ::testing::Types<CompV4Config<Row, Row, Row, F16>,
CompV4Config<Row, Col, Row, F16>,
CompV4Config<Col, Row, Row, F16>,
CompV4Config<Col, Col, Row, F16>,
CompV4Config<Row, Row, Row, F8>,
CompV4Config<Row, Col, Row, F8>,
CompV4Config<Col, Row, Row, F8>,
CompV4Config<Col, Col, Row, F8>>;
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompAsyncConfig = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I256, // MBlockTileSize
I256, // NBlockTileSize
I32, // KBlockTileSize
I32, // MWarpTileSize
I32, // NWarpTileSize
I16, // KWarpTileSize
Intrawave,
CompAsync>;
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
CompAsyncConfig<Row, Col, Row, F16>,
CompAsyncConfig<Col, Row, Row, F16>,
CompAsyncConfig<Col, Col, Row, F16>,
CompAsyncConfig<Row, Row, Row, F8>,
CompAsyncConfig<Row, Col, Row, F8>,
CompAsyncConfig<Col, Row, Row, F8>,
CompAsyncConfig<Col, Col, Row, F8>>;
// clang-format off
using KernelTypesCompV6 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
@@ -153,12 +200,6 @@ using KernelTypesCompV6 = ::testing::Types<
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>
>;
using KernelTypesCompAsync = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>
>;
using KernelTypesCompV4Wmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,