From 3b5f2b2d99e990c968a044f345227c388f4dd611 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Tue, 6 Jan 2026 05:49:26 +0800 Subject: [PATCH] Joye/revise wp pipeline (#3493) * [CK_TILE] unify double and single lds implementation (#108) Unify LDS buffer management API for single and double buffering modes This change consolidates the Local Data Store (LDS) buffer management by: Merging single and double LDS buffer APIs into a unified interface Implementing ping-pong address calculation in pipeline when double LDS is enabled Computing pong buffer addresses dynamically using base address offsets --------- Co-authored-by: joye Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update wp_pipeline * fix a c++17 issue * update for ci errors * fix ci issues * include a header to fix ci errors * fix some rebase issues * update with rebase --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> [ROCm/composable_kernel commit: 2b563ad04828c5c970f7544d49831f33203587fb] --- include/ck_tile/ops/gemm.hpp | 1 + .../gemm/block/block_wp_asmem_breg_creg.hpp | 212 +++++ .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 34 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 175 ++-- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 127 ++- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 29 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 50 +- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 19 +- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 795 ++++++------------ .../gemm_quant/kernel/gemm_quant_kernel.hpp | 139 +-- .../kernel/grouped_gemm_quant_kernel.hpp | 49 +- ...p_bquant_pipeline_ag_bg_cr_base_policy.hpp | 42 + .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 23 +- 13 files changed, 766 insertions(+), 929 deletions(-) create mode 100644 include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 0eaedbfb3a..2c3a161121 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -25,6 +25,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp new file mode 100644 index 0000000000..4fc180b42b --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp @@ -0,0 +1,212 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on register +// C is block distributed tensor +template +struct BlockWeightPreshuffleASmemBRegCReg +{ + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM; + static constexpr index_t KPerBlockPerIter = WarpGemm::kK; + + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + statically_indexed_array preloaded_a_warp_tensor; + + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence<1>>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + template + CK_TILE_DEVICE auto MakeALoadWindows(SmemBlockWindow& a_block_window) const + { + constexpr auto a_load_dstr = make_static_tile_distribution(MakeABlockDistributionEncode()); + + // create MIterPerWarp × KIterPerWarp window + return generate_tuple( + [&](auto kIter) { + return generate_tuple( + [&](auto mIter) { + return make_tile_window( + get_slice_tile( + a_block_window, + sequence{}, + sequence<(mIter + 1) * MPerBlockPerIter, + (kIter + 1) * KPerBlockPerIter>{}), + a_load_dstr); + }, + number{}); + }, + number{}); + } + + template + CK_TILE_DEVICE void LocalPrefetch(const ALoadWindows& a_load_windows) + { + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + + load_tile(preloaded_a_warp_tensor(loadIter), + a_load_windows[number{}][number{}]); + }); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ALoadWindows& a_load_windows, + BFlatBlockTensor& b_block_tensor, + const BFlatDistribution&) + { + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + + using CWarpDstr = typename WarpGemm::CWarpDstr; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + using BWarpTensor = typename WarpGemm::BWarpTensor; + + constexpr auto b_block_y_lengths = + to_sequence(BFlatDistribution{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto b_block_y_index_zeros = + uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + BWarpTensor b_warp_tensor; + CWarpTensor c_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + typename sequence_split::right_type{}), + merge_sequences( + sequence<1, 1>{}, + typename sequence_split::right_type{})); + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WarpGemm{}( + c_warp_tensor, preloaded_a_warp_tensor(number{}), b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + + load_tile(preloaded_a_warp_tensor(number{}), + a_load_windows[number{}][number{}]); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 5ba5699dda..3f028ead2b 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -303,24 +303,15 @@ struct GroupedGemmKernel CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // TO DO: // Can we simplify this branching logic? if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - kargs.ds_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); + RunGemmWithPipelineSelection2LDS( + a_ptr, b_ptr, c_ptr, kargs.ds_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } else // SingleSmemBuffer { @@ -331,7 +322,7 @@ struct GroupedGemmKernel b_ptr, kargs.ds_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -343,7 +334,7 @@ struct GroupedGemmKernel {b_ptr}, kargs.ds_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -425,9 +416,7 @@ struct GroupedGemmKernel * @param a_ptr input A pointer * @param b_ptr input B pointer * @param c_ptr output C pointer - * @param ds_ptr input Ds pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. @@ -439,8 +428,7 @@ struct GroupedGemmKernel const BDataType* b_ptr, CDataType* c_ptr, const std::array& ds_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, + void* __restrict__ smem_ptr, const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -460,8 +448,8 @@ struct GroupedGemmKernel amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + const auto& c_block_tile = + GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); // Run Epilogue Pipeline if(kargs.k_batch == 1) @@ -469,7 +457,7 @@ struct GroupedGemmKernel auto c_block_window = Base::template MakeCBlockWindows( c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr); } else { @@ -477,7 +465,7 @@ struct GroupedGemmKernel Base::template MakeCBlockWindows( c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr); } } diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 65f58a8ca5..c77459b4ec 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -978,7 +978,7 @@ struct UniversalGemmKernel * @param bs_ptr input Bs pointer * @param ds_ptr input Ds pointer * @param e_ptr output E pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. @@ -990,7 +990,7 @@ struct UniversalGemmKernel const std::array& bs_ptr, const std::array& ds_ptr, EDataType* e_ptr, - void* smem_ptr_0, + void* smem_ptr, const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -1008,7 +1008,7 @@ struct UniversalGemmKernel // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}.template operator()( - as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr); const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); // Run Epilogue Pipeline @@ -1016,77 +1016,63 @@ struct UniversalGemmKernel { auto c_block_window = MakeCBlockWindows( e_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } else { auto c_block_window = MakeCBlockWindows( e_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } } - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param as_ptr input As pointer - * @param bs_ptr input Bs pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) + CK_TILE_DEVICE static auto + GetTileCoordinates(const KernelArgs& kargs) -> tuple { - // Create block windows using specialized methods - const auto& as_block_window = - MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); - const auto& bs_block_window = - MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + index_t iM, iN; - const index_t num_loop = - amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + // Regular launch: use 1D block indexing + const auto blockId = amd_wave_read_first_lane(blockIdx.x); + const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + iM = tile_m; + iN = tile_n; - // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window, - AElementWise{}, - bs_block_window, - BElementWise{}, - num_loop, - smem_ptr_0, - smem_ptr_1); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - // Run Epilogue Pipeline - if(kargs.k_batch == 1) + return make_tuple(i_m, i_n); + } + + // Helper functions + CK_TILE_DEVICE static auto GetBlockId() -> index_t + { + // For 1D regular launch + return amd_wave_read_first_lane(get_block_id()); + } + + CK_TILE_DEVICE static auto GetGridSize() -> index_t + { + // For 1D regular launch + return amd_wave_read_first_lane(get_grid_size()); + } + + // Helper to get total number of tiles, handling both dim3 and index_t return types + template + CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t + { + auto grid_size = TilePartitioner::GridSize(std::forward(args)...); + + using GridSizeType = decltype(grid_size); + + if constexpr(std::is_same_v) { - auto c_block_window = MakeCBlockWindows( - e_ptr, kargs, block_idx_m, block_idx_n); - - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + // GridSize returns dim3: compute total tiles as x * y * z + return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z); } else { - auto c_block_window = MakeCBlockWindows( - e_ptr, kargs, block_idx_m, block_idx_n); - - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + // GridSize returns scalar (index_t): use directly + return amd_wave_read_first_lane(grid_size); } } @@ -1123,36 +1109,12 @@ struct UniversalGemmKernel } // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + constexpr auto scheduler_type = + GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1); + RunGemm( + as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } // Persistent kernel entry point @@ -1199,34 +1161,19 @@ struct UniversalGemmKernel } // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // Run the GEMM - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n); + // Advance to the next work item block_id += grid_size; if(block_id >= num_work) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 343e37ed66..4973d9c941 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -64,12 +64,17 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template + template CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_tile(dst_block_tile, dram_tile_window); + load_int4_tile(dst_block_tile, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } @@ -217,22 +222,17 @@ struct GemmPipelineAgBgCrImplBase return std::move(a_copy_dram_window); } - template - CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const ALdsTensorView& a_lds_block_view, - const ALdsLoadTileDistr&, - const array& offset = {0, 0}) const + template + CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view, + const ALdsLoadTileDistr&) const { - // A DRAM tile window for load - auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); - - // A LDS tile window for store auto a_lds_shape = []() { if constexpr(is_a_load_tr) return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(); + auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}); auto a_lds_load_tile_distr = []() { @@ -244,32 +244,73 @@ struct GemmPipelineAgBgCrImplBase else return ALdsLoadTileDistr{}; }(); + auto a_lds_gemm_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr); + return make_tuple(std::move(a_copy_lds_window), std::move(a_lds_gemm_window)); + } + + template < + typename ADramBlockWindowTmp, + typename ALdsTensorView, + typename ALdsLoadTileDistr, + typename std::enable_if_t::value, bool>* = nullptr> + CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorView& a_lds_block_view, + const ALdsLoadTileDistr& a_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // A DRAM tile window for load + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); + + // Create LDS windows + auto [a_copy_lds_window, a_lds_gemm_window] = + MakeALdsWindows(a_lds_block_view, a_lds_load_tile_distr); + return make_tuple(std::move(a_copy_dram_window), std::move(a_copy_lds_window), std::move(a_lds_gemm_window)); } - template - CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BLdsTensorView& b_lds_block_view, - const BLdsLoadTileDistr&, + // Unified GetAWindows that supports 1, 2, or 3 LDS buffers + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorViewsTuple& a_lds_block_views_tuple, + const ALdsLoadTileDistr& a_lds_load_tile_distr, const array& offset = {0, 0}) const { // A DRAM tile window for load - auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); - // TODO: Do we really need those two tile windows??? - // They're exactly same... - // B LDS tile window for store + // Create LDS windows for each buffer + constexpr index_t num_buffers = ALdsTensorViewsTuple::size(); + auto a_lds_windows = generate_tuple( + [&](auto i) { + return MakeALdsWindows(a_lds_block_views_tuple[i], a_lds_load_tile_distr); + }, + number{}); + + // Return: (dram_window, lds_windows_tuple) + // lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i) + return make_tuple(std::move(a_copy_dram_window), std::move(a_lds_windows)); + } + + template + CK_TILE_DEVICE constexpr auto MakeBLdsWindows(const BLdsTensorView& b_lds_block_view, + const BLdsLoadTileDistr&) const + { auto b_lds_shape = []() { if constexpr(is_b_load_tr) return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(); + auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); using BLdsDataType = @@ -286,13 +327,61 @@ struct GemmPipelineAgBgCrImplBase else return BLdsLoadTileDistr{}; }(); + auto b_lds_gemm_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr); + return make_tuple(std::move(b_copy_lds_window), std::move(b_lds_gemm_window)); + } + + template < + typename BDramBlockWindowTmp, + typename BLdsTensorView, + typename BLdsLoadTileDistr, + typename std::enable_if_t::value, bool>* = nullptr> + CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorView& b_lds_block_view, + const BLdsLoadTileDistr& b_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // A DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + + // Create LDS windows + auto [b_copy_lds_window, b_lds_gemm_window] = + MakeBLdsWindows(b_lds_block_view, b_lds_load_tile_distr); + return make_tuple(std::move(b_copy_dram_window), std::move(b_copy_lds_window), std::move(b_lds_gemm_window)); } + + // Unified GetBWindows that supports 1, 2, or 3 LDS buffers + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorViewsTuple& b_lds_block_views_tuple, + const BLdsLoadTileDistr& b_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // B DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + + // Create LDS windows for each buffer + constexpr index_t num_buffers = BLdsTensorViewsTuple::size(); + auto b_lds_windows = generate_tuple( + [&](auto i) { + return MakeBLdsWindows(b_lds_block_views_tuple[i], b_lds_load_tile_distr); + }, + number{}); + + // Return: (dram_window, lds_windows_tuple) + // lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i) + return make_tuple(std::move(b_copy_dram_window), std::move(b_lds_windows)); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 0b2cdde05e..8acfea4580 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -158,6 +158,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; @@ -172,7 +174,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync(); + constexpr index_t smem_size = Policy::template GetSmemSize(); + return 2 * smem_size; } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() @@ -240,8 +243,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync); @@ -303,8 +305,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}); // this pipeline has a pair of LDS buffers per logical tile - auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); - auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + constexpr index_t smem_size = Policy::template GetSmemSize(); + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem); + auto&& [a_lds_block1, b_lds_block1] = + Base::GetABLdsTensorViews(static_cast(p_smem) + smem_size); // set up LDS tile shapes constexpr auto a_lds_shape = []() { @@ -534,21 +538,18 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}.template operator()( a_dram_block_window_tmp, a_element_func, b_dram_block_window_tmp, b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -559,8 +560,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer"); + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() { // clang-format off @@ -191,7 +193,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return Policy::template GetSmemSize(); + constexpr index_t smem_size = Policy::template GetSmemSize(); + return 2 * smem_size; } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() @@ -281,8 +284,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { using ADramBlockWindowTmp = remove_cvref_t{}, AsDramBlockWindowTmp>>; @@ -324,8 +326,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // global read 0 ////////////// LDS desc, window & register ///////////////// - auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); - auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + constexpr index_t smem_size = Policy::template GetSmemSize(); + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem); + auto&& [a_lds_block1, b_lds_block1] = + Base::GetABLdsTensorViews(static_cast(p_smem) + smem_size); constexpr auto a_lds_shape = []() { if constexpr(is_a_load_tr_v()) @@ -680,8 +684,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* p_smem_0, - void* p_smem_1) const + void* p_smem) const { const bool has_hot_loop = Base::BlockHasHotloop(num_loop); const auto tail_number = Base::GetBlockLoopTailNum(num_loop); @@ -693,8 +696,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -708,8 +710,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { const bool has_hot_loop = Base::BlockHasHotloop(num_loop); const auto tail_number = Base::GetBlockLoopTailNum(num_loop); @@ -721,8 +722,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, [](auto& e, const BDataType& b) { e = b; }, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -738,8 +738,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 index_t num_loop, bool has_hot_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; @@ -751,8 +750,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, PassThrough, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } @@ -769,16 +767,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* p_smem_0, - void* p_smem_1) const + void* p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), a_element_func, ck_tile::make_tuple(b_dram_block_window_tmp), b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); } template CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), ck_tile::make_tuple(b_dram_block_window_tmp), num_loop, - p_smem_0, - p_smem_1); + p_smem); } template index_t num_loop, bool has_hot_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), ck_tile::make_tuple(b_dram_block_window_tmp), num_loop, has_hot_loop, tail_number, - p_smem_0, - p_smem_1); + p_smem); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 019a828ec0..e90c6a27d7 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" namespace ck_tile { @@ -201,6 +202,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using TileShape = typename Problem::BlockGemmShape; + constexpr index_t kNPerBlock = TileShape::kN; + constexpr index_t kKPerBlock = TileShape::kK; + constexpr index_t NIterPerWarp = + kNPerBlock / TileShape::BlockWarps::at(I1) / TileShape::WarpTile::at(I1); + constexpr index_t KIterPerWarp = kKPerBlock / TileShape::WarpTile::at(I2); + constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; @@ -213,13 +220,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy #endif constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; - constexpr index_t KRepeat = 1; + constexpr index_t KRepeat = KIterPerWarp; static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); constexpr index_t NBPerLoad = 1; constexpr index_t NThdPerWave = 1; constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp - constexpr index_t NRepeat = 1; + constexpr index_t NRepeat = NIterPerWarp; constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; return make_static_tile_distribution( @@ -232,8 +239,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy tuple, sequence<0, 1, 2>>, // which direction tuple, sequence<1, 2, 2>>, // which index // - sequence<1, 1, 2, 2>, - sequence<0, 3, 0, 3>>{}); + sequence<1, 2, 1, 2>, + sequence<0, 0, 3, 3>>{}); } template @@ -307,7 +314,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy typename Problem::CDataType, BlockWarps, WarpGemm>; - return BlockWeightPreshuffleASmemBSmemCRegV1{}; + return BlockWeightPreshuffleASmemBRegCReg{}; } /** * @brief Get the vector store size for C tensor. @@ -325,7 +332,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { using BlockGemm = remove_cvref_t())>; - using WG_ = typename BlockGemm::WG; + using WG_ = typename BlockGemm::WarpGemm; constexpr bool TransposeC = Problem::TransposeC; using CLayout = typename Problem::CLayout; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index f64901755b..c9499106de 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -32,19 +32,34 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 template CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) { - if(tail_number == TailNumber::Odd) + if(has_hot_loop) { - return run_func(bool_constant{}, - integral_constant{}); + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else // Even tail number + { + return run_func(bool_constant{}, + integral_constant{}); + } } - else // Even tail number + else { - return run_func(bool_constant{}, - integral_constant{}); + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else // Even tail number + { + return run_func(bool_constant{}, + integral_constant{}); + } } - return run_func(bool_constant{}, integral_constant{}); } }; @@ -52,7 +67,8 @@ template { - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -75,11 +91,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 using BlockWeightPreshuffle = remove_cvref_t())>; - static constexpr auto config = - BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read @@ -95,6 +106,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock; + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; @@ -131,12 +144,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 using BlockWarps = remove_cvref_t; using WarpTile = remove_cvref_t; - static constexpr index_t MWarp = config.template at<1>(); - static constexpr index_t NWarp = config.template at<2>(); + static constexpr index_t MWarp = BlockWarps::at(I0); + static constexpr index_t NWarp = BlockWarps::at(I1); - static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + static constexpr index_t WarpTileM = WarpTile::at(I0); + static constexpr index_t WarpTileN = WarpTile::at(I1); + static constexpr index_t WarpTileK = WarpTile::at(I2); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpTileM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpTileN); + static constexpr index_t KIterPerWarp = kKPerBlock / WarpTileK; static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; @@ -154,20 +171,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 #else static constexpr index_t mfma_per_wg = 1; #endif - static constexpr index_t dsread_per_wg = - max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1); + static constexpr index_t dsread_per_wg = max( + index_t(WarpTileM * WarpTileK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1); #if defined(__HIP_DEVICE_COMPILE__) - static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + static_assert((WarpTileM * WarpTileK * sizeof(ADataType) * MIterPerWarp / WaveSize) % Problem::VectorLoadSize == 0); #endif - static constexpr index_t dsread_num_perK = - WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize; + static constexpr index_t dsread_num_perK = WarpTileM * WarpTileK * sizeof(ADataType) * + MIterPerWarp / WaveSize / Problem::VectorLoadSize; static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp); static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; static constexpr index_t Aload_num_perK = dswrite_num_perK; static constexpr index_t Aload_rep = dswrite_rep; - static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize; + static constexpr index_t Bload_num_perK = kNPerBlock * WarpTileK / NWarp / K1 / WaveSize; static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; @@ -187,7 +204,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // clang-format off return concat('_', "pipeline_AGmemBGmemCRegV2", concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), - concat('x', WG::kM, WG::kN, WG::kK), + concat('x', WarpTileM, WarpTileN, WarpTileK), concat('x', GetVectorSizeA(), GetVectorSizeB()), concat('x', kPadM, kPadN, kPadK)); @@ -195,14 +212,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr index_t Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return PipelinePolicy::template GetSmemSize(); + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + return DoubleSmemBuffer ? 2 * smem_size : smem_size; } // dsread_perM: how many LDS reads want to issue in this M-iter @@ -515,515 +534,184 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // __builtin_amdgcn_sched_barrier(0); } - template ::value && - !is_detected::value, - bool>* = nullptr, - index_t UnaryOpSize_ = 8> - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + struct PipelineImpl : public PipelineImplBase { - static_assert( - std::is_same_v>, - "wrong!"); + using Base = PipelineImplBase; - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], - "wrong!"); - static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; - const index_t iMWarp = get_warp_id() / NWarp; - - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - __builtin_amdgcn_sched_barrier(0); - - // A tile in LDS - ADataType* p_a_lds_ping = static_cast(p_smem_ping); - ADataType* p_a_lds_pong = static_cast(p_smem_pong); - - constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeALdsBlockDescriptor(); - - auto a_lds_block_ping = - make_tensor_view(p_a_lds_ping, a_lds_block_desc); - auto a_lds_block_pong = - make_tensor_view(p_a_lds_pong, a_lds_block_desc); - - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); - - auto a_copy_lds_window_ping = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - - auto a_copy_lds_window_pong = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - - // ping-pong window for A LDS - auto a_warp_window_ping_tmp = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - - auto a_warp_window_pong_tmp = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows_ping; - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows_pong; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - // Block GEMM - auto block_weight_preshuffle = BlockWeightPreshuffle(); - // Acc register tile - auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); - - // B flat DRAM window for load - auto b_flat_distribution = - PipelinePolicy::template MakeBFlatDramTileDistribution(); - auto b_flat_dram_window = // tile_window_with_static_distribution - make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - b_flat_distribution); - - // pingpong buffer for B - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; - using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); - - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_flat_dram_windows; - - statically_indexed_array, NIterPerWarp> - b_warp_tensor_ping; - - statically_indexed_array, NIterPerWarp> - b_warp_tensor_pong; - - // Prefetch A0 - auto a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // Prefill A0 - auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - - __builtin_amdgcn_sched_barrier(0); - - // Prefetch A1 - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - block_sync_lds(); - - // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor; - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); - }); - __builtin_amdgcn_sched_barrier(0); - - // MAIN LOOP - index_t iCounter = (num_loop - 1) / 2; - while(iCounter > 0) + template ::value && + !is_detected::value, + bool>* = nullptr, + index_t UnaryOpSize_ = 8> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const { - // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_assert( + std::is_same_v>, + "wrong!"); - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); + // A tile in LDS + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); - // Prefill A(2i+1) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); - // Prefetch A(2i+2) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + auto a_lds_blocks = generate_tuple( + [&](auto i) { + ADataType* p_a_lds = static_cast( + static_cast(static_cast(p_smem) + smem_size * i.value)); + return make_tensor_view(p_a_lds, a_lds_block_desc); + }, + number<2>{}); - // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + constexpr auto a_lds_load_tile_distr = make_static_tile_distribution( + BlockWeightPreshuffle::MakeABlockDistributionEncode()); + auto&& windows_result = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_blocks, a_lds_load_tile_distr); + auto&& a_copy_dram_window = windows_result.template get<0>(); + auto&& a_lds_windows = windows_result.template get<1>(); + auto a_copy_lds_windows = generate_tuple( + [&](auto i) -> decltype(auto) { return a_lds_windows[i].template at<0>(); }, + number<2>{}); + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + auto a_load_windows = generate_tuple( + [&](auto i) -> decltype(auto) { + return block_weight_preshuffle.MakeALoadWindows(a_copy_lds_windows[i]); + }, + number<2>{}); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window(b_flat_dram_block_window_tmp + .get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, + number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BFlatBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kflatKPerBlock); - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BBlockTile = + decltype(make_static_distributed_tensor(b_flat_distribution)); + + ABlockTile a_global_tile; + BBlockTile b_global_tile[2]; + + // // Prefetch A0 + Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + + Base::template GlobalPrefetch( + b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + + // Prefill A0 + Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); + + // Prefetch A1 + Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]); + + __builtin_amdgcn_sched_barrier(0); + // MAIN LOOP + if constexpr(HasHotLoop) + { + index_t i_global_read = amd_wave_read_first_lane(2); + do + { { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } + Base::template GlobalPrefetch( + b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); + Base::GlobalPrefetch( + a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + block_weight_preshuffle(c_block_tile, + a_load_windows[I0], + b_global_tile[0], + b_flat_distribution); - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]); + HotLoopScheduler(); + } { - block_sync_lds(); + Base::template GlobalPrefetch( + b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); + Base::GlobalPrefetch( + a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + block_weight_preshuffle(c_block_tile, + a_load_windows[I1], + b_global_tile[1], + b_flat_distribution); + + block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]); + HotLoopScheduler(); } - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + i_global_read += 2; + } while(i_global_read < num_loop); + } - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); - }); - HotLoopScheduler(); + // tail + if constexpr(TailNum == TailNumber::Even) + { + { + Base::template GlobalPrefetch( + b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); + block_weight_preshuffle( + c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution); + block_sync_lds(); + block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]); + Last2ndHotLoopScheduler(); + } + { + block_weight_preshuffle( + c_block_tile, a_load_windows[I1], b_global_tile[1], b_flat_distribution); + LastHotLoopScheduler(); + } + } + else if constexpr(TailNum == TailNumber::Odd) + { + block_weight_preshuffle( + c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution); + LastHotLoopScheduler(); + } - // Next K - - // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(2i+2) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - - // Prefetch A(2i+3) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); - }); - HotLoopScheduler(); - - iCounter--; + return c_block_tile; } - - // tail - if constexpr(TailNum == TailNumber::Even) - { - // __builtin_amdgcn_sched_barrier(0); - // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(loopK) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); - - // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - // TailHotLoopScheduler(); - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); - }); - - Last2ndHotLoopScheduler(); - - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - LastHotLoopScheduler(); - } - else if constexpr(TailNum == TailNumber::Odd) - { - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - LastHotLoopScheduler(); - } - - return c_block_tile; - } + }; // called from universal gemm kernel template (a_dram_block_window_tmp[number<0>{}], - PassThrough, - b_flat_dram_block_window_tmp[number<0>{}], - num_loop, - p_smem_ping, - p_smem_pong); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp[number<0>{}], + a_element_func, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } // called from general gemm kernel @@ -1066,23 +751,21 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + void* p_smem) const { - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - const auto RunPipeline = [&](auto bool_val, auto tail_num_) { - (void)bool_val; // Suppress unused parameter warning - constexpr auto tail_num = tail_num_.value; + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr auto PassThrough = [](const ADataType& a) { return a; }; - return operator()(a_dram_block_window_tmp, - PassThrough, - b_flat_dram_block_window_tmp, - num_loop, - p_smem_ping, - p_smem_pong); + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } // called from grouped gemm kernel @@ -1095,21 +778,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { - const auto RunPipeline = [&](auto bool_val, auto tail_num_) { - (void)bool_val; // Suppress unused parameter warning - constexpr auto tail_num = tail_num_.value; + const auto has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr auto PassThrough = [](const auto& x) { return x; }; - return operator()(a_dram_block_window_tmp, - PassThrough, - b_flat_dram_block_window_tmp, - num_loop, - p_smem_0, - p_smem_1); + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 8aab756ccf..4f79361037 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1723,7 +1723,7 @@ struct QuantGemmKernel * @param aq_ptr input AQ pointer * @param bq_ptr input BQ pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. @@ -1735,7 +1735,7 @@ struct QuantGemmKernel const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, + void* smem_ptr, const QuantGemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -1762,7 +1762,7 @@ struct QuantGemmKernel m = kargs.M; } return GemmPipeline{}.template operator()( - a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m); + a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m); } else if constexpr(kQuantType == QuantType::BQuantGrouped) { @@ -1772,7 +1772,7 @@ struct QuantGemmKernel n = kargs.N; } return GemmPipeline{}.template operator()( - a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n); + a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n); } else if constexpr(kQuantType == QuantType::ABQuantGrouped) { @@ -1788,7 +1788,7 @@ struct QuantGemmKernel aq_block_window, bq_block_window, num_loop, - smem_ptr_0, + smem_ptr, m, n); } @@ -1796,7 +1796,7 @@ struct QuantGemmKernel kQuantType == QuantType::TensorQuant) { return GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0); + a_block_window, b_block_window, num_loop, smem_ptr); } }(); @@ -1812,14 +1812,14 @@ struct QuantGemmKernel kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -1828,7 +1828,7 @@ struct QuantGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } else @@ -1840,14 +1840,14 @@ struct QuantGemmKernel kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -1856,89 +1856,7 @@ struct QuantGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); - } - } - } - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param aq_ptr input AQ pointer - * @param bq_ptr input BQ pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - [[maybe_unused]] const AQDataType* aq_ptr, - const BQDataType* bq_ptr, - CDataType* c_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const QuantGemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create block windows using specialized methods - const auto& a_block_window = - MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); - const auto& b_block_window = - MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); - - const index_t num_loop = - amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = [&]() { - if constexpr(kQuantType == QuantType::BQuantGrouped) - { - index_t n = 0; - if constexpr(PreshuffleQuant) - { - n = kargs.N; - } - return GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - smem_ptr_0, - smem_ptr_1, - n); - } - else - { - return nullptr; - } - }(); - - const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - - // Run Epilogue Pipeline with k_batch dispatch - if constexpr(kQuantType == QuantType::BQuantGrouped) - { - if(k_batch == 1) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } } @@ -1961,37 +1879,10 @@ struct QuantGemmKernel CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - - RunGemm2LDS(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - RunGemm(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + RunGemm( + a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 1c98a372be..06a80c8b55 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -318,21 +318,18 @@ struct QuantGroupedGemmKernel CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // Only for BQuantGrouped DoubleSmemBuffer is supported if constexpr(GemmPipeline::DoubleSmemBuffer == true && kQuantType == QuantType::BQuantGrouped) { - - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemmWithPipelineSelection2LDS(a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, - smem_ptr_1, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -348,7 +345,7 @@ struct QuantGroupedGemmKernel aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -361,7 +358,7 @@ struct QuantGroupedGemmKernel aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -377,8 +374,7 @@ struct QuantGroupedGemmKernel [[maybe_unused]] const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, - void* smem_ptr_1, + void* smem_ptr, const QuantGroupedGemmKernelArgs& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -399,27 +395,22 @@ struct QuantGroupedGemmKernel const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); // Run GEMM cooperatively by whole workgroup - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, bq_block_window, num_loop, tail_num, smem_ptr); // Run Epilogue Pipeline with split_k dispatch if(kargs.k_batch == 1) { auto c_block_window = Base::template MakeCBlockWindow( c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else { auto c_block_window = Base::template MakeCBlockWindow( c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } } @@ -435,7 +426,7 @@ struct QuantGroupedGemmKernel * @param aq_ptr input AQ pointer * @param bq_ptr input BQ pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k * batch. @@ -449,7 +440,7 @@ struct QuantGroupedGemmKernel const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, + void* smem_ptr, const QuantGroupedGemmKernelArgs& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -481,7 +472,7 @@ struct QuantGroupedGemmKernel num_loop, has_hot_loop, tail_num, - smem_ptr_0); + smem_ptr); } else if constexpr(kQuantType == QuantType::BQuantGrouped) { @@ -491,13 +482,13 @@ struct QuantGroupedGemmKernel num_loop, has_hot_loop, tail_num, - smem_ptr_0); + smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant || kQuantType == QuantType::TensorQuant) { return GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr); } }(); @@ -510,14 +501,14 @@ struct QuantGroupedGemmKernel if constexpr(kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -526,7 +517,7 @@ struct QuantGroupedGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } else @@ -538,14 +529,14 @@ struct QuantGroupedGemmKernel if constexpr(kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -554,7 +545,7 @@ struct QuantGroupedGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp index b155297054..b7dc0bd616 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp @@ -29,6 +29,48 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution(); } + // as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed; + // move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here + // temporarily + template + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant() { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 18b236c29b..43f37ec4d8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -184,8 +184,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV const BQDramBlockWindowTmp& bq_dram_block_window_tmp, index_t n, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + void* p_smem) const { static_assert( std::is_same_v> && @@ -210,8 +209,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV __builtin_amdgcn_sched_barrier(0); // A tile in LDS - ADataType* p_a_lds_ping = static_cast(p_smem_ping); - ADataType* p_a_lds_pong = static_cast(p_smem_pong); + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + ADataType* p_a_lds_ping = static_cast(p_smem); + ADataType* p_a_lds_pong = + reinterpret_cast(static_cast(p_smem) + smem_size); constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor(); @@ -561,9 +562,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, const BQDramBlockWindowTmp& bq_dram_block_window_tmp, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong, - index_t n = 0) const // Default value for non-preshuffle case + void* p_smem, + index_t n = 0) const { return operator()( a_dram_block_window_tmp, @@ -572,8 +572,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV bq_dram_block_window_tmp, n, num_loop, - p_smem_ping, - p_smem_pong); + p_smem); } template