diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 07b925d0eb..a831a4f26c 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -16,8 +16,9 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 +#define CK_TILE_PIPELINE_COMPUTE_V6 5 +#define CK_TILE_PIPELINE_PRESHUFFLE_V1 6 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 7 template constexpr ck_tile::index_t get_k_warp_tile() @@ -251,9 +252,29 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; +}; + +template +struct GemmConfigComputeV6 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6; + static constexpr ck_tile::index_t NumWaveGroups = 1; }; template @@ -484,6 +505,15 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; +}; + template <> struct PipelineTypeTraits { diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 204d67a0ff..2a4f9d21e3 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -44,6 +44,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp new file mode 100644 index 0000000000..2ae9001098 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp @@ -0,0 +1,770 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompV6 +{ + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopUnroll = 2; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + else + { + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + // If execution reaches here, it's an invalid tail_number because it wasn't handled above. +#if defined(__HIP_DEVICE_COMPILE__) + __builtin_unreachable(); +#else + throw std::logic_error("Invalid TailNumber: Only TailNumber::Odd and TailNumber::Even are " + "supported in this pipeline context."); +#endif + } +}; + +// Compute optimized pipeline +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 +template +struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 +{ + using Base = BaseGemmPipelineAgBgCrCompV6; + using BasePImpl = GemmPipelineAgBgCrImplBase; + + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + using BlockGemm = remove_cvref_t())>; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr index_t KRepeat = BlockGemm::WarpGemm::kKPerThread / GetSmemPackA(); + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrCompV6", BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK), + concat('x', TailNum), + concat('_', KRepeat), + concat('_', DoubleSmemBuffer), + concat('_', Preshuffle), + concat('_', HasHotLoop)); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } + + template + struct PipelineImpl : public BasePImpl + { + }; + + template <> + struct PipelineImpl : public BasePImpl + { + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1); + constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2); + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + + constexpr index_t A_LDS_Read_Width = KPerXDL; + constexpr index_t B_LDS_Read_Width = KPerXDL; + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat; + constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat; + + constexpr auto num_dsread_stage1_a_mfma = + (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage1_b_mfma = + (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + constexpr auto num_dsread_stage3_a_mfma = + (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage3_b_mfma = + (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num - + num_ds_read_inst_a / ds_read_a_mfma_rate - + num_ds_read_inst_b / ds_read_b_mfma_rate; + constexpr auto num_mfma_per_issue = + num_mfma_stage2 / (A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num); + constexpr auto num_dswrite_per_issue_a = A_LDS_Write_Inst_Num / A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = B_LDS_Write_Inst_Num / B_Buffer_Load_Inst_Num; + + // stage 1 + static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // stage 2 + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 3 + static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + // TODO: Add Multi A/B support + static_assert(std::tuple_size>::value == 1, + "Multi A/B is not yet supported for this pipeline."); + static_assert(std::tuple_size>::value == 1, + "Multi A/B is not yet supported for this pipeline."); + + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]), + "B block window has incorrect lengths for defined BLayout!"); + + ////////////// LDS desc, window & register ///////////////// + using ALdsType = + remove_cvref_t; + using BLdsType = + remove_cvref_t; + auto&& ABLdsTensorViews = BasePImpl::GetABLdsTensorViews(p_smem); + ALdsType& a_lds_block = ABLdsTensorViews.at(I0); + BLdsType& b_lds_block = ABLdsTensorViews.at(I1); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + using acopy_dram_type = + remove_cvref_t; + using bcopy_dram_type = + remove_cvref_t; + + using a_copy_lds_window_type = + remove_cvref_t; + using b_copy_lds_window_type = + remove_cvref_t; + + using a_lds_load_tile_distr_type = + remove_cvref_t; + using b_lds_load_tile_distr_type = + remove_cvref_t; + + auto&& aWindows = + BasePImpl::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto&& bWindows = + BasePImpl::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + acopy_dram_type& a_copy_dram_window = aWindows.at(I0); + a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1); + a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + bcopy_dram_type& b_copy_dram_window = bWindows.at(I0); + b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1); + b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2); + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile[Base::GlobalBufferNum]; + BBlockTile b_block_tile[Base::GlobalBufferNum]; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_lds_tile; + BLdsTile b_lds_tile; + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Global prefetch 1 + a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // Local prefill 1 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[I0]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[I0]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[I0]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[I0]); + } + + // Global prefetch 2 + a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // Global prefetch 3 + a_block_tile[I1] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I1] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + // Local prefetch 1 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + + if(HasHotLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto vmem_buf_idx) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + // Local prefill 2 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution< + Problem>()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, + a_block_tile[vmem_buf_idx]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution< + Problem>()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, + b_block_tile[vmem_buf_idx]); + } + + // Global prefetch 4 + a_block_tile[vmem_buf_idx] = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[vmem_buf_idx] = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + } + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + // Local prefetch 2 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I0); + LoopFunc(I1); + + i += Base::HotloopUnroll; + } while(i < (num_loop - Base::PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto vmem_buf_idx) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + // Local prefill 3 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[vmem_buf_idx]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[vmem_buf_idx]); + } + + block_sync_lds(); + } + + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + }); + + HotLoopScheduler(); + }; + + auto ReadCompFunc = [&]() { + static_for<0, KRepeat - 1, 1>{}([&]() { + __syncthreads(); + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + // Local prefetch 4 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + + __syncthreads(); + }); + + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + HotLoopScheduler(); + }; + + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I0); + ReadWriteCompFunc(I1); + ReadCompFunc(); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadWriteCompFunc(I0); + ReadCompFunc(); + } + + return c_block_tile; + } + }; + + public: + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp new file mode 100644 index 0000000000..6ac702d38b --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAGmemBGmemCregComputeV6, except the block gemm method, it shares +// the same vector size implementation, SmemSize, Global memory tile distiribution as the +// UniversalGemm Pipeline Policy. +// Default policy class should not be templated, put template on +// member functions instead. +struct GemmPipelineAgBgCrCompV6DefaultPolicy + : public UniversalGemmBasePolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); + constexpr auto wg_attr_num_access = + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + + using WarpGemm = WarpGemmDispatcher; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } +}; +} // namespace ck_tile diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 1ca7f4fc7d..24cc1bc5ab 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -24,12 +24,13 @@ endif() if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_fp8 test_gemm_pipeline_basic_fp8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) + + target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") @@ -55,10 +56,13 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv6 test_gemm_pipeline_compv6.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv6 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx95") diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp new file mode 100644 index 0000000000..a72ff98055 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineCompV6 + : public TestCkTileGemmPipeline> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV6 + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV6, KernelTypesCompV6); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index bba106174c..aa1f610022 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -29,6 +29,7 @@ using Interwave = ck_tile::integral_constant; using CompV3 = ck_tile::integral_constant; using CompV4 = ck_tile::integral_constant; +using CompV6 = ck_tile::integral_constant; using CompAsync = ck_tile::integral_constant; using Persistent = std::true_type; @@ -130,6 +131,28 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> >; +using KernelTypesCompV6 = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6> +>; using KernelTypesCompAsync = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 01bc3d7522..994510c060 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -38,6 +38,7 @@ enum struct GemmPipelineType Mem, CompV3, CompV4, + CompV6, CompAsync }; @@ -71,6 +72,15 @@ struct GemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; } }; +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; + using pipeline = ck_tile::GemmPipelineAgBgCrCompV6; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV6"; } +}; + template struct GemmPipelineTypeSelector { @@ -120,11 +130,13 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || + constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::CompAsync); + constexpr bool TransposeC = false; + static constexpr bool StructuredSparsity = false; + static constexpr bool NumWaveGroup = 1; // TODO: For now - but this should also be a test parameter - constexpr bool TransposeC = false; constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; @@ -140,8 +152,6 @@ class TestCkTileGemmPipeline : public ::testing::Test GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - static constexpr bool StructuredSparsity = false; - static constexpr bool NumWaveGroup = 1; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits