mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4964 (commit 3271d9a)
[CK Tile] Eight Waves pipeline GEMM ## Motivation Eight waves pipeline was added for ABQuant. The goal of this PR is to enable it also for GEMM ## Technical Details Summary: - Block: - Create block struct for GEMM using eight warps specific distribution encodings - Use this block struct in ABQuant for encodings - Pipeline: - Create impl pipeline for eight waves which can be used by GEMM and ABQuant as base (and for AQuant and BQuant in the future) - Create eight waves pipeline for GEMM (this can not be easily integrated in the existing async pipeline) - Pipeline policy: - Extract GEMM specific parts in the ABQuant policy to define GEMM policy (then ABQuant use it as base and add Quant specific methods) - Minor: naming was inconsistent between warp/wave, everything is now referred to as eight waves So overall we have: - block struct directly used by GEMM -> ABQuant derived struct to implement operator - Impl base pipeline with general implementation -> GEMM and ABQuant pipelines use it to avoid code duplication but still define their own pipelines - pipeline policy struct directly used by GEMM -> ABQuant derived policy struct for Quant specific parts ## Test Plan Added new tests for GEMM pipeline: `test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950 supports it). Note: K padding test is disabled for this pipeline because it's not implemented yet ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8108662da
commit
eb033ef208
@@ -7,9 +7,9 @@
|
||||
|
||||
#if defined(CK_USE_GFX950)
|
||||
template <typename T, bool TransposeC = true>
|
||||
using GemmConfig = GemmConfigEightWarps<T, TransposeC>;
|
||||
using GemmConfig = GemmConfigEightWaves<T, TransposeC>;
|
||||
template <typename T, bool TransposeC = true>
|
||||
using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps<T, TransposeC>;
|
||||
using GemmConfigPrefill = GemmConfigPreshuffleBEightWaves<T, TransposeC>;
|
||||
#else
|
||||
template <typename T, bool TransposeC = true>
|
||||
using GemmConfig = GemmConfigABQuantPrefill<T, TransposeC>;
|
||||
|
||||
@@ -297,7 +297,7 @@ struct GemmConfigMixedPrecision : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType, bool TransposeC_ = true>
|
||||
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType, TransposeC_>
|
||||
struct GemmConfigEightWaves : public GemmConfigABQuantPrefill<PrecType, TransposeC_>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
|
||||
@@ -312,7 +312,7 @@ struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType, Transpos
|
||||
};
|
||||
|
||||
template <typename PrecType, bool TransposeC_ = true>
|
||||
struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps<PrecType, TransposeC_>
|
||||
struct GemmConfigPreshuffleBEightWaves : public GemmConfigEightWaves<PrecType, TransposeC_>
|
||||
{
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
@@ -42,7 +42,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
(std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>);
|
||||
constexpr bool transpose_c = GemmConfig::TransposeC;
|
||||
constexpr bool eight_warps =
|
||||
constexpr bool eight_waves =
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
IS_FP8BLOCKSCALE && (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
|
||||
GemmConfig::K_Warp_Tile == 128;
|
||||
@@ -85,7 +85,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
|
||||
// Base pipeline selection based on quant mode and preshuffle settings
|
||||
constexpr auto base_gemm_pipeline = []() {
|
||||
if constexpr(eight_warps)
|
||||
if constexpr(eight_waves)
|
||||
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
|
||||
else if constexpr(GemmConfig::PreshuffleB)
|
||||
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
|
||||
@@ -184,8 +184,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using ABQuantPipeline = std::conditional_t<
|
||||
eight_warps,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrEightWarps<PipelineProblem>,
|
||||
eight_waves,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrEightWaves<PipelineProblem>,
|
||||
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
@@ -256,7 +256,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
<< std::endl;
|
||||
}
|
||||
float ave_time = 0;
|
||||
using k_attr_t = ck_tile::kernel_attr<eight_warps>;
|
||||
using k_attr_t = ck_tile::kernel_attr<eight_waves>;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user