[rocm-libraries] ROCm/rocm-libraries#5218 (commit 60156cf)

[CK] Fix the issue of the aiter to call eightwarps pipeline.
 (#5218)

## Motivation

Fix the failure of the aiter to call eightwarp.
Changed Async to the name eightwarps.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

Pass

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
kensclin
2026-03-09 18:13:07 +00:00
committed by assistant-librarian[bot]
parent fe8b7d0c27
commit 8c216604d4
7 changed files with 10 additions and 14 deletions

View File

@@ -116,11 +116,7 @@ struct CShuffleEpilogue
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
#if defined(CK_GFX950_SUPPORT)
static constexpr bool EightWave = (MWave * NWave == 8);
#else
static constexpr bool EightWave = false;
#endif
static constexpr bool EightWave = (MWave * NWave == 8);
static constexpr index_t BlockedXDLN_PerWarp =
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;

View File

@@ -6,14 +6,14 @@
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_eightwarps.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"

View File

@@ -10,7 +10,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -22,7 +22,7 @@ namespace ck_tile {
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmABQuantPipelineAgBgCrAsyncPolicy>
struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Problem>
struct ABQuantGemmPipelineAgBgCrEightWarps : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmABQuantPipelineAgBgCrImplBase<Problem, Policy>;
@@ -126,7 +126,7 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "ABQuantGemmPipelineAgBgCrAsync",
return concat('_', "ABQuantGemmPipelineAgBgCrEightWarps",
concat('x', MPerBlock, NPerBlock, KPerBlock),
Problem::kBlockSize,
concat('x', MWarps, NWarps),
@@ -140,7 +140,7 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrAsync\n"; }
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWarps\n"; }
static constexpr index_t A_LOAD_INST = MPerBlock * KPerBlock / BlockSize / GetVectorSizeA();
static constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / GetVectorSizeB();