From f4b880d058d915ad896554fac80fc0e5f6d35d67 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sun, 2 Nov 2025 00:06:28 -0400 Subject: [PATCH] refactor: remove gemm preshuffle pipeline v1 by removing all references from codebase (#3132) * test: temporarily disable flaky test_ck_tile_moe_sorting_2d_buf * refactor: deprecate gemm preshuffle pipeline v1 by removing all references from codebase * Revert "test: temporarily disable flaky test_ck_tile_moe_sorting_2d_buf" This reverts commit 573c08a085f3695833a63ae766493f7ac4cd958d. [ROCm/composable_kernel commit: 73f637894da54ac2014d3f7be675f1bf75a689c1] --- example/ck_tile/03_gemm/gemm_utils.hpp | 13 +- include/ck_tile/ops/gemm.hpp | 1 - .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 518 ------------------ .../test_gemm_pipeline_kernel_types.hpp | 11 +- .../test_gemm_pipeline_util.hpp | 10 - .../gemm_preshuffle_common.hpp | 8 +- .../gemm_preshuffle_instance_builder.py | 2 - 7 files changed, 5 insertions(+), 558 deletions(-) delete mode 100644 include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index a831a4f26c..dbed40800e 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -17,8 +17,7 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 #define CK_TILE_PIPELINE_COMPUTE_V6 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V1 6 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 7 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 template constexpr ck_tile::index_t get_k_warp_tile() @@ -514,16 +513,6 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; }; -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; - template - using UniversalGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; -}; - template <> struct PipelineTypeTraits { diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e1026485d7..33be18948b 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -59,7 +59,6 @@ #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp deleted file mode 100644 index 7095b4bd23..0000000000 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ /dev/null @@ -1,518 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" -#include "ck_tile/host/concat.hpp" -#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" - -namespace ck_tile { - -template -struct BaseWeightPreshufflePipelineAGmemBGmemCRegV1 -{ - static constexpr index_t PrefetchStages = 1; - static constexpr index_t PrefillStages = 1; - static constexpr index_t GlobalBufferNum = 1; - static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; - - CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - - CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } - - CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) - { - return TailNumber::Empty; - } - - template - CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) - { - return run_func(bool_constant{}, integral_constant{}); - } -}; - -template -struct WeightPreshufflePipelineAGmemBGmemCRegV1 - : public BaseWeightPreshufflePipelineAGmemBGmemCRegV1 -{ - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1; - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - - using AElementWise = remove_cvref_t; - using BElementWise = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - using AsLayout = remove_cvref_t; - using BsLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - using ALayout = remove_cvref_t>; - using BLayout = remove_cvref_t>; - - using ADataType = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - using BlockWeightPreshuffle = - remove_cvref_t())>; - - static constexpr index_t BlockSize = Problem::kBlockSize; - - static constexpr index_t kMPerBlock = BlockGemmShape::kM; - static constexpr index_t kNPerBlock = BlockGemmShape::kN; - static constexpr index_t kKPerBlock = BlockGemmShape::kK; - - static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; - static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; - - template - static constexpr index_t GetVectorSizeA() - { - return PipelinePolicy::template GetVectorSizeA(); - } - template - static constexpr index_t GetVectorSizeB() - { - return PipelinePolicy::template GetVectorSizeB(); - } - - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kPadK = Problem::kPadK; - - static constexpr index_t kLdsAlignmentInBytes = 16; - static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - static constexpr auto I0 = number<0>(); - static constexpr auto I1 = number<1>(); - static constexpr auto I2 = number<2>(); - - using BlockTile = remove_cvref_t; - using BlockWarps = remove_cvref_t; - using WarpTile = remove_cvref_t; - - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr bool Preshuffle = Problem::Preshuffle; - using Base::UsePersistentKernel; - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - return concat('_', "pipeline_AGmemBGmemCRegV1", - concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), - concat('x', GetVectorSizeA(), GetVectorSizeB()), - concat('x', kPadM, kPadN, kPadK)); - // clang-format on - } - - CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return PipelinePolicy::template GetSmemSize(); - } - - CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() - { - constexpr auto config = - BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - - constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; - constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; - constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; - - if constexpr(WG::kM == 16 && WG::kN == 16) - { - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA - }); - static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - }); - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA - }); - } - else if constexpr(WG::kM == 32 && WG::kN == 32 && - (A_LDS_Read_Inst_Num / 2 > - A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) - { - static_for<0, - A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num, - 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA - } - } - - template ::value && - !is_detected::value, - bool>* = nullptr, - index_t UnaryOpSize_ = 8> - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const - { - static_assert( - std::is_same_v> && - std::is_same_v>, - "A/B Dram block window should have the same data type as appropriate " - "([A|B]DataType) defined in Problem definition!"); - - constexpr bool is_a_col_major = std::is_same_v; - - static_assert(is_a_col_major - ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && - kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]) - : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]), - "A block window has incorrect lengths for defined ALayout!"); - - constexpr auto config = - BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; - constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; - - constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; - constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - - const index_t iMWarp = get_warp_id() / NWarp; - - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); - - // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - auto a_warp_window_tmp = make_tile_window( - a_lds_gemm_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - // Block GEMM - auto block_flatmm = BlockWeightPreshuffle(); - - // B flat DRAM window for load - auto b_flat_distribution = - PipelinePolicy::template MakeBFlatDramTileDistribution(); - auto b_flat_dram_window = - make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - b_flat_distribution); - - // Acc register tile - auto c_block_tile = block_flatmm.MakeCBlockTile(); - - // prefetch - // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_flat_dram_windows; - - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; - using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); - - statically_indexed_array, NIterPerWarp> - b_warp_tensor; - - statically_indexed_array, NIterPerWarp> - b_warp_tensor_2; - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - { - // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // move to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - if constexpr(std::is_same_v) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - PipelinePolicy::template MakeShuffledARegBlockDistribution()); - shuffle_tile(a_shuffle_tmp, a_block_tile); - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); - store_tile(a_copy_lds_window, a_block_tile_tmp); - } - else - { - store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); - } - block_sync_lds(); - } - - index_t iCounter = num_loop / 2 - 1; - while(iCounter > 0) - { - // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - - // GEMM i - block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); - - block_sync_lds(); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // move to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // LDS write i + 1 - auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); - HotLoopScheduler(); - block_sync_lds(); - - // iCounter--; - - // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - - // GEMM i - block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); - - block_sync_lds(); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // move to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // LDS write i + 1 - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); - - HotLoopScheduler(); - block_sync_lds(); - - iCounter--; - } - - // tail - { - // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - - // GEMM i - block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); - - block_sync_lds(); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // move to i + 2 - // move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); - - // move to next flat K - // move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - HotLoopScheduler(); - block_sync_lds(); - - // GEMM num_loop - 1 - block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); - } - - return c_block_tile; - } - - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - [[maybe_unused]] const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - [[maybe_unused]] const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - return operator()( - a_dram_block_window_tmp[number<0>{}], - [](const ADataType & a) { return a; }, - b_flat_dram_block_window_tmp[number<0>{}], - num_loop, - p_smem); - } - - template ::value && - !is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const - { - return operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType & a) { e = a; }, - b_flat_dram_block_window_tmp, - num_loop, - p_smem); - } -}; -} // namespace ck_tile diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp index 01dc25c7e2..001528a1e2 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp @@ -22,26 +22,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Default = ck_tile::integral_constant; -using WeightPreshuffleV1 = - ck_tile::integral_constant; using WeightPreshuffleV2 = ck_tile::integral_constant; // clang-format off using KernelTypesWeightPreshuffle = ::testing::Types< - std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV1>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV2>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV2>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV1> + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV2> #if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 , - std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV1>, std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV2>, - std::tuple< Row, Col, Row, F8, BF8, F32, F16, Default, WeightPreshuffleV1>, std::tuple< Row, Col, Row, F8, BF8, F32, F16, Default, WeightPreshuffleV2>, - std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV2>, - std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV1> + std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV2> #endif >; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 22d83306c3..c3ca8d5fe3 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -35,22 +35,12 @@ auto calculate_rtol_atol(const ck_tile::index_t K, enum struct GemmPipelineType { - WeightPreshuffleV1, WeightPreshuffleV2 }; template struct GemmPipelineTypeSelector; -template -struct GemmPipelineTypeSelector -{ - using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; - using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; - - static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV1"; } -}; - template struct GemmPipelineTypeSelector { diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index 09ec895ab5..abaa5ebd46 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -78,7 +78,7 @@ constexpr auto is_row_major(Layout) // Structure to hold kernel traits for dispatcher struct KernelTraits { - std::string pipeline; // preshufflev1, preshufflev2 + std::string pipeline; // preshufflev2 std::string scheduler; // intrawave, interwave, default std::string epilogue; // cshuffle, default bool pad_m; @@ -105,11 +105,7 @@ inline KernelTraits extract_traits_from_name(const std::string& kernel_name) KernelTraits traits; // Extract pipeline - if(kernel_name.find("preshufflev1") != std::string::npos) - { - traits.pipeline = "preshufflev1"; - } - else if(kernel_name.find("preshufflev2") != std::string::npos) + if(kernel_name.find("preshufflev2") != std::string::npos) { traits.pipeline = "preshufflev2"; } diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index e93b7a0c79..57c250f57e 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -357,13 +357,11 @@ class GemmPreshuffleKernelBuilder: # Map pipeline names to the correct pipeline implementation pipeline_impl_map = { - "preshufflev1": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1", "preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2", } # Map pipeline names to base pipeline for hot loop detection base_pipeline_map = { - "preshufflev1": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1", "preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2", }