mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#5218 (commit 60156cf)
[CK] Fix the issue of the aiter to call eightwarps pipeline. (#5218) ## Motivation Fix the failure of the aiter to call eightwarp. Changed Async to the name eightwarps. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan Pass ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fe8b7d0c27
commit
8c216604d4
@@ -116,11 +116,7 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
static constexpr bool EightWave = (MWave * NWave == 8);
|
||||
#else
|
||||
static constexpr bool EightWave = false;
|
||||
#endif
|
||||
static constexpr bool EightWave = (MWave * NWave == 8);
|
||||
static constexpr index_t BlockedXDLN_PerWarp =
|
||||
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
@@ -6,14 +6,14 @@
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_async.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_eightwarps.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -22,7 +22,7 @@ namespace ck_tile {
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <typename Problem, typename Policy = GemmABQuantPipelineAgBgCrAsyncPolicy>
|
||||
struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
struct ABQuantGemmPipelineAgBgCrEightWarps : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmABQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
@@ -126,7 +126,7 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "ABQuantGemmPipelineAgBgCrAsync",
|
||||
return concat('_', "ABQuantGemmPipelineAgBgCrEightWarps",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock),
|
||||
Problem::kBlockSize,
|
||||
concat('x', MWarps, NWarps),
|
||||
@@ -140,7 +140,7 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrAsync\n"; }
|
||||
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWarps\n"; }
|
||||
|
||||
static constexpr index_t A_LOAD_INST = MPerBlock * KPerBlock / BlockSize / GetVectorSizeA();
|
||||
static constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / GetVectorSizeB();
|
||||
Reference in New Issue
Block a user