mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
[rocm-libraries] ROCm/rocm-libraries#4964 (commit 3271d9a)
[CK Tile] Eight Waves pipeline GEMM ## Motivation Eight waves pipeline was added for ABQuant. The goal of this PR is to enable it also for GEMM ## Technical Details Summary: - Block: - Create block struct for GEMM using eight warps specific distribution encodings - Use this block struct in ABQuant for encodings - Pipeline: - Create impl pipeline for eight waves which can be used by GEMM and ABQuant as base (and for AQuant and BQuant in the future) - Create eight waves pipeline for GEMM (this can not be easily integrated in the existing async pipeline) - Pipeline policy: - Extract GEMM specific parts in the ABQuant policy to define GEMM policy (then ABQuant use it as base and add Quant specific methods) - Minor: naming was inconsistent between warp/wave, everything is now referred to as eight waves So overall we have: - block struct directly used by GEMM -> ABQuant derived struct to implement operator - Impl base pipeline with general implementation -> GEMM and ABQuant pipelines use it to avoid code duplication but still define their own pipelines - pipeline policy struct directly used by GEMM -> ABQuant derived policy struct for Quant specific parts ## Test Plan Added new tests for GEMM pipeline: `test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950 supports it). Note: K padding test is disabled for this pipeline because it's not implemented yet ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8108662da
commit
eb033ef208
@@ -46,7 +46,8 @@ enum struct GemmPipelineType
|
||||
CompV3,
|
||||
CompV4,
|
||||
CompV6,
|
||||
CompAsync
|
||||
CompAsync,
|
||||
CompAsyncEightWaves
|
||||
};
|
||||
|
||||
template <GemmPipelineType PT, typename Problem>
|
||||
@@ -97,6 +98,15 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompAsync, Problem>
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; }
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::CompAsyncEightWaves, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves<Problem>;
|
||||
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsyncEightWaves"; }
|
||||
};
|
||||
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
@@ -129,7 +139,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t M_Warp =
|
||||
PipelineType == GemmPipelineType::CompAsyncEightWaves ? 4 : 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
@@ -246,6 +257,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
|
||||
}
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4 ||
|
||||
PipelineType == GemmPipelineType::CompAsyncEightWaves ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Only do k_batch = 1 when pipeline is CompV4, or BDataType is I4
|
||||
|
||||
Reference in New Issue
Block a user