mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
[rocm-libraries] ROCm/rocm-libraries#8554 (commit be9af54)
refactor(ck): mx gemm kernel unification ## Motivation CK tile currently has two separate MX GEMM kernels for gfx950 and gfx1250. This pull request refactors and modernizes the MX GEMM kernel and example to use new scale tensor handling, improved kernel argument structures, and updated pipeline and kernel APIs. The changes simplify the interface and improve type safety. JIRA ID ROCM-26313 ## Technical Details - Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused kernel - Unify comp async pipeline for GEMM and MX GEMM - Unify eight waves pipeline for GEMM and MX GEMM - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops - Unify testing framework for MX GEMM - Add gfx950 tests for grouped MX GEMM ## Test Plan - `test_mx_gemm_async.cpp` for MX GEMM on gfx950 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
604c56bc0e
commit
d559ec00a8
@@ -115,31 +115,6 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
|
||||
static_assert(Problem::TransposeC, "Wrong!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<AComputeDataType,
|
||||
BComputeDataType,
|
||||
CDataType,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
WGAccessDouble>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return ABQuantBlockUniversalGemmAsBsCrAsync<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
@@ -165,6 +140,47 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy : public GemmPipelineAgBgCrCompAsync
|
||||
FORWARD_METHOD_(GetInstCountBQ);
|
||||
|
||||
#undef FORWARD_METHOD_
|
||||
|
||||
template <typename Problem, bool IsPackMNIter = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
|
||||
static_assert(Problem::TransposeC, "Wrong!");
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
|
||||
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
|
||||
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
|
||||
constexpr index_t WarpTileM = WarpTile::at(I0);
|
||||
constexpr index_t WarpTileN = WarpTile::at(I1);
|
||||
constexpr index_t WarpTileK = WarpTile::at(I2);
|
||||
|
||||
constexpr auto WGAccessDouble = WGAttrNumAccessEnum::Double;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<AComputeDataType,
|
||||
BComputeDataType,
|
||||
CDataType,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
WGAccessDouble>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return ABQuantBlockUniversalGemmAsBsCrAsync<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user