[rocm-libraries] ROCm/rocm-libraries#4964 (commit 3271d9a)

[CK Tile] Eight Waves pipeline GEMM

## Motivation

Eight waves pipeline was added for ABQuant. The goal of this PR is to
enable it also for GEMM

## Technical Details

Summary:
 - Block:
- Create block struct for GEMM using eight warps specific distribution
encodings
   - Use this block struct in ABQuant for encodings
 - Pipeline:
- Create impl pipeline for eight waves which can be used by GEMM and
ABQuant as base (and for AQuant and BQuant in the future)
- Create eight waves pipeline for GEMM (this can not be easily
integrated in the existing async pipeline)
 - Pipeline policy:
- Extract GEMM specific parts in the ABQuant policy to define GEMM
policy (then ABQuant use it as base and add Quant specific methods)
- Minor: naming was inconsistent between warp/wave, everything is now
referred to as eight waves

So overall we have:
- block struct directly used by GEMM -> ABQuant derived struct to
implement operator
- Impl base pipeline with general implementation -> GEMM and ABQuant
pipelines use it to avoid code duplication but still define their own
pipelines
- pipeline policy struct directly used by GEMM -> ABQuant derived policy
struct for Quant specific parts

## Test Plan

Added new tests for GEMM pipeline:
`test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950
supports it).

Note: K padding test is disabled for this pipeline because it's not
implemented yet

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Enrico Degregori
2026-03-16 08:31:56 +00:00
committed by assistant-librarian[bot]
parent b8108662da
commit eb033ef208
21 changed files with 1742 additions and 769 deletions

View File

@@ -7,9 +7,9 @@
#if defined(CK_USE_GFX950)
template <typename T, bool TransposeC = true>
using GemmConfig = GemmConfigEightWarps<T, TransposeC>;
using GemmConfig = GemmConfigEightWaves<T, TransposeC>;
template <typename T, bool TransposeC = true>
using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps<T, TransposeC>;
using GemmConfigPrefill = GemmConfigPreshuffleBEightWaves<T, TransposeC>;
#else
template <typename T, bool TransposeC = true>
using GemmConfig = GemmConfigABQuantPrefill<T, TransposeC>;

View File

@@ -297,7 +297,7 @@ struct GemmConfigMixedPrecision : public GemmConfigBase
};
template <typename PrecType, bool TransposeC_ = true>
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType, TransposeC_>
struct GemmConfigEightWaves : public GemmConfigABQuantPrefill<PrecType, TransposeC_>
{
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
@@ -312,7 +312,7 @@ struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType, Transpos
};
template <typename PrecType, bool TransposeC_ = true>
struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps<PrecType, TransposeC_>
struct GemmConfigPreshuffleBEightWaves : public GemmConfigEightWaves<PrecType, TransposeC_>
{
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;

View File

@@ -42,7 +42,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
(std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>);
constexpr bool transpose_c = GemmConfig::TransposeC;
constexpr bool eight_warps =
constexpr bool eight_waves =
#ifdef CK_GFX950_SUPPORT
IS_FP8BLOCKSCALE && (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
GemmConfig::K_Warp_Tile == 128;
@@ -85,7 +85,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
// Base pipeline selection based on quant mode and preshuffle settings
constexpr auto base_gemm_pipeline = []() {
if constexpr(eight_warps)
if constexpr(eight_waves)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else if constexpr(GemmConfig::PreshuffleB)
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
@@ -184,8 +184,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
using ABQuantPipeline = std::conditional_t<
eight_warps,
ck_tile::ABQuantGemmPipelineAgBgCrEightWarps<PipelineProblem>,
eight_waves,
ck_tile::ABQuantGemmPipelineAgBgCrEightWaves<PipelineProblem>,
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
@@ -256,7 +256,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
<< std::endl;
}
float ave_time = 0;
using k_attr_t = ck_tile::kernel_attr<eight_warps>;
using k_attr_t = ck_tile::kernel_attr<eight_waves>;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;