mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#5218 (commit 60156cf)
[CK] Fix the issue of the aiter to call eightwarps pipeline. (#5218) ## Motivation Fix the failure of the aiter to call eightwarp. Changed Async to the name eightwarps. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan Pass ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fe8b7d0c27
commit
8c216604d4
@@ -10,7 +10,7 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -22,7 +22,7 @@ namespace ck_tile {
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <typename Problem, typename Policy = GemmABQuantPipelineAgBgCrAsyncPolicy>
|
||||
struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
struct ABQuantGemmPipelineAgBgCrEightWarps : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmABQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
@@ -126,7 +126,7 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "ABQuantGemmPipelineAgBgCrAsync",
|
||||
return concat('_', "ABQuantGemmPipelineAgBgCrEightWarps",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock),
|
||||
Problem::kBlockSize,
|
||||
concat('x', MWarps, NWarps),
|
||||
@@ -140,7 +140,7 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrAsync\n"; }
|
||||
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWarps\n"; }
|
||||
|
||||
static constexpr index_t A_LOAD_INST = MPerBlock * KPerBlock / BlockSize / GetVectorSizeA();
|
||||
static constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / GetVectorSizeB();
|
||||
Reference in New Issue
Block a user