[rocm-libraries] ROCm/rocm-libraries#4964 (commit 3271d9a)

[CK Tile] Eight Waves pipeline GEMM

## Motivation

Eight waves pipeline was added for ABQuant. The goal of this PR is to
enable it also for GEMM

## Technical Details

Summary:
 - Block:
- Create block struct for GEMM using eight warps specific distribution
encodings
   - Use this block struct in ABQuant for encodings
 - Pipeline:
- Create impl pipeline for eight waves which can be used by GEMM and
ABQuant as base (and for AQuant and BQuant in the future)
- Create eight waves pipeline for GEMM (this can not be easily
integrated in the existing async pipeline)
 - Pipeline policy:
- Extract GEMM specific parts in the ABQuant policy to define GEMM
policy (then ABQuant use it as base and add Quant specific methods)
- Minor: naming was inconsistent between warp/wave, everything is now
referred to as eight waves

So overall we have:
- block struct directly used by GEMM -> ABQuant derived struct to
implement operator
- Impl base pipeline with general implementation -> GEMM and ABQuant
pipelines use it to avoid code duplication but still define their own
pipelines
- pipeline policy struct directly used by GEMM -> ABQuant derived policy
struct for Quant specific parts

## Test Plan

Added new tests for GEMM pipeline:
`test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950
supports it).

Note: K padding test is disabled for this pipeline because it's not
implemented yet

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Enrico Degregori
2026-03-16 08:31:56 +00:00
committed by assistant-librarian[bot]
parent b8108662da
commit eb033ef208
21 changed files with 1742 additions and 769 deletions

View File

@@ -9,6 +9,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp"
namespace ck_tile {
@@ -183,71 +184,21 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
using I0 = number<0>;
using I1 = number<1>;
using BlockGemmBase = BlockGemmARegBRegCRegEightWavesV1<Problem_, Policy_>;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KWarp, KIterInterwave>,
sequence<KWarp, KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<2, NWarp / 2>,
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
tuple<sequence<0, 2, 1, 0>>,
tuple<sequence<0, 0, 1, 1>>,
sequence<1, 2>,
sequence<0, 1>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
return BlockGemmBase::MakeABlockDistributionEncode();
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KWarp, KIterInterwave>,
sequence<KWarp, KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<2, NIterPerWarp, NWarp / 2>, KIterSeq>,
tuple<sequence<2, 1, 0, 1>>,
tuple<sequence<0, 0, 0, 2>>,
sequence</*1, 2*/>,
sequence</*0, 1*/>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
return BlockGemmBase::MakeBBlockDistributionEncode();
}
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<KWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<2, NIterPerWarp, NWarp / 2>>,
tuple<sequence<2, 0, 1, 2>>,
tuple<sequence<0, 0, 1, 2>>,
sequence<1, 2>,
sequence<0, 1>>{};
constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encoding;
return BlockGemmBase::MakeCBlockDistributionEncode();
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
@@ -256,14 +207,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
make_static_tile_distribution(MakeCBlockDistributionEncode()));
}
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(
make_static_tile_distribution(MakeABlockDistributionEncode())));
using BLdsTile = statically_indexed_array<
statically_indexed_array<decltype(make_static_distributed_tensor<ComputeDataType>(
make_static_tile_distribution(
MakeBBlockDistributionEncode()))),
KIterPerWarp>,
NIterPerWarp>;
using ALdsTile = typename BlockGemmBase::ALdsTile;
using BLdsTiles = typename BlockGemmBase::BLdsTiles;
private:
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
@@ -291,7 +236,7 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
template <typename CBlockTensor, typename AQBlockTensor, typename BQBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ALdsTile& a_warp_tile_,
const BLdsTile& b_warp_tile_,
const BLdsTiles& b_warp_tiles_,
AQBlockTensor& aq_block_tensor,
BQBlockTensor& bq_block_tensor)
{
@@ -328,7 +273,7 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_[nIter][kIter].get_thread_buffer();
b_warp_tiles_[nIter][kIter].get_thread_buffer();
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);