mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Merge commit '86d542f663201d7923c56cd8e31d46e01c4dcfcf' into develop
This commit is contained in:
@@ -472,6 +472,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
{
|
||||
// write to LDS window(0) must complete before the local prefetch
|
||||
block_sync_lds_direct_load();
|
||||
// read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0)
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
|
||||
@@ -124,12 +124,59 @@ using KernelTypesCompV3Wmma = ::testing::Types<
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
|
||||
>;
|
||||
// clang-format on
|
||||
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
|
||||
using CompV4Config = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I256, // MBlockTileSize
|
||||
I256, // NBlockTileSize
|
||||
I32, // KBlockTileSize
|
||||
I32, // MWarpTileSize
|
||||
I32, // NWarpTileSize
|
||||
I16, // KWarpTileSize
|
||||
Intrawave,
|
||||
CompV4>;
|
||||
|
||||
using KernelTypesCompV4 = ::testing::Types<CompV4Config<Row, Row, Row, F16>,
|
||||
CompV4Config<Row, Col, Row, F16>,
|
||||
CompV4Config<Col, Row, Row, F16>,
|
||||
CompV4Config<Col, Col, Row, F16>,
|
||||
CompV4Config<Row, Row, Row, F8>,
|
||||
CompV4Config<Row, Col, Row, F8>,
|
||||
CompV4Config<Col, Row, Row, F8>,
|
||||
CompV4Config<Col, Col, Row, F8>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
|
||||
using CompAsyncConfig = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I256, // MBlockTileSize
|
||||
I256, // NBlockTileSize
|
||||
I32, // KBlockTileSize
|
||||
I32, // MWarpTileSize
|
||||
I32, // NWarpTileSize
|
||||
I16, // KWarpTileSize
|
||||
Intrawave,
|
||||
CompAsync>;
|
||||
|
||||
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
|
||||
CompAsyncConfig<Row, Col, Row, F16>,
|
||||
CompAsyncConfig<Col, Row, Row, F16>,
|
||||
CompAsyncConfig<Col, Col, Row, F16>,
|
||||
CompAsyncConfig<Row, Row, Row, F8>,
|
||||
CompAsyncConfig<Row, Col, Row, F8>,
|
||||
CompAsyncConfig<Col, Row, Row, F8>,
|
||||
CompAsyncConfig<Col, Col, Row, F8>>;
|
||||
// clang-format off
|
||||
|
||||
using KernelTypesCompV6 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
|
||||
@@ -153,12 +200,6 @@ using KernelTypesCompV6 = ::testing::Types<
|
||||
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
|
||||
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>
|
||||
>;
|
||||
using KernelTypesCompAsync = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4Wmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
|
||||
|
||||
Reference in New Issue
Block a user