[rocm-libraries] ROCm/rocm-libraries#4964 (commit 3271d9a)

[CK Tile] Eight Waves pipeline GEMM

## Motivation

Eight waves pipeline was added for ABQuant. The goal of this PR is to
enable it also for GEMM

## Technical Details

Summary:
 - Block:
- Create block struct for GEMM using eight warps specific distribution
encodings
   - Use this block struct in ABQuant for encodings
 - Pipeline:
- Create impl pipeline for eight waves which can be used by GEMM and
ABQuant as base (and for AQuant and BQuant in the future)
- Create eight waves pipeline for GEMM (this can not be easily
integrated in the existing async pipeline)
 - Pipeline policy:
- Extract GEMM specific parts in the ABQuant policy to define GEMM
policy (then ABQuant use it as base and add Quant specific methods)
- Minor: naming was inconsistent between warp/wave, everything is now
referred to as eight waves

So overall we have:
- block struct directly used by GEMM -> ABQuant derived struct to
implement operator
- Impl base pipeline with general implementation -> GEMM and ABQuant
pipelines use it to avoid code duplication but still define their own
pipelines
- pipeline policy struct directly used by GEMM -> ABQuant derived policy
struct for Quant specific parts

## Test Plan

Added new tests for GEMM pipeline:
`test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950
supports it).

Note: K padding test is disabled for this pipeline because it's not
implemented yet

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Enrico Degregori
2026-03-16 08:31:56 +00:00
committed by assistant-librarian[bot]
parent b8108662da
commit eb033ef208
21 changed files with 1742 additions and 769 deletions

View File

@@ -46,7 +46,8 @@ enum struct GemmPipelineType
CompV3,
CompV4,
CompV6,
CompAsync
CompAsync,
CompAsyncEightWaves
};
template <GemmPipelineType PT, typename Problem>
@@ -97,6 +98,15 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompAsync, Problem>
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; }
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompAsyncEightWaves, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves<Problem>;
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsyncEightWaves"; }
};
template <typename Tuple, typename Derived>
class TestCkTileGemmPipeline : public ::testing::Test
{
@@ -129,7 +139,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t M_Warp =
PipelineType == GemmPipelineType::CompAsyncEightWaves ? 4 : 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
@@ -246,6 +257,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
}
if constexpr(PipelineType == GemmPipelineType::CompV4 ||
PipelineType == GemmPipelineType::CompAsyncEightWaves ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Only do k_batch = 1 when pipeline is CompV4, or BDataType is I4