[CK] Fix the issue of the aiter to call eightwarps pipeline. (#5218)

## Motivation

Fix the failure of the aiter to call eightwarp.
Changed Async to the name eightwarps.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

Pass

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
kensclin
2026-03-10 02:12:13 +08:00
committed by GitHub
parent 72bb61d91a
commit 9f038ac7bc
7 changed files with 10 additions and 14 deletions

View File

@@ -185,7 +185,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using ABQuantPipeline = std::conditional_t<
eight_warps,
ck_tile::ABQuantGemmPipelineAgBgCrAsync<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrEightWarps<PipelineProblem>,
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;

View File

@@ -116,11 +116,7 @@ struct CShuffleEpilogue
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
#if defined(CK_GFX950_SUPPORT)
static constexpr bool EightWave = (MWave * NWave == 8);
#else
static constexpr bool EightWave = false;
#endif
static constexpr bool EightWave = (MWave * NWave == 8);
static constexpr index_t BlockedXDLN_PerWarp =
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;

View File

@@ -6,14 +6,14 @@
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_eightwarps.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"

View File

@@ -10,7 +10,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -22,7 +22,7 @@ namespace ck_tile {
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmABQuantPipelineAgBgCrAsyncPolicy>
struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Problem>
struct ABQuantGemmPipelineAgBgCrEightWarps : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmABQuantPipelineAgBgCrImplBase<Problem, Policy>;
@@ -126,7 +126,7 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "ABQuantGemmPipelineAgBgCrAsync",
return concat('_', "ABQuantGemmPipelineAgBgCrEightWarps",
concat('x', MPerBlock, NPerBlock, KPerBlock),
Problem::kBlockSize,
concat('x', MWarps, NWarps),
@@ -140,7 +140,7 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrAsync\n"; }
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWarps\n"; }
static constexpr index_t A_LOAD_INST = MPerBlock * KPerBlock / BlockSize / GetVectorSizeA();
static constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / GetVectorSizeB();

View File

@@ -1276,7 +1276,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
using GemmPipeline = std::conditional_t<
eight_warps,
ck_tile::ABQuantGemmPipelineAgBgCrAsync<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrEightWarps<PipelineProblem>,
std::conditional_t<PreshuffleB,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;