From e9ade6918548499ad4fd4176aa38e606f8e8301c Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 8 Oct 2025 20:22:30 -0400 Subject: [PATCH] Add Memory pipeline for AQuant Block Scale GEMM (#2987) * WIP: add memory pipeline boiler plate code that compiles and works for one block * WIP: tail handling works for memory pipeline * WIP: numerical errors appears to have gone by adding block_sync_lds() * fix: numerical error with memory pipeline by adding block_sync_lds() and new tail handler * refactror: remove debug print statements and lints * fix: remove redundant sync barriars * chore: remove lint * fix: remove unused code from tile handler and remove redundant block_sync_lds() * fix: correct parent struct name for memory pipeline * fix: remove static assert check from parent struct and add it to child struct because not all child structs needs to static assert * fix: defer block sync lds to just before prefill [ROCm/composable_kernel commit: e99356dabce7c391423567297b934fae683e2c66] --- .../38_block_scale_gemm/gemm_quant_basic.cpp | 11 +- .../run_gemm_quant_example.inc | 4 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 4 +- include/ck_tile/ops/gemm_quant.hpp | 1 + .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 480 ++++++++++++++++++ .../pipeline/gemm_quant_pipeline_problem.hpp | 1 - 6 files changed, 489 insertions(+), 12 deletions(-) mode change 100755 => 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp mode change 100755 => 100644 example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp old mode 100755 new mode 100644 index fa9ad967ad..00d1af5aaa --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -59,7 +59,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + ck_tile::BaseAQuantGemmPipelineAgBgCrMem>; // memory pipeline hardcoded + // for aquant const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -118,7 +119,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, + ck_tile::AQuantGemmPipelineAgBgCrMem, // memory pipeline hardcoded + // for aquant std::conditional_t, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; @@ -448,7 +450,4 @@ int run_gemm_example(int argc, char* argv[]) } } -int main(int argc, char* argv[]) -{ - return !run_gemm_example(argc, argv); -} +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc old mode 100755 new mode 100644 index e68eb23641..1a69589216 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -182,7 +182,7 @@ int run_gemm_example_with_layouts(int argc, if(K % QuantGroupSize != 0) { throw std::runtime_error( - "K must be aligned with QuantGroupSize for AQuantGrouped mode"); + "K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode"); } } ck_tile::index_t AQK, BQK; @@ -204,7 +204,7 @@ int run_gemm_example_with_layouts(int argc, } else { - static_assert(false, "Unsupported QuantMode"); + throw std::runtime_error("Unsupported QuantMode"); } ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 9e522d4364..d363626efd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -21,8 +21,6 @@ struct BaseGemmPipelineAgBgCrMem using BDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static_assert(!std::is_same_v, "Not implemented"); - static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = @@ -174,7 +172,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using ADataType = remove_cvref_t>; using BDataType = remove_cvref_t>; - + static_assert(!std::is_same_v, "Not implemented"); using BlockGemm = remove_cvref_t())>; using I0 = number<0>; diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 531cd676a5..cde0b6833f 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp new file mode 100644 index 0000000000..1ba9b2a903 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct BaseAQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem +{ + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(num_loop % BaseGemmPipelineAgBgCrCompV3::PrefetchStages == 0) + { + return TailNumber::Even; + } + else + { + return TailNumber::Odd; + } + } + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + if(has_hot_loop) + { + if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + else + { + + if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + } +}; + +template +struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem +{ + using Base = BaseGemmPipelineAgBgCrMem; + using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t AQPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeAQ() + { + return Policy::template GetVectorSizeAQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "aquant_pipeline_AgBgCrMem", + concat('x', MPerBlock, NPerBlock, KPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK), + concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize, + Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t AQ_Buffer_Load_Inst_Num = + MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "AQ vector size: " << GetVectorSizeAQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t m, + index_t num_loop, + void* p_smem) const + { + (void)m; // unused variable + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/AQ Dram block window should have the same data type as appropriate " + "([A|B|AQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_aq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); + static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Aq block window has incorrect lengths for defined AqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + // A/B tiles in LDS - using the same approach as regular gemm pipeline + auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); + auto& a_lds_block = ab_lds_blocks.at(I0{}); + auto& b_lds_block = ab_lds_blocks.at(I1{}); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto a_windows = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto& a_copy_dram_window = a_windows.at(I0{}); + auto& a_copy_lds_window = a_windows.at(I1{}); + auto& a_lds_gemm_window = a_windows.at(I2{}); + + auto b_windows = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + auto& b_copy_dram_window = b_windows.at(I0{}); + auto& b_copy_lds_window = b_windows.at(I1{}); + auto& b_lds_gemm_window = b_windows.at(I2{}); + + auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); + + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + + // Memory pipeline uses multiple prefetch stages + tuple_array a_block_tiles; + tuple_array b_block_tiles; + tuple_array aq_block_tiles; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr AQDramTileWindowStep aq_dram_tile_window_step = + is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + + // Global prefetch initialization - DRAM to VGPRs + Base::GlobalPrefetch( + a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( + aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS prefill - VGPRs to LDS + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + } + // Additional prefetching for memory pipeline - DRAM to VGPRs + static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { + Base::GlobalPrefetch(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window, + b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tiles.get(number{}), + aq_copy_dram_window, + aq_dram_tile_window_step); + }); + + // Main hot loop for memory pipeline + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + block_sync_lds(); + // Prepare next iteration data + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d( + a_shuffle_tmp, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill( + a_copy_lds_window, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d( + b_shuffle_tmp, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill( + b_copy_lds_window, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + b_element_func); + } + + Base::GlobalPrefetch(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window, + b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tiles.get(number{}), + aq_copy_dram_window, + aq_dram_tile_window_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // Tail handling + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm( + c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + + if constexpr(TailNum == TailNumber::Even) + { + + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I1{}), a_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I1{}), b_element_func); + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm( + c_block_tile, aq_block_tiles.get(I1{}), a_lds_gemm_window, b_lds_gemm_window); + } + return c_block_tile; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t m, + index_t num_loop, + void* p_smem) const + { + + return PipelineImpl{} + .template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + m, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 4978e70099..2cfafba6f3 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -73,7 +73,6 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase