mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4964 (commit 3271d9a)
[CK Tile] Eight Waves pipeline GEMM ## Motivation Eight waves pipeline was added for ABQuant. The goal of this PR is to enable it also for GEMM ## Technical Details Summary: - Block: - Create block struct for GEMM using eight warps specific distribution encodings - Use this block struct in ABQuant for encodings - Pipeline: - Create impl pipeline for eight waves which can be used by GEMM and ABQuant as base (and for AQuant and BQuant in the future) - Create eight waves pipeline for GEMM (this can not be easily integrated in the existing async pipeline) - Pipeline policy: - Extract GEMM specific parts in the ABQuant policy to define GEMM policy (then ABQuant use it as base and add Quant specific methods) - Minor: naming was inconsistent between warp/wave, everything is now referred to as eight waves So overall we have: - block struct directly used by GEMM -> ABQuant derived struct to implement operator - Impl base pipeline with general implementation -> GEMM and ABQuant pipelines use it to avoid code duplication but still define their own pipelines - pipeline policy struct directly used by GEMM -> ABQuant derived policy struct for Quant specific parts ## Test Plan Added new tests for GEMM pipeline: `test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950 supports it). Note: K padding test is disabled for this pipeline because it's not implemented yet ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8108662da
commit
eb033ef208
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -183,71 +184,21 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
using BlockGemmBase = BlockGemmARegBRegCRegEightWavesV1<Problem_, Policy_>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KWarp, KIterInterwave>,
|
||||
sequence<KWarp, KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<2, NWarp / 2>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
|
||||
tuple<sequence<0, 2, 1, 0>>,
|
||||
tuple<sequence<0, 0, 1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
return BlockGemmBase::MakeABlockDistributionEncode();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KWarp, KIterInterwave>,
|
||||
sequence<KWarp, KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<2, NIterPerWarp, NWarp / 2>, KIterSeq>,
|
||||
tuple<sequence<2, 1, 0, 1>>,
|
||||
tuple<sequence<0, 0, 0, 2>>,
|
||||
sequence</*1, 2*/>,
|
||||
sequence</*0, 1*/>>{};
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
return BlockGemmBase::MakeBBlockDistributionEncode();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<KWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<2, NIterPerWarp, NWarp / 2>>,
|
||||
tuple<sequence<2, 0, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
return c_block_dstr_encoding;
|
||||
return BlockGemmBase::MakeCBlockDistributionEncode();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
@@ -256,14 +207,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
|
||||
make_static_tile_distribution(MakeCBlockDistributionEncode()));
|
||||
}
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode())));
|
||||
using BLdsTile = statically_indexed_array<
|
||||
statically_indexed_array<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
make_static_tile_distribution(
|
||||
MakeBBlockDistributionEncode()))),
|
||||
KIterPerWarp>,
|
||||
NIterPerWarp>;
|
||||
using ALdsTile = typename BlockGemmBase::ALdsTile;
|
||||
using BLdsTiles = typename BlockGemmBase::BLdsTiles;
|
||||
|
||||
private:
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
@@ -291,7 +236,7 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
|
||||
template <typename CBlockTensor, typename AQBlockTensor, typename BQBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ALdsTile& a_warp_tile_,
|
||||
const BLdsTile& b_warp_tile_,
|
||||
const BLdsTiles& b_warp_tiles_,
|
||||
AQBlockTensor& aq_block_tensor,
|
||||
BQBlockTensor& bq_block_tensor)
|
||||
{
|
||||
@@ -328,7 +273,7 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_[nIter][kIter].get_thread_buffer();
|
||||
b_warp_tiles_[nIter][kIter].get_thread_buffer();
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
Reference in New Issue
Block a user